diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 3b26a16e3e..2645f9c10e 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -4,14 +4,26 @@ Contributing to seaborn General support --------------- -General support questions ("how do I do ?") are most at home on [StackOverflow](https://stackoverflow.com/), where they will be seen by more people and are more easily searchable. StackOverflow has a `[seaborn]` tag, which will bring the question to the attention of people who might be able to answer. +General support questions ("how do I do X?") are most at home on either [StackOverflow](https://stackoverflow.com/) or [discourse](https://discourse.matplotlib.org/c/3rdparty/seaborn/21), which have a larger audience of people who will see your post and may be able to offer assistance. StackOverflow is better for specific issues, while discourse is better for more open-ended discussion. Your chance of getting a quick answer will be higher if you include runnable code, a precise statement of what you are hoping to achieve, and a clear explanation of the problems that you have encountered. Reporting bugs -------------- -If you have encountered a bug in seaborn, please report it on the [Github issue tracker](https://github.com/mwaskom/seaborn/issues/new). It is only really possible to address bug reports if they include a reproducible script using randomly-generated data or one of the example datasets (accessed through `seaborn.load_dataset()`). Please also specify your versions of seaborn and matplotlib, as well as which matplotlib backend you are using. +If you think you've encountered a bug in seaborn, please report it on the [Github issue tracker](https://github.com/mwaskom/seaborn/issues/new). To be useful, bug reports must include the following information: + +- A reproducible code example that demonstrates the problem +- The output that you are seeing (an image of a plot, or the error message) +- A clear explanation of why you think something is wrong +- The specific versions of seaborn and matplotlib that you are working with + +Bug reports are easiest to address if they can be demonstrated using one of the example datasets from the seaborn docs (i.e. with `seaborn.load_dataset`). Otherwise, it is preferable that your example generate synthetic data to reproduce the problem. If you can only demonstrate the issue with your actual dataset, you will need to share it, ideally as a csv. Note that you can upload a csv directly to a github issue thread, but it must have a `.txt` suffix. + +If you've encountered an error, searching the specific text of the message before opening a new issue can often help you solve the problem quickly and avoid making a duplicate report. + +Because matplotlib handles the actual rendering, errors or incorrect outputs may be due to a problem in matplotlib rather than one in seaborn. It can save time if you try to reproduce the issue in an example that uses only matplotlib, so that you can report it in the right place. But it is alright to skip this step if it's not obvious how to do it. + New features ------------ -If you think there is a new feature that should be added to seaborn, you can open an issue to discuss it. However, seaborn's development has become increasingly conservative, and the answer to most feature requests or proposed additions is "no". Polite requests with an explanation of the proposed feature's virtues will usually get an explanation; feature requests that say "I would like feature X, you need to add it" typically won't. +If you think there is a new feature that should be added to seaborn, you can open an issue to discuss it. But please be aware that current development efforts are mostly focused on standardizing the API and internals, and there may be relatively low enthusiasm for novel features that do not fit well into short- and medium-term development plans. diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000000..e431cf18dc --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,99 @@ +name : CI + +on: + push: + branches: master + pull_request: + branches: master + +env: + NB_KERNEL: python + MPLBACKEND: Agg + +jobs: + + build-docs: + runs-on: ubuntu-latest + steps: + + - name: Checkout + uses: actions/checkout@v2 + + - name: Setup Python + uses: actions/setup-python@v2 + + - name: Install seaborn + run: | + python -m pip install --upgrade pip + pip install .[all] -r ci/utils.txt + + - name: Install doc tools + run: | + pip install -r doc/requirements.txt + sudo apt-get install pandoc + + - name: Build docs + run: | + make -C doc -j `nproc` notebooks + make -C doc html + + + run-tests: + runs-on: ubuntu-latest + + strategy: + matrix: + + python: [3.7.x, 3.8.x, 3.9.x] + target: [test] + install: [all] + deps: [latest] + backend: [agg] + + include: + - python: 3.7.x + target: unittests + install: all + deps: pinned + backend: agg + - python: 3.9.x + target: unittests + install: light + deps: latest + backend: agg + - python: 3.9.x + target: test + install: all + deps: latest + backend: tkagg + + steps: + + - name: Checkout + uses: actions/checkout@v2 + + - name: Setup Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python }} + + - name: Install seaborn + run: | + python -m pip install --upgrade pip + if [[ ${{matrix.install}} == 'all' ]]; then EXTRAS='[all]'; fi + if [[ ${{matrix.deps }} == 'pinned' ]]; then DEPS='-r ci/deps_pinned.txt'; fi + pip install .$EXTRAS $DEPS -r ci/utils.txt + + - name: Cache datastes + run: python ci/cache_test_datasets.py + + - name: Run tests + env: + MPLBACKEND: ${{ matrix.backend }} + run: | + if [[ ${{ matrix.backend }} == 'tkagg' ]]; then PREFIX='xvfb-run -a'; fi + $PREFIX make ${{ matrix.target }} + + - name: Upload coverage + uses: codecov/codecov-action@v1 + if: ${{ success() }} diff --git a/.gitignore b/.gitignore index 2ac96fb5b7..013093cdc4 100644 --- a/.gitignore +++ b/.gitignore @@ -7,5 +7,8 @@ seaborn.egg-info/ .cache/ .coverage cover/ -.idea +htmlcov/ +.idea/ +.vscode/ .pytest_cache/ +notes/ diff --git a/.mailmap b/.mailmap deleted file mode 100644 index 236389d106..0000000000 --- a/.mailmap +++ /dev/null @@ -1,3 +0,0 @@ -Michael Waskom mwaskom -Tal Yarkoni -Daniel B. Allan diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 30fec39552..0000000000 --- a/.travis.yml +++ /dev/null @@ -1,63 +0,0 @@ -language: python - - -dist: xenial -services: - - xvfb - - -env: - - PYTHON=2.7 DEPS=latest BACKEND=agg DOCTESTS=true - - PYTHON=2.7 DEPS=pinned BACKEND=agg DOCTESTS=false - - PYTHON=2.7 DEPS=latest BACKEND=qtagg DOCTESTS=true - - PYTHON=3.5 DEPS=latest BACKEND=agg DOCTESTS=true - - PYTHON=3.6 DEPS=latest BACKEND=agg DOCTESTS=true - - PYTHON=3.7 DEPS=latest BACKEND=agg DOCTESTS=true - - PYTHON=3.7 DEPS=latest BACKEND=qtagg DOCTESTS=true - - PYTHON=3.7 DEPS=minimal BACKEND=agg DOCTESTS=false - - -before_install: - - sudo apt-get update -yq - - sudo sh testing/getmsfonts.sh - - wget https://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh -O miniconda.sh - - bash miniconda.sh -b -p $HOME/miniconda - - export PATH="$HOME/miniconda/bin:$PATH" - - hash -r - - conda config --set always_yes yes --set changeps1 no - - conda update -q conda - - conda info -a - - -install: - - conda create -n testenv pip python=$PYTHON - - source activate testenv - - cat testing/deps_${DEPS}.txt testing/utils.txt > deps.txt - - conda install --file deps.txt - - pip install . - - -before_script: - - cp testing/matplotlibrc_${BACKEND} matplotlibrc - - if [ $BACKEND == "qtagg" ]; then - export DISPLAY=:99.0; - sh -e /etc/init.d/xvfb start; - sleep 3; - fi - # https://www.python.org/dev/peps/pep-0493/ - - if [ $PYTHON == "2.7" ]; then - export PYTHONHTTPSVERIFY=0; - fi - - -script: - - make lint - - if [ $DOCTESTS == 'true' ]; - then make coverage; - else make unittests; - fi - - -after_success: - - pip install codecov - - codecov diff --git a/LICENSE b/LICENSE index 68c8c17cc4..b5ebba6263 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2012-2019, Michael L. Waskom +Copyright (c) 2012-2021, Michael L. Waskom All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/Makefile b/Makefile index c22a8ec490..1e15125f36 100644 --- a/Makefile +++ b/Makefile @@ -1,13 +1,10 @@ export SHELL := /bin/bash test: - pytest --doctest-modules seaborn + pytest -n auto --doctest-modules --cov=seaborn --cov-config=.coveragerc seaborn unittests: - pytest seaborn - -coverage: - pytest --doctest-modules --cov=seaborn --cov-config=.coveragerc seaborn + pytest -n auto --cov=seaborn --cov-config=.coveragerc seaborn lint: - flake8 --ignore E121,E123,E126,E226,E24,E704,E741,W503,W504 --exclude seaborn/__init__.py,seaborn/colors/__init__.py,seaborn/cm.py,seaborn/tests,seaborn/external seaborn + flake8 seaborn diff --git a/README.md b/README.md index 0248403aa0..1cf0ed91d7 100644 --- a/README.md +++ b/README.md @@ -1,40 +1,14 @@ -seaborn: statistical data visualization -======================================= - -
- - - - - - - - - - - - - - - - - - - - - - - - - -
+
-------------------------------------- +seaborn: statistical data visualization +======================================= + [![PyPI Version](https://img.shields.io/pypi/v/seaborn.svg)](https://pypi.org/project/seaborn/) [![License](https://img.shields.io/pypi/l/seaborn.svg)](https://github.com/mwaskom/seaborn/blob/master/LICENSE) -[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.1313201.svg)](https://doi.org/10.5281/zenodo.1313201) -[![Build Status](https://travis-ci.org/mwaskom/seaborn.svg?branch=master)](https://travis-ci.org/mwaskom/seaborn) +[![DOI](https://joss.theoj.org/papers/10.21105/joss.03021/status.svg)](https://doi.org/10.21105/joss.03021) +![Tests](https://github.com/mwaskom/seaborn/workflows/CI/badge.svg) [![Code Coverage](https://codecov.io/gh/mwaskom/seaborn/branch/master/graph/badge.svg)](https://codecov.io/gh/mwaskom/seaborn) Seaborn is a Python visualization library based on matplotlib. It provides a high-level interface for drawing attractive statistical graphics. @@ -47,39 +21,58 @@ Online documentation is available at [seaborn.pydata.org](https://seaborn.pydata The docs include a [tutorial](https://seaborn.pydata.org/tutorial.html), [example gallery](https://seaborn.pydata.org/examples/index.html), [API reference](https://seaborn.pydata.org/api.html), and other useful information. +To build the documentation locally, please refer to [`doc/README.md`](doc/README.md). Dependencies ------------ -Seaborn supports Python 2.7 and 3.5+. +Seaborn supports Python 3.7+ and no longer supports Python 2. -Installation requires [numpy](http://www.numpy.org/), [scipy](https://www.scipy.org/), [pandas](https://pandas.pydata.org/), and [matplotlib](https://matplotlib.org/). Some functions will optionally use [statsmodels](https://www.statsmodels.org/) if it is installed. +Installation requires [numpy](https://numpy.org/), [pandas](https://pandas.pydata.org/), and [matplotlib](https://matplotlib.org/). Some functions will optionally use [scipy](https://www.scipy.org/) and/or [statsmodels](https://www.statsmodels.org/) if they are available. Installation ------------ -The latest stable release (and older versions) can be installed from PyPI: +The latest stable release (and required dependencies) can be installed from PyPI: pip install seaborn +It is also possible to include the optional dependencies: + + pip install seaborn[all] + You may instead want to use the development version from Github: - pip install git+https://github.com/mwaskom/seaborn.git#egg=seaborn + pip install git+https://github.com/mwaskom/seaborn.git + +Seaborn is also available from Anaconda and can be installed with conda: + conda install seaborn + +Note that the main anaconda repository typically lags PyPI in adding new releases. + +Citing +------ + +A paper describing seaborn has been published in the [Journal of Open Source Software](https://joss.theoj.org/papers/10.21105/joss.03021). The paper provides an introduction to the key features of the library, and it can be used as a citation if seaborn proves integral to a scientific publication. Testing ------- -To test seaborn, run `make test` in the source directory. +Testing seaborn requires installing additional packages listed in `ci/utils.txt`. + +To test the code, run `make test` in the source directory. This will exercise both the unit tests and docstring examples (using [pytest](https://docs.pytest.org/)) and generate a coverage report. + +The doctests require a network connection (unless all example datasets are cached), but the unit tests can be run offline with `make unittests`. -This will exercise both the unit tests and docstring examples (using `pytest`). +Code style is enforced with `flake8` using the settings in the [`setup.cfg`](./setup.cfg) file. Run `make lint` to check. Development ----------- Seaborn development takes place on Github: https://github.com/mwaskom/seaborn -Please submit any reproducible bugs you encounter to the [issue tracker](https://github.com/mwaskom/seaborn/issues). +Please submit bugs that you encounter to the [issue tracker](https://github.com/mwaskom/seaborn/issues) with a reproducible example demonstrating the problem. Questions about usage are more at home on StackOverflow, where there is a [seaborn tag](https://stackoverflow.com/questions/tagged/seaborn). diff --git a/ci/cache_test_datasets.py b/ci/cache_test_datasets.py new file mode 100644 index 0000000000..7bcd8b0d2f --- /dev/null +++ b/ci/cache_test_datasets.py @@ -0,0 +1,19 @@ +""" +Cache test datasets before running test suites to avoid +race conditions to due tests parallelization +""" +import seaborn as sns + +datasets = ( + "anscombe", + "attention", + "dots", + "exercise", + "flights", + "fmri", + "iris", + "planets", + "tips", + "titanic" +) +list(map(sns.load_dataset, datasets)) diff --git a/ci/check_gallery.py b/ci/check_gallery.py new file mode 100644 index 0000000000..60db2e12c6 --- /dev/null +++ b/ci/check_gallery.py @@ -0,0 +1,14 @@ +"""Execute the scripts that comprise the example gallery in the online docs.""" +from glob import glob +import matplotlib.pyplot as plt + +if __name__ == "__main__": + + fnames = sorted(glob("examples/*.py")) + + for fname in fnames: + + print(f"- {fname}") + with open(fname) as fid: + exec(fid.read()) + plt.close("all") diff --git a/ci/deps_pinned.txt b/ci/deps_pinned.txt new file mode 100644 index 0000000000..9949d00c47 --- /dev/null +++ b/ci/deps_pinned.txt @@ -0,0 +1,5 @@ +numpy~=1.16.0 +pandas~=0.24.0 +matplotlib~=3.0.0 +scipy~=1.2.0 +statsmodels~=0.9.0 diff --git a/testing/getmsfonts.sh b/ci/getmsfonts.sh similarity index 100% rename from testing/getmsfonts.sh rename to ci/getmsfonts.sh diff --git a/ci/utils.txt b/ci/utils.txt new file mode 100644 index 0000000000..99f8cc215f --- /dev/null +++ b/ci/utils.txt @@ -0,0 +1,4 @@ +pytest!=5.3.4 +pytest-cov +pytest-xdist +flake8 diff --git a/doc/.gitignore b/doc/.gitignore index ea122c6ca9..5cb06a8e24 100644 --- a/doc/.gitignore +++ b/doc/.gitignore @@ -4,13 +4,5 @@ generated/ examples/ example_thumbs/ introduction.rst -aesthetics.rst -relational.rst -color_palettes.rst -distributions.rst -regression.rst -categorical.rst -plotting_distributions.rst -dataset_exploration.rst -timeseries_plots.rst -axis_grids.rst +tutorial/*.rst +docstrings/*.rst diff --git a/doc/Makefile b/doc/Makefile index 9a5de4aa74..8ed9484cd5 100644 --- a/doc/Makefile +++ b/doc/Makefile @@ -44,17 +44,24 @@ clean: -rm -rf $(BUILDDIR)/* -rm -rf examples/* -rm -rf example_thumbs/* - -rm -rf tutorial/*_files/ - -rm -rf tutorial/*.rst -rm -rf generated/* + -rm -rf introduction_files/* + -rm introduction.rst + -make -C docstrings clean + -make -C tutorial clean +.PHONY: tutorials tutorials: make -C tutorial -introduction: introduction.ipynb - tools/nb_to_doc.py introduction +.PHONY: docstrings +docstrings: + make -C docstrings -notebooks: tutorials introduction +introduction.rst: introduction.ipynb + tools/nb_to_doc.py ./introduction.ipynb + +notebooks: tutorials docstrings introduction.rst html: $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html @@ -166,11 +173,3 @@ doctest: $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest @echo "Testing of doctests in the sources finished, look at the " \ "results in $(BUILDDIR)/doctest/output.txt." - -upload: - rsync -azP $(BUILDDIR)/html/ mwaskom@cardinal.stanford.edu:WWW/software/seaborn - @echo "Uploaded to Stanford webspace" - -upload-dev: - rsync -azP $(BUILDDIR)/html/ mwaskom@cardinal.stanford.edu:WWW/software/seaborn-dev - @echo "Uploaded to Stanford webspace (development page)" diff --git a/doc/README.md b/doc/README.md new file mode 100644 index 0000000000..69ad2b97e7 --- /dev/null +++ b/doc/README.md @@ -0,0 +1,12 @@ +Building the seaborn docs +========================= + +Building the docs requires additional dependencies listed in [`./requirements.txt`](./requirements.txt). + +The build process involves conversion of Jupyter notebooks to `rst` files. To facilitate this, you may need to set `NB_KERNEL` environment variable to the name of a kernel on your machine (e.g. `export NB_KERNEL="python3"`). To get a list of available Python kernels, run `jupyter kernelspec list`. + +After you're set up, run `make notebooks html` from the `doc` directory to convert all notebooks, generate all gallery examples, and build the documentation itself. The site will live in `_build/html`. + +Run `make clean` to delete the built site and all intermediate files. Run `make -C docstrings clean` or `make -C tutorial clean` to remove intermediate files for the API or tutorial components. + +If your goal is to obtain an offline copy of the docs for a released version, it may be easier to clone the [website repository](https://github.com/seaborn/seaborn.github.io) or to download a zipfile corresponding to a [specific version](https://github.com/seaborn/seaborn.github.io/tags). diff --git a/doc/_static/favicon.ico b/doc/_static/favicon.ico old mode 100644 new mode 100755 index 1145b96d4e..fac1e28c2c Binary files a/doc/_static/favicon.ico and b/doc/_static/favicon.ico differ diff --git a/doc/_static/favicon_old.ico b/doc/_static/favicon_old.ico new file mode 100644 index 0000000000..1145b96d4e Binary files /dev/null and b/doc/_static/favicon_old.ico differ diff --git a/doc/_static/logo-mark-darkbg.png b/doc/_static/logo-mark-darkbg.png new file mode 100644 index 0000000000..d585461137 Binary files /dev/null and b/doc/_static/logo-mark-darkbg.png differ diff --git a/doc/_static/logo-mark-darkbg.svg b/doc/_static/logo-mark-darkbg.svg new file mode 100644 index 0000000000..4b06364224 --- /dev/null +++ b/doc/_static/logo-mark-darkbg.svg @@ -0,0 +1,4946 @@ + + + + + + + + + 2020-09-07T14:13:59.975140 + image/svg+xml + + + Matplotlib v3.3.1, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/doc/_static/logo-mark-lightbg.png b/doc/_static/logo-mark-lightbg.png new file mode 100644 index 0000000000..378044557f Binary files /dev/null and b/doc/_static/logo-mark-lightbg.png differ diff --git a/doc/_static/logo-mark-lightbg.svg b/doc/_static/logo-mark-lightbg.svg new file mode 100644 index 0000000000..1405269edc --- /dev/null +++ b/doc/_static/logo-mark-lightbg.svg @@ -0,0 +1,4946 @@ + + + + + + + + + 2020-09-07T14:13:57.855925 + image/svg+xml + + + Matplotlib v3.3.1, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/doc/_static/logo-mark-whitebg.png b/doc/_static/logo-mark-whitebg.png new file mode 100644 index 0000000000..2e022db5d7 Binary files /dev/null and b/doc/_static/logo-mark-whitebg.png differ diff --git a/doc/_static/logo-tall-darkbg.png b/doc/_static/logo-tall-darkbg.png new file mode 100644 index 0000000000..0a2e3c06d9 Binary files /dev/null and b/doc/_static/logo-tall-darkbg.png differ diff --git a/doc/_static/logo-tall-darkbg.svg b/doc/_static/logo-tall-darkbg.svg new file mode 100644 index 0000000000..3d7d910206 --- /dev/null +++ b/doc/_static/logo-tall-darkbg.svg @@ -0,0 +1,5206 @@ + + + + + + + + + 2020-09-07T14:14:01.511527 + image/svg+xml + + + Matplotlib v3.3.1, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/doc/_static/logo-tall-lightbg.png b/doc/_static/logo-tall-lightbg.png new file mode 100644 index 0000000000..347dd9b344 Binary files /dev/null and b/doc/_static/logo-tall-lightbg.png differ diff --git a/doc/_static/logo-tall-lightbg.svg b/doc/_static/logo-tall-lightbg.svg new file mode 100644 index 0000000000..eb52f345c0 --- /dev/null +++ b/doc/_static/logo-tall-lightbg.svg @@ -0,0 +1,5206 @@ + + + + + + + + + 2020-09-07T14:13:59.334522 + image/svg+xml + + + Matplotlib v3.3.1, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/doc/_static/logo-tall-whitebg.png b/doc/_static/logo-tall-whitebg.png new file mode 100644 index 0000000000..002c383a22 Binary files /dev/null and b/doc/_static/logo-tall-whitebg.png differ diff --git a/doc/_static/logo-wide-darkbg.png b/doc/_static/logo-wide-darkbg.png new file mode 100644 index 0000000000..e2d087b186 Binary files /dev/null and b/doc/_static/logo-wide-darkbg.png differ diff --git a/doc/_static/logo-wide-darkbg.svg b/doc/_static/logo-wide-darkbg.svg new file mode 100644 index 0000000000..83b0ef8289 --- /dev/null +++ b/doc/_static/logo-wide-darkbg.svg @@ -0,0 +1,5216 @@ + + + + + + + + + 2020-09-07T14:14:00.795540 + image/svg+xml + + + Matplotlib v3.3.1, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/doc/_static/logo-wide-lightbg.png b/doc/_static/logo-wide-lightbg.png new file mode 100644 index 0000000000..ec249b06ca Binary files /dev/null and b/doc/_static/logo-wide-lightbg.png differ diff --git a/doc/_static/logo-wide-lightbg.svg b/doc/_static/logo-wide-lightbg.svg new file mode 100644 index 0000000000..57f1f71345 --- /dev/null +++ b/doc/_static/logo-wide-lightbg.svg @@ -0,0 +1,5216 @@ + + + + + + + + + 2020-09-07T14:13:58.676334 + image/svg+xml + + + Matplotlib v3.3.1, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/doc/_static/logo-wide-whitebg.png b/doc/_static/logo-wide-whitebg.png new file mode 100644 index 0000000000..4638939fab Binary files /dev/null and b/doc/_static/logo-wide-whitebg.png differ diff --git a/doc/_static/style.css b/doc/_static/style.css index 27aea84f21..e48bb69cb1 100644 --- a/doc/_static/style.css +++ b/doc/_static/style.css @@ -26,12 +26,13 @@ blockquote { } pre { + margin-top: 11.5px !important; background-color: #f6f6f9 !important; } code { color: #49759c !important; - background-color: #ffffff !important; + background-color: transparent !important; } code.descclassname { @@ -55,11 +56,21 @@ ul.dropdown-menu { } .alert-info { - background-color: #adb8cb !important; - border-color: #adb8cb !important; + background-color: #bbd2ea !important; + border-color: #bbd2ea !important; color: #2c3e50 !important; } +.alert-warning { + background-color: #e09572 !important; + border-color: #e09572 !important; + color: #222222 !important; +} + +img { + margin-bottom: 10px !important; +} + /* From https://github.com/twbs/bootstrap/issues/1768 */ *[id]:before { display: block; @@ -69,7 +80,7 @@ ul.dropdown-menu { visibility: hidden; } -table { +.dataframe table { /*Uncomment to center tables horizontally*/ /* margin-left: auto; */ /* margin-right: auto; */ @@ -80,13 +91,13 @@ table { table-layout: fixed; } -thead { +.dataframe thead { border-bottom: 1px solid; vertical-align: bottom; } -tr, th, td { - text-align: right; +.dataframe tr, th, td { + text-align: left; vertical-align: middle; padding: 0.5em 0.5em; line-height: normal; @@ -95,10 +106,14 @@ tr, th, td { border: none; } -th { +.dataframe th { font-weight: bold; } +table { + margin-bottom: 20px; +} + tbody tr:nth-child(odd) { background: #f5f5f5; } @@ -106,3 +121,48 @@ tbody tr:nth-child(odd) { tbody tr:hover { background: rgba(66, 165, 245, 0.2); } + +.label, +.badge { + display: inline-block; + padding: 2px 4px; + font-size: 11.844px; + /* font-weight: bold; */ + line-height: 13px; + color: #ffffff; + vertical-align: baseline; + white-space: nowrap; + /* text-shadow: 0 -1px 0 rgba(0, 0, 0, 0.25); */ + background-color: #999999; +} +.badge { + padding-left: 9px; + padding-right: 9px; + -webkit-border-radius: 9px; + -moz-border-radius: 9px; + border-radius: 9px; + opacity: 70%; +} +.badge-api { + background-color: #c44e52; +} +.badge-defaults { + background-color: #dd8452; +} +.badge-docs { + background-color: #8172b3; +} +.badge-feature { + background-color: #55a868; +} +.badge-enhancement { + background-color: #4c72b0; +} +.badge-fix { + background-color: #ccb974; +} + +.navbar-brand { + padding-top: 16px; + padding-bottom: 16px; +} \ No newline at end of file diff --git a/doc/_templates/layout.html b/doc/_templates/layout.html new file mode 100644 index 0000000000..7a41d911f9 --- /dev/null +++ b/doc/_templates/layout.html @@ -0,0 +1,23 @@ +{% extends "!layout.html" %} +{%- block footer %} +
+
+

+ Back to top + {% if theme_source_link_position == "footer" %} +
+ {% include "sourcelink.html" %} + {% endif %} +

+

+ {% trans copyright=copyright|e %}© Copyright {{ copyright }}, Michael Waskom.{% endtrans %} + {%- if last_updated %} + {% trans last_updated=last_updated|e %}Last updated on {{ last_updated }}.{% endtrans %}
+ {%- endif %} + {%- if show_sphinx %} + {% trans sphinx_version=sphinx_version|e %}Created using Sphinx {{ sphinx_version }}.{% endtrans %}
+ {%- endif %} +

+
+
+{%- endblock %} diff --git a/doc/api.rst b/doc/api.rst index eace53a953..4d30f9dfe8 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -11,12 +11,29 @@ Relational plots ---------------- .. autosummary:: - :toctree: generated + :toctree: generated/ + :nosignatures: relplot scatterplot lineplot +.. _distribution_api: + +Distribution plots +------------------ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + displot + histplot + kdeplot + ecdfplot + rugplot + distplot + .. _categorical_api: Categorical plots @@ -24,6 +41,7 @@ Categorical plots .. autosummary:: :toctree: generated/ + :nosignatures: catplot stripplot @@ -35,18 +53,6 @@ Categorical plots barplot countplot -.. _distribution_api: - -Distribution plots ------------------- - -.. autosummary:: - :toctree: generated/ - - distplot - kdeplot - rugplot - .. _regression_api: Regression plots @@ -54,6 +60,7 @@ Regression plots .. autosummary:: :toctree: generated/ + :nosignatures: lmplot regplot @@ -65,7 +72,8 @@ Matrix plots ------------ .. autosummary:: - :toctree: generated/ + :toctree: generated/ + :nosignatures: heatmap clustermap @@ -79,7 +87,8 @@ Facet grids ~~~~~~~~~~~ .. autosummary:: - :toctree: generated/ + :toctree: generated/ + :nosignatures: FacetGrid FacetGrid.map @@ -89,7 +98,8 @@ Pair grids ~~~~~~~~~~ .. autosummary:: - :toctree: generated/ + :toctree: generated/ + :nosignatures: pairplot PairGrid @@ -103,7 +113,8 @@ Joint grids ~~~~~~~~~~~ .. autosummary:: - :toctree: generated/ + :toctree: generated/ + :nosignatures: jointplot JointGrid @@ -113,13 +124,14 @@ Joint grids .. _style_api: -Style control -------------- +Themeing +-------- .. autosummary:: :toctree: generated/ + :nosignatures: - set + set_theme axes_style set_style plotting_context @@ -127,6 +139,7 @@ Style control set_color_codes reset_defaults reset_orig + set .. _palette_api: @@ -135,6 +148,7 @@ Color palettes .. autosummary:: :toctree: generated/ + :nosignatures: set_palette color_palette @@ -154,6 +168,7 @@ Palette widgets .. autosummary:: :toctree: generated/ + :nosignatures: choose_colorbrewer_palette choose_cubehelix_palette @@ -167,8 +182,11 @@ Utility functions .. autosummary:: :toctree: generated/ + :nosignatures: load_dataset + get_dataset_names + get_data_home despine desaturate saturate diff --git a/doc/archive.rst b/doc/archive.rst new file mode 100644 index 0000000000..83598a0be7 --- /dev/null +++ b/doc/archive.rst @@ -0,0 +1,7 @@ +.. _archive: + +Documentation archive +===================== + +- `Version 0.10 <./archive/0.10/index.html>`_ +- `Version 0.9 <./archive/0.9/index.html>`_ \ No newline at end of file diff --git a/doc/citing.rst b/doc/citing.rst new file mode 100644 index 0000000000..e71c488f1f --- /dev/null +++ b/doc/citing.rst @@ -0,0 +1,57 @@ +.. _citing: + +Citing and logo +=============== + +Citing seaborn +-------------- + +If seaborn is integral to a scientific publication, please cite it. +A paper describing seaborn has been published in the `Journal of Open Source Software `_. +Here is a ready-made BibTeX entry: + +.. highlight:: none + +:: + + @article{Waskom2021, + doi = {10.21105/joss.03021}, + url = {https://doi.org/10.21105/joss.03021}, + year = {2021}, + publisher = {The Open Journal}, + volume = {6}, + number = {60}, + pages = {3021}, + author = {Michael L. Waskom}, + title = {seaborn: statistical data visualization}, + journal = {Journal of Open Source Software} + } + +In most situations where seaborn is cited, a citation to `matplotlib `_ would also be appropriate. + +Logo files +---------- + +Additional logo files, including hi-res PNGs and images suitable for use over a dark background, are available +`on GitHub `_. + +Wide logo +~~~~~~~~~ + +.. image:: _static/logo-wide-lightbg.svg + :width: 400px + +Tall logo +~~~~~~~~~ + +.. image:: _static/logo-tall-lightbg.svg + :width: 150px + +Logo mark +~~~~~~~~~ + +.. image:: _static/logo-mark-lightbg.svg + :width: 150px + +Credit to `Matthias Bussonnier `_ for the initial design +and implementation of the logo. \ No newline at end of file diff --git a/doc/conf.py b/doc/conf.py index ef3ad76d93..1f3bd100f0 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -13,8 +13,6 @@ import sys, os import sphinx_bootstrap_theme -import matplotlib as mpl -mpl.use("Agg") # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the @@ -39,8 +37,12 @@ 'matplotlib.sphinxext.plot_directive', 'gallery_generator', 'numpydoc', + 'sphinx_issues', ] +# Sphinx-issues configuration +issues_github_path = 'mwaskom/seaborn' + # Generate the API documentation when building autosummary_generate = True numpydoc_show_class_members = False @@ -66,7 +68,7 @@ # General information about the project. project = u'seaborn' import time -copyright = u'2012-{}, Michael Waskom'.format(time.strftime("%Y")) +copyright = u'2012-{}'.format(time.strftime("%Y")) # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -91,10 +93,10 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['_build'] +exclude_patterns = ['_build', 'docstrings'] # The reST default role (used for this markup: `text`) to use for all documents. -#default_role = None +default_role = 'literal' # If true, '()' will be appended to :func: etc. cross-reference text. #add_function_parentheses = True @@ -126,15 +128,18 @@ html_theme_options = { 'source_link_position': "footer", 'bootswatch_theme': "paper", + 'navbar_title': " ", 'navbar_sidebarrel': False, 'bootstrap_version': "3", + 'nosidebar': True, + 'body_max_width': '100%', 'navbar_links': [ - ("Gallery", "examples/index"), - ("Tutorial", "tutorial"), - ("API", "api"), - ], + ("Gallery", "examples/index"), + ("Tutorial", "tutorial"), + ("API", "api"), + ], - } +} # Add any paths that contain custom themes here, relative to this directory. html_theme_path = sphinx_bootstrap_theme.get_html_theme_path() @@ -148,7 +153,7 @@ # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +html_logo = "_static/logo-wide-lightbg.svg" # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 @@ -288,8 +293,10 @@ def setup(app): # -- Intersphinx ------------------------------------------------ -intersphinx_mapping = {'numpy': ('http://docs.scipy.org/doc/numpy/', None), - 'scipy': ('http://docs.scipy.org/doc/scipy/reference/', None), - 'matplotlib': ('http://matplotlib.org/', None), - 'pandas': ('https://pandas.pydata.org/pandas-docs/stable/', None), - 'statsmodels': ('http://www.statsmodels.org/stable/', None)} \ No newline at end of file +intersphinx_mapping = { + 'numpy': ('https://numpy.org/doc/stable/', None), + 'scipy': ('https://docs.scipy.org/doc/scipy/reference/', None), + 'matplotlib': ('https://matplotlib.org/stable', None), + 'pandas': ('https://pandas.pydata.org/pandas-docs/stable/', None), + 'statsmodels': ('https://www.statsmodels.org/stable/', None) +} diff --git a/doc/docstrings/FacetGrid.ipynb b/doc/docstrings/FacetGrid.ipynb new file mode 100644 index 0000000000..b69bfa59ca --- /dev/null +++ b/doc/docstrings/FacetGrid.ipynb @@ -0,0 +1,314 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "import seaborn as sns\n", + "sns.set_theme(style=\"ticks\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Calling the constructor requires a long-form data object. This initializes the grid, but doesn't plot anything on it:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tips = sns.load_dataset(\"tips\")\n", + "sns.FacetGrid(tips)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Assign column and/or row variables to add more subplots to the figure:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.FacetGrid(tips, col=\"time\", row=\"sex\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "To draw a plot on every facet, pass a function and the name of one or more columns in the dataframe to :meth:`FacetGrid.map`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.FacetGrid(tips, col=\"time\", row=\"sex\")\n", + "g.map(sns.scatterplot, \"total_bill\", \"tip\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The variable specification in :meth:`FacetGrid.map` requires a positional argument mapping, but if the function has a ``data`` parameter and accepts named variable assignments, you can also use :meth:`FacetGrid.map_dataframe`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.FacetGrid(tips, col=\"time\", row=\"sex\")\n", + "g.map_dataframe(sns.histplot, x=\"total_bill\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "One difference between the two methods is that :meth:`FacetGrid.map_dataframe` does not add axis labels. There is a dedicated method to do that:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.FacetGrid(tips, col=\"time\", row=\"sex\")\n", + "g.map_dataframe(sns.histplot, x=\"total_bill\")\n", + "g.set_axis_labels(\"Total bill\", \"Count\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Notice how the bins have different widths in each facet. A separate plot is drawn on each facet, so if the plotting function derives any parameters from the data, they may not be shared across facets. You can pass additional keyword arguments to synchronize them. But when possible, using a figure-level function like :func:`displot` will take care of this bookkeeping for you:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.FacetGrid(tips, col=\"time\", row=\"sex\")\n", + "g.map_dataframe(sns.histplot, x=\"total_bill\", binwidth=2)\n", + "g.set_axis_labels(\"Total bill\", \"Count\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The :class:`FacetGrid` constructor accepts a ``hue`` parameter. Setting this will condition the data on another variable and make multiple plots in different colors. Where possible, label information is tracked so that a single legend can be drawn:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.FacetGrid(tips, col=\"time\", hue=\"sex\")\n", + "g.map_dataframe(sns.scatterplot, x=\"total_bill\", y=\"tip\")\n", + "g.set_axis_labels(\"Total bill\", \"Tip\")\n", + "g.add_legend()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "When ``hue`` is set on the :class:`FacetGrid`, however, a separate plot is drawn for each level of the variable. If the plotting function understands ``hue``, it is better to let it handle that logic. It is important, however, to ensure that each facet will use the same hue mapping. In the sample ``tips`` data, the ``sex`` column has a categorical datatype, which ensures this:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.FacetGrid(tips, col=\"time\")\n", + "g.map_dataframe(sns.scatterplot, x=\"total_bill\", y=\"tip\", hue=\"sex\")\n", + "g.set_axis_labels(\"Total bill\", \"Tip\")\n", + "g.add_legend()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The size and shape of the plot is specified at the level of each subplot using the ``height`` and ``aspect`` parameters:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Change the height and aspect ratio of each facet:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.FacetGrid(tips, col=\"day\", height=3.5, aspect=.65)\n", + "g.map(sns.histplot, \"total_bill\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "If the variable assigned to ``col`` has many levels, it is possible to \"wrap\" it so that it spans multiple rows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.FacetGrid(tips, col=\"size\", height=2, col_wrap=3)\n", + "g.map(sns.histplot, \"total_bill\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can pass custom functions to plot with, or to annotate each facet. Your custom function must use the matplotlib state-machine interface to plot on the \"current\" axes, and it should catch additional keyword arguments:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "def annotate(data, **kws):\n", + " n = len(data)\n", + " ax = plt.gca()\n", + " ax.text(.1, .6, f\"N = {n}\", transform=ax.transAxes)\n", + "\n", + "g = sns.FacetGrid(tips, col=\"time\")\n", + "g.map_dataframe(sns.scatterplot, x=\"total_bill\", y=\"tip\")\n", + "g.set_axis_labels(\"Total bill\", \"Tip\")\n", + "g.map_dataframe(annotate)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The :class:`FacetGrid` object has some other useful parameters and methods for tweaking the plot:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.FacetGrid(tips, col=\"sex\", row=\"time\", margin_titles=True)\n", + "g.map_dataframe(sns.scatterplot, x=\"total_bill\", y=\"tip\")\n", + "g.set_axis_labels(\"Total bill\", \"Tip\")\n", + "g.set_titles(col_template=\"{col_name} patrons\", row_template=\"{row_name}\")\n", + "g.set(xlim=(0, 60), ylim=(0, 12), xticks=[10, 30, 50], yticks=[2, 6, 10])\n", + "g.tight_layout()\n", + "g.savefig(\"facet_plot.png\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "import os\n", + "if os.path.exists(\"facet_plot.png\"):\n", + " os.remove(\"facet_plot.png\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "You also have access to the underlying matplotlib objects for additional tweaking:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.FacetGrid(tips, col=\"sex\", row=\"time\", margin_titles=True, despine=False)\n", + "g.map_dataframe(sns.scatterplot, x=\"total_bill\", y=\"tip\")\n", + "g.set_axis_labels(\"Total bill\", \"Tip\")\n", + "g.fig.subplots_adjust(wspace=0, hspace=0)\n", + "for (row_val, col_val), ax in g.axes_dict.items():\n", + " if row_val == \"Lunch\" and col_val == \"Female\":\n", + " ax.set_facecolor(\".95\")\n", + " else:\n", + " ax.set_facecolor((0, 0, 0, 0))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "seaborn-py38-latest", + "language": "python", + "name": "seaborn-py38-latest" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/docstrings/JointGrid.ipynb b/doc/docstrings/JointGrid.ipynb new file mode 100644 index 0000000000..f01b3e2f01 --- /dev/null +++ b/doc/docstrings/JointGrid.ipynb @@ -0,0 +1,226 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "import seaborn as sns\n", + "sns.set_theme()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Calling the constructor initializes the figure, but it does not plot anything:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "penguins = sns.load_dataset(\"penguins\")\n", + "sns.JointGrid(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The simplest plotting method, :meth:`JointGrid.plot` accepts a pair of functions (one for the joint axes and one for both marginal axes):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.JointGrid(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")\n", + "g.plot(sns.scatterplot, sns.histplot)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The :meth:`JointGrid.plot` function also accepts additional keyword arguments, but it passes them to both functions:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.JointGrid(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")\n", + "g.plot(sns.scatterplot, sns.histplot, alpha=.7, edgecolor=\".2\", linewidth=.5)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "If you need to pass different keyword arguments to each function, you'll have to invoke :meth:`JointGrid.plot_joint` and :meth:`JointGrid.plot_marginals`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.JointGrid(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")\n", + "g.plot_joint(sns.scatterplot, s=100, alpha=.5)\n", + "g.plot_marginals(sns.histplot, kde=True)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "You can also set up the grid without assigning any data:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.JointGrid()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "You can then plot by accessing the ``ax_joint``, ``ax_marg_x``, and ``ax_marg_y`` attributes, which are :class:`matplotlib.axes.Axes` objects:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.JointGrid()\n", + "x, y = penguins[\"bill_length_mm\"], penguins[\"bill_depth_mm\"]\n", + "sns.scatterplot(x=x, y=y, ec=\"b\", fc=\"none\", s=100, linewidth=1.5, ax=g.ax_joint)\n", + "sns.histplot(x=x, fill=False, linewidth=2, ax=g.ax_marg_x)\n", + "sns.kdeplot(y=y, linewidth=2, ax=g.ax_marg_y)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The plotting methods can use any seaborn functions that accept ``x`` and ``y`` variables:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.JointGrid(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")\n", + "g.plot(sns.regplot, sns.boxplot)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "If the functions accept a ``hue`` variable, you can use it by assigning ``hue`` when you call the constructor:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.JointGrid(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", hue=\"species\")\n", + "g.plot(sns.scatterplot, sns.histplot)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The figure will always be square (unless you resize it at the matplotlib layer), but its overall size and layout are configurable. The size is controlled by the ``height`` parameter. The relative ratio between the joint and marginal axes is controlled by ``ratio``, and the amount of space between the plots is controlled by ``space``:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.JointGrid(height=4, ratio=2, space=.05)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "By default, the ticks on the density axis of the marginal plots are turned off, but this is configurable:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.JointGrid(marginal_ticks=True)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Limits on the two data axes (which are shared across plots) can also be defined when setting up the figure:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.JointGrid(xlim=(-2, 5), ylim=(0, 10))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "seaborn-py38-latest", + "language": "python", + "name": "seaborn-py38-latest" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/docstrings/Makefile b/doc/docstrings/Makefile new file mode 100644 index 0000000000..dd04013de1 --- /dev/null +++ b/doc/docstrings/Makefile @@ -0,0 +1,15 @@ +rst_files := $(patsubst %.ipynb,%.rst,$(wildcard *.ipynb)) + +docstrings: ${rst_files} + +%.rst: %.ipynb + @../tools/nb_to_doc.py $*.ipynb + @cp -r $*_files ../generated/ + @if [ -f ../generated/seaborn.$*.rst ]; then \ + touch ../generated/seaborn.$*.rst; \ + fi + +clean: + rm -rf *.rst + rm -rf *_files/ + rm -rf .ipynb_checkpoints/ diff --git a/doc/docstrings/PairGrid.ipynb b/doc/docstrings/PairGrid.ipynb new file mode 100644 index 0000000000..82a5b9b118 --- /dev/null +++ b/doc/docstrings/PairGrid.ipynb @@ -0,0 +1,271 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "import seaborn as sns; sns.set_theme()\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Calling the constructor sets up a blank grid of subplots with each row and one column corresponding to a numeric variable in the dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "penguins = sns.load_dataset(\"penguins\")\n", + "g = sns.PairGrid(penguins)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Passing a bivariate function to :meth:`PairGrid.map` will draw a bivariate plot on every axes:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.PairGrid(penguins)\n", + "g.map(sns.scatterplot)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Passing separate functions to :meth:`PairGrid.map_diag` and :meth:`PairGrid.map_offdiag` will show each variable's marginal distribution on the diagonal:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.PairGrid(penguins)\n", + "g.map_diag(sns.histplot)\n", + "g.map_offdiag(sns.scatterplot)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "It's also possible to use different functions on the upper and lower triangles of the plot (which are otherwise redundant):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.PairGrid(penguins, diag_sharey=False)\n", + "g.map_upper(sns.scatterplot)\n", + "g.map_lower(sns.kdeplot)\n", + "g.map_diag(sns.kdeplot)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Or to avoid the redundancy altogether:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.PairGrid(penguins, diag_sharey=False, corner=True)\n", + "g.map_lower(sns.scatterplot)\n", + "g.map_diag(sns.kdeplot)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The :class:`PairGrid` constructor accepts a ``hue`` variable. This variable is passed directly to functions that understand it:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.PairGrid(penguins, hue=\"species\")\n", + "g.map_diag(sns.histplot)\n", + "g.map_offdiag(sns.scatterplot)\n", + "g.add_legend()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "But you can also pass matplotlib functions, in which case a groupby is performed internally and a separate plot is drawn for each level:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.PairGrid(penguins, hue=\"species\")\n", + "g.map_diag(plt.hist)\n", + "g.map_offdiag(plt.scatter)\n", + "g.add_legend()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Additional semantic variables can be assigned by passing data vectors directly while mapping the function:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.PairGrid(penguins, hue=\"species\")\n", + "g.map_diag(sns.histplot)\n", + "g.map_offdiag(sns.scatterplot, size=penguins[\"sex\"])\n", + "g.add_legend(title=\"\", adjust_subtitles=True)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "When using seaborn functions that can implement a numeric hue mapping, you will want to disable mapping of the variable on the diagonal axes. Note that the ``hue`` variable is excluded from the list of variables shown by default:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.PairGrid(penguins, hue=\"body_mass_g\")\n", + "g.map_diag(sns.histplot, hue=None, color=\".3\")\n", + "g.map_offdiag(sns.scatterplot)\n", + "g.add_legend()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The ``vars`` parameter can be used to control exactly which variables are used:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "variables = [\"body_mass_g\", \"bill_length_mm\", \"flipper_length_mm\"]\n", + "g = sns.PairGrid(penguins, hue=\"body_mass_g\", vars=variables)\n", + "g.map_diag(sns.histplot, hue=None, color=\".3\")\n", + "g.map_offdiag(sns.scatterplot)\n", + "g.add_legend()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The plot need not be square: separate variables can be used to define the rows and columns:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "x_vars = [\"body_mass_g\", \"bill_length_mm\", \"bill_depth_mm\", \"flipper_length_mm\"]\n", + "y_vars = [\"body_mass_g\"]\n", + "g = sns.PairGrid(penguins, hue=\"species\", x_vars=x_vars, y_vars=y_vars)\n", + "g.map_diag(sns.histplot, color=\".3\")\n", + "g.map_offdiag(sns.scatterplot)\n", + "g.add_legend()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "It can be useful to explore different approaches to resolving multiple distributions on the diagonal axes:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.PairGrid(penguins, hue=\"species\")\n", + "g.map_diag(sns.histplot, multiple=\"stack\", element=\"step\")\n", + "g.map_offdiag(sns.scatterplot)\n", + "g.add_legend()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "seaborn-py38-latest", + "language": "python", + "name": "seaborn-py38-latest" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/docstrings/color_palette.ipynb b/doc/docstrings/color_palette.ipynb new file mode 100644 index 0000000000..0ba638c57f --- /dev/null +++ b/doc/docstrings/color_palette.ipynb @@ -0,0 +1,198 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "import seaborn as sns; sns.set_theme()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "# Add colormap display methods to matplotlib colormaps.\n", + "# These are forthcoming in matplotlib 3.4, but, the matplotlib display\n", + "# method includes the colormap name, which is redundant.\n", + "def _repr_png_(self):\n", + " \"\"\"Generate a PNG representation of the Colormap.\"\"\"\n", + " import io\n", + " from PIL import Image\n", + " import numpy as np\n", + " IMAGE_SIZE = (400, 50)\n", + " X = np.tile(np.linspace(0, 1, IMAGE_SIZE[0]), (IMAGE_SIZE[1], 1))\n", + " pixels = self(X, bytes=True)\n", + " png_bytes = io.BytesIO()\n", + " Image.fromarray(pixels).save(png_bytes, format='png')\n", + " return png_bytes.getvalue()\n", + " \n", + "def _repr_html_(self):\n", + " \"\"\"Generate an HTML representation of the Colormap.\"\"\"\n", + " import base64\n", + " png_bytes = self._repr_png_()\n", + " png_base64 = base64.b64encode(png_bytes).decode('ascii')\n", + " return ('')\n", + " \n", + "import matplotlib as mpl\n", + "mpl.colors.Colormap._repr_png_ = _repr_png_\n", + "mpl.colors.Colormap._repr_html_ = _repr_html_" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Calling with no arguments returns all colors from the current default\n", + "color cycle:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.color_palette()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Other variants on the seaborn categorical color palette can be referenced by name:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.color_palette(\"pastel\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Return a specified number of evenly spaced hues in the \"HUSL\" system:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.color_palette(\"husl\", 9)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Return all unique colors in a categorical Color Brewer palette:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.color_palette(\"Set2\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Return one of the perceptually-uniform colormaps included in seaborn:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.color_palette(\"flare\", as_cmap=True)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Return a customized cubehelix color palette:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.color_palette(\"ch:s=.25,rot=-.25\", as_cmap=True)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Return a light-themed sequential colormap to a seed color:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.color_palette(\"light:#5A9\", as_cmap=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "seaborn-py38-latest", + "language": "python", + "name": "seaborn-py38-latest" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/docstrings/displot.ipynb b/doc/docstrings/displot.ipynb new file mode 100644 index 0000000000..c8d2024d8c --- /dev/null +++ b/doc/docstrings/displot.ipynb @@ -0,0 +1,239 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "import seaborn as sns; sns.set_theme(style=\"ticks\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The default plot kind is a histogram:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "penguins = sns.load_dataset(\"penguins\")\n", + "sns.displot(data=penguins, x=\"flipper_length_mm\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Use the ``kind`` parameter to select a different representation:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(data=penguins, x=\"flipper_length_mm\", kind=\"kde\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "There are three main plot kinds; in addition to histograms and kernel density estimates (KDEs), you can also draw empirical cumulative distribution functions (ECDFs):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(data=penguins, x=\"flipper_length_mm\", kind=\"ecdf\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "While in histogram mode, it is also possible to add a KDE curve:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(data=penguins, x=\"flipper_length_mm\", kde=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To draw a bivariate plot, assign both ``x`` and ``y``:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(data=penguins, x=\"flipper_length_mm\", y=\"bill_length_mm\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Currently, bivariate plots are available only for histograms and KDEs:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(data=penguins, x=\"flipper_length_mm\", y=\"bill_length_mm\", kind=\"kde\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For each kind of plot, you can also show individual observations with a marginal \"rug\":" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.displot(data=penguins, x=\"flipper_length_mm\", y=\"bill_length_mm\", kind=\"kde\", rug=True)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Each kind of plot can be drawn separately for subsets of data using ``hue`` mapping:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(data=penguins, x=\"flipper_length_mm\", hue=\"species\", kind=\"kde\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Additional keyword arguments are passed to the appropriate underlying plotting function, allowing for further customization:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(data=penguins, x=\"flipper_length_mm\", hue=\"species\", multiple=\"stack\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The figure is constructed using a :class:`FacetGrid`, meaning that you can also show subsets on distinct subplots, or \"facets\":" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(data=penguins, x=\"flipper_length_mm\", hue=\"species\", col=\"sex\", kind=\"kde\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Because the figure is drawn with a :class:`FacetGrid`, you control its size and shape with the ``height`` and ``aspect`` parameters:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(\n", + " data=penguins, y=\"flipper_length_mm\", hue=\"sex\", col=\"species\",\n", + " kind=\"ecdf\", height=4, aspect=.7,\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The function returns the :class:`FacetGrid` object with the plot, and you can use the methods on this object to customize it further:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.displot(\n", + " data=penguins, y=\"flipper_length_mm\", hue=\"sex\", col=\"species\",\n", + " kind=\"kde\", height=4, aspect=.7,\n", + ")\n", + "g.set_axis_labels(\"Density (a.u.)\", \"Flipper length (mm)\")\n", + "g.set_titles(\"{col_name} penguins\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "seaborn-py38-latest", + "language": "python", + "name": "seaborn-py38-latest" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/docstrings/ecdfplot.ipynb b/doc/docstrings/ecdfplot.ipynb new file mode 100644 index 0000000000..3775a4ecdf --- /dev/null +++ b/doc/docstrings/ecdfplot.ipynb @@ -0,0 +1,142 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot a univariate distribution along the x axis:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "import seaborn as sns; sns.set_theme()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "penguins = sns.load_dataset(\"penguins\")\n", + "sns.ecdfplot(data=penguins, x=\"flipper_length_mm\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Flip the plot by assigning the data variable to the y axis:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.ecdfplot(data=penguins, y=\"flipper_length_mm\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If neither `x` nor `y` is assigned, the dataset is treated as wide-form, and a histogram is drawn for each numeric column:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.ecdfplot(data=penguins.filter(like=\"bill_\", axis=\"columns\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can also draw multiple histograms from a long-form dataset with hue mapping:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.ecdfplot(data=penguins, x=\"bill_length_mm\", hue=\"species\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The default distribution statistic is normalized to show a proportion, but you can show absolute counts instead:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.ecdfplot(data=penguins, x=\"bill_length_mm\", hue=\"species\", stat=\"count\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It's also possible to plot the empirical complementary CDF (1 - CDF):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.ecdfplot(data=penguins, x=\"bill_length_mm\", hue=\"species\", complementary=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "seaborn-refactor (py38)", + "language": "python", + "name": "seaborn-refactor" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/docstrings/histplot.ipynb b/doc/docstrings/histplot.ipynb new file mode 100644 index 0000000000..99ed6c551d --- /dev/null +++ b/doc/docstrings/histplot.ipynb @@ -0,0 +1,483 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "import seaborn as sns\n", + "sns.set_theme(style=\"white\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Assign a variable to ``x`` to plot a univariate distribution along the x axis:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "penguins = sns.load_dataset(\"penguins\")\n", + "sns.histplot(data=penguins, x=\"flipper_length_mm\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Flip the plot by assigning the data variable to the y axis:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.histplot(data=penguins, y=\"flipper_length_mm\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Check how well the histogram represents the data by specifying a different bin width:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.histplot(data=penguins, x=\"flipper_length_mm\", binwidth=3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can also define the total number of bins to use:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.histplot(data=penguins, x=\"flipper_length_mm\", bins=30)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Add a kernel density estimate to smooth the histogram, providing complementary information about the shape of the distribution:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.histplot(data=penguins, x=\"flipper_length_mm\", kde=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If neither `x` nor `y` is assigned, the dataset is treated as wide-form, and a histogram is drawn for each numeric column:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.histplot(data=penguins)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can otherwise draw multiple histograms from a long-form dataset with hue mapping:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.histplot(data=penguins, x=\"flipper_length_mm\", hue=\"species\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The default approach to plotting multiple distributions is to \"layer\" them, but you can also \"stack\" them:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.histplot(data=penguins, x=\"flipper_length_mm\", hue=\"species\", multiple=\"stack\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Overlapping bars can be hard to visually resolve. A different approach would be to draw a step function:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.histplot(penguins, x=\"flipper_length_mm\", hue=\"species\", element=\"step\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can move even farther away from bars by drawing a polygon with vertices in the center of each bin. This may make it easier to see the shape of the distribution, but use with caution: it will be less obvious to your audience that they are looking at a histogram:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.histplot(penguins, x=\"flipper_length_mm\", hue=\"species\", element=\"poly\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To compare the distribution of subsets that differ substantially in size, use indepdendent density normalization:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.histplot(\n", + " penguins, x=\"bill_length_mm\", hue=\"island\", element=\"step\",\n", + " stat=\"density\", common_norm=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It's also possible to normalize so that each bar's height shows a probability, which make more sense for discrete variables:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tips = sns.load_dataset(\"tips\")\n", + "sns.histplot(data=tips, x=\"size\", stat=\"probability\", discrete=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can even draw a histogram over categorical variables (although this is an experimental feature):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.histplot(data=tips, x=\"day\", shrink=.8)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When using a ``hue`` semantic with discrete data, it can make sense to \"dodge\" the levels:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.histplot(data=tips, x=\"day\", hue=\"sex\", multiple=\"dodge\", shrink=.8)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Real-world data is often skewed. For heavily skewed distributions, it's better to define the bins in log space. Compare:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "planets = sns.load_dataset(\"planets\")\n", + "sns.histplot(data=planets, x=\"distance\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To the log-scale version:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.histplot(data=planets, x=\"distance\", log_scale=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "There are also a number of options for how the histogram appears. You can show unfilled bars:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.histplot(data=planets, x=\"distance\", log_scale=True, fill=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Or an unfilled step function:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.histplot(data=planets, x=\"distance\", log_scale=True, element=\"step\", fill=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Step functions, esepcially when unfilled, make it easy to compare cumulative histograms:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.histplot(\n", + " data=planets, x=\"distance\", hue=\"method\",\n", + " hue_order=[\"Radial Velocity\", \"Transit\"],\n", + " log_scale=True, element=\"step\", fill=False,\n", + " cumulative=True, stat=\"density\", common_norm=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When both ``x`` and ``y`` are assigned, a bivariate histogram is computed and shown as a heatmap:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.histplot(penguins, x=\"bill_depth_mm\", y=\"body_mass_g\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It's possible to assign a ``hue`` variable too, although this will not work well if data from the different levels have substantial overlap:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.histplot(penguins, x=\"bill_depth_mm\", y=\"body_mass_g\", hue=\"species\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Multiple color maps can make sense when one of the variables is discrete:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.histplot(\n", + " penguins, x=\"bill_depth_mm\", y=\"species\", hue=\"species\", legend=False\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The bivariate histogram accepts all of the same options for computation as its univariate counterpart, using tuples to parametrize ``x`` and ``y`` independently:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.histplot(\n", + " planets, x=\"year\", y=\"distance\",\n", + " bins=30, discrete=(True, False), log_scale=(False, True),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The default behavior makes cells with no observations transparent, although this can be disabled: " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.histplot(\n", + " planets, x=\"year\", y=\"distance\",\n", + " bins=30, discrete=(True, False), log_scale=(False, True),\n", + " thresh=None,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It's also possible to set the threshold and colormap saturation point in terms of the proportion of cumulative counts:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.histplot(\n", + " planets, x=\"year\", y=\"distance\",\n", + " bins=30, discrete=(True, False), log_scale=(False, True),\n", + " pthresh=.05, pmax=.9,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To annotate the colormap, add a colorbar:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.histplot(\n", + " planets, x=\"year\", y=\"distance\",\n", + " bins=30, discrete=(True, False), log_scale=(False, True),\n", + " cbar=True, cbar_kws=dict(shrink=.75),\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "seaborn-py38-latest", + "language": "python", + "name": "seaborn-py38-latest" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/docstrings/jointplot.ipynb b/doc/docstrings/jointplot.ipynb new file mode 100644 index 0000000000..201877c36a --- /dev/null +++ b/doc/docstrings/jointplot.ipynb @@ -0,0 +1,194 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "import seaborn as sns\n", + "sns.set_theme(style=\"white\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "In the simplest invocation, assign ``x`` and ``y`` to create a scatterplot (using :func:`scatterplot`) with marginal histograms (using :func:`histplot`):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "penguins = sns.load_dataset(\"penguins\")\n", + "sns.jointplot(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Assigning a ``hue`` variable will add conditional colors to the scatterplot and draw separate density curves (using :func:`kdeplot`) on the marginal axes:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.jointplot(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", hue=\"species\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Several different approaches to plotting are available through the ``kind`` parameter. Setting ``kind=\"kde\"`` will draw both bivariate and univariate KDEs:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.jointplot(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", hue=\"species\", kind=\"kde\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Set ``kind=\"reg\"`` to add a linear regression fit (using :func:`regplot`) and univariate KDE curves:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.jointplot(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", kind=\"reg\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "There are also two options for bin-based visualization of the joint distribution. The first, with ``kind=\"hist\"``, uses :func:`histplot` on all of the axes:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.jointplot(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", kind=\"hist\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Alternatively, setting ``kind=\"hex\"`` will use :meth:`matplotlib.axes.Axes.hexbin` to compute a bivariate histogram using hexagonal bins:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.jointplot(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", kind=\"hex\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Additional keyword arguments can be passed down to the underlying plots:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.jointplot(\n", + " data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\",\n", + " marker=\"+\", s=100, marginal_kws=dict(bins=25, fill=False),\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Use :class:`JointGrid` parameters to control the size and layout of the figure:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.jointplot(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", height=5, ratio=2, marginal_ticks=True)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "To add more layers onto the plot, use the methods on the :class:`JointGrid` object that :func:`jointplot` returns:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.jointplot(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")\n", + "g.plot_joint(sns.kdeplot, color=\"r\", zorder=0, levels=6)\n", + "g.plot_marginals(sns.rugplot, color=\"r\", height=-.15, clip_on=False)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "seaborn-py38-latest", + "language": "python", + "name": "seaborn-py38-latest" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/docstrings/kdeplot.ipynb b/doc/docstrings/kdeplot.ipynb new file mode 100644 index 0000000000..5a355bbea2 --- /dev/null +++ b/doc/docstrings/kdeplot.ipynb @@ -0,0 +1,349 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "import seaborn as sns; sns.set_theme()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot a univariate distribution along the x axis:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tips = sns.load_dataset(\"tips\")\n", + "sns.kdeplot(data=tips, x=\"total_bill\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Flip the plot by assigning the data variable to the y axis:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.kdeplot(data=tips, y=\"total_bill\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot distributions for each column of a wide-form dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "iris = sns.load_dataset(\"iris\")\n", + "sns.kdeplot(data=iris)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Use less smoothing:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.kdeplot(data=tips, x=\"total_bill\", bw_adjust=.2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Use more smoothing, but don't smooth past the extreme data points:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ax= sns.kdeplot(data=tips, x=\"total_bill\", bw_adjust=5, cut=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot conditional distributions with hue mapping of a second variable:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.kdeplot(data=tips, x=\"total_bill\", hue=\"time\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"Stack\" the conditional distributions:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.kdeplot(data=tips, x=\"total_bill\", hue=\"time\", multiple=\"stack\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Normalize the stacked distribution at each value in the grid:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.kdeplot(data=tips, x=\"total_bill\", hue=\"time\", multiple=\"fill\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Estimate the cumulative distribution function(s), normalizing each subset:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.kdeplot(\n", + " data=tips, x=\"total_bill\", hue=\"time\",\n", + " cumulative=True, common_norm=False, common_grid=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Estimate distribution from aggregated data, using weights:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tips_agg = (tips\n", + " .groupby(\"size\")\n", + " .agg(total_bill=(\"total_bill\", \"mean\"), n=(\"total_bill\", \"count\"))\n", + ")\n", + "sns.kdeplot(data=tips_agg, x=\"total_bill\", weights=\"n\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Map the data variable with log scaling:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "diamonds = sns.load_dataset(\"diamonds\")\n", + "sns.kdeplot(data=diamonds, x=\"price\", log_scale=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Use numeric hue mapping:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.kdeplot(data=tips, x=\"total_bill\", hue=\"size\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Modify the appearance of the plot:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.kdeplot(\n", + " data=tips, x=\"total_bill\", hue=\"size\",\n", + " fill=True, common_norm=False, palette=\"crest\",\n", + " alpha=.5, linewidth=0,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot a bivariate distribution:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "geyser = sns.load_dataset(\"geyser\")\n", + "sns.kdeplot(data=geyser, x=\"waiting\", y=\"duration\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Map a third variable with a hue semantic to show conditional distributions:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.kdeplot(data=geyser, x=\"waiting\", y=\"duration\", hue=\"kind\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Show filled contours:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.kdeplot(\n", + " data=geyser, x=\"waiting\", y=\"duration\", hue=\"kind\", fill=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Show fewer contour levels, covering less of the distribution:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.kdeplot(\n", + " data=geyser, x=\"waiting\", y=\"duration\", hue=\"kind\",\n", + " levels=5, thresh=.2,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Fill the axes extent with a smooth distribution, using a different colormap:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.kdeplot(\n", + " data=geyser, x=\"waiting\", y=\"duration\",\n", + " fill=True, thresh=0, levels=100, cmap=\"mako\",\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "seaborn-py38-latest", + "language": "python", + "name": "seaborn-py38-latest" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/docstrings/lineplot.ipynb b/doc/docstrings/lineplot.ipynb new file mode 100644 index 0000000000..9ae0a2477a --- /dev/null +++ b/doc/docstrings/lineplot.ipynb @@ -0,0 +1,434 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "import matplotlib as mpl\n", + "import matplotlib.pyplot as plt\n", + "sns.set_theme()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The ``flights`` dataset has 10 years of monthly airline passenger data:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "flights = sns.load_dataset(\"flights\")\n", + "flights.head()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "To draw a line plot using long-form data, assign the ``x`` and ``y`` variables:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "may_flights = flights.query(\"month == 'May'\")\n", + "sns.lineplot(data=may_flights, x=\"year\", y=\"passengers\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Pivot the dataframe to a wide-form representation:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "flights_wide = flights.pivot(\"year\", \"month\", \"passengers\")\n", + "flights_wide.head()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "To plot a single vector, pass it to ``data``. If the vector is a :class:`pandas.Series`, it will be plotted against its index:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.lineplot(data=flights_wide[\"May\"])" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Passing the entire wide-form dataset to ``data`` plots a separate line for each column:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.lineplot(data=flights_wide)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Passing the entire dataset in long-form mode will aggregate over repeated values (each year) to show the mean and 95% confidence interval:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.lineplot(data=flights, x=\"year\", y=\"passengers\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Assign a grouping semantic (``hue``, ``size``, or ``style``) to plot separate lines" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.lineplot(data=flights, x=\"year\", y=\"passengers\", hue=\"month\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The same column can be assigned to multiple semantic variables, which can increase the accessibility of the plot:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.lineplot(data=flights, x=\"year\", y=\"passengers\", hue=\"month\", style=\"month\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Each semantic variable can also represent a different column. For that, we'll need a more complex dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fmri = sns.load_dataset(\"fmri\")\n", + "fmri.head()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Repeated observations are aggregated even when semantic grouping is used:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.lineplot(data=fmri, x=\"timepoint\", y=\"signal\", hue=\"event\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Assign both ``hue`` and ``style`` to represent two different grouping variables:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.lineplot(data=fmri, x=\"timepoint\", y=\"signal\", hue=\"region\", style=\"event\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "When assigning a ``style`` variable, markers can be used instead of (or along with) dashes to distinguish the groups:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.lineplot(\n", + " data=fmri,\n", + " x=\"timepoint\", y=\"signal\", hue=\"event\", style=\"event\",\n", + " markers=True, dashes=False\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Show error bars instead of error bands and extend them to two standard error widths:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.lineplot(\n", + " data=fmri, x=\"timepoint\", y=\"signal\", hue=\"event\", err_style=\"bars\", errorbar=(\"se\", 2),\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Assigning the ``units`` variable will plot multiple lines without applying a semantic mapping:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.lineplot(\n", + " data=fmri.query(\"region == 'frontal'\"),\n", + " x=\"timepoint\", y=\"signal\", hue=\"event\", units=\"subject\",\n", + " estimator=None, lw=1,\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Load another dataset with a numeric grouping variable:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dots = sns.load_dataset(\"dots\").query(\"align == 'dots'\")\n", + "dots.head()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Assigning a numeric variable to ``hue`` maps it differently, using a different default palette and a quantitative color mapping:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.lineplot(\n", + " data=dots, x=\"time\", y=\"firing_rate\", hue=\"coherence\", style=\"choice\",\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Control the color mapping by setting the ``palette`` and passing a :class:`matplotlib.colors.Normalize` object:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.lineplot(\n", + " data=dots.query(\"coherence > 0\"),\n", + " x=\"time\", y=\"firing_rate\", hue=\"coherence\", style=\"choice\",\n", + " palette=\"flare\", hue_norm=mpl.colors.LogNorm(),\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Or pass specific colors, either as a Python list or dictionary:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "palette = sns.color_palette(\"mako_r\", 6)\n", + "sns.lineplot(\n", + " data=dots, x=\"time\", y=\"firing_rate\",\n", + " hue=\"coherence\", style=\"choice\",\n", + " palette=palette\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Assign the ``size`` semantic to map the width of the lines with a numeric variable:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.lineplot(\n", + " data=dots, x=\"time\", y=\"firing_rate\",\n", + " size=\"coherence\", hue=\"choice\",\n", + " legend=\"full\"\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Pass a a tuple, ``sizes=(smallest, largest)``, to control the range of linewidths used to map the ``size`` semantic:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.lineplot(\n", + " data=dots, x=\"time\", y=\"firing_rate\",\n", + " size=\"coherence\", hue=\"choice\",\n", + " sizes=(.25, 2.5)\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "By default, the observations are sorted by ``x``. Disable this to plot a line with the order that observations appear in the dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "x, y = np.random.normal(size=(2, 5000)).cumsum(axis=1)\n", + "sns.lineplot(x=x, y=y, sort=False, lw=1)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Use :func:`relplot` to combine :func:`lineplot` and :class:`FacetGrid`. This allows grouping within additional categorical variables. Using :func:`relplot` is safer than using :class:`FacetGrid` directly, as it ensures synchronization of the semantic mappings across facets:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.relplot(\n", + " data=fmri, x=\"timepoint\", y=\"signal\",\n", + " col=\"region\", hue=\"event\", style=\"event\",\n", + " kind=\"line\"\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "seaborn-py38-latest", + "language": "python", + "name": "seaborn-py38-latest" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/docstrings/pairplot.ipynb b/doc/docstrings/pairplot.ipynb new file mode 100644 index 0000000000..af3d6d6685 --- /dev/null +++ b/doc/docstrings/pairplot.ipynb @@ -0,0 +1,225 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "import seaborn as sns\n", + "sns.set_theme(style=\"ticks\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The simplest invocation uses :func:`scatterplot` for each pairing of the variables and :func:`histplot` for the marginal plots along the diagonal:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "penguins = sns.load_dataset(\"penguins\")\n", + "sns.pairplot(penguins)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Assigning a ``hue`` variable adds a semantic mapping and changes the default marginal plot to a layered kernel density estimate (KDE):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.pairplot(penguins, hue=\"species\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "It's possible to force marginal histograms:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.pairplot(penguins, hue=\"species\", diag_kind=\"hist\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The ``kind`` parameter determines both the diagonal and off-diagonal plotting style. Several options are available, including using :func:`kdeplot` to draw KDEs:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.pairplot(penguins, kind=\"kde\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Or :func:`histplot` to draw both bivariate and univariate histograms:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.pairplot(penguins, kind=\"hist\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The ``markers`` parameter applies a style mapping on the off-diagonal axes. Currently, it will be redundant with the ``hue`` variable:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.pairplot(penguins, hue=\"species\", markers=[\"o\", \"s\", \"D\"])" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "As with other figure-level functions, the size of the figure is controlled by setting the ``height`` of each individual subplot:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.pairplot(penguins, height=1.5)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Use ``vars`` or ``x_vars`` and ``y_vars`` to select the variables to plot:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.pairplot(\n", + " penguins,\n", + " x_vars=[\"bill_length_mm\", \"bill_depth_mm\", \"flipper_length_mm\"],\n", + " y_vars=[\"bill_length_mm\", \"bill_depth_mm\"],\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Set ``corner=True`` to plot only the lower triangle:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.pairplot(penguins, corner=True)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The ``plot_kws`` and ``diag_kws`` parameters accept dicts of keyword arguments to customize the off-diagonal and diagonal plots, respectively:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.pairplot(\n", + " penguins,\n", + " plot_kws=dict(marker=\"+\", linewidth=1),\n", + " diag_kws=dict(fill=False),\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The return object is the underlying :class:`PairGrid`, which can be used to further customize the plot:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.pairplot(penguins, diag_kind=\"kde\")\n", + "g.map_lower(sns.kdeplot, levels=4, color=\".2\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "seaborn-py38-latest", + "language": "python", + "name": "seaborn-py38-latest" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/docstrings/relplot.ipynb b/doc/docstrings/relplot.ipynb new file mode 100644 index 0000000000..ac5a3a2fd8 --- /dev/null +++ b/doc/docstrings/relplot.ipynb @@ -0,0 +1,262 @@ +{ + "cells": [ + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "These examples will illustrate only some of the functionality that :func:`relplot` is capable of. For more information, consult the examples for :func:`scatterplot` and :func:`lineplot`, which are used when ``kind=\"scatter\"`` or ``kind=\"line\"``, respectively." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "import seaborn as sns\n", + "import matplotlib.pyplot as plt\n", + "sns.set_theme(style=\"ticks\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "To illustrate ``kind=\"scatter\"`` (the default style of plot), we will use the \"tips\" dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tips = sns.load_dataset(\"tips\")\n", + "tips.head()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Assigning ``x`` and ``y`` and any semantic mapping variables will draw a single plot:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.relplot(data=tips, x=\"total_bill\", y=\"tip\", hue=\"day\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Assigning a ``col`` variable creates a faceted figure with multiple subplots arranged across the columns of the grid:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.relplot(data=tips, x=\"total_bill\", y=\"tip\", hue=\"day\", col=\"time\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Different variables can be assigned to facet on both the columns and rows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.relplot(data=tips, x=\"total_bill\", y=\"tip\", hue=\"day\", col=\"time\", row=\"sex\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "When the variable assigned to ``col`` has many levels, it can be \"wrapped\" across multiple rows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.relplot(data=tips, x=\"total_bill\", y=\"tip\", hue=\"time\", col=\"day\", col_wrap=2)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Assigning multiple semantic variables can show multi-dimensional relationships, but be mindful to avoid making an overly-complicated plot." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.relplot(\n", + " data=tips, x=\"total_bill\", y=\"tip\", col=\"time\",\n", + " hue=\"time\", size=\"size\", style=\"sex\",\n", + " palette=[\"b\", \"r\"], sizes=(10, 100)\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "When there is a natural continuity to one of the variables, it makes more sense to show lines instead of points. To draw the figure using :func:`lineplot`, set ``kind=\"line\"``. We will illustrate this effect with the \"fmri dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fmri = sns.load_dataset(\"fmri\")\n", + "fmri.head()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Using ``kind=\"line\"`` offers the same flexibility for semantic mappings as ``kind=\"scatter\"``, but :func:`lineplot` transforms the data more before plotting. Observations are sorted by their ``x`` value, and repeated observations are aggregated. By default, the resulting plot shows the mean and 95% CI for each unit" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.relplot(\n", + " data=fmri, x=\"timepoint\", y=\"signal\", col=\"region\",\n", + " hue=\"event\", style=\"event\", kind=\"line\",\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The size and shape of the figure is parametrized by the ``height`` and ``aspect`` ratio of each individual facet:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.relplot(\n", + " data=fmri,\n", + " x=\"timepoint\", y=\"signal\",\n", + " hue=\"event\", style=\"event\", col=\"region\",\n", + " height=4, aspect=.7, kind=\"line\"\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The object returned by :func:`relplot` is always a :class:`FacetGrid`, which has several methods that allow you to quickly tweak the title, labels, and other aspects of the plot:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.relplot(\n", + " data=fmri,\n", + " x=\"timepoint\", y=\"signal\",\n", + " hue=\"event\", style=\"event\", col=\"region\",\n", + " height=4, aspect=.7, kind=\"line\"\n", + ")\n", + "(g.map(plt.axhline, y=0, color=\".7\", dashes=(2, 1), zorder=0)\n", + " .set_axis_labels(\"Timepoint\", \"Percent signal change\")\n", + " .set_titles(\"Region: {col_name} cortex\")\n", + " .tight_layout(w_pad=0))" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "It is also possible to use wide-form data with :func:`relplot`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "flights_wide = sns.load_dataset(\"flights\").pivot(\"year\", \"month\", \"passengers\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Faceting is not an option in this case, but the plot will still take advantage of the external legend offered by :class:`FacetGrid`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.relplot(data=flights_wide, kind=\"line\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "seaborn-py38-latest", + "language": "python", + "name": "seaborn-py38-latest" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/docstrings/rugplot.ipynb b/doc/docstrings/rugplot.ipynb new file mode 100644 index 0000000000..789102d927 --- /dev/null +++ b/doc/docstrings/rugplot.ipynb @@ -0,0 +1,137 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Add a rug along one of the axes:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import seaborn as sns; sns.set_theme()\n", + "tips = sns.load_dataset(\"tips\")\n", + "sns.kdeplot(data=tips, x=\"total_bill\")\n", + "sns.rugplot(data=tips, x=\"total_bill\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Add a rug along both axes:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\")\n", + "sns.rugplot(data=tips, x=\"total_bill\", y=\"tip\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Represent a third variable with hue mapping:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\", hue=\"time\")\n", + "sns.rugplot(data=tips, x=\"total_bill\", y=\"tip\", hue=\"time\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Draw a taller rug:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\")\n", + "sns.rugplot(data=tips, x=\"total_bill\", y=\"tip\", height=.1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Put the rug outside the axes:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\")\n", + "sns.rugplot(data=tips, x=\"total_bill\", y=\"tip\", height=-.02, clip_on=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Show the density of a larger dataset using thinner lines and alpha blending:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "diamonds = sns.load_dataset(\"diamonds\")\n", + "sns.scatterplot(data=diamonds, x=\"carat\", y=\"price\", s=5)\n", + "sns.rugplot(data=diamonds, x=\"carat\", y=\"price\", lw=1, alpha=.005)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "seaborn-refactor (py38)", + "language": "python", + "name": "seaborn-refactor" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/docstrings/scatterplot.ipynb b/doc/docstrings/scatterplot.ipynb new file mode 100644 index 0000000000..db93462b70 --- /dev/null +++ b/doc/docstrings/scatterplot.ipynb @@ -0,0 +1,307 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "import matplotlib.pyplot as plt\n", + "sns.set_theme()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "These examples will use the \"tips\" dataset, which has a mixture of numeric and categorical variables:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tips = sns.load_dataset(\"tips\")\n", + "tips.head()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Passing long-form data and assigning ``x`` and ``y`` will draw a scatter plot between two variables:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Assigning a variable to ``hue`` will map its levels to the color of the points:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\", hue=\"time\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Assigning the same variable to ``style`` will also vary the markers and create a more accessible plot:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\", hue=\"time\", style=\"time\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Assigning ``hue`` and ``style`` to different variables will vary colors and markers independently:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\", hue=\"day\", style=\"time\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "If the variable assigned to ``hue`` is numeric, the semantic mapping will be quantitative and use a different default palette:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\", hue=\"size\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Pass the name of a categorical palette or explicit colors (as a Python list of dictionary) to force categorical mapping of the ``hue`` variable:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\", hue=\"size\", palette=\"deep\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "If there are a large number of unique numeric values, the legend will show a representative, evenly-spaced set:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tip_rate = tips.eval(\"tip / total_bill\").rename(\"tip_rate\")\n", + "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\", hue=tip_rate)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "A numeric variable can also be assigned to ``size`` to apply a semantic mapping to the areas of the points:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\", hue=\"size\", size=\"size\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Control the range of marker areas with ``sizes``, and set ``lengend=\"full\"`` to force every unique value to appear in the legend:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.scatterplot(\n", + " data=tips, x=\"total_bill\", y=\"tip\", hue=\"size\", size=\"size\",\n", + " sizes=(20, 200), legend=\"full\"\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Pass a tuple of values or a :class:`matplotlib.colors.Normalize` object to ``hue_norm`` to control the quantitative hue mapping:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.scatterplot(\n", + " data=tips, x=\"total_bill\", y=\"tip\", hue=\"size\", size=\"size\",\n", + " sizes=(20, 200), hue_norm=(0, 7), legend=\"full\"\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Control the specific markers used to map the ``style`` variable by passing a Python list or dictionary of marker codes:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "markers = {\"Lunch\": \"s\", \"Dinner\": \"X\"}\n", + "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\", style=\"time\", markers=markers)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Additional keyword arguments are passed to :meth:`matplotlib.axes.Axes.scatter`, allowing you to directly set the attributes of the plot that are not semantically mapped:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\", s=100, color=\".2\", marker=\"+\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The previous examples used a long-form dataset. When working with wide-form data, each column will be plotted against its index using both ``hue`` and ``style`` mapping:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "index = pd.date_range(\"1 1 2000\", periods=100, freq=\"m\", name=\"date\")\n", + "data = np.random.randn(100, 4).cumsum(axis=0)\n", + "wide_df = pd.DataFrame(data, index, [\"a\", \"b\", \"c\", \"d\"])\n", + "sns.scatterplot(data=wide_df)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Use :func:`relplot` to combine :func:`scatterplot` and :class:`FacetGrid`. This allows grouping within additional categorical variables, and plotting them across multiple subplots.\n", + "\n", + "Using :func:`relplot` is safer than using :class:`FacetGrid` directly, as it ensures synchronization of the semantic mappings across facets." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.relplot(\n", + " data=tips, x=\"total_bill\", y=\"tip\",\n", + " col=\"time\", hue=\"day\", style=\"day\",\n", + " kind=\"scatter\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "seaborn-py38-latest", + "language": "python", + "name": "seaborn-py38-latest" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/docstrings/stripplot.ipynb b/doc/docstrings/stripplot.ipynb new file mode 100644 index 0000000000..5d787e9e78 --- /dev/null +++ b/doc/docstrings/stripplot.ipynb @@ -0,0 +1,313 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "import seaborn as sns\n", + "sns.set_theme(style=\"whitegrid\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Assigning a single numeric variable shows its univariate distribution with points randomly \"jittered\" on the other axis:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tips = sns.load_dataset(\"tips\")\n", + "sns.stripplot(data=tips, x=\"total_bill\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Assigning a second variable splits the strips of points to compare categorical levels of that variable:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(data=tips, x=\"total_bill\", y=\"day\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Show vertically-oriented strips by swapping the assignment of the categorical and numerical variables:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(data=tips, x=\"day\", y=\"total_bill\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Prior to version 0.12, the levels of the categorical variable had different colors by default. To get the same effect, assign the `hue` variable explicitly:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"day\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Or you can assign a distinct variable to `hue` to show a multidimensional relationship:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"sex\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "If the `hue` variable is numeric, it will be mapped with a quantitative palette by default (note that this was not the case prior to version 0.12):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"size\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Use `palette` to control the color mapping, including forcing a categorical mapping by passing the name of a qualitative palette:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"size\", palette=\"deep\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "By default, the different levels of the `hue` variable are intermingled in each strip, but setting `dodge=True` will split them:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"sex\", dodge=True)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The random jitter can be disabled by setting `jitter=False`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"sex\", dodge=True, jitter=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If plotting in wide-form mode, each numeric column of the dataframe will be mapped to both `x` and `hue`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(data=tips)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "To change the orientation while in wide-form mode, pass `orient` explicitly:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(data=tips, orient=\"h\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The `orient` parameter is also useful when both axis variables are numeric, as it will resolve ambiguity about which dimension to group (and jitter) along:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(data=tips, x=\"total_bill\", y=\"size\", orient=\"h\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "By default, the categorical variable will be mapped to discrete indices with a fixed scale (0, 1, ...), even when it is numeric:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(\n", + " data=tips.query(\"size in [2, 3, 5]\"),\n", + " x=\"total_bill\", y=\"size\", orient=\"h\",\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "To disable this behavior and use the original scale of the variable, set `fixed_scale=False`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(\n", + " data=tips.query(\"size in [2, 3, 5]\"),\n", + " x=\"total_bill\", y=\"size\", orient=\"h\",\n", + " fixed_scale=False,\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Further visual customization can be achieved by passing keyword arguments for :func:`matplotlib.axes.Axes.scatter`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.stripplot(\n", + " data=tips, x=\"total_bill\", y=\"day\", hue=\"time\",\n", + " jitter=False, s=20, marker=\"D\", linewidth=1, alpha=.1,\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "To make a plot with multiple facets, it is safer to use :func:`catplot` than to work with :class:`FacetGrid` directly, because :func:`catplot` will ensure that the categorical and hue variables are properly synchronized in each facet:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.catplot(data=tips, x=\"time\", y=\"total_bill\", hue=\"sex\", col=\"day\", aspect=.5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "seaborn-py38-latest", + "language": "python", + "name": "seaborn-py38-latest" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/docstrings/swarmplot.ipynb b/doc/docstrings/swarmplot.ipynb new file mode 100644 index 0000000000..ebc74b92ab --- /dev/null +++ b/doc/docstrings/swarmplot.ipynb @@ -0,0 +1,285 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "import seaborn as sns\n", + "sns.set_theme(style=\"whitegrid\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Assigning a single numeric variable shows its univariate distribution with points adjusted along on the other axis such that they don't overlap:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tips = sns.load_dataset(\"tips\")\n", + "sns.swarmplot(data=tips, x=\"total_bill\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Assigning a second variable splits the groups of points to compare categorical levels of that variable:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.swarmplot(data=tips, x=\"total_bill\", y=\"day\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Show vertically-oriented swarms by swapping the assignment of the categorical and numerical variables:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.swarmplot(data=tips, x=\"day\", y=\"total_bill\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Prior to version 0.12, the levels of the categorical variable had different colors by default. To get the same effect, assign the `hue` variable explicitly:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.swarmplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"day\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Or you can assign a distinct variable to `hue` to show a multidimensional relationship:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.swarmplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"sex\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "If the `hue` variable is numeric, it will be mapped with a quantitative palette by default (note that this was not the case prior to version 0.12):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.swarmplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"size\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Use `palette` to control the color mapping, including forcing a categorical mapping by passing the name of a qualitative palette:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.swarmplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"size\", palette=\"deep\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "By default, the different levels of the `hue` variable are intermingled in each swarm, but setting `dodge=True` will split them:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.swarmplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"sex\", dodge=True)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The \"orientation\" of the plot (defined as the direction along which quantitative relationships are preserved) is usualy inferred automatically. But in ambiguous cases, such as when both axis variables are numeric, it can be specified:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.swarmplot(data=tips, x=\"total_bill\", y=\"size\", orient=\"h\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "When the local density of points is too high, they will be forced to overlap in the \"gutters\" of each swarm and a warning will be issued. Decreasing the size of the points can help to avoid this problem:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.swarmplot(data=tips, x=\"total_bill\", y=\"size\", orient=\"h\", size=3)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "By default, the categorical variable will be mapped to discrete indices with a fixed scale (0, 1, ...), even when it is numeric:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.swarmplot(\n", + " data=tips.query(\"size in [2, 3, 5]\"),\n", + " x=\"total_bill\", y=\"size\", orient=\"h\",\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "To disable this behavior and use the original scale of the variable, set `fixed_scale=False` (notice how this also changes the order of the variables on the y axis):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.swarmplot(\n", + " data=tips.query(\"size in [2, 3, 5]\"),\n", + " x=\"total_bill\", y=\"size\", orient=\"h\",\n", + " fixed_scale=False,\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Further visual customization can be achieved by passing keyword arguments for :func:`matplotlib.axes.Axes.scatter`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.swarmplot(\n", + " data=tips, x=\"total_bill\", y=\"day\", hue=\"time\",\n", + " marker=\"x\", linewidth=1, \n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "To make a plot with multiple facets, it is safer to use :func:`catplot` with `kind=\"swarm\"` than to work with :class:`FacetGrid` directly, because :func:`catplot` will ensure that the categorical and hue variables are properly synchronized in each facet:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.catplot(\n", + " data=tips, kind=\"swarm\",\n", + " x=\"time\", y=\"total_bill\", hue=\"sex\", col=\"day\",\n", + " aspect=.5\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "seaborn-py38-latest", + "language": "python", + "name": "seaborn-py38-latest" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/index.rst b/doc/index.rst index bb75b0fc25..97d93cf5e2 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -37,9 +37,14 @@ seaborn: statistical data visualization - +
- + +
+
+ +
+
@@ -52,11 +57,6 @@ seaborn: statistical data visualization - -
- -
-

@@ -70,14 +70,19 @@ Seaborn is a Python data visualization library based on `matplotlib attractive and informative statistical graphics. For a brief introduction to the ideas behind the library, you can read the -:doc:`introductory notes `. Visit the :doc:`installation page -` to see how you can download the package. You can browse the -:doc:`example gallery ` to see what you can do with seaborn, -and then check out the :doc:`tutorial ` and :doc:`API reference -` to find out how. - -To see the code or report a bug, please visit the `github repository -`_. General support issues are most at home on `stackoverflow `_, where there is a seaborn tag. +:doc:`introductory notes ` or the `paper +`_. Visit the +:doc:`installation page ` to see how you can download the package +and get started with it. You can browse the :doc:`example gallery +` to see some of the things that you can do with seaborn, +and then check out the :doc:`tutorial ` or :doc:`API reference ` +to find out how. + +To see the code or report a bug, please visit the `GitHub repository +`_. General support questions are most at home +on `stackoverflow `_ or +`discourse `_, which +have dedicated channels for seaborn. .. raw:: html @@ -100,6 +105,12 @@ To see the code or report a bug, please visit the `github repository Tutorial API reference +.. toctree:: + :hidden: + + Citing + Archive + .. raw:: html @@ -113,9 +124,9 @@ To see the code or report a bug, please visit the `github repository
* Relational: :ref:`API ` | :doc:`Tutorial ` +* Distribution: :ref:`API ` | :doc:`Tutorial ` * Categorical: :ref:`API ` | :doc:`Tutorial ` -* Distributions: :ref:`API ` | :doc:`Tutorial ` -* Regressions: :ref:`API ` | :doc:`Tutorial ` +* Regression: :ref:`API ` | :doc:`Tutorial ` * Multiples: :ref:`API ` | :doc:`Tutorial ` * Style: :ref:`API ` | :doc:`Tutorial ` * Color: :ref:`API ` | :doc:`Tutorial ` diff --git a/doc/installing.rst b/doc/installing.rst index 9e2576f503..80ed1e7049 100644 --- a/doc/installing.rst +++ b/doc/installing.rst @@ -9,67 +9,131 @@ Installing and getting started
-To install the latest release of seaborn, you can use ``pip``:: +Official releases of seaborn can be installed from `PyPI `_: pip install seaborn -It's also possible to install the released version using ``conda``:: +The basic invocation of `pip` will install seaborn and, if necessary, its mandatory dependencies. +It is possible to include optional dependencies that give access to a few advanced features: - conda install seaborn - -Alternatively, you can use ``pip`` to install the development version directly from github:: + pip install seaborn[all] - pip install git+https://github.com/mwaskom/seaborn.git +The library is also included as part of the `Anaconda `_ distribution, +and it can be installed with `conda`: -Another option would be to to clone the `github repository -`_ and install from your local copy:: - - pip install . + conda install seaborn Dependencies ~~~~~~~~~~~~ -- Python 2.7 or 3.5+ +Supported Python versions +^^^^^^^^^^^^^^^^^^^^^^^^^ + +- Python 3.7+ Mandatory dependencies ^^^^^^^^^^^^^^^^^^^^^^ -- `numpy `__ (>= 1.9.3) +- `numpy `__ + +- `pandas `__ + +- `matplotlib `__ -- `scipy `__ (>= 0.14.0) +Optional dependencies +^^^^^^^^^^^^^^^^^^^^^ -- `matplotlib `__ (>= 1.4.3) +- `statsmodels `__, for advanced regression plots -- `pandas `__ (>= 0.15.2) +- `scipy `__, for clustering matrices and some advanced options -Recommended dependencies -^^^^^^^^^^^^^^^^^^^^^^^^ +- `fastcluster `__, faster clustering of large matrices -- `statsmodels `__ (>= 0.5.0) +Quickstart +~~~~~~~~~~ -Testing -~~~~~~~ +Once you have seaborn installed, you're ready to get started. +To test it out, you could load and plot one of the example datasets:: -To test seaborn, run ``make test`` in the root directory of the source -distribution. This runs the unit test suite (using ``pytest``, but many older -tests use ``nose`` asserts). It also runs the example code in function -docstrings to smoke-test a broader and more realistic range of example usage. + import seaborn as sns + df = sns.load_dataset("penguins") + sns.pairplot(df, hue="species") -The full set of tests requires an internet connection to download the example -datasets (if they haven't been previously cached), but the unit tests should -be possible to run offline. +If you're working in a Jupyter notebook or an IPython terminal with +`matplotlib mode `_ +enabled, you should immediately see :ref:`the plot `. +Otherwise, you may need to explicitly call :func:`matplotlib.pyplot.show`:: + import matplotlib.pyplot as plt + plt.show() -Bugs -~~~~ +While you can get pretty far with only seaborn imported, having access to +matplotlib functions is often useful. The tutorials and API documentation +typically assume the following imports:: + + import numpy as np + import pandas as pd + import seaborn as sns + import matplotlib.pyplot as plt + +Debugging install issues +~~~~~~~~~~~~~~~~~~~~~~~~ + +The seaborn codebase is pure Python, and the library should generally install +without issue. Occasionally, difficulties will arise because the dependencies +include compiled code and link to system libraries. These difficulties +typically manifest as errors on import with messages such as ``"DLL load +failed"``. To debug such problems, read through the exception trace to +figure out which specific library failed to import, and then consult the +installation docs for that package to see if they have tips for your particular +system. + +In some cases, an installation of seaborn will appear to succeed, but trying +to import it will raise an error with the message ``"No module named +seaborn"``. This usually means that you have multiple Python installations on +your system and that your ``pip`` or ``conda`` points towards a different +installation than where your interpreter lives. Resolving this issue +will involve sorting out the paths on your system, but it can sometimes be +avoided by invoking ``pip`` with ``python -m pip install seaborn``. + +Getting help +~~~~~~~~~~~~ -Please report any bugs you encounter through the github `issue tracker -`_. It will be most helpful to -include a reproducible example on one of the example datasets (accessed through -:func:`load_dataset`). It is difficult debug any issues without knowing the -versions of seaborn and matplotlib you are using, as well as what `matplotlib -backend `__ you -are using to draw the plots, so please include those in your bug report. +If you think you've encountered a bug in seaborn, please report it on the +`GitHub issue tracker `_. +To be useful, bug reports must include the following information: + +- A reproducible code example that demonstrates the problem +- The output that you are seeing (an image of a plot, or the error message) +- A clear explanation of why you think something is wrong +- The specific versions of seaborn and matplotlib that you are working with + +Bug reports are easiest to address if they can be demonstrated using one of the +example datasets from the seaborn docs (i.e. with :func:`load_dataset`). +Otherwise, it is preferable that your example generate synthetic data to +reproduce the problem. If you can only demonstrate the issue with your +actual dataset, you will need to share it, ideally as a csv. + +If you've encountered an error, searching the specific text of the message +before opening a new issue can often help you solve the problem quickly and +avoid making a duplicate report. + +Because matplotlib handles the actual rendering, errors or incorrect outputs +may be due to a problem in matplotlib rather than one in seaborn. It can save time +if you try to reproduce the issue in an example that uses only matplotlib, +so that you can report it in the right place. But it is alright to skip this +step if it's not obvious how to do it. + +General support questions are more at home on either `stackoverflow +`_ or `discourse +`_, which have a larger +audience of people who will see your post and may be able to offer +assistance. StackOverflow is better for specific issues, while discourse is +better for more open-ended discussion. Your chance of getting a quick answer +will be higher if you include `runnable code +`_, a precise +statement of what you are hoping to achieve, and a clear explanation of the +problems that you have encountered. .. raw:: html diff --git a/doc/introduction.ipynb b/doc/introduction.ipynb index 05794e3e65..3f4807a344 100644 --- a/doc/introduction.ipynb +++ b/doc/introduction.ipynb @@ -15,35 +15,14 @@ "\n", "
\n", "\n", - "Seaborn is a library for making statistical graphics in Python. It is built on top of `matplotlib `_ and closely integrated with `pandas `_ data structures.\n", + "Seaborn is a library for making statistical graphics in Python. It builds on top of `matplotlib `_ and integrates closely with `pandas `_ data structures.\n", "\n", - "Here is some of the functionality that seaborn offers:\n", + "Seaborn helps you explore and understand your data. Its plotting functions operate on dataframes and arrays containing whole datasets and internally perform the necessary semantic mapping and statistical aggregation to produce informative plots. Its dataset-oriented, declarative API lets you focus on what the different elements of your plots mean, rather than on the details of how to draw them.\n", "\n", - "- A dataset-oriented API for examining :ref:`relationships ` between :ref:`multiple variables `\n", - "- Specialized support for using categorical variables to show :ref:`observations ` or :ref:`aggregate statistics ` \n", - "- Options for visualizing :ref:`univariate ` or :ref:`bivariate ` distributions and for :ref:`comparing ` them between subsets of data\n", - "- Automatic estimation and plotting of :ref:`linear regression ` models for different kinds :ref:`dependent ` variables\n", - "- Convenient views onto the overall :ref:`structure ` of complex datasets\n", - "- High-level abstractions for structuring :ref:`multi-plot grids ` that let you easily build :ref:`complex ` visualizations\n", - "- Concise control over matplotlib figure styling with several :ref:`built-in themes `\n", - "- Tools for choosing :ref:`color palettes ` that faithfully reveal patterns in your data\n", + "Our first seaborn plot\n", + "----------------------\n", "\n", - "Seaborn aims to make visualization a central part of exploring and understanding data. Its dataset-oriented plotting functions operate on dataframes and arrays containing whole datasets and internally perform the necessary semantic mapping and statistical aggregation to produce informative plots.\n", - "\n", - "Here's an example of what this means:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [ - "hide" - ] - }, - "outputs": [], - "source": [ - "%matplotlib inline" + "Here's an example of what seaborn can do:" ] }, { @@ -52,21 +31,28 @@ "metadata": {}, "outputs": [], "source": [ + "# Import seaborn\n", "import seaborn as sns\n", - "sns.set()\n", + "\n", + "# Apply the default theme\n", + "sns.set_theme()\n", + "\n", + "# Load an example dataset\n", "tips = sns.load_dataset(\"tips\")\n", - "sns.relplot(x=\"total_bill\", y=\"tip\", col=\"time\",\n", - " hue=\"smoker\", style=\"smoker\", size=\"size\",\n", - " data=tips);" + "\n", + "# Create a visualization\n", + "sns.relplot(\n", + " data=tips,\n", + " x=\"total_bill\", y=\"tip\", col=\"time\",\n", + " hue=\"smoker\", style=\"smoker\", size=\"size\",\n", + ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "A few things have happened here. Let's go through them one by one:\n", - "\n", - "1. We import seaborn, which is the only library necessary for this simple example." + "A few things have happened here. Let's go through them one by one:" ] }, { @@ -79,6 +65,7 @@ }, "outputs": [], "source": [ + "# Import seaborn\n", "import seaborn as sns" ] }, @@ -86,9 +73,9 @@ "cell_type": "raw", "metadata": {}, "source": [ - "Behind the scenes, seaborn uses matplotlib to draw plots. Many tasks can be accomplished with only seaborn functions, but further customization might require using matplotlib directly. This is explained in more detail :ref:`below `. For interactive work, it's recommended to use a Jupyter/IPython interface in `matplotlib mode `_, or else you'll have to call :ref:`matplotlib.pyplot.show` when you want to see the plot.\n", + "Seaborn is the only library we need to import for this simple example. By convention, it is imported with the shorthand ``sns``.\n", "\n", - "2. We apply the default default seaborn theme, scaling, and color palette." + "Behind the scenes, seaborn uses matplotlib to draw its plots. For interactive work, it's recommended to use a Jupyter/IPython interface in `matplotlib mode `_, or else you'll have to call :func:`matplotlib.pyplot.show` when you want to see the plot." ] }, { @@ -101,16 +88,15 @@ }, "outputs": [], "source": [ - "sns.set()" + "# Apply the default theme\n", + "sns.set_theme()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "This uses the `matplotlib rcParam system `_ and will affect how all matplotlib plots look, even if you don't make them with seaborn. Beyond the default theme, there are :ref:`several other options `, and you can independently control the style and scaling of the plot to quickly translate your work between presentation contexts (e.g., making a plot that will have readable fonts when projected during a talk). If you like the matplotlib defaults or prefer a different theme, you can skip this step and still use the seaborn plotting functions.\n", - "\n", - "3. We load one of the example datasets." + "This uses the `matplotlib rcParam system `_ and will affect how all matplotlib plots look, even if you don't make them with seaborn. Beyond the default theme, there are :doc:`several other options `, and you can independently control the style and scaling of the plot to quickly translate your work between presentation contexts (e.g., making a version of your figure that will have readable fonts when projected during a talk). If you like the matplotlib defaults or prefer a different theme, you can skip this step and still use the seaborn plotting functions." ] }, { @@ -123,6 +109,7 @@ }, "outputs": [], "source": [ + "# Load an example dataset\n", "tips = sns.load_dataset(\"tips\")" ] }, @@ -130,9 +117,7 @@ "cell_type": "raw", "metadata": {}, "source": [ - "Most code in the docs will use the :func:`load_dataset` function to get quick access to an example dataset. There's nothing particularly special about these datasets; they are just pandas dataframes, and we could have loaded them with :ref:`pandas.read_csv` or build them by hand. Many examples use the \"tips\" dataset, which is very boring but quite useful for demonstration. The tips dataset illustrates the \"tidy\" approach to organizing a dataset. You'll get the most out of seaborn if your datasets are organized this way, and it is explained in more detail :ref:`below `.\n", - "\n", - "4. We draw a faceted scatter plot with multiple semantic variables." + "Most code in the docs will use the :func:`load_dataset` function to get quick access to an example dataset. There's nothing special about these datasets: they are just pandas dataframes, and we could have loaded them with :func:`pandas.read_csv` or built them by hand. Most of the examples in the documentation will specify data using pandas dataframes, but seaborn is very flexible about the :doc:`data structures ` that it accepts." ] }, { @@ -145,27 +130,28 @@ }, "outputs": [], "source": [ - "sns.relplot(x=\"total_bill\", y=\"tip\", col=\"time\",\n", - " hue=\"smoker\", style=\"smoker\", size=\"size\",\n", - " data=tips)" + "# Create a visualization\n", + "sns.relplot(\n", + " data=tips,\n", + " x=\"total_bill\", y=\"tip\", col=\"time\",\n", + " hue=\"smoker\", style=\"smoker\", size=\"size\",\n", + ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "This particular plot shows the relationship between five variables in the tips dataset. Three are numeric, and two are categorical. Two numeric variables (``total_bill`` and ``tip``) determined the position of each point on the axes, and the third (``size``) determined the size of each point. One categorical variable split the dataset onto two different axes (facets), and the other determined the color and shape of each point.\n", - "\n", - "All of this was accomplished using a single call to the seaborn function :func:`relplot`. Notice how we only provided the names of the variables in the dataset and the roles that we wanted them to play in the plot. Unlike when using matplotlib directly, it wasn't necessary to translate the variables into parameters of the visualization (e.g., the specific color or marker to use for each category). That translation was done automatically by seaborn. This lets the user stay focused on the question they want the plot to answer.\n", + "This plot shows the relationship between five variables in the tips dataset using a single call to the seaborn function :func:`relplot`. Notice how we provided only the names of the variables and their roles in the plot. Unlike when using matplotlib directly, it wasn't necessary to specify attributes of the plot elements in terms of the color values or marker codes. Behind the scenes, seaborn handled the translation from values in the dataframe to arguments that matplotlib understands. This declarative approach lets you stay focused on the questions that you want to answer, rather than on the details of how to control matplotlib.\n", "\n", ".. _intro_api_abstraction:\n", "\n", "API abstraction across visualizations\n", "-------------------------------------\n", "\n", - "There is no universal best way to visualize data. Different questions are best answered by different kinds of visualizations. Seaborn tries to make it easy to switch between different visual representations that can be parameterized with the same dataset-oriented API.\n", + "There is no universally best way to visualize data. Different questions are best answered by different plots. Seaborn makes it easy to switch between different visual representations by using a consistent dataset-oriented API.\n", "\n", - "The function :func:`relplot` is named that way because it is designed to visualize many different statistical *relationships*. While scatter plots are a highly effective way of doing this, relationships where one variable represents a measure of time are better represented by a line. The :func:`relplot` function has a convenient ``kind`` parameter to let you easily switch to this alternate representation:" + "The function :func:`relplot` is named that way because it is designed to visualize many different statistical *relationships*. While scatter plots are often effective, relationships where one variable represents a measure of time are better represented by a line. The :func:`relplot` function has a convenient ``kind`` parameter that lets you easily switch to this alternate representation:" ] }, { @@ -175,24 +161,26 @@ "outputs": [], "source": [ "dots = sns.load_dataset(\"dots\")\n", - "sns.relplot(x=\"time\", y=\"firing_rate\", col=\"align\",\n", - " hue=\"choice\", size=\"coherence\", style=\"choice\",\n", - " facet_kws=dict(sharex=False),\n", - " kind=\"line\", legend=\"full\", data=dots);" + "sns.relplot(\n", + " data=dots, kind=\"line\",\n", + " x=\"time\", y=\"firing_rate\", col=\"align\",\n", + " hue=\"choice\", size=\"coherence\", style=\"choice\",\n", + " facet_kws=dict(sharex=False),\n", + ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Notice how the ``size`` and ``style`` parameters are shared across the scatter and line plots, but they affect the two visualizations differently (changing marker area and symbol vs line width and dashing). We did not need to keep those details in mind, letting us focus on the overall structure of the plot and the information we want it to convey.\n", + "Notice how the ``size`` and ``style`` parameters are used in both the scatter and line plots, but they affect the two visualizations differently: changing the marker area and symbol in the scatter plot vs the line width and dashing in the line plot. We did not need to keep those details in mind, letting us focus on the overall structure of the plot and the information we want it to convey.\n", "\n", ".. _intro_stat_estimation:\n", "\n", "Statistical estimation and error bars\n", "-------------------------------------\n", "\n", - "Often we are interested in the average value of one variable as a function of other variables. Many seaborn functions can automatically perform the statistical estimation that is necessary to answer these questions:" + "Often, we are interested in the *average* value of one variable as a function of other variables. Many seaborn functions will automatically perform the statistical estimation that is necessary to answer these questions:" ] }, { @@ -202,9 +190,11 @@ "outputs": [], "source": [ "fmri = sns.load_dataset(\"fmri\")\n", - "sns.relplot(x=\"timepoint\", y=\"signal\", col=\"region\",\n", - " hue=\"event\", style=\"event\",\n", - " kind=\"line\", data=fmri);" + "sns.relplot(\n", + " data=fmri, kind=\"line\",\n", + " x=\"timepoint\", y=\"signal\", col=\"region\",\n", + " hue=\"event\", style=\"event\",\n", + ")" ] }, { @@ -213,7 +203,7 @@ "source": [ "When statistical values are estimated, seaborn will use bootstrapping to compute confidence intervals and draw error bars representing the uncertainty of the estimate.\n", "\n", - "Statistical estimation in seaborn goes beyond descriptive statistics. For example, it is also possible to enhance a scatterplot to include a linear regression model (and its uncertainty) using :func:`lmplot`:" + "Statistical estimation in seaborn goes beyond descriptive statistics. For example, it is possible to enhance a scatterplot by including a linear regression model (and its uncertainty) using :func:`lmplot`:" ] }, { @@ -222,22 +212,20 @@ "metadata": {}, "outputs": [], "source": [ - "sns.lmplot(x=\"total_bill\", y=\"tip\", col=\"time\", hue=\"smoker\",\n", - " data=tips);" + "sns.lmplot(data=tips, x=\"total_bill\", y=\"tip\", col=\"time\", hue=\"smoker\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - ".. _intro_categorical:\n", + ".. _intro_distributions:\n", "\n", - "Specialized categorical plots\n", - "-----------------------------\n", "\n", - "Standard scatter and line plots visualize relationships between numerical variables, but many data analyses involve categorical variables. There are several specialized plot types in seaborn that are optimized for visualizing this kind of data. They can be accessed through :func:`catplot`. Similar to :func:`relplot`, the idea of :func:`catplot` is that it exposes a common dataset-oriented API that generalizes over different representations of the relationship between one numeric variable and one (or more) categorical variables.\n", + "Informative distributional summaries\n", + "------------------------------------\n", "\n", - "These representations offer different levels of granularity in their presentation of the underlying data. At the finest level, you may wish to see every observation by drawing a scatter plot that adjusts the positions of the points along the categorical axis so that they don't overlap:" + "Statistical analyses require knowledge about the distribution of variables in your dataset. The seaborn function :func:`displot` supports several approaches to visualizing distributions. These include classic techniques like histograms and computationally-intensive approaches like kernel density estimation:" ] }, { @@ -246,15 +234,14 @@ "metadata": {}, "outputs": [], "source": [ - "sns.catplot(x=\"day\", y=\"total_bill\", hue=\"smoker\",\n", - " kind=\"swarm\", data=tips);" + "sns.displot(data=tips, x=\"total_bill\", col=\"time\", kde=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Alternately, you could use kernel density estimation to represent the underlying distribution that the points are sampled from:" + "Seaborn also tries to promote techniques that are powerful but less familiar, such as calculating and plotting the empirical cumulative distribution function of the data:" ] }, { @@ -263,15 +250,19 @@ "metadata": {}, "outputs": [], "source": [ - "sns.catplot(x=\"day\", y=\"total_bill\", hue=\"smoker\",\n", - " kind=\"violin\", split=True, data=tips);" + "sns.displot(data=tips, kind=\"ecdf\", x=\"total_bill\", col=\"time\", hue=\"smoker\", rug=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Or you could show the only mean value and its confidence interval within each nested category:" + ".. _intro_categorical:\n", + "\n", + "Specialized plots for categorical data\n", + "--------------------------------------\n", + "\n", + "Several specialized plot types in seaborn are oriented towards visualizing categorical data. They can be accessed through :func:`catplot`. These plots offer different levels of granularity. At the finest level, you may wish to see every observation by drawing a \"swarm\" plot: a scatter plot that adjusts the positions of the points along the categorical axis so that they don't overlap:" ] }, { @@ -280,24 +271,14 @@ "metadata": {}, "outputs": [], "source": [ - "sns.catplot(x=\"day\", y=\"total_bill\", hue=\"smoker\",\n", - " kind=\"bar\", data=tips);" + "sns.catplot(data=tips, kind=\"swarm\", x=\"day\", y=\"total_bill\", hue=\"smoker\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - ".. _intro_func_types:\n", - "\n", - "Figure-level and axes-level functions\n", - "-------------------------------------\n", - "\n", - "How do these tools work? It's important to know about a major distinction between seaborn plotting functions. All of the plots shown so far have been made with \"figure-level\" functions. These are optimized for exploratory analysis because they set up the matplotlib figure containing the plot(s) and make it easy to spread out the visualization across multiple axes. They also handle some tricky business like putting the legend outside the axes. To do these things, they use a seaborn :class:`FacetGrid`.\n", - "\n", - "Each different figure-level plot ``kind`` combines a particular \"axes-level\" function with the :class:`FacetGrid` object. For example, the scatter plots are drawn using the :func:`scatterplot` function, and the bar plots are drawn using the :func:`barplot` function. These functions are called \"axes-level\" because they draw onto a single matplotlib axes and don't otherwise affect the rest of the figure.\n", - "\n", - "The upshot is that the figure-level function needs to control the figure it lives in, while axes-level functions can be combined into a more complex matplotlib figure with other axes that may or may not have seaborn plots on them:" + "Alternately, you could use kernel density estimation to represent the underlying distribution that the points are sampled from:" ] }, { @@ -306,17 +287,14 @@ "metadata": {}, "outputs": [], "source": [ - "import matplotlib.pyplot as plt\n", - "f, axes = plt.subplots(1, 2, sharey=True, figsize=(6, 4))\n", - "sns.boxplot(x=\"day\", y=\"tip\", data=tips, ax=axes[0])\n", - "sns.scatterplot(x=\"total_bill\", y=\"tip\", hue=\"day\", data=tips, ax=axes[1]);" + "sns.catplot(data=tips, kind=\"violin\", x=\"day\", y=\"total_bill\", hue=\"smoker\", split=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Controlling the size of the figure-level functions works a little bit differently than it does for other matplotlib figures. Instead of setting the overall figure size, the figure-level functions are parameterized by the size of each facet. And instead of setting the height and width of each facet, you control the height and *aspect* ratio (ratio of width to height). This parameterization makes it easy to control the size of the graphic without thinking about exactly how many rows and columns it will have, although it can be a source of confusion:" + "Or you could show only the mean value and its confidence interval within each nested category:" ] }, { @@ -325,26 +303,19 @@ "metadata": {}, "outputs": [], "source": [ - "sns.relplot(x=\"time\", y=\"firing_rate\", col=\"align\",\n", - " hue=\"choice\", size=\"coherence\", style=\"choice\",\n", - " height=4.5, aspect=2 / 3,\n", - " facet_kws=dict(sharex=False),\n", - " kind=\"line\", legend=\"full\", data=dots);" + "sns.catplot(data=tips, kind=\"bar\", x=\"day\", y=\"total_bill\", hue=\"smoker\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "The way you can tell whether a function is \"figure-level\" or \"axes-level\" is whether it takes an ``ax=`` parameter. You can also distinguish the two classes by their output type: axes-level functions return the matplotlib ``axes``, while figure-level functions return the :class:`FacetGrid`.\n", - "\n", - "\n", ".. _intro_dataset_funcs:\n", "\n", - "Visualizing dataset structure\n", - "-----------------------------\n", + "Composite views onto multivariate datasets\n", + "------------------------------------------\n", "\n", - "There are two other kinds of figure-level functions in seaborn that can be used to make visualizations with multiple plots. They are each oriented towards illuminating the structure of a dataset. One, :func:`jointplot`, focuses on a single relationship:" + "Some seaborn functions combine multiple kinds of plots to quickly give informative summaries of a dataset. One, :func:`jointplot`, focuses on a single relationship. It plots the joint distribution between two variables along with each variable's marginal distribution:" ] }, { @@ -353,15 +324,15 @@ "metadata": {}, "outputs": [], "source": [ - "iris = sns.load_dataset(\"iris\")\n", - "sns.jointplot(x=\"sepal_length\", y=\"petal_length\", data=iris);" + "penguins = sns.load_dataset(\"penguins\")\n", + "sns.jointplot(data=penguins, x=\"flipper_length_mm\", y=\"bill_length_mm\", hue=\"species\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "The other, :func:`pairplot`, takes a broader view, showing all pairwise relationships and the marginal distributions, optionally conditioned on a categorical variable :" + "The other, :func:`pairplot`, takes a broader view: it shows joint and marginal distributions for all pairwise relationships and for each variable, respectively:" ] }, { @@ -370,23 +341,19 @@ "metadata": {}, "outputs": [], "source": [ - "sns.pairplot(data=iris, hue=\"species\");" + "sns.pairplot(data=penguins, hue=\"species\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Both :func:`jointplot` and :func:`pairplot` have a few different options for visual representation, and they are built on top of classes that allow more thoroughly customized multi-plot figures (:class:`JointGrid` and :class:`PairGrid`, respectively).\n", - "\n", - ".. _intro_plot_customization:\n", + ".. _intro_figure_classes:\n", "\n", - "Customizing plot appearance\n", - "---------------------------\n", + "Classes and functions for making complex graphics\n", + "-------------------------------------------------\n", "\n", - "The plotting functions try to use good default aesthetics and add informative labels so that their output is immediately useful. But defaults can only go so far, and creating a fully-polished custom plot will require additional steps. Several levels of additional customization are possible. \n", - "\n", - "The first way is to use one of the alternate seaborn themes to give your plots a different look. Setting a different theme or color palette will make it take effect for all plots:" + "These tools work by combining :doc:`axes-level ` plotting functions with objects that manage the layout of the figure, linking the structure of a dataset to a :doc:`grid of axes `. Both elements are part of the public API, and you can use them directly to create complex figures with only a few more lines of code:" ] }, { @@ -395,19 +362,26 @@ "metadata": {}, "outputs": [], "source": [ - "sns.set(style=\"ticks\", palette=\"muted\")\n", - "sns.relplot(x=\"total_bill\", y=\"tip\", col=\"time\",\n", - " hue=\"smoker\", style=\"smoker\", size=\"size\",\n", - " data=tips);" + "g = sns.PairGrid(penguins, hue=\"species\", corner=True)\n", + "g.map_lower(sns.kdeplot, hue=None, levels=5, color=\".2\")\n", + "g.map_lower(sns.scatterplot, marker=\"+\")\n", + "g.map_diag(sns.histplot, element=\"step\", linewidth=0, kde=True)\n", + "g.add_legend(frameon=True)\n", + "g.legend.set_bbox_to_anchor((.61, .6))" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "For figure-specific customization, all seaborn functions accept a number of optional parameters for switching to non-default semantic mappings, such as different colors. (Appropriate use of color is critical for effective data visualization, and seaborn has :ref:`extensive support ` for customizing color palettes).\n", + ".. _intro_defaults:\n", + "\n", + "Opinionated defaults and flexible customization\n", + "-----------------------------------------------\n", "\n", - "Finally, where there is a direct correspondence with an underlying matplotlib function (like :func:`scatterplot` and ``plt.scatter``), additional keyword arguments will be passed through to the matplotlib layer:" + "Seaborn creates complete graphics with a single function call: when possible, its functions will automatically add informative axis labels and legends that explain the semantic mappings in the plot.\n", + "\n", + "In many cases, seaborn will also choose default values for its parameters based on characteristics of the data. For example, the :doc:`color mappings ` that we have seen so far used distinct hues (blue, orange, and sometimes green) to represent different levels of the categorical variables assigned to ``hue``. When mapping a numeric variable, some functions will switch to a continuous gradient:" ] }, { @@ -416,20 +390,17 @@ "metadata": {}, "outputs": [], "source": [ - "sns.relplot(x=\"total_bill\", y=\"tip\", col=\"time\",\n", - " hue=\"size\", style=\"smoker\", size=\"size\",\n", - " palette=\"YlGnBu\", markers=[\"D\", \"o\"], sizes=(10, 125),\n", - " edgecolor=\".2\", linewidth=.5, alpha=.75,\n", - " data=tips);" + "sns.relplot(\n", + " data=penguins,\n", + " x=\"bill_length_mm\", y=\"bill_depth_mm\", hue=\"body_mass_g\"\n", + ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "In the case of :func:`relplot` and other figure-level functions, that means there are a few levels of indirection because :func:`relplot` passes its exta keyword arguments to the underlying seaborn axes-level function, which passes *its* extra keyword arguments to the underlying matplotlib function. So it might take some effort to find the right documentation for the parameters you'll need to use, but in principle an extremely detailed level of customization is possible.\n", - "\n", - "Some customization of figure-level functions can be accomplished through additional parameters that get passed to :class:`FacetGrid`, and you can use the methods on that object to control many other properties of the figure. For even more tweaking, you can access the matplotlib objects that the plot is drawn onto, which are stored as attributes:" + "When you're ready to share or publish your work, you'll probably want to polish the figure beyond what the defaults achieve. Seaborn allows for several levels of customization. It defines multiple built-in :doc:`themes ` that apply to all figures, its functions have standardized parameters that can modify the semantic mappings for each plot, and additional keyword arguments are passed down to the underlying matplotlib artsts, allowing even more control. Once you've created a plot, its properties can be modified through both the seaborn API and by dropping down to the matplotlib layer for fine-grained tweaking:" ] }, { @@ -438,74 +409,45 @@ "metadata": {}, "outputs": [], "source": [ - "g = sns.catplot(x=\"total_bill\", y=\"day\", hue=\"time\",\n", - " height=3.5, aspect=1.5,\n", - " kind=\"box\", legend=False, data=tips);\n", - "g.add_legend(title=\"Meal\")\n", - "g.set_axis_labels(\"Total bill ($)\", \"\")\n", - "g.set(xlim=(0, 60), yticklabels=[\"Thursday\", \"Friday\", \"Saturday\", \"Sunday\"])\n", - "g.despine(trim=True)\n", - "g.fig.set_size_inches(6.5, 3.5)\n", - "g.ax.set_xticks([5, 15, 25, 35, 45, 55], minor=True);\n", - "plt.setp(g.ax.get_yticklabels(), rotation=30);" + "sns.set_theme(style=\"ticks\", font_scale=1.25)\n", + "g = sns.relplot(\n", + " data=penguins,\n", + " x=\"bill_length_mm\", y=\"bill_depth_mm\", hue=\"body_mass_g\",\n", + " palette=\"crest\", marker=\"x\", s=100,\n", + ")\n", + "g.set_axis_labels(\"Bill length (mm)\", \"Bill depth (mm)\", labelpad=10)\n", + "g.legend.set_title(\"Body mass (g)\")\n", + "g.fig.set_size_inches(6.5, 4.5)\n", + "g.ax.margins(.15)\n", + "g.despine(trim=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Because the figure-level functions are oriented towards efficient exploration, using them to manage a figure that you need to be precisely sized and organized may take more effort than setting up the figure directly in matplotlib and using the corresponding axes-level seaborn function. Matplotlib has a comprehensive and powerful API; just about any attribute of the figure can be changed to your liking. The hope is that a combination of seaborn's high-level interface and matplotlib's deep customizability will allow you to quickly explore your data and create graphics that can be tailored into a `publication quality `_ final product.\n", + ".. _intro_matplotlib:\n", "\n", - ".. _intro_tidy_data:\n", + "Relationship to matplotlib\n", + "--------------------------\n", "\n", - "Organizing datasets\n", - "-------------------\n", + "Seaborn's integration with matplotlib allows you to use it across the many environments that matplotlib supports, inlcuding exploratory analysis in notebooks, real-time interaction in GUI applications, and archival output in a number of raster and vector formats.\n", "\n", - "As mentioned above, seaborn will be most powerful when your datasets have a particular organization. This format is alternately called \"long-form\" or \"tidy\" data and is described in detail by Hadley Wickham in this `academic paper `_. The rules can be simply stated:\n", + "While you can be productive using only seaborn functions, full customization of your graphics will require some knowledge of matplotlib's concepts and API. One aspect of the learning curve for new users of seaborn will be knowing when dropping down to the matplotlib layer is necessary to achieve a particular customization. On the other hand, users coming from matplotlib will find that much of their knowledge transfers.\n", "\n", - "1. Each variable is a column\n", - "2. Each observation is a row\n", - "\n", - "A helpful mindset for determining whether your data are tidy is to think backwards from the plot you want to draw. From this perspective, a \"variable\" is something that will be assigned a role in the plot. It may be useful to look at the example datasets and see how they are structured. For example, the first five rows of the \"tips\" dataset look like this:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "tips.head()" + "Matplotlib has a comprehensive and powerful API; just about any attribute of the figure can be changed to your liking. A combination of seaborn's high-level interface and matplotlib's deep customizability will allow you both to quickly explore your data and to create graphics that can be tailored into a `publication quality `_ final product." ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "In some domains, the tidy format might feel awkward at first. Timeseries data, for example, are sometimes stored with every timepoint as part of the same observational unit and appearing in the columns. The \"fmri\" dataset that we used :ref:`above ` illustrates how a tidy timeseries dataset has each timepoint in a different row:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fmri.head()" - ] - }, - { - "cell_type": "raw", - "metadata": {}, - "source": [ - "Many seaborn functions can plot wide-form data, but only with limited functionality. To take advantage of the features that depend on tidy-formatted data, you'll likely find the ``pandas.melt`` function useful for \"un-pivoting\" a wide-form dataframe. More information and useful examples can be found `in this blog post `_ by one of the pandas developers.\n", - "\n", ".. _intro_next_steps:\n", "\n", "Next steps\n", "----------\n", "\n", - "You have a few options for where to go next. You might first want to learn how to :ref:`install seaborn `. Once that's done, you can browse the :ref:`example gallery ` to get a broader sense for what kind of graphics seaborn can produce. Or you can read through the :ref:`official tutorial ` for a deeper discussion of the different tools and what they are designed to accomplish. If you have a specific plot in mind and want to know how to make it, you could check out the :ref:`API reference `, which documents each function's parameters and shows many examples to illustrate usage." + "You have a few options for where to go next. You might first want to learn how to :doc:`install seaborn `. Once that's done, you can browse the :doc:`example gallery ` to get a broader sense for what kind of graphics seaborn can produce. Or you can read through the :doc:`user guide and tutorial ` for a deeper discussion of the different tools and what they are designed to accomplish. If you have a specific plot in mind and want to know how to make it, you could check out the :doc:`API reference `, which documents each function's parameters and shows many examples to illustrate usage." ] }, { @@ -521,9 +463,9 @@ "metadata": { "celltoolbar": "Tags", "kernelspec": { - "display_name": "Python 3.6 (seaborn-dev)", + "display_name": "seaborn-py38-latest", "language": "python", - "name": "seaborn-dev" + "name": "seaborn-py38-latest" }, "language_info": { "codemirror_mode": { @@ -535,9 +477,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.3" + "version": "3.8.5" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/doc/matplotlibrc b/doc/matplotlibrc new file mode 100644 index 0000000000..67a95bbfd0 --- /dev/null +++ b/doc/matplotlibrc @@ -0,0 +1 @@ +savefig.bbox : tight diff --git a/doc/releases/v0.10.0.txt b/doc/releases/v0.10.0.txt new file mode 100644 index 0000000000..3340f5b215 --- /dev/null +++ b/doc/releases/v0.10.0.txt @@ -0,0 +1,28 @@ + +v0.10.0 (January 2020) +---------------------- + +This is a major update that is being released simultaneously with version 0.9.1. It has all of the same features (and bugs!) as 0.9.1, but there are important changes to the dependencies. + +Most notably, all support for Python 2 has now been dropped. Support for Python 3.5 has also been dropped. Seaborn is now strictly compatible with Python 3.6+. + +Minimally supported versions of the dependent PyData libraries have also been increased, in some cases substantially. While seaborn has tended to be very conservative about maintaining compatibility with older dependencies, this was causing increasing pain during development. At the same time, these libraries are now much easier to install. Going forward, seaborn will likely stay close to the `Numpy community guidelines `_ for version support. + +This release also removes a few previously-deprecated features: + +- The ``tsplot`` function and ``seaborn.timeseries`` module have been removed. Recall that ``tsplot`` was replaced with :func:`lineplot`. + +- The ``seaborn.apionly`` entry-point has been removed. + +- The ``seaborn.linearmodels`` module (previously renamed to ``seaborn.regression``) has been removed. + +Looking forward +~~~~~~~~~~~~~~~ + +Now that seaborn is a Python 3 library, it can take advantage of `keyword-only arguments `_. It is likely that future versions will introduce this syntax, potentially in a breaking way. For guidance, most seaborn functions have a signature that looks like + +:: + + func(x, y, ..., data=None, **kwargs) + +where the ``**kwargs`` are specified in the function. Going forward it will likely be necessary to specify ``data`` and all subsequent arguments with an explicit ``key=value`` mapping. This style has long been used throughout the documentation, and the formal requirement will not be introduced until at least the next major release. Adding this feature will make it possible to enhance some older functions with more modern capabilities (e.g., adding a native ``hue`` semantic within functions like :func:`jointplot` and :func:`regplot`) and will allow parameters that control new features to be situated nearby related, making them more discoverable. diff --git a/doc/releases/v0.10.1.txt b/doc/releases/v0.10.1.txt new file mode 100644 index 0000000000..fc7622446d --- /dev/null +++ b/doc/releases/v0.10.1.txt @@ -0,0 +1,25 @@ + +v0.10.1 (April 2020) +-------------------- + +This is minor release with bug fixes for issues identified since 0.10.0. + +- Fixed a bug that appeared within the bootstrapping algorithm on 32-bit systems. + +- Fixed a bug where :func:`regplot` would crash on singleton inputs. Now a crash is avoided and regression estimation/plotting is skipped. + +- Fixed a bug where :func:`heatmap` would ignore user-specified under/over/bad values when recentering a colormap. + +- Fixed a bug where :func:`heatmap` would use values from masked cells when computing default colormap limits. + +- Fixed a bug where :func:`despine` would cause an error when trying to trim spines on a matplotlib categorical axis. + +- Adapted to a change in matplotlib that caused problems with single swarm plots. + +- Added the ``showfliers`` parameter to :func:`boxenplot` to suppress plotting of outlier data points, matching the API of :func:`boxplot`. + +- Avoided seeing an error from statmodels when data with an IQR of 0 is passed to :func:`kdeplot`. + +- Added the ``legend.title_fontsize`` to the :func:`plotting_context` definition. + +- Deprecated several utility functions that are no longer used internally (``percentiles``, ``sig_stars``, ``pmf_hist``, and ``sort_df``). diff --git a/doc/releases/v0.11.0.txt b/doc/releases/v0.11.0.txt new file mode 100644 index 0000000000..a88e25a7cd --- /dev/null +++ b/doc/releases/v0.11.0.txt @@ -0,0 +1,212 @@ + +v0.11.0 (September 2020) +------------------------ + +This is a major release with several important new features, enhancements to existing functions, and changes to the library. Highlights include an overhaul and modernization of the distributions plotting functions, more flexible data specification, new colormaps, and better narrative documentation. + +For an overview of the new features and a guide to updating, see `this Medium post `_. + +Required keyword arguments +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +|API| + +Most plotting functions now require all of their parameters to be specified using keyword arguments. To ease adaptation, code without keyword arguments will trigger a ``FutureWarning`` in v0.11. In a future release (v0.12 or v0.13, depending on release cadence), this will become an error. Once keyword arguments are fully enforced, the signature of the plotting functions will be reorganized to accept ``data`` as the first and only positional argument (:pr:`2052,2081`). + +Modernization of distribution functions +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The distribution module has been completely overhauled, modernizing the API and introducing several new functions and features within existing functions. Some new features are explained here; the :doc:`tutorial documentation ` has also been rewritten and serves as a good introduction to the functions. + +New plotting functions +^^^^^^^^^^^^^^^^^^^^^^ + +|Feature| |Enhancement| + +First, three new functions, :func:`displot`, :func:`histplot` and :func:`ecdfplot` have been added (:pr:`2157`, :pr:`2125`, :pr:`2141`). + +The figure-level :func:`displot` function is an interface to the various distribution plots (analogous to :func:`relplot` or :func:`catplot`). It can draw univariate or bivariate histograms, density curves, ECDFs, and rug plots on a :class:`FacetGrid`. + +The axes-level :func:`histplot` function draws univariate or bivariate histograms with a number of features, including: + +- mapping multiple distributions with a ``hue`` semantic +- normalization to show density, probability, or frequency statistics +- flexible parameterization of bin size, including proper bins for discrete variables +- adding a KDE fit to show a smoothed distribution over all bin statistics +- experimental support for histograms over categorical and datetime variables. + +The axes-level :func:`ecdfplot` function draws univariate empirical cumulative distribution functions, using a similar interface. + +Changes to existing functions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +|API| |Feature| |Enhancement| |Defaults| + +Second, the existing functions :func:`kdeplot` and :func:`rugplot` have been completely overhauled (:pr:`2060,2104`). + +The overhauled functions now share a common API with the rest of seaborn, they can show conditional distributions by mapping a third variable with a ``hue`` semantic, and they have been improved in numerous other ways. The github pull request (:pr:`2104`) has a longer explanation of the changes and the motivation behind them. + +This is a necessarily API-breaking change. The parameter names for the positional variables are now ``x`` and ``y``, and the old names have been deprecated. Efforts were made to handle and warn when using the deprecated API, but it is strongly suggested to check your plots carefully. + +Additionally, the statsmodels-based computation of the KDE has been removed. Because there were some inconsistencies between the way different parameters (specifically, ``bw``, ``clip``, and ``cut``) were implemented by each backend, this may cause plots to look different with non-default parameters. Support for using non-Gaussian kernels, which was available only in the statsmodels backend, has been removed. + +Other new features include: + +- several options for representing multiple densities (using the ``multiple`` and ``common_norm`` parameters) +- weighted density estimation (using the new ``weights`` parameter) +- better control over the smoothing bandwidth (using the new ``bw_adjust`` parameter) +- more meaningful parameterization of the contours that represent a bivariate density (using the ``thresh`` and ``levels`` parameters) +- log-space density estimation (using the new ``log_scale`` parameter, or by scaling the data axis before plotting) +- "bivariate" rug plots with a single function call (by assigning both ``x`` and ``y``) + +Deprecations +^^^^^^^^^^^^ + +|API| + +Finally, the :func:`distplot` function is now formally deprecated. Its features have been subsumed by :func:`displot` and :func:`histplot`. Some effort was made to gradually transition :func:`distplot` by adding the features in :func:`displot` and handling backwards compatibility, but this proved to be too difficult. The similarity in the names will likely cause some confusion during the transition, which is regrettable. + +Related enhancements and changes +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +|API| |Feature| |Enhancement| |Defaults| + +These additions facilitated new features (and forced changes) in :func:`jointplot` and :class:`JointGrid` (:pr:`2210`) and in :func:`pairplot` and :class:`PairGrid` (:pr:`2234`). + +- Added support for the ``hue`` semantic in :func:`jointplot`/:class:`JointGrid`. This support is lightweight and simply delegates the mapping to the underlying axes-level functions. + +- Delegated the handling of ``hue`` in :class:`PairGrid`/:func:`pairplot` to the plotting function when it understands ``hue``, meaning that (1) the zorder of scatterplot points will be determined by row in dataframe, (2) additional options for resolving hue (e.g. the ``multiple`` parameter) can be used, and (3) numeric hue variables can be naturally mapped when using :func:`scatterplot`. + +- Added ``kind="hist"`` to :func:`jointplot`, which draws a bivariate histogram on the joint axes and univariate histograms on the marginal axes, as well as both ``kind="hist"`` and ``kind="kde"`` to :func:`pairplot`, which behaves likewise. + +- The various modes of :func:`jointplot` that plot marginal histograms now use :func:`histplot` rather than :func:`distplot`. This slightly changes the default appearance and affects the valid keyword arguments that can be passed to customize the plot. Likewise, the marginal histogram plots in :func:`pairplot` now use :func:`histplot`. + +Standardization and enhancements of data ingest +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +|Feature| |Enhancement| |Docs| + +The code that processes input data has been refactored and enhanced. In v0.11, this new code takes effect for the relational and distribution modules; other modules will be refactored to use it in future releases (:pr:`2071`). + +These changes should be transparent for most use-cases, although they allow a few new features: + +- Named variables for long-form data can refer to the named index of a :class:`pandas.DataFrame` or to levels in the case of a multi-index. Previously, it was necessary to call :meth:`pandas.DataFrame.reset_index` before using index variables (e.g., after a groupby operation). +- :func:`relplot` now has the same flexibility as the axes-level functions to accept data in long- or wide-format and to accept data vectors (rather than named variables) in long-form mode. +- The data parameter can now be a Python ``dict`` or an object that implements that interface. This is a new feature for wide-form data. For long-form data, it was previously supported but not documented. +- A wide-form data object can have a mixture of types; the non-numeric types will be removed before plotting. Previously, this caused an error. +- There are better error messages for other instances of data mis-specification. + +See the new user guide chapter on :doc:`data formats ` for more information about what is supported. + +Other changes +~~~~~~~~~~~~~ + +Documentation improvements +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +- |Docs| Added two new chapters to the user guide, one giving an overview of the :doc:`types of functions in seaborn `, and one discussing the different :doc:`data formats ` that seaborn understands. + +- |Docs| Expanded the :doc:`color palette tutorial ` to give more background on color theory and better motivate the use of color in statistical graphics. + +- |Docs| Added more information to the :doc:`installation guidelines ` and streamlined the :doc:`introduction ` page. + +- |Docs| Improved cross-linking within the seaborn docs and between the seaborn and matplotlib docs. + +Theming +^^^^^^^ + +- |API| The :func:`set` function has been renamed to :func:`set_theme` for more clarity about what it does. For the foreseeable future, :func:`set` will remain as an alias, but it is recommended to update your code. + +Relational plots +^^^^^^^^^^^^^^^^ + +- |Enhancement| |Defaults| Reduced some of the surprising behavior of relational plot legends when using a numeric hue or size mapping (:pr:`2229`): + + - Added an "auto" mode (the new default) that chooses between "brief" and "full" legends based on the number of unique levels of each variable. + - Modified the ticking algorithm for a "brief" legend to show up to 6 values and not to show values outside the limits of the data. + - Changed the approach to the legend title: the normal matplotlib legend title is used when only one variable is assigned a semantic mapping, whereas the old approach of adding an invisible legend artist with a subtitle label is used only when multiple semantic variables are defined. + - Modified legend subtitles to be left-aligned and to be drawn in the default legend title font size. + +- |Enhancement| |Defaults| Changed how functions that use different representations for numeric and categorical data handle vectors with an ``object`` data type. Previously, data was considered numeric if it could be coerced to a float representation without error. Now, object-typed vectors are considered numeric only when their contents are themselves numeric. As a consequence, numbers that are encoded as strings will now be treated as categorical data (:pr:`2084`). + +- |Enhancement| |Defaults| Plots with a ``style`` semantic can now generate an infinite number of unique dashes and/or markers by default. Previously, an error would be raised if the ``style`` variable had more levels than could be mapped using the default lists. The existing defaults were slightly modified as part of this change; if you need to exactly reproduce plots from earlier versions, refer to the `old defaults `_ (:pr:`2075`). + +- |Defaults| Changed how :func:`scatterplot` sets the default linewidth for the edges of the scatter points. New behavior is to scale with the point sizes themselves (on a plot-wise, not point-wise basis). This change also slightly reduces the default width when point sizes are not varied. Set ``linewidth=0.75`` to reproduce the previous behavior. (:pr:`2708`). + +- |Enhancement| Improved support for datetime variables in :func:`scatterplot` and :func:`lineplot` (:pr:`2138`). + +- |Fix| Fixed a bug where :func:`lineplot` did not pass the ``linestyle`` parameter down to matplotlib (:pr:`2095`). + +- |Fix| Adapted to a change in matplotlib that prevented passing vectors of literal values to ``c`` and ``s`` in :func:`scatterplot` (:pr:`2079`). + +Categorical plots +^^^^^^^^^^^^^^^^^ + +- |Enhancement| |Defaults| |Fix| Fixed a few computational issues in :func:`boxenplot` and improved its visual appearance (:pr:`2086`): + + - Changed the default method for computing the number of boxes to``k_depth="tukey"``, as the previous default (``k_depth="proportion"``) is based on a heuristic that produces too many boxes for small datasets. + - Added the option to specify the specific number of boxes (e.g. ``k_depth=6``) or to plot boxes that will cover most of the data points (``k_depth="full"``). + - Added a new parameter, ``trust_alpha``, to control the number of boxes when ``k_depth="trustworthy"``. + - Changed the visual appearance of :func:`boxenplot` to more closely resemble :func:`boxplot`. Notably, thin boxes will remain visible when the edges are white. + +- |Enhancement| Allowed :func:`catplot` to use different values on the categorical axis of each facet when axis sharing is turned off (e.g. by specifying ``sharex=False``) (:pr:`2196`). + +- |Enhancement| Improved the error messages produced when categorical plots process the orientation parameter. + +- |Enhancement| Added an explicit warning in :func:`swarmplot` when more than 5% of the points overlap in the "gutters" of the swarm (:pr:`2045`). + +Multi-plot grids +^^^^^^^^^^^^^^^^ + +- |Feature| |Enhancement| |Defaults| A few small changes to make life easier when using :class:`PairGrid` (:pr:`2234`): + + - Added public access to the legend object through the ``legend`` attribute (also affects :class:`FacetGrid`). + - The ``color`` and ``label`` parameters are no longer passed to the plotting functions when ``hue`` is not used. + - The data is no longer converted to a numpy object before plotting on the marginal axes. + - It is possible to specify only one of ``x_vars`` or ``y_vars``, using all variables for the unspecified dimension. + - The ``layout_pad`` parameter is stored and used every time you call the :meth:`PairGrid.tight_layout` method. + +- |Feature| Added a ``tight_layout`` method to :class:`FacetGrid` and :class:`PairGrid`, which runs the :func:`matplotlib.pyplot.tight_layout` algorithm without interference from the external legend (:pr:`2073`). + +- |Feature| Added the ``axes_dict`` attribute to :class:`FacetGrid` for named access to the component axes (:pr:`2046`). + +- |Enhancement| Made :meth:`FacetGrid.set_axis_labels` clear labels from "interior" axes (:pr:`2046`). + +- |Feature| Added the ``marginal_ticks`` parameter to :class:`JointGrid` which, if set to ``True``, will show ticks on the count/density axis of the marginal plots (:pr:`2210`). + +- |Enhancement| Improved :meth:`FacetGrid.set_titles` with ``margin_titles=True``, such that texts representing the original row titles are removed before adding new ones (:pr:`2083`). + +- |Defaults| Changed the default value for ``dropna`` to ``False`` in :class:`FacetGrid`, :class:`PairGrid`, :class:`JointGrid`, and corresponding functions. As all or nearly all seaborn and matplotlib plotting functions handle missing data well, this option is no longer useful, but it causes problems in some edge cases. It may be deprecated in the future. (:pr:`2204`). + +- |Fix| Fixed a bug in :class:`PairGrid` that appeared when setting ``corner=True`` and ``despine=False`` (:pr:`2203`). + +Color palettes +~~~~~~~~~~~~~~ + +- |Docs| Improved and modernized the :doc:`color palettes chapter ` of the seaborn tutorial. + +- |Feature| Added two new perceptually-uniform colormaps: "flare" and "crest". The new colormaps are similar to "rocket" and "mako", but their luminance range is reduced. This makes them well suited to numeric mappings of line or scatter plots, which need contrast with the axes background at the extremes (:pr:`2237`). + +- |Enhancement| |Defaults| Enhanced numeric colormap functionality in several ways (:pr:`2237`): + + - Added string-based access within the :func:`color_palette` interface to :func:`dark_palette`, :func:`light_palette`, and :func:`blend_palette`. This means that anywhere you specify a palette in seaborn, a name like ``"dark:blue"`` will use :func:`dark_palette` with the input ``"blue"``. + - Added the ``as_cmap`` parameter to :func:`color_palette` and changed internal code that uses a continuous colormap to take this route. + - Tweaked the :func:`light_palette` and :func:`dark_palette` functions to use an endpoint that is a very desaturated version of the input color, rather than a pure gray. This produces smoother ramps. To exactly reproduce previous plots, use :func:`blend_palette` with ``".13"`` for dark or ``".95"`` for light. + - Changed :func:`diverging_palette` to have a default value of ``sep=1``, which gives better results. + +- |Enhancement| Added a rich HTML representation to the object returned by :func:`color_palette` (:pr:`2225`). + +- |Fix| Fixed the ``"{palette}_d"`` logic to modify reversed colormaps and to use the correct direction of the luminance ramp in both cases. + +Deprecations and removals +^^^^^^^^^^^^^^^^^^^^^^^^^ + +- |Enhancement| Removed an optional (and undocumented) dependency on BeautifulSoup (:pr:`2190`) in :func:`get_dataset_names`. + +- |API| Deprecated the ``axlabel`` function; use ``ax.set(xlabel=, ylabel=)`` instead. + +- |API| Deprecated the ``iqr`` function; use :func:`scipy.stats.iqr` instead. + +- |API| Final removal of the previously-deprecated ``annotate`` method on :class:`JointGrid`, along with related parameters. + +- |API| Final removal of the ``lvplot`` function (the previously-deprecated name for :func:`boxenplot`). diff --git a/doc/releases/v0.11.1.txt b/doc/releases/v0.11.1.txt new file mode 100644 index 0000000000..20fe08108c --- /dev/null +++ b/doc/releases/v0.11.1.txt @@ -0,0 +1,37 @@ + +v0.11.1 (December 2020) +----------------------- + +This a bug fix release and is a recommended upgrade for all users on v0.11.0. + +- |Enhancement| Reduced the use of matplotlib global state in the :ref:`multi-grid classes ` (:pr:`2388`). + +- |Fix| Restored support for using tuples or numeric keys to reference fields in a long-form `data` object (:pr:`2386`). + +- |Fix| Fixed a bug in :func:`lineplot` where NAs were propagating into the confidence interval, sometimes erasing it from the plot (:pr:`2273`). + +- |Fix| Fixed a bug in :class:`PairGrid`/:func:`pairplot` where diagonal axes would be empty when the grid was not square and the diagonal axes did not contain the marginal plots (:pr:`2270`). + +- |Fix| Fixed a bug in :class:`PairGrid`/:func:`pairplot` where off-diagonal plots would not appear when column names in `data` had non-string type (:pr:`2368`). + +- |Fix| Fixed a bug where categorical dtype information was ignored when data consisted of boolean or boolean-like values (:pr:`2379`). + +- |Fix| Fixed a bug in :class:`FacetGrid` where interior tick labels would be hidden when only the orthogonal axis was shared (:pr:`2347`). + +- |Fix| Fixed a bug in :class:`FacetGrid` that caused an error when `legend_out=False` was set (:pr:`2304`). + +- |Fix| Fixed a bug in :func:`kdeplot` where ``common_norm=True`` was ignored if ``hue`` was not assigned (:pr:`2378`). + +- |Fix| Fixed a bug in :func:`displot` where the ``row_order`` and ``col_order`` parameters were not used (:pr:`2262`). + +- |Fix| Fixed a bug in :class:`PairGrid`/:func:`pairplot` that caused an exception when using `corner=True` and `diag_kind=None` (:pr:`2382`). + +- |Fix| Fixed a bug in :func:`clustermap` where `annot=False` was ignored (:pr:`2323`). + +- |Fix| Fixed a bug in :func:`clustermap` where row/col color annotations could not have a categorical dtype (:pr:`2389`). + +- |Fix| Fixed a bug in :func:`boxenplot` where the `linewidth` parameter was ignored (:pr:`2287`). + +- |Fix| Raise a more informative error in :class:`PairGrid`/:func:`pairplot` when no variables can be found to define the rows/columns of the grid (:pr:`2382`). + +- |Fix| Raise a more informative error from :func:`clustermap` if row/col color objects have semantic index but data object does not (:pr:`2313`). \ No newline at end of file diff --git a/doc/releases/v0.12.0.txt b/doc/releases/v0.12.0.txt new file mode 100644 index 0000000000..fba59d1750 --- /dev/null +++ b/doc/releases/v0.12.0.txt @@ -0,0 +1,45 @@ + +v0.12.0 (Unreleased) +-------------------- + +A paper describing seaborn was published in the `Journal of Open Source Software `_. The paper serves as an introduction to the library and can be used to cite seaborn if it has been integral to a scientific publication. + +- |API| |Feature| |Enhancement| TODO (Flesh this out further). Increased flexibility of what can be shown by the internally-calculated errorbars (:pr:2407). + +- |Fix| |Enhancement| Improved robustness to missing data, including additional support for the `pd.NA` type (:pr:`2417`). + +- TODO function specific categorical enhancements, including: + + - In :func:`stripplot`, a "strip" with a single observation will be plotted without jitter (:pr:`2413`) + + - In :func:`swarmplot`, the points are now swarmed at draw time, meaning that the plot will adapt to further changes in axes scaling or tweaks to the plot layout (:pr:`2443`). + + - In :func:`swarmplot`, the order of the points in each swarm now matches the order in the original dataset; previously they were sorted. This affects only the underlying data stored in the matplotlib artist, not the visual representation (:pr:`2443`). + + - In :func:`swarmplot`, the proportion of points that must overlap before issuing a warning can now be controlled with the `warn_thresh` parameter (:pr:`2447`). + +- |Enhancement| In :func:`histplot`, added `stat="percent"` as an option for normalization such that bar heights sum to 100 (:pr:`2461`). + +- |Enhancement| |Fix| Improved integration with the matplotlib color cycle in most axes-level functions (:pr:`2449`). + +- |Fix| In :func:`lineplot, allowed the `dashes` keyword to set the style of a line without mapping a `style` variable (:pr:`2449`). + +- |Fix| In :func:`rugplot`, fixed a bug that prevented the use of datetime data (:pr:`2458`). + +- |Fix| In :func:`histplot` and :func:`kdeplot`, fixed a bug where the `alpha` parameter was ignored when `fill=False` (:pr:`2460`). + +- |Fix| In :func:`histplot` and :func:`kdeplot`, fixed a bug where the `multiple` was ignored when `hue` was provided as a vector without a name (:pr:`2462`). + +- |Fix| In :func:`histplot`, fixed a bug where using `shrink` with non-discrete bins shifted bar positions inaccurately (:pr:`2477`). + +- |Fix| In :func:`histplot`, fixed two bugs where automatically computed edge widths were too thick for log-scaled histograms and categorical histograms on the y axis (:pr:2522`). + +- |Fix| In :func:`displot`, fixed a bug where `common_norm` was ignored when `kind="hist"` and faceting was used without assigning `hue` (:pr:`2468`). + +- |Defaults| In :func:`displot`, the default alpha value now adjusts to a provided `multiple` parameter even when `hue` is not assigned (:pr:`2462`). + +- Made `scipy` an optional dependency and added `pip install seaborn[all]` as a method for ensuring the availability of compatible `scipy` and `statsmodels` libraries at install time. This has a few minor implications for existing code, which are explained in the Github pull request (:pr:`2398`). + +- Following `NEP29 `_, dropped support for Python 3.6 and bumped the minimally-supported versions of the library dependencies. + +- Removed several previously-deprecated utility functions (`iqr`, `percentiles`, `pmf_hist`, and `sort_df`). diff --git a/doc/releases/v0.4.0.txt b/doc/releases/v0.4.0.txt index 129c7b4b65..39299c14fc 100644 --- a/doc/releases/v0.4.0.txt +++ b/doc/releases/v0.4.0.txt @@ -26,7 +26,7 @@ Style and color palettes - Added the :func:`cubehelix_palette` function for generating sequential palettes from the cubehelix system. See the :ref:`palette docs ` for more information on how these palettes can be used. There is also the :func:`choose_cubehelix` which will launch an interactive app to select cubehelix parameters in the notebook. -- Added the :func:`xkcd_palette` and the ``xkcd_rgb`` dictionary so that colors :ref:`can be specified ` with names from the `xkcd color survey `_. +- Added the :func:`xkcd_palette` and the ``xkcd_rgb`` dictionary so that colors can be specified with names from the `xkcd color survey `_. - Added the ``font_scale`` option to :func:`plotting_context`, :func:`set_context`, and :func:`set`. ``font_scale`` can independently increase or decrease the size of the font elements in the plot. diff --git a/doc/releases/v0.6.0.txt b/doc/releases/v0.6.0.txt index b4002d53f7..0f2bb122c6 100644 --- a/doc/releases/v0.6.0.txt +++ b/doc/releases/v0.6.0.txt @@ -2,9 +2,6 @@ v0.6.0 (June 2015) ------------------ -.. image:: https://zenodo.org/badge/DOI/10.5281/zenodo.19108.svg - :target: https://doi.org/10.5281/zenodo.19108 - This is a major release from 0.5. The main objective of this release was to unify the API for categorical plots, which means that there are some relatively large API changes in some of the older functions. See below for details of those changes, which may break code written for older versions of seaborn. There are also some new functions (:func:`stripplot`, and :func:`countplot`), numerous enhancements to existing functions, and bug fixes. Additionally, the documentation has been completely revamped and expanded for the 0.6 release. Now, the API docs page for each function has multiple examples with embedded plots showing how to use the various options. These pages should be considered the most comprehensive resource for examples, and the tutorial pages are now streamlined and oriented towards a higher-level overview of the various features. diff --git a/doc/releases/v0.7.0.txt b/doc/releases/v0.7.0.txt index 0a6abbe08c..d809c10320 100644 --- a/doc/releases/v0.7.0.txt +++ b/doc/releases/v0.7.0.txt @@ -2,9 +2,6 @@ v0.7.0 (January 2016) --------------------- -.. image:: https://zenodo.org/badge/DOI/10.5281/zenodo.45133.svg - :target: https://doi.org/10.5281/zenodo.45133 - This is a major release from 0.6. The main new feature is :func:`swarmplot` which implements the beeswarm approach for drawing categorical scatterplots. There are also some performance improvements, bug fixes, and updates for compatibility with new versions of dependencies. - Added the :func:`swarmplot` function, which draws beeswarm plots. These are categorical scatterplots, similar to those produced by :func:`stripplot`, but position of the points on the categorical axis is chosen to avoid overlapping points. See the :ref:`categorical plot tutorial ` for more information. diff --git a/doc/releases/v0.7.1.txt b/doc/releases/v0.7.1.txt index a802771b27..809358e05b 100644 --- a/doc/releases/v0.7.1.txt +++ b/doc/releases/v0.7.1.txt @@ -2,9 +2,6 @@ v0.7.1 (June 2016) ------------------- -.. image:: https://zenodo.org/badge/DOI/10.5281/zenodo.54844.svg - :target: https://doi.org/10.5281/zenodo.54844 - - Added the ability to put "caps" on the error bars that are drawn by :func:`barplot` or :func:`pointplot` (and, by extension, ``factorplot``). Additionally, the line width of the error bars can now be controlled. These changes involve the new parameters ``capsize`` and ``errwidth``. See the `github pull request (#898) `_ for examples of usage. - Improved the row and column colors display in :func:`clustermap`. It is now possible to pass Pandas objects for these elements and, when possible, the semantic information in the Pandas objects will be used to add labels to the plot. When Pandas objects are used, the color data is matched against the main heatmap based on the index, not on position. This is more accurate, but it may lead to different results if current code assumed positional matching. diff --git a/doc/releases/v0.8.0.txt b/doc/releases/v0.8.0.txt index 58c9ec6d07..073d81a862 100644 --- a/doc/releases/v0.8.0.txt +++ b/doc/releases/v0.8.0.txt @@ -2,9 +2,6 @@ v0.8.0 (July 2017) ------------------ -.. image:: https://zenodo.org/badge/DOI/10.5281/zenodo.824567.svg - :target: https://doi.org/10.5281/zenodo.824567 - - The default style is no longer applied when seaborn is imported. It is now necessary to explicitly call :func:`set` or one or more of :func:`set_style`, :func:`set_context`, and :func:`set_palette`. Correspondingly, the ``seaborn.apionly`` module has been deprecated. - Changed the behavior of :func:`heatmap` (and by extension :func:`clustermap`) when plotting divergent dataesets (i.e. when the ``center`` parameter is used). Instead of extending the lower and upper limits of the colormap to be symmetrical around the ``center`` value, the colormap is modified so that its middle color corresponds to ``center``. This means that the full range of the colormap will not be used (unless the data or specified ``vmin`` and ``vmax`` are symmetric), but the upper and lower limits of the colorbar will correspond to the range of the data. See the Github pull request `(#1184) `_ for examples of the behavior. @@ -39,6 +36,6 @@ v0.8.0 (July 2017) - Some modules and functions have been internally reorganized; there should be no effect on code that uses the ``seaborn`` namespace. -- Added a deprecation warning to :func:`tsplot` function to indicate that it will be removed or replaced with a substantially altered version in a future release. +- Added a deprecation warning to ``tsplot`` function to indicate that it will be removed or replaced with a substantially altered version in a future release. - The ``interactplot`` and ``coefplot`` functions are officially deprecated and will be removed in a future release. diff --git a/doc/releases/v0.8.1.txt b/doc/releases/v0.8.1.txt index ade4c857bf..5c8a1b75ce 100644 --- a/doc/releases/v0.8.1.txt +++ b/doc/releases/v0.8.1.txt @@ -2,9 +2,6 @@ v0.8.1 (September 2017) ----------------------- -.. image:: https://zenodo.org/badge/DOI/10.5281/zenodo.883859.svg - :target: https://doi.org/10.5281/zenodo.883859 - - Added a warning in :class:`FacetGrid` when passing a categorical plot function without specifying ``order`` (or ``hue_order`` when ``hue`` is used), which is likely to produce a plot that is incorrect. - Improved compatibility between :class:`FacetGrid` or :class:`PairGrid` and interactive matplotlib backends so that the legend no longer remains inside the figure when using ``legend_out=True``. diff --git a/doc/releases/v0.9.0.txt b/doc/releases/v0.9.0.txt index 8d7f696998..83084859b9 100644 --- a/doc/releases/v0.9.0.txt +++ b/doc/releases/v0.9.0.txt @@ -2,15 +2,12 @@ v0.9.0 (July 2018) ------------------ -.. image:: https://zenodo.org/badge/DOI/10.5281/zenodo.1313201.svg - :target: https://doi.org/10.5281/zenodo.1313201 - This is a major release with several substantial and long-desired new features. There are also updates/modifications to the themes and color palettes that give better consistency with matplotlib 2.0 and some notable API changes. New relational plots ~~~~~~~~~~~~~~~~~~~~ -Three completely new plotting functions have been added: :func:`catplot`, :func:`scatterplot`, and :func:`lineplot`. The first is a figure-level interface to the latter two that combines them with a :class:`FacetGrid`. The functions bring the high-level, dataset-oriented API of the seaborn categorical plotting functions to more general plots (scatter plots and line plots). +Three completely new plotting functions have been added: :func:`relplot`, :func:`scatterplot`, and :func:`lineplot`. The first is a figure-level interface to the latter two that combines them with a :class:`FacetGrid`. The functions bring the high-level, dataset-oriented API of the seaborn categorical plotting functions to more general plots (scatter plots and line plots). These functions can visualize a relationship between two numeric variables while mapping up to three additional variables by modifying ``hue``, ``size``, and/or ``style`` semantics. The common high-level API is implemented differently in the two functions. For example, the size semantic in :func:`scatterplot` scales the area of scatter plot points, but in :func:`lineplot` it scales width of the line plot lines. The API is dataset-oriented, meaning that in both cases you pass the variable in your dataset rather than directly specifying the matplotlib parameters to use for point area or line width. @@ -52,7 +49,7 @@ A few functions have been renamed or have had changes to their default parameter - Renamed the ``size`` parameter to ``height`` in multi-plot grid objects (:class:`FacetGrid`, :class:`PairGrid`, and :class:`JointGrid`) along with functions that use them (``factorplot``, :func:`lmplot`, :func:`pairplot`, and :func:`jointplot`) to avoid conflicts with the ``size`` parameter that is used in ``scatterplot`` and ``lineplot`` (necessary to make :func:`relplot` work) and also makes the meaning of the parameter a bit more clear. -- Changed the default diagonal plots in :func:`pairplot` to use `func`:kdeplot` when a ``"hue"`` dimension is used. +- Changed the default diagonal plots in :func:`pairplot` to use func:`kdeplot` when a ``"hue"`` dimension is used. - Deprecated the statistical annotation component of :class:`JointGrid`. The method is still available but will be removed in a future version. diff --git a/doc/releases/v0.9.1.txt b/doc/releases/v0.9.1.txt index 012b9c669f..24efff3de8 100644 --- a/doc/releases/v0.9.1.txt +++ b/doc/releases/v0.9.1.txt @@ -1,29 +1,81 @@ -v0.9.1 (Unreleased) ------------------- +v0.9.1 (January 2020) +--------------------- -This is a minor release, comprising mostly bug fixes and compatibility with dependency updates. +This is a minor release with a number of bug fixes and adaptations to changes in seaborn's dependencies. There are also several new features. -- Added the ``corner`` option to :class:`PairGrid` and :func:`pairplot` to make a grid with only the lower triangle of bivariate axes. +This is the final version of seaborn that will support Python 2.7 or 3.5. + +New features +~~~~~~~~~~~~ + +- Added more control over the arrangement of the elements drawn by :func:`clustermap` with the ``{dendrogram,colors}_ratio`` and ``cbar_pos`` parameters. Additionally, the default organization and scaling with different figure sizes has been improved. + +- Added the ``corner`` option to :class:`PairGrid` and :func:`pairplot` to make a grid without the upper triangle of bivariate axes. + +- Added the ability to seed the random number generator for the bootstrap used to define error bars in several plots. Relevant functions now have a ``seed`` parameter, which can take either fixed seed (typically an ``int``) or a numpy random number generator object (either the newer :class:`numpy.random.Generator` or the older :class:`numpy.random.mtrand.RandomState`). - Generalized the idea of "diagonal" axes in :class:`PairGrid` to any axes that share an x and y variable. - In :class:`PairGrid`, the ``hue`` variable is now excluded from the default list of variables that make up the rows and columns of the grid. -- Fixed the behavior of ``dropna`` in :class:`PairGrid` to properly exclude null datapoints from each plot when set to ``True``. - - Exposed the ``layout_pad`` parameter in :class:`PairGrid` and set a smaller default than what matptlotlib sets for more efficient use of space in dense grids. +- It is now possible to force a categorical interpretation of the ``hue`` variable in a relational plot by passing the name of a categorical palette (e.g. ``"deep"``, or ``"Set2"``). This complements the (previously supported) option of passing a list/dict of colors. + +- Added the ``tree_kws`` parameter to :func:`clustermap` to control the properties of the lines in the dendrogram. + +- Added the ability to pass hierarchical label names to the :class:`FacetGrid` legend, which also fixes a bug in :func:`relplot` when the same label appeared in different semantics. + +- Improved support for grouping observations based on pandas index information in categorical plots. + +Bug fixes and adaptations +~~~~~~~~~~~~~~~~~~~~~~~~~ + - Avoided an error when singular data is passed to :func:`kdeplot`, issuing a warning instead. This makes :func:`pairplot` more robust. +- Fixed the behavior of ``dropna`` in :class:`PairGrid` to properly exclude null datapoints from each plot when set to ``True``. + - Fixed an issue where :func:`regplot` could interfere with other axes in a multi-plot matplotlib figure. - Semantic variables with a ``category`` data type will always be treated as categorical in relational plots. - Avoided a warning about color specifications that arose from :func:`boxenplot` on newer matplotlibs. +- Adapted to a change in how matplotlib scales axis margins, which caused multiple calls to :func:`regplot` with ``truncate=False`` to progressively expand the x axis limits. Because there are currently limitations on how autoscaling works in matplotlib, the default value for ``truncate`` in seaborn has also been changed to ``True``. + +- Relational plots no longer error when hue/size data are inferred to be numeric but stored with a string datatype. + +- Relational plots now consider semantics with only a single value that can be interpreted as boolean (0 or 1) to be categorical, not numeric. + +- Relational plots now handle list or dict specifications for ``sizes`` correctly. + +- Fixed an issue in :func:`pointplot` where missing levels of a hue variable would cause an exception after a recent update in matplotlib. + - Fixed a bug when setting the rotation of x tick labels on a :class:`FacetGrid`. +- Fixed a bug where values would be excluded from categorical plots when only one variable was a pandas ``Series`` with a non-default index. + +- Fixed a bug when using ``Series`` objects as arguments for ``x_partial`` or ``y_partial`` in :func:`regplot`. + - Fixed a bug when passing a ``norm`` object and using color annotations in :func:`clustermap`. +- Fixed a bug where annotations were not rearranged to match the clustering in :func:`clustermap`. + - Fixed a bug when trying to call :func:`set` while specifying a list of colors for the palette. + +- Fixed a bug when resetting the color code short-hands to the matplotlib default. + +- Avoided errors from stricter type checking in upcoming ``numpy`` changes. + +- Avoided error/warning in :func:`lineplot` when plotting categoricals with empty levels. + +- Allowed ``colors`` to be passed through to a bivariate :func:`kdeplot`. + +- Standardized the output format of custom color palette functions. + +- Fixed a bug where legends for numerical variables in a relational plot could show a surprisingly large number of decimal places. + +- Improved robustness to missing values in distribution plots. + +- Made it possible to specify the location of the :class:`FacetGrid` legend using matplotlib keyword arguments. diff --git a/doc/requirements.txt b/doc/requirements.txt new file mode 100644 index 0000000000..c6157fa9c4 --- /dev/null +++ b/doc/requirements.txt @@ -0,0 +1,6 @@ +sphinx==3.3.1 +sphinx_bootstrap_theme==0.7.1 +numpydoc +nbconvert +ipykernel +sphinx-issues diff --git a/doc/sphinxext/gallery_generator.py b/doc/sphinxext/gallery_generator.py index b3314c303f..ff546a8419 100644 --- a/doc/sphinxext/gallery_generator.py +++ b/doc/sphinxext/gallery_generator.py @@ -4,7 +4,6 @@ Lightly modified from the mpld3 project. """ -from __future__ import division import os import os.path as op import re @@ -12,31 +11,30 @@ import token import tokenize import shutil - -from seaborn.external import six +import warnings import matplotlib matplotlib.use('Agg') -import matplotlib.pyplot as plt +import matplotlib.pyplot as plt # noqa: E402 -from matplotlib import image -if six.PY3: - # Python 3 has no execfile - def execfile(filename, globals=None, locals=None): - with open(filename, "rb") as fp: - six.exec_(compile(fp.read(), filename, 'exec'), globals, locals) +# Python 3 has no execfile +def execfile(filename, globals=None, locals=None): + with open(filename, "rb") as fp: + exec(compile(fp.read(), filename, 'exec'), globals, locals) RST_TEMPLATE = """ +.. currentmodule:: seaborn + .. _{sphinx_tag}: {docstring} .. image:: {img_file} -**Python source code:** :download:`[download source: {fname}]<{fname}>` +**seaborn components used:** {components} .. raw:: html @@ -138,7 +136,7 @@ def create_thumbnail(infile, thumbfile, cx=0.5, cy=0.5, border=4): baseout, extout = op.splitext(thumbfile) - im = image.imread(infile) + im = matplotlib.image.imread(infile) rows, cols = im.shape[:2] x0 = int(cx * cols - .5 * width) y0 = int(cy * rows - .5 * height) @@ -153,8 +151,13 @@ def create_thumbnail(infile, thumbfile, ax = fig.add_axes([0, 0, 1, 1], aspect='auto', frameon=False, xticks=[], yticks=[]) - ax.imshow(thumb, aspect='auto', resample=True, - interpolation='bilinear') + if all(thumb.shape): + ax.imshow(thumb, aspect='auto', resample=True, + interpolation='bilinear') + else: + warnings.warn( + f"Bad thumbnail crop. {thumbfile} will be empty." + ) fig.savefig(thumbfile, dpi=dpi) return fig @@ -178,12 +181,10 @@ def __init__(self, filename, target_dir): # Only actually run it if the output RST file doesn't # exist or it was modified less recently than the example - if (not op.exists(outfilename) - or (op.getmtime(outfilename) < op.getmtime(filename))): - + file_mtime = op.getmtime(filename) + if not op.exists(outfilename) or op.getmtime(outfilename) < file_mtime: self.exec_file() else: - print("skipping {0}".format(self.filename)) @property @@ -241,6 +242,19 @@ def plotfunc(self): return match.group(1) return "" + @property + def components(self): + + objects = re.findall(r"sns\.(\w+)\(", self.filetext) + + refs = [] + for obj in objects: + if obj[0].isupper(): + refs.append(f":class:`{obj}`") + else: + refs.append(f":func:`{obj}`") + return ", ".join(refs) + def extract_docstring(self): """ Extract a module-level docstring """ @@ -326,8 +340,7 @@ def main(app): target_dir = op.join(app.builder.srcdir, 'examples') image_dir = op.join(app.builder.srcdir, 'examples/_images') thumb_dir = op.join(app.builder.srcdir, "example_thumbs") - source_dir = op.abspath(op.join(app.builder.srcdir, - '..', 'examples')) + source_dir = op.abspath(op.join(app.builder.srcdir, '..', 'examples')) if not op.exists(static_dir): os.makedirs(static_dir) @@ -351,7 +364,7 @@ def main(app): contents = "\n\n" # Write individual example files - for filename in glob.glob(op.join(source_dir, "*.py")): + for filename in sorted(glob.glob(op.join(source_dir, "*.py"))): ex = ExampleGenerator(filename, target_dir) @@ -362,6 +375,7 @@ def main(app): output = RST_TEMPLATE.format(sphinx_tag=ex.sphinxtag, docstring=ex.docstring, end_line=ex.end_line, + components=ex.components, fname=ex.pyfilename, img_file=ex.pngfilename) with open(op.join(target_dir, ex.rstfilename), 'w') as f: diff --git a/doc/tools/extract_examples.py b/doc/tools/extract_examples.py new file mode 100644 index 0000000000..36b0eff626 --- /dev/null +++ b/doc/tools/extract_examples.py @@ -0,0 +1,73 @@ +"""Turn the examples section of a function docstring into a notebook.""" +import re +import sys +import pydoc +import seaborn +from seaborn.external.docscrape import NumpyDocString +import nbformat + + +def line_type(line): + + if line.startswith(" "): + return "code" + else: + return "markdown" + + +def add_cell(nb, lines, cell_type): + + cell_objs = { + "code": nbformat.v4.new_code_cell, + "markdown": nbformat.v4.new_markdown_cell, + } + text = "\n".join(lines) + cell = cell_objs[cell_type](text) + nb["cells"].append(cell) + + +if __name__ == "__main__": + + _, name = sys.argv + + # Parse the docstring and get the examples section + obj = getattr(seaborn, name) + if obj.__class__.__name__ != "function": + obj = obj.__init__ + lines = NumpyDocString(pydoc.getdoc(obj))["Examples"] + + # Remove code indentation, the prompt, and mpl return variable + pat = re.compile(r"\s{4}[>\.]{3} (ax = ){0,1}(g = ){0,1}") + + nb = nbformat.v4.new_notebook() + + # We always start with at least one line of text + cell_type = "markdown" + cell = [] + + for line in lines: + + # Ignore matplotlib plot directive + if ".. plot" in line or ":context:" in line: + continue + + # Ignore blank lines + if not line: + continue + + if line_type(line) != cell_type: + # We are on the first line of the next cell, + # so package up the last cell + add_cell(nb, cell, cell_type) + cell_type = line_type(line) + cell = [] + + if line_type(line) == "code": + line = re.sub(pat, "", line) + + cell.append(line) + + # Package the final cell + add_cell(nb, cell, cell_type) + + nbformat.write(nb, f"docstrings/{name}.ipynb") diff --git a/doc/tools/generate_logos.py b/doc/tools/generate_logos.py new file mode 100644 index 0000000000..3e1477a9bb --- /dev/null +++ b/doc/tools/generate_logos.py @@ -0,0 +1,224 @@ +import numpy as np +import seaborn as sns +from matplotlib import patches +import matplotlib.pyplot as plt +from scipy.signal import gaussian +from scipy.spatial import distance + + +XY_CACHE = {} + +STATIC_DIR = "_static" +plt.rcParams["savefig.dpi"] = 300 + + +def poisson_disc_sample(array_radius, pad_radius, candidates=100, d=2, seed=None): + """Find positions using poisson-disc sampling.""" + # See http://bost.ocks.org/mike/algorithms/ + rng = np.random.default_rng(seed) + uniform = rng.uniform + randint = rng.integers + + # Cache the results + key = array_radius, pad_radius, seed + if key in XY_CACHE: + return XY_CACHE[key] + + # Start at a fixed point we know will work + start = np.zeros(d) + samples = [start] + queue = [start] + + while queue: + + # Pick a sample to expand from + s_idx = randint(len(queue)) + s = queue[s_idx] + + for i in range(candidates): + # Generate a candidate from this sample + coords = uniform(s - 2 * pad_radius, s + 2 * pad_radius, d) + + # Check the three conditions to accept the candidate + in_array = np.sqrt(np.sum(coords ** 2)) < array_radius + in_ring = np.all(distance.cdist(samples, [coords]) > pad_radius) + + if in_array and in_ring: + # Accept the candidate + samples.append(coords) + queue.append(coords) + break + + if (i + 1) == candidates: + # We've exhausted the particular sample + queue.pop(s_idx) + + samples = np.array(samples) + XY_CACHE[key] = samples + return samples + + +def logo( + ax, + color_kws, ring, ring_idx, edge, + pdf_means, pdf_sigma, dy, y0, w, h, + hist_mean, hist_sigma, hist_y0, lw, skip, + scatter, pad, scale, +): + + # Square, invisible axes with specified limits to center the logo + ax.set(xlim=(35 + w, 95 - w), ylim=(-3, 53)) + ax.set_axis_off() + ax.set_aspect('equal') + + # Magic numbers for the logo circle + radius = 27 + center = 65, 25 + + # Full x and y grids for a gaussian curve + x = np.arange(101) + y = gaussian(x.size, pdf_sigma) + + x0 = 30 # Magic number + xx = x[x0:] + + # Vertical distances between the PDF curves + n = len(pdf_means) + dys = np.linspace(0, (n - 1) * dy, n) - (n * dy / 2) + dys -= dys.mean() + + # Compute the PDF curves with vertical offsets + pdfs = [h * (y[x0 - m:-m] + y0 + dy) for m, dy in zip(pdf_means, dys)] + + # Add in constants to fill from bottom and to top + pdfs.insert(0, np.full(xx.shape, -h)) + pdfs.append(np.full(xx.shape, 50 + h)) + + # Color gradient + colors = sns.cubehelix_palette(n + 1 + bool(hist_mean), **color_kws) + + # White fill between curves and around edges + bg = patches.Circle( + center, radius=radius - 1 + ring, color="white", + transform=ax.transData, zorder=0, + ) + ax.add_artist(bg) + + # Clipping artist (not shown) for the interior elements + fg = patches.Circle(center, radius=radius - edge, transform=ax.transData) + + # Ring artist to surround the circle (optional) + if ring: + wedge = patches.Wedge( + center, r=radius + edge / 2, theta1=0, theta2=360, width=edge / 2, + transform=ax.transData, color=colors[ring_idx], alpha=1 + ) + ax.add_artist(wedge) + + # Add histogram bars + if hist_mean: + hist_color = colors.pop(0) + hist_y = gaussian(x.size, hist_sigma) + hist = 1.1 * h * (hist_y[x0 - hist_mean:-hist_mean] + hist_y0) + dx = x[skip] - x[0] + hist_x = xx[::skip] + hist_h = h + hist[::skip] + # Magic number to avoid tiny sliver of bar on edge + use = hist_x < center[0] + radius * .5 + bars = ax.bar( + hist_x[use], hist_h[use], bottom=-h, width=dx, + align="edge", color=hist_color, ec="w", lw=lw, + zorder=3, + ) + for bar in bars: + bar.set_clip_path(fg) + + # Add each smooth PDF "wave" + for i, pdf in enumerate(pdfs[1:], 1): + u = ax.fill_between(xx, pdfs[i - 1] + w, pdf, color=colors[i - 1], lw=0) + u.set_clip_path(fg) + + # Add scatterplot in top wave area + if scatter: + seed = sum(map(ord, "seaborn logo")) + xy = poisson_disc_sample(radius - edge - ring, pad, seed=seed) + clearance = distance.cdist(xy + center, np.c_[xx, pdfs[-2]]) + use = clearance.min(axis=1) > pad / 1.8 + x, y = xy[use].T + sizes = (x - y) % 9 + + points = ax.scatter( + x + center[0], y + center[1], s=scale * (10 + sizes * 5), + zorder=5, color=colors[-1], ec="w", lw=scale / 2, + ) + path = u.get_paths()[0] + points.set_clip_path(path, transform=u.get_transform()) + u.set_visible(False) + + +def savefig(fig, shape, variant): + + fig.subplots_adjust(0, 0, 1, 1, 0, 0) + + facecolor = (1, 1, 1, 1) if bg == "white" else (1, 1, 1, 0) + + for ext in ["png", "svg"]: + fig.savefig(f"{STATIC_DIR}/logo-{shape}-{variant}bg.{ext}", facecolor=facecolor) + + +if __name__ == "__main__": + + for bg in ["white", "light", "dark"]: + + color_idx = -1 if bg == "dark" else 0 + + kwargs = dict( + color_kws=dict(start=.3, rot=-.4, light=.8, dark=.3, reverse=True), + ring=True, ring_idx=color_idx, edge=1, + pdf_means=[8, 24], pdf_sigma=16, + dy=1, y0=1.8, w=.5, h=12, + hist_mean=2, hist_sigma=10, hist_y0=.6, lw=1, skip=6, + scatter=True, pad=1.8, scale=.5, + ) + color = sns.cubehelix_palette(**kwargs["color_kws"])[color_idx] + + # ------------------------------------------------------------------------ # + + fig, ax = plt.subplots(figsize=(2, 2), facecolor="w", dpi=100) + logo(ax, **kwargs) + savefig(fig, "mark", bg) + + # ------------------------------------------------------------------------ # + + fig, axs = plt.subplots(1, 2, figsize=(8, 2), dpi=100, + gridspec_kw=dict(width_ratios=[1, 3])) + logo(axs[0], **kwargs) + + font = { + "family": "avenir", + "color": color, + "weight": "regular", + "size": 120, + } + axs[1].text(.01, .35, "seaborn", ha="left", va="center", + fontdict=font, transform=axs[1].transAxes) + axs[1].set_axis_off() + savefig(fig, "wide", bg) + + # ------------------------------------------------------------------------ # + + fig, axs = plt.subplots(2, 1, figsize=(2, 2.5), dpi=100, + gridspec_kw=dict(height_ratios=[4, 1])) + + logo(axs[0], **kwargs) + + font = { + "family": "avenir", + "color": color, + "weight": "regular", + "size": 34, + } + axs[1].text(.5, 1, "seaborn", ha="center", va="top", + fontdict=font, transform=axs[1].transAxes) + axs[1].set_axis_off() + savefig(fig, "tall", bg) diff --git a/doc/tools/nb_to_doc.py b/doc/tools/nb_to_doc.py index bba1e6d13a..61ee7a0214 100755 --- a/doc/tools/nb_to_doc.py +++ b/doc/tools/nb_to_doc.py @@ -1,33 +1,177 @@ #! /usr/bin/env python -""" -Convert empty IPython notebook to a sphinx doc page. +"""Execute a .ipynb file, write out a processed .rst and clean .ipynb. + +Some functions in this script were copied from the nbstripout tool: + +Copyright (c) 2015 Min RK, Florian Rathgeber, Michael McNeil Forbes +2019 Casper da Costa-Luis + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +import os import sys -from subprocess import check_call as sh +import nbformat +from nbconvert import RSTExporter +from nbconvert.preprocessors import ( + ExecutePreprocessor, + TagRemovePreprocessor, + ExtractOutputPreprocessor +) +from traitlets.config import Config + + +class MetadataError(Exception): + pass + + +def pop_recursive(d, key, default=None): + """dict.pop(key) where `key` is a `.`-delimited list of nested keys. + >>> d = {'a': {'b': 1, 'c': 2}} + >>> pop_recursive(d, 'a.c') + 2 + >>> d + {'a': {'b': 1}} + """ + nested = key.split('.') + current = d + for k in nested[:-1]: + if hasattr(current, 'get'): + current = current.get(k, {}) + else: + return default + if not hasattr(current, 'pop'): + return default + return current.pop(nested[-1], default) + + +def strip_output(nb): + """ + Strip the outputs, execution count/prompt number and miscellaneous + metadata from a notebook object, unless specified to keep either the + outputs or counts. + """ + keys = {'metadata': [], 'cell': {'metadata': []}} + nb.metadata.pop('signature', None) + nb.metadata.pop('widgets', None) -def convert_nb(nbname): + for field in keys['metadata']: + pop_recursive(nb.metadata, field) - # Execute the notebook - sh(["jupyter", "nbconvert", "--to", "notebook", - "--execute", "--inplace", nbname]) + for cell in nb.cells: - # Convert to .rst for Sphinx - sh(["jupyter", "nbconvert", "--to", "rst", nbname, - "--TagRemovePreprocessor.remove_cell_tags={'hide'}", - "--TagRemovePreprocessor.remove_input_tags={'hide-input'}", - "--TagRemovePreprocessor.remove_all_outputs_tags={'hide-output'}"]) + # Remove the outputs, unless directed otherwise + if 'outputs' in cell: - # Clear notebook output - sh(["jupyter", "nbconvert", "--to", "notebook", "--inplace", - "--ClearOutputPreprocessor.enabled=True", nbname]) + cell['outputs'] = [] - # Touch the .rst file so it has a later modify time than the source - sh(["touch", nbname + ".rst"]) + # Remove the prompt_number/execution_count, unless directed otherwise + if 'prompt_number' in cell: + cell['prompt_number'] = None + if 'execution_count' in cell: + cell['execution_count'] = None + + # Always remove this metadata + for output_style in ['collapsed', 'scrolled']: + if output_style in cell.metadata: + cell.metadata[output_style] = False + if 'metadata' in cell: + for field in ['collapsed', 'scrolled', 'ExecuteTime']: + cell.metadata.pop(field, None) + for (extra, fields) in keys['cell'].items(): + if extra in cell: + for field in fields: + pop_recursive(getattr(cell, extra), field) + return nb if __name__ == "__main__": - for nbname in sys.argv[1:]: - convert_nb(nbname) + # Get the desired ipynb file path and parse into components + _, fpath = sys.argv + basedir, fname = os.path.split(fpath) + fstem = fname[:-6] + + # Read the notebook + print(f"Executing {fpath} ...", end=" ", flush=True) + with open(fpath) as f: + nb = nbformat.read(f, as_version=4) + + # Run the notebook + kernel = os.environ.get("NB_KERNEL", None) + if kernel is None: + kernel = nb["metadata"]["kernelspec"]["name"] + ep = ExecutePreprocessor( + timeout=600, + kernel_name=kernel, + extra_arguments=["--InlineBackend.rc={'figure.dpi': 88}"] + ) + ep.preprocess(nb, {"metadata": {"path": basedir}}) + + # Remove plain text execution result outputs + for cell in nb.get("cells", {}): + fields = cell.get("outputs", []) + for field in fields: + if field["output_type"] == "execute_result": + data_keys = field["data"].keys() + for key in list(data_keys): + if key == "text/plain": + field["data"].pop(key) + if not field["data"]: + fields.remove(field) + + # Convert to .rst formats + exp = RSTExporter() + + c = Config() + c.TagRemovePreprocessor.remove_cell_tags = {"hide"} + c.TagRemovePreprocessor.remove_input_tags = {"hide-input"} + c.TagRemovePreprocessor.remove_all_outputs_tags = {"hide-output"} + c.ExtractOutputPreprocessor.output_filename_template = \ + f"{fstem}_files/{fstem}_" + "{cell_index}_{index}{extension}" + + exp.register_preprocessor(TagRemovePreprocessor(config=c), True) + exp.register_preprocessor(ExtractOutputPreprocessor(config=c), True) + + body, resources = exp.from_notebook_node(nb) + + # Clean the output on the notebook and save a .ipynb back to disk + print(f"Writing clean {fpath} ... ", end=" ", flush=True) + nb = strip_output(nb) + with open(fpath, "wt") as f: + nbformat.write(nb, f) + + # Write the .rst file + rst_path = os.path.join(basedir, f"{fstem}.rst") + print(f"Writing {rst_path}") + with open(rst_path, "w") as f: + f.write(body) + + # Write the individual image outputs + imdir = os.path.join(basedir, f"{fstem}_files") + if not os.path.exists(imdir): + os.mkdir(imdir) + + for imname, imdata in resources["outputs"].items(): + if imname.startswith(fstem): + impath = os.path.join(basedir, f"{imname}") + with open(impath, "wb") as f: + f.write(imdata) diff --git a/doc/tutorial.rst b/doc/tutorial.rst index c207a21ccc..a56eb28a08 100644 --- a/doc/tutorial.rst +++ b/doc/tutorial.rst @@ -1,29 +1,156 @@ .. _tutorial: -Official seaborn tutorial -========================= +User guide and tutorial +=============================== .. raw:: html
+ +.. raw:: html + +
+
+

API overview

+
+
+
+
+
+ + + +
+
+ +.. toctree:: + :maxdepth: 2 + + tutorial/function_overview + +.. raw:: html + +
+
+
+
+ + + +
+
+ + +.. toctree:: + :maxdepth: 2 + + tutorial/data_structure + +.. raw:: html + +
+
+
+
+
+
+ + + +
+
+ +.. toctree:: + :maxdepth: 2 + + tutorial/error_bars + +.. raw:: html + +
+
+
+
+
+

Plotting functions

+
+
+ + + +
+
.. toctree:: :maxdepth: 2 tutorial/relational - tutorial/categorical + +.. raw:: html + +
+
+
+
+
+ + + +
+
+ +.. toctree:: + :maxdepth: 2 + tutorial/distributions + +.. raw:: html + +
+
+
+
+
+ + + +
+
+ +.. toctree:: + :maxdepth: 2 + + tutorial/categorical + +.. raw:: html + +
+
+
+
+
+ + + +
+
+ +.. toctree:: + :maxdepth: 2 + tutorial/regression .. raw:: html +
+
@@ -35,6 +162,13 @@ Official seaborn tutorial

Multi-plot grids

+
+
+ + + +
+
.. toctree:: :maxdepth: 2 @@ -43,6 +177,8 @@ Official seaborn tutorial .. raw:: html +
+
@@ -52,15 +188,41 @@ Official seaborn tutorial

Plot aesthetics

+
+
+ + + +
+
.. toctree:: :maxdepth: 2 tutorial/aesthetics + +.. raw:: html + +
+
+
+
+
+ + + +
+
+ +.. toctree:: + :maxdepth: 2 + tutorial/color_palettes .. raw:: html +
+
diff --git a/doc/tutorial/Makefile b/doc/tutorial/Makefile index a77fda5054..fda6b816ee 100644 --- a/doc/tutorial/Makefile +++ b/doc/tutorial/Makefile @@ -3,4 +3,9 @@ rst_files := $(patsubst %.ipynb,%.rst,$(wildcard *.ipynb)) tutorial: ${rst_files} %.rst: %.ipynb - ../tools/nb_to_doc.py $* + ../tools/nb_to_doc.py $*.ipynb + +clean: + rm -rf *.rst + rm -rf *_files/ + rm -rf .ipynb_checkpoints/ diff --git a/doc/tutorial/aesthetics.ipynb b/doc/tutorial/aesthetics.ipynb index ed84e77d08..326979e4e8 100644 --- a/doc/tutorial/aesthetics.ipynb +++ b/doc/tutorial/aesthetics.ipynb @@ -94,7 +94,7 @@ "cell_type": "raw", "metadata": {}, "source": [ - "To switch to seaborn defaults, simply call the :func:`set` function." + "To switch to seaborn defaults, simply call the :func:`set_theme` function." ] }, { @@ -103,7 +103,7 @@ "metadata": {}, "outputs": [], "source": [ - "sns.set()\n", + "sns.set_theme()\n", "sinplot()" ] }, @@ -111,7 +111,7 @@ "cell_type": "raw", "metadata": {}, "source": [ - "(Note that in versions of seaborn prior to 0.8, :func:`set` was called on import. On later versions, it must be explicitly invoked).\n", + "(Note that in versions of seaborn prior to 0.8, :func:`set_theme` was called on import. On later versions, it must be explicitly invoked).\n", "\n", "Seaborn splits matplotlib parameters into two independent groups. The first group sets the aesthetic style of the plot, and the second scales various elements of the figure so that it can be easily incorporated into different contexts.\n", "\n", @@ -254,12 +254,26 @@ "metadata": {}, "outputs": [], "source": [ - "f = plt.figure()\n", + "f = plt.figure(figsize=(6, 6))\n", + "gs = f.add_gridspec(2, 2)\n", + "\n", "with sns.axes_style(\"darkgrid\"):\n", - " ax = f.add_subplot(1, 2, 1)\n", + " ax = f.add_subplot(gs[0, 0])\n", + " sinplot()\n", + " \n", + "with sns.axes_style(\"white\"):\n", + " ax = f.add_subplot(gs[0, 1])\n", + " sinplot()\n", + "\n", + "with sns.axes_style(\"ticks\"):\n", + " ax = f.add_subplot(gs[1, 0])\n", + " sinplot()\n", + "\n", + "with sns.axes_style(\"whitegrid\"):\n", + " ax = f.add_subplot(gs[1, 1])\n", " sinplot()\n", - "ax = f.add_subplot(1, 2, 2)\n", - "sinplot(-1)" + " \n", + "f.tight_layout()" ] }, { @@ -269,7 +283,7 @@ "Overriding elements of the seaborn styles\n", "-----------------------------------------\n", "\n", - "If you want to customize the seaborn styles, you can pass a dictionary of parameters to the ``rc`` argument of :func:`axes_style` and :func:`set_style`. Note that you can only override the parameters that are part of the style definition through this method. (However, the higher-level :func:`set` function takes a dictionary of any matplotlib parameters).\n", + "If you want to customize the seaborn styles, you can pass a dictionary of parameters to the ``rc`` argument of :func:`axes_style` and :func:`set_style`. Note that you can only override the parameters that are part of the style definition through this method. (However, the higher-level :func:`set_theme` function takes a dictionary of any matplotlib parameters).\n", "\n", "If you want to see what parameters are included, you can just call the function with no arguments, which will return the current settings:" ] @@ -311,7 +325,7 @@ "\n", "A separate set of parameters control the scale of plot elements, which should let you use the same code to make plots that are suited for use in settings where larger or smaller plots are appropriate.\n", "\n", - "First let's reset the default parameters by calling :func:`set`:" + "First let's reset the default parameters by calling :func:`set_theme`:" ] }, { @@ -320,7 +334,7 @@ "metadata": {}, "outputs": [], "source": [ - "sns.set()" + "sns.set_theme()" ] }, { @@ -403,9 +417,9 @@ "metadata": { "celltoolbar": "Tags", "kernelspec": { - "display_name": "Python 3.6 (seaborn-dev)", + "display_name": "seaborn-py38-latest", "language": "python", - "name": "seaborn-dev" + "name": "seaborn-py38-latest" }, "language_info": { "codemirror_mode": { @@ -417,9 +431,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.3" + "version": "3.8.5" } }, "nbformat": 4, - "nbformat_minor": 1 + "nbformat_minor": 4 } diff --git a/doc/tutorial/axis_grids.ipynb b/doc/tutorial/axis_grids.ipynb index cb2aa6b556..7c5707348b 100644 --- a/doc/tutorial/axis_grids.ipynb +++ b/doc/tutorial/axis_grids.ipynb @@ -25,17 +25,19 @@ "cell_type": "raw", "metadata": {}, "source": [ - "When exploring medium-dimensional data, a useful approach is to draw multiple instances of the same plot on different subsets of your dataset. This technique is sometimes called either \"lattice\" or \"trellis\" plotting, and it is related to the idea of `\"small multiples\" `_. It allows a viewer to quickly extract a large amount of information about complex data. Matplotlib offers good support for making figures with multiple axes; seaborn builds on top of this to directly link the structure of the plot to the structure of your dataset.\n", + "When exploring multi-dimensional data, a useful approach is to draw multiple instances of the same plot on different subsets of your dataset. This technique is sometimes called either \"lattice\" or \"trellis\" plotting, and it is related to the idea of `\"small multiples\" `_. It allows a viewer to quickly extract a large amount of information about a complex dataset. Matplotlib offers good support for making figures with multiple axes; seaborn builds on top of this to directly link the structure of the plot to the structure of your dataset.\n", "\n", - "To use these features, your data has to be in a Pandas DataFrame and it must take the form of what Hadley Whickam calls `\"tidy\" data `_. In brief, that means your dataframe should be structured such that each column is a variable and each row is an observation.\n", - "\n", - "For advanced use, you can use the objects discussed in this part of the tutorial directly, which will provide maximum flexibility. Some seaborn functions (such as :func:`lmplot`, :func:`catplot`, and :func:`pairplot`) also use them behind the scenes. Unlike other seaborn functions that are \"Axes-level\" and draw onto specific (possibly already-existing) matplotlib ``Axes`` without otherwise manipulating the figure, these higher-level functions create a figure when called and are generally more strict about how it gets set up. In some cases, arguments either to those functions or to the constructor of the class they rely on will provide a different interface attributes like the figure size, as in the case of :func:`lmplot` where you can set the height and aspect ratio for each facet rather than the overall size of the figure. Any function that uses one of these objects will always return it after plotting, though, and most of these objects have convenience methods for changing how the plot is drawn, often in a more abstract and easy way." + "The :doc:`figure-level ` functions are built on top of the objects discussed in this chapter of the tutorial. In most cases, you will want to work with those functions. They take care of some important bookkeeping that synchronizes the multiple plots in each grid. This chapter explains how the underlying objects work, which may be useful for advanced applications." ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "tags": [ + "hide" + ] + }, "outputs": [], "source": [ "import seaborn as sns\n", @@ -45,10 +47,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "tags": [ + "hide" + ] + }, "outputs": [], "source": [ - "sns.set(style=\"ticks\")" + "sns.set_theme(style=\"ticks\")" ] }, { @@ -77,18 +83,9 @@ "\n", "The :class:`FacetGrid` class is useful when you want to visualize the distribution of a variable or the relationship between multiple variables separately within subsets of your dataset. A :class:`FacetGrid` can be drawn with up to three dimensions: ``row``, ``col``, and ``hue``. The first two have obvious correspondence with the resulting array of axes; think of the hue variable as a third dimension along a depth axis, where different levels are plotted with different colors.\n", "\n", - "The class is used by initializing a :class:`FacetGrid` object with a dataframe and the names of the variables that will form the row, column, or hue dimensions of the grid. These variables should be categorical or discrete, and then the data at each level of the variable will be used for a facet along that axis. For example, say we wanted to examine differences between lunch and dinner in the ``tips`` dataset.\n", + "Each of :func:`relplot`, :func:`displot`, :func:`catplot`, and :func:`lmplot` use this object internally, and they return the object when they are finished so that it can be used for further tweaking.\n", "\n", - "Additionally, each of :func:`relplot`, :func:`catplot`, and :func:`lmplot` use this object internally, and they return the object when they are finished so that it can be used for further tweaking." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "tips = sns.load_dataset(\"tips\")" + "The class is used by initializing a :class:`FacetGrid` object with a dataframe and the names of the variables that will form the row, column, or hue dimensions of the grid. These variables should be categorical or discrete, and then the data at each level of the variable will be used for a facet along that axis. For example, say we wanted to examine differences between lunch and dinner in the ``tips`` dataset:" ] }, { @@ -97,6 +94,7 @@ "metadata": {}, "outputs": [], "source": [ + "tips = sns.load_dataset(\"tips\")\n", "g = sns.FacetGrid(tips, col=\"time\")" ] }, @@ -106,7 +104,7 @@ "source": [ "Initializing the grid like this sets up the matplotlib figure and axes, but doesn't draw anything on them.\n", "\n", - "The main approach for visualizing data on this grid is with the :meth:`FacetGrid.map` method. Provide it with a plotting function and the name(s) of variable(s) in the dataframe to plot. Let's look at the distribution of tips in each of these subsets, using a histogram." + "The main approach for visualizing data on this grid is with the :meth:`FacetGrid.map` method. Provide it with a plotting function and the name(s) of variable(s) in the dataframe to plot. Let's look at the distribution of tips in each of these subsets, using a histogram:" ] }, { @@ -116,7 +114,7 @@ "outputs": [], "source": [ "g = sns.FacetGrid(tips, col=\"time\")\n", - "g.map(plt.hist, \"tip\");" + "g.map(sns.histplot, \"tip\")" ] }, { @@ -133,8 +131,8 @@ "outputs": [], "source": [ "g = sns.FacetGrid(tips, col=\"sex\", hue=\"smoker\")\n", - "g.map(plt.scatter, \"total_bill\", \"tip\", alpha=.7)\n", - "g.add_legend();" + "g.map(sns.scatterplot, \"total_bill\", \"tip\", alpha=.7)\n", + "g.add_legend()" ] }, { @@ -151,7 +149,7 @@ "outputs": [], "source": [ "g = sns.FacetGrid(tips, row=\"smoker\", col=\"time\", margin_titles=True)\n", - "g.map(sns.regplot, \"size\", \"total_bill\", color=\".3\", fit_reg=False, x_jitter=.1);" + "g.map(sns.regplot, \"size\", \"total_bill\", color=\".3\", fit_reg=False, x_jitter=.1)" ] }, { @@ -170,7 +168,7 @@ "outputs": [], "source": [ "g = sns.FacetGrid(tips, col=\"day\", height=4, aspect=.5)\n", - "g.map(sns.barplot, \"sex\", \"total_bill\");" + "g.map(sns.barplot, \"sex\", \"total_bill\", order=[\"Male\", \"Female\"])" ] }, { @@ -189,7 +187,7 @@ "ordered_days = tips.day.value_counts().index\n", "g = sns.FacetGrid(tips, row=\"day\", row_order=ordered_days,\n", " height=1.7, aspect=4,)\n", - "g.map(sns.distplot, \"total_bill\", hist=False, rug=True);" + "g.map(sns.kdeplot, \"total_bill\")" ] }, { @@ -205,28 +203,10 @@ "metadata": {}, "outputs": [], "source": [ - "pal = dict(Lunch=\"seagreen\", Dinner=\"gray\")\n", + "pal = dict(Lunch=\"seagreen\", Dinner=\".7\")\n", "g = sns.FacetGrid(tips, hue=\"time\", palette=pal, height=5)\n", - "g.map(plt.scatter, \"total_bill\", \"tip\", s=50, alpha=.7, linewidth=.5, edgecolor=\"white\")\n", - "g.add_legend();" - ] - }, - { - "cell_type": "raw", - "metadata": {}, - "source": [ - "You can also let other aspects of the plot vary across levels of the hue variable, which can be helpful for making plots that will be more comprehensible when printed in black-and-white. To do this, pass a dictionary to ``hue_kws`` where keys are the names of plotting function keyword arguments and values are lists of keyword values, one for each level of the hue variable." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "g = sns.FacetGrid(tips, hue=\"sex\", palette=\"Set1\", height=5, hue_kws={\"marker\": [\"^\", \"v\"]})\n", - "g.map(plt.scatter, \"total_bill\", \"tip\", s=100, linewidth=.5, edgecolor=\"white\")\n", - "g.add_legend();" + "g.map(sns.scatterplot, \"total_bill\", \"tip\", s=100, alpha=.5)\n", + "g.add_legend()" ] }, { @@ -244,7 +224,7 @@ "source": [ "attend = sns.load_dataset(\"attention\").query(\"subject <= 12\")\n", "g = sns.FacetGrid(attend, col=\"subject\", col_wrap=4, height=2, ylim=(0, 10))\n", - "g.map(sns.pointplot, \"solutions\", \"score\", color=\".3\", ci=None);" + "g.map(sns.pointplot, \"solutions\", \"score\", order=[1, 2, 3], color=\".3\", ci=None)" ] }, { @@ -262,10 +242,10 @@ "source": [ "with sns.axes_style(\"white\"):\n", " g = sns.FacetGrid(tips, row=\"sex\", col=\"smoker\", margin_titles=True, height=2.5)\n", - "g.map(plt.scatter, \"total_bill\", \"tip\", color=\"#334488\", edgecolor=\"white\", lw=.5);\n", - "g.set_axis_labels(\"Total bill (US Dollars)\", \"Tip\");\n", - "g.set(xticks=[10, 30, 50], yticks=[2, 6, 10]);\n", - "g.fig.subplots_adjust(wspace=.02, hspace=.02);" + "g.map(sns.scatterplot, \"total_bill\", \"tip\", color=\"#334488\")\n", + "g.set_axis_labels(\"Total bill (US Dollars)\", \"Tip\")\n", + "g.set(xticks=[10, 30, 50], yticks=[2, 6, 10])\n", + "g.fig.subplots_adjust(wspace=.02, hspace=.02)" ] }, { @@ -284,8 +264,8 @@ "g = sns.FacetGrid(tips, col=\"smoker\", margin_titles=True, height=4)\n", "g.map(plt.scatter, \"total_bill\", \"tip\", color=\"#338844\", edgecolor=\"white\", s=50, lw=1)\n", "for ax in g.axes.flat:\n", - " ax.plot((0, 50), (0, .2 * 50), c=\".2\", ls=\"--\")\n", - "g.set(xlim=(0, 60), ylim=(0, 14));" + " ax.axline((0, 0), slope=.2, c=\".2\", ls=\"--\", zorder=0)\n", + "g.set(xlim=(0, 60), ylim=(0, 14))" ] }, { @@ -299,7 +279,7 @@ "\n", "You're not limited to existing matplotlib and seaborn functions when using :class:`FacetGrid`. However, to work properly, any function you use must follow a few rules:\n", "\n", - "1. It must plot onto the \"currently active\" matplotlib ``Axes``. This will be true of functions in the ``matplotlib.pyplot`` namespace, and you can call ``plt.gca`` to get a reference to the current ``Axes`` if you want to work directly with its methods.\n", + "1. It must plot onto the \"currently active\" matplotlib ``Axes``. This will be true of functions in the ``matplotlib.pyplot`` namespace, and you can call :func:`matplotlib.pyplot.gca` to get a reference to the current ``Axes`` if you want to work directly with its methods.\n", "2. It must accept the data that it plots in positional arguments. Internally, :class:`FacetGrid` will pass a ``Series`` of data for each of the named positional arguments passed to :meth:`FacetGrid.map`.\n", "3. It must be able to accept ``color`` and ``label`` keyword arguments, and, ideally, it will do something useful with them. In most cases, it's easiest to catch a generic dictionary of ``**kwargs`` and pass it along to the underlying plotting function.\n", "\n", @@ -314,11 +294,11 @@ "source": [ "from scipy import stats\n", "def quantile_plot(x, **kwargs):\n", - " qntls, xr = stats.probplot(x, fit=False)\n", - " plt.scatter(xr, qntls, **kwargs)\n", + " quantiles, xr = stats.probplot(x, fit=False)\n", + " plt.scatter(xr, quantiles, **kwargs)\n", " \n", "g = sns.FacetGrid(tips, col=\"sex\", height=4)\n", - "g.map(quantile_plot, \"total_bill\");" + "g.map(quantile_plot, \"total_bill\")" ] }, { @@ -340,14 +320,14 @@ " plt.scatter(xr, yr, **kwargs)\n", " \n", "g = sns.FacetGrid(tips, col=\"smoker\", height=4)\n", - "g.map(qqplot, \"total_bill\", \"tip\");" + "g.map(qqplot, \"total_bill\", \"tip\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Because ``plt.scatter`` accepts ``color`` and ``label`` keyword arguments and does the right thing with them, we can add a hue facet without any difficulty:" + "Because :func:`matplotlib.pyplot.scatter` accepts ``color`` and ``label`` keyword arguments and does the right thing with them, we can add a hue facet without any difficulty:" ] }, { @@ -358,33 +338,14 @@ "source": [ "g = sns.FacetGrid(tips, hue=\"time\", col=\"sex\", height=4)\n", "g.map(qqplot, \"total_bill\", \"tip\")\n", - "g.add_legend();" - ] - }, - { - "cell_type": "raw", - "metadata": {}, - "source": [ - "This approach also lets us use additional aesthetics to distinguish the levels of the hue variable, along with keyword arguments that won't be dependent on the faceting variables:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "g = sns.FacetGrid(tips, hue=\"time\", col=\"sex\", height=4,\n", - " hue_kws={\"marker\": [\"s\", \"D\"]})\n", - "g.map(qqplot, \"total_bill\", \"tip\", s=40, edgecolor=\"w\")\n", - "g.add_legend();" + "g.add_legend()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Sometimes, though, you'll want to map a function that doesn't work the way you expect with the ``color`` and ``label`` keyword arguments. In this case, you'll want to explicitly catch them and handle them in the logic of your custom function. For example, this approach will allow use to map ``plt.hexbin``, which otherwise does not play well with the :class:`FacetGrid` API:" + "Sometimes, though, you'll want to map a function that doesn't work the way you expect with the ``color`` and ``label`` keyword arguments. In this case, you'll want to explicitly catch them and handle them in the logic of your custom function. For example, this approach will allow use to map :func:`matplotlib.pyplot.hexbin`, which otherwise does not play well with the :class:`FacetGrid` API:" ] }, { @@ -426,7 +387,7 @@ "source": [ "iris = sns.load_dataset(\"iris\")\n", "g = sns.PairGrid(iris)\n", - "g.map(plt.scatter);" + "g.map(sns.scatterplot)" ] }, { @@ -443,8 +404,8 @@ "outputs": [], "source": [ "g = sns.PairGrid(iris)\n", - "g.map_diag(plt.hist)\n", - "g.map_offdiag(plt.scatter);" + "g.map_diag(sns.histplot)\n", + "g.map_offdiag(sns.scatterplot)" ] }, { @@ -461,9 +422,9 @@ "outputs": [], "source": [ "g = sns.PairGrid(iris, hue=\"species\")\n", - "g.map_diag(plt.hist)\n", - "g.map_offdiag(plt.scatter)\n", - "g.add_legend();" + "g.map_diag(sns.histplot)\n", + "g.map_offdiag(sns.scatterplot)\n", + "g.add_legend()" ] }, { @@ -480,7 +441,7 @@ "outputs": [], "source": [ "g = sns.PairGrid(iris, vars=[\"sepal_length\", \"sepal_width\"], hue=\"species\")\n", - "g.map(plt.scatter);" + "g.map(sns.scatterplot)" ] }, { @@ -497,9 +458,9 @@ "outputs": [], "source": [ "g = sns.PairGrid(iris)\n", - "g.map_upper(plt.scatter)\n", + "g.map_upper(sns.scatterplot)\n", "g.map_lower(sns.kdeplot)\n", - "g.map_diag(sns.kdeplot, lw=3, legend=False);" + "g.map_diag(sns.kdeplot, lw=3, legend=False)" ] }, { @@ -517,7 +478,7 @@ "source": [ "g = sns.PairGrid(tips, y_vars=[\"tip\"], x_vars=[\"total_bill\", \"size\"], height=4)\n", "g.map(sns.regplot, color=\".3\")\n", - "g.set(ylim=(-1, 11), yticks=[0, 5, 10]);" + "g.set(ylim=(-1, 11), yticks=[0, 5, 10])" ] }, { @@ -535,7 +496,7 @@ "source": [ "g = sns.PairGrid(tips, hue=\"size\", palette=\"GnBu_d\")\n", "g.map(plt.scatter, s=50, edgecolor=\"white\")\n", - "g.add_legend();" + "g.add_legend()" ] }, { @@ -551,7 +512,7 @@ "metadata": {}, "outputs": [], "source": [ - "sns.pairplot(iris, hue=\"species\", height=2.5);" + "sns.pairplot(iris, hue=\"species\", height=2.5)" ] }, { @@ -583,9 +544,9 @@ "metadata": { "celltoolbar": "Tags", "kernelspec": { - "display_name": "Python 3.6 (seaborn-dev)", + "display_name": "seaborn-py38-latest", "language": "python", - "name": "seaborn-dev" + "name": "seaborn-py38-latest" }, "language_info": { "codemirror_mode": { @@ -597,9 +558,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.3" + "version": "3.8.5" } }, "nbformat": 4, - "nbformat_minor": 1 + "nbformat_minor": 4 } diff --git a/doc/tutorial/categorical.ipynb b/doc/tutorial/categorical.ipynb index a02a67d8d9..d83445b3ea 100644 --- a/doc/tutorial/categorical.ipynb +++ b/doc/tutorial/categorical.ipynb @@ -56,7 +56,7 @@ "source": [ "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", - "sns.set(style=\"ticks\", color_codes=True)" + "sns.set_theme(style=\"ticks\", color_codes=True)" ] }, { @@ -91,7 +91,7 @@ "outputs": [], "source": [ "tips = sns.load_dataset(\"tips\")\n", - "sns.catplot(x=\"day\", y=\"total_bill\", data=tips);" + "sns.catplot(x=\"day\", y=\"total_bill\", data=tips)" ] }, { @@ -107,7 +107,7 @@ "metadata": {}, "outputs": [], "source": [ - "sns.catplot(x=\"day\", y=\"total_bill\", jitter=False, data=tips);" + "sns.catplot(x=\"day\", y=\"total_bill\", jitter=False, data=tips)" ] }, { @@ -123,7 +123,7 @@ "metadata": {}, "outputs": [], "source": [ - "sns.catplot(x=\"day\", y=\"total_bill\", kind=\"swarm\", data=tips);" + "sns.catplot(x=\"day\", y=\"total_bill\", kind=\"swarm\", data=tips)" ] }, { @@ -139,7 +139,7 @@ "metadata": {}, "outputs": [], "source": [ - "sns.catplot(x=\"day\", y=\"total_bill\", hue=\"sex\", kind=\"swarm\", data=tips);" + "sns.catplot(x=\"day\", y=\"total_bill\", hue=\"sex\", kind=\"swarm\", data=tips)" ] }, { @@ -155,8 +155,7 @@ "metadata": {}, "outputs": [], "source": [ - "sns.catplot(x=\"size\", y=\"total_bill\", kind=\"swarm\",\n", - " data=tips.query(\"size != 3\"));" + "sns.catplot(x=\"size\", y=\"total_bill\", data=tips)" ] }, { @@ -172,7 +171,7 @@ "metadata": {}, "outputs": [], "source": [ - "sns.catplot(x=\"smoker\", y=\"tip\", order=[\"No\", \"Yes\"], data=tips);" + "sns.catplot(x=\"smoker\", y=\"tip\", order=[\"No\", \"Yes\"], data=tips)" ] }, { @@ -188,7 +187,7 @@ "metadata": {}, "outputs": [], "source": [ - "sns.catplot(x=\"total_bill\", y=\"day\", hue=\"time\", kind=\"swarm\", data=tips);" + "sns.catplot(x=\"total_bill\", y=\"day\", hue=\"time\", kind=\"swarm\", data=tips)" ] }, { @@ -212,7 +211,7 @@ "metadata": {}, "outputs": [], "source": [ - "sns.catplot(x=\"day\", y=\"total_bill\", kind=\"box\", data=tips);" + "sns.catplot(x=\"day\", y=\"total_bill\", kind=\"box\", data=tips)" ] }, { @@ -228,7 +227,7 @@ "metadata": {}, "outputs": [], "source": [ - "sns.catplot(x=\"day\", y=\"total_bill\", hue=\"smoker\", kind=\"box\", data=tips);" + "sns.catplot(x=\"day\", y=\"total_bill\", hue=\"smoker\", kind=\"box\", data=tips)" ] }, { @@ -246,7 +245,7 @@ "source": [ "tips[\"weekend\"] = tips[\"day\"].isin([\"Sat\", \"Sun\"])\n", "sns.catplot(x=\"day\", y=\"total_bill\", hue=\"weekend\",\n", - " kind=\"box\", dodge=False, data=tips);" + " kind=\"box\", dodge=False, data=tips)" ] }, { @@ -264,7 +263,7 @@ "source": [ "diamonds = sns.load_dataset(\"diamonds\")\n", "sns.catplot(x=\"color\", y=\"price\", kind=\"boxen\",\n", - " data=diamonds.sort_values(\"color\"));" + " data=diamonds.sort_values(\"color\"))" ] }, { @@ -284,7 +283,7 @@ "outputs": [], "source": [ "sns.catplot(x=\"total_bill\", y=\"day\", hue=\"sex\",\n", - " kind=\"violin\", data=tips);" + " kind=\"violin\", data=tips)" ] }, { @@ -302,7 +301,7 @@ "source": [ "sns.catplot(x=\"total_bill\", y=\"day\", hue=\"sex\",\n", " kind=\"violin\", bw=.15, cut=0,\n", - " data=tips);" + " data=tips)" ] }, { @@ -319,7 +318,7 @@ "outputs": [], "source": [ "sns.catplot(x=\"day\", y=\"total_bill\", hue=\"sex\",\n", - " kind=\"violin\", split=True, data=tips);" + " kind=\"violin\", split=True, data=tips)" ] }, { @@ -337,7 +336,7 @@ "source": [ "sns.catplot(x=\"day\", y=\"total_bill\", hue=\"sex\",\n", " kind=\"violin\", inner=\"stick\", split=True,\n", - " palette=\"pastel\", data=tips);" + " palette=\"pastel\", data=tips)" ] }, { @@ -354,7 +353,7 @@ "outputs": [], "source": [ "g = sns.catplot(x=\"day\", y=\"total_bill\", kind=\"violin\", inner=None, data=tips)\n", - "sns.swarmplot(x=\"day\", y=\"total_bill\", color=\"k\", size=3, data=tips, ax=g.ax);" + "sns.swarmplot(x=\"day\", y=\"total_bill\", color=\"k\", size=3, data=tips, ax=g.ax)" ] }, { @@ -379,7 +378,7 @@ "outputs": [], "source": [ "titanic = sns.load_dataset(\"titanic\")\n", - "sns.catplot(x=\"sex\", y=\"survived\", hue=\"class\", kind=\"bar\", data=titanic);" + "sns.catplot(x=\"sex\", y=\"survived\", hue=\"class\", kind=\"bar\", data=titanic)" ] }, { @@ -395,7 +394,7 @@ "metadata": {}, "outputs": [], "source": [ - "sns.catplot(x=\"deck\", kind=\"count\", palette=\"ch:.25\", data=titanic);" + "sns.catplot(x=\"deck\", kind=\"count\", palette=\"ch:.25\", data=titanic)" ] }, { @@ -413,7 +412,7 @@ "source": [ "sns.catplot(y=\"deck\", hue=\"class\", kind=\"count\",\n", " palette=\"pastel\", edgecolor=\".6\",\n", - " data=titanic);" + " data=titanic)" ] }, { @@ -432,7 +431,7 @@ "metadata": {}, "outputs": [], "source": [ - "sns.catplot(x=\"sex\", y=\"survived\", hue=\"class\", kind=\"point\", data=titanic);" + "sns.catplot(x=\"sex\", y=\"survived\", hue=\"class\", kind=\"point\", data=titanic)" ] }, { @@ -451,7 +450,7 @@ "sns.catplot(x=\"class\", y=\"survived\", hue=\"sex\",\n", " palette={\"male\": \"g\", \"female\": \"m\"},\n", " markers=[\"^\", \"o\"], linestyles=[\"-\", \"--\"],\n", - " kind=\"point\", data=titanic);" + " kind=\"point\", data=titanic)" ] }, { @@ -471,7 +470,7 @@ "outputs": [], "source": [ "iris = sns.load_dataset(\"iris\")\n", - "sns.catplot(data=iris, orient=\"h\", kind=\"box\");" + "sns.catplot(data=iris, orient=\"h\", kind=\"box\")" ] }, { @@ -487,7 +486,7 @@ "metadata": {}, "outputs": [], "source": [ - "sns.violinplot(x=iris.species, y=iris.sepal_length);" + "sns.violinplot(x=iris.species, y=iris.sepal_length)" ] }, { @@ -504,7 +503,7 @@ "outputs": [], "source": [ "f, ax = plt.subplots(figsize=(7, 3))\n", - "sns.countplot(y=\"deck\", data=titanic, color=\"c\");" + "sns.countplot(y=\"deck\", data=titanic, color=\"c\")" ] }, { @@ -526,8 +525,8 @@ "outputs": [], "source": [ "sns.catplot(x=\"day\", y=\"total_bill\", hue=\"smoker\",\n", - " col=\"time\", aspect=.6,\n", - " kind=\"swarm\", data=tips);" + " col=\"time\", aspect=.7,\n", + " kind=\"swarm\", data=tips)" ] }, { @@ -546,7 +545,7 @@ "g = sns.catplot(x=\"fare\", y=\"survived\", row=\"class\",\n", " kind=\"box\", orient=\"h\", height=1.5, aspect=4,\n", " data=titanic.query(\"fare > 0\"))\n", - "g.set(xscale=\"log\");" + "g.set(xscale=\"log\")" ] }, { @@ -562,9 +561,9 @@ "metadata": { "celltoolbar": "Tags", "kernelspec": { - "display_name": "Python 3.6 (seaborn-dev)", + "display_name": "seaborn-py38-latest", "language": "python", - "name": "seaborn-dev" + "name": "seaborn-py38-latest" }, "language_info": { "codemirror_mode": { @@ -576,9 +575,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.3" + "version": "3.8.5" } }, "nbformat": 4, - "nbformat_minor": 1 + "nbformat_minor": 4 } diff --git a/doc/tutorial/color_palettes.ipynb b/doc/tutorial/color_palettes.ipynb index e1570c2069..616318db46 100644 --- a/doc/tutorial/color_palettes.ipynb +++ b/doc/tutorial/color_palettes.ipynb @@ -20,30 +20,246 @@ "\n", ".. raw:: html\n", "\n", - "
\n" + "
\n", + "\n", + "Seaborn makes it easy to use colors that are well-suited to the characteristics of your data and your visualization goals. This chapter discusses both the general principles that should guide your choices and the tools in seaborn that help you quickly find the best solution for a given application." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import seaborn as sns\n", + "import matplotlib.pyplot as plt\n", + "sns.set_theme(style=\"white\", rc={\"xtick.major.pad\": 1, \"ytick.major.pad\": 1})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "np.random.seed(sum(map(ord, \"palettes\")))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "# Add colormap display methods to matplotlib colormaps.\n", + "# These are forthcoming in matplotlib 3.4, but, the matplotlib display\n", + "# method includes the colormap name, which is redundant.\n", + "def _repr_png_(self):\n", + " \"\"\"Generate a PNG representation of the Colormap.\"\"\"\n", + " import io\n", + " from PIL import Image\n", + " import numpy as np\n", + " IMAGE_SIZE = (400, 50)\n", + " X = np.tile(np.linspace(0, 1, IMAGE_SIZE[0]), (IMAGE_SIZE[1], 1))\n", + " pixels = self(X, bytes=True)\n", + " png_bytes = io.BytesIO()\n", + " Image.fromarray(pixels).save(png_bytes, format='png')\n", + " return png_bytes.getvalue()\n", + " \n", + "def _repr_html_(self):\n", + " \"\"\"Generate an HTML representation of the Colormap.\"\"\"\n", + " import base64\n", + " png_bytes = self._repr_png_()\n", + " png_base64 = base64.b64encode(png_bytes).decode('ascii')\n", + " return ('')\n", + " \n", + "import matplotlib as mpl\n", + "mpl.colors.Colormap._repr_png_ = _repr_png_\n", + "mpl.colors.Colormap._repr_html_ = _repr_html_" ] }, { "cell_type": "raw", + "metadata": {}, + "source": [ + "General principles for using color in plots\n", + "-------------------------------------------\n", + "\n", + "Components of color\n", + "~~~~~~~~~~~~~~~~~~~\n", + "\n", + "Because of the way our eyes work, a particular color can be defined using three components. We usually program colors in a computer by specifying their RGB values, which set the intensity of the red, green, and blue channels in a display. But for analyzing the perceptual attributes of a color, it's better to think in terms of *hue*, *saturation*, and *luminance* channels.\n", + "\n", + "Hue is the component that distinguishes \"different colors\" in a non-technical sense. It's property of color that leads to first-order names like \"red\" and \"blue\":" + ] + }, + { + "cell_type": "code", + "execution_count": null, "metadata": { - "raw_mimetype": "text/restructuredtext" + "tags": [ + "hide-input" + ] }, + "outputs": [], "source": [ - "Color is more important than other aspects of figure style because color can reveal patterns in the data if used effectively or hide those patterns if used poorly. There are a number of great resources to learn about good techniques for using color in visualizations, I am partial to this `series of blog posts `_ from Rob Simmon and this `more technical paper `_. The matplotlib docs also now have a `nice tutorial `_ that illustrates some of the perceptual properties of the built in colormaps.\n", + "sns.husl_palette(8, s=.7)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Saturation (or chroma) is the *colorfulness*. Two colors with different hues will look more distinct when they have more saturation:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "c = sns.color_palette(\"muted\")[0]\n", + "sns.blend_palette([sns.desaturate(c, 0), c], 8)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "And lightness corresponds to how much light is emitted (or reflected, for printed colors), ranging from black to white:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "sns.blend_palette([\".1\", c, \".95\"], 8)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Vary hue to distinguish categories\n", + "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", - "Seaborn makes it easy to select and use color palettes that are suited to the kind of data you are working with and the goals you have in visualizing it." + "When you want to represent multiple categories in a plot, you typically should vary the color of the elements. Consider this simple example: in which of these two plots is it easier to count the number of triangular points?" ] }, { "cell_type": "code", "execution_count": null, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "n = 45\n", + "rng = np.random.default_rng(200)\n", + "x = rng.uniform(0, 1, n * 2)\n", + "y = rng.uniform(0, 1, n * 2)\n", + "a = np.concatenate([np.zeros(n * 2 - 10), np.ones(10)])\n", + "\n", + "f, axs = plt.subplots(1, 2, figsize=(7, 3.5), sharey=True, sharex=True)\n", + "\n", + "sns.scatterplot(\n", + " x=x[::2], y=y[::2], style=a[::2], size=a[::2], legend=False,\n", + " markers=[\"o\", (3, 1, 1)], sizes=[70, 140], ax=axs[0],\n", + ")\n", + "\n", + "sns.scatterplot(\n", + " x=x[1::2], y=y[1::2], style=a[1::2], size=a[1::2], hue=a[1::2], legend=False,\n", + " markers=[\"o\", (3, 1, 1)], sizes=[70, 140], ax=axs[1],\n", + ")\n", + "\n", + "f.tight_layout(w_pad=2)" + ] + }, + { + "cell_type": "raw", "metadata": {}, + "source": [ + "In the plot on the right, the orange triangles \"pop out\", making it easy to distinguish them from the circles. This pop-out effect happens because our visual system prioritizes color differences.\n", + "\n", + "The blue and orange colors differ mostly in terms of their hue. Hue is useful for representing categories: most people can distinguish a moderate number of hues relatively easily, and points that have different hues but similar brightness or intensity seem equally important. It also makes plots easier to talk about. Consider this example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-input" + ] + }, "outputs": [], "source": [ - "import numpy as np\n", - "import seaborn as sns\n", - "import matplotlib.pyplot as plt\n", - "sns.set()" + "b = np.tile(np.arange(10), n // 5)\n", + "\n", + "f, axs = plt.subplots(1, 2, figsize=(7, 3.5), sharey=True, sharex=True)\n", + "\n", + "sns.scatterplot(\n", + " x=x[::2], y=y[::2], hue=b[::2],\n", + " legend=False, palette=\"muted\", s=70, ax=axs[0],\n", + ")\n", + "\n", + "sns.scatterplot(\n", + " x=x[1::2], y=y[1::2], hue=b[1::2],\n", + " legend=False, palette=\"blend:.75,C0\", s=70, ax=axs[1],\n", + ")\n", + "\n", + "f.tight_layout(w_pad=2)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Most people would be able to quickly ascertain that there are five distinct categories in the plot on the left and, if asked to characterize the \"blue\" points, would be able to do so.\n", + "\n", + "With the plot on the right, where the points are all blue but vary in their luminance and saturation, it's harder to say how many unique categories are present. And how would we talk about a particular category? \"The fairly-but-not-too-blue points?\" What's more, the gray dots seem to fade into the background, de-emphasizing them relative to the more intense blue dots. If the categories are equally important, this is a poor representation.\n", + "\n", + "So as a general rule, use hue variation to represent categories. With that said, here are few notes of caution. If you have more than a handful of colors in your plot, it can become difficult to keep in mind what each one means, unless there are pre-existing associations between the categories and the colors used to represent them. This makes your plot harder to interpret: rather than focusing on the data, a viewer will have to continually refer to the legend to make sense of what is shown. So you should strive not to make plots that are too complex. And be mindful that not everyone sees colors the same way. Varying both shape (or some other attribute) and color can help people with anomalous color vision understand your plots, and it can keep them (somewhat) interpretable if they are printed to black-and-white." + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Vary luminance to represent numbers\n", + "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", + "\n", + "On the other hand, hue variations are not well suited to representing numeric data. Consider this example, where we need colors to represent the counts in a bivariate histogram. On the left, we use a circular colormap, where gradual changes in the number of observation within each bin correspond to gradual changes in hue. On the right, we use a palette that uses brighter colors to represent bins with larger counts:" ] }, { @@ -51,13 +267,46 @@ "execution_count": null, "metadata": { "tags": [ - "hide" + "hide-input" ] }, "outputs": [], "source": [ - "%matplotlib inline\n", - "np.random.seed(sum(map(ord, \"palettes\")))" + "penguins = sns.load_dataset(\"penguins\")\n", + "\n", + "f, axs = plt.subplots(1, 2, figsize=(7, 4.25), sharey=True, sharex=True)\n", + "\n", + "sns.histplot(\n", + " data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\",\n", + " binwidth=(3, .75), cmap=\"hls\", ax=axs[0],\n", + " cbar=True, cbar_kws=dict(orientation=\"horizontal\", pad=.1),\n", + ")\n", + "axs[0].set(xlabel=\"\", ylabel=\"\")\n", + "\n", + "\n", + "sns.histplot(\n", + " data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\",\n", + " binwidth=(3, .75), cmap=\"flare_r\", ax=axs[1],\n", + " cbar=True, cbar_kws=dict(orientation=\"horizontal\", pad=.1),\n", + ")\n", + "axs[1].set(xlabel=\"\", ylabel=\"\")\n", + "\n", + "f.tight_layout(w_pad=3)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "With the hue-based palette, it's quite difficult to ascertain the shape of the bivariate distribution. In contrast, the luminance palette makes it much more clear that there are two prominant peaks.\n", + "\n", + "Varying luminance helps you see structure in data, and changes in luminance are more intuitively processed as changes in importance. But the plot on the right does not use a grayscale colormap. Its colorfulness makes it more interesting, and the subtle hue variation increases the perceptual distance between two values. As a result, small differencess slightly easier to resolve.\n", + "\n", + "These examples show that color palette choices are about more than aesthetics: the colors you choose can reveal patterns in your data if used effectively or hide them if used poorly. There is not one optimal palette, but there are palettes that are better or worse for particular datasets and visualization approaches.\n", + "\n", + "And aesthetics do matter: the more that people want to look at your figures, the greater the chance that they will learn something from them. This is true even when you are making plots for yourself. During exploratory data analysis, you may generate many similar figures. Varying the color palettes will add a sense of novelty, which keeps you engaged and prepared to notice interesting features of your data.\n", + "\n", + "So how can you choose color palettes that both represent your data well and look attractive?" ] }, { @@ -66,18 +315,20 @@ "raw_mimetype": "text/restructuredtext" }, "source": [ - "Building color palettes\n", - "-----------------------\n", + "Tools for choosing color palettes\n", + "---------------------------------\n", "\n", - "The most important function for working with discrete color palettes is :func:`color_palette`. This function provides an interface to many (though not all) of the possible ways you can generate colors in seaborn, and it's used internally by any function that has a ``palette`` argument (and in some cases for a ``color`` argument when multiple colors are needed).\n", + "The most important function for working with color palettes is, aptly, :func:`color_palette`. This function provides an interface to most of the possible ways that one can generate color palettes in seaborn. And it's used internally by any function that has a ``palette`` argument.\n", "\n", - ":func:`color_palette` will accept the name of any seaborn palette or matplotlib colormap (except ``jet``, which you should never use). It can also take a list of colors specified in any valid matplotlib format (RGB tuples, hex color codes, or HTML color names). The return value is always a list of RGB tuples.\n", + "The primary argument to :func:`color_palette` is usually a string: either the a name of a specific palette or the name of a family and additional arguments to select a specific member. In the latter case, :func:`color_palette` will delegate to more specific function, such as :func:`cubehelix_palette`. It's also possible to pass a list of colors specified any way that matplotlib accepts (an RGB tuple, a hex code, or a name in the X11 table). The return value is an object that wraps a list of RGB tuples with a few useful methods, such as conversion to hex codes and a rich HTML representation.\n", "\n", - "Finally, calling :func:`color_palette` with no arguments will return the current default color cycle.\n", + "Calling :func:`color_palette` with no arguments will return the current default color palette that matplotlib (and most seaborn functions) will use if colors are not otherwise specified. This default palette can be set with the corresponding :func:`set_palette` function, which calls :func:`color_palette` internally and accepts the same arguments.\n", "\n", - "A corresponding function, :func:`set_palette`, takes the same arguments and will set the default color cycle for all plots. You can also use :func:`color_palette` in a ``with`` statement to temporarily change the default palette (see :ref:`below `).\n", + "To motivate the different options that :func:`color_palette` provides, it will be useful to introduce a classification scheme for color palettes. Broadly, palettes fall into one of three categories:\n", "\n", - "It is generally not possible to know what kind of color palette or colormap is best for a set of data without knowing about the characteristics of the data. Following that, we'll break up the different ways to use :func:`color_palette` and other seaborn palette functions by the three general kinds of color palettes: *qualitative*, *sequential*, and *diverging*." + "- qualitative palettes, good for representing categorical data\n", + "- sequential palettes, good for representing numeric data\n", + "- diverging palettes, good for representing numeric data with a categorical boundary" ] }, { @@ -89,9 +340,7 @@ "Qualitative color palettes\n", "--------------------------\n", "\n", - "Qualitative (or categorical) palettes are best when you want to distinguish discrete chunks of data that do not have an inherent ordering.\n", - "\n", - "When importing seaborn, the default color cycle is changed to a set of ten colors that evoke the standard matplotlib color cycle while aiming to be a bit more pleasing to look at." + "Qualitative palettes are well-suited to representing categorical data because most of their variation is in the hue component. The default color palette in seaborn is a qualitative palette with ten distinct hues:" ] }, { @@ -100,15 +349,30 @@ "metadata": {}, "outputs": [], "source": [ - "current_palette = sns.color_palette()\n", - "sns.palplot(current_palette)" + "sns.color_palette()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "There are six variations of the default theme, called ``deep``, ``muted``, ``pastel``, ``bright``, ``dark``, and ``colorblind``." + "These colors have the same ordering as the default matplotlib color palette, ``\"tab10\"``, but they are a bit less intense. Compare:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.color_palette(\"tab10\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Seaborn in fact has six variations of matplotlib's palette, called ``deep``, ``muted``, ``pastel``, ``bright``, ``dark``, and ``colorblind``. These span a range of average luminance and saturation values:" ] }, { @@ -121,7 +385,6 @@ }, "outputs": [], "source": [ - "# TODO hide input here when merged with doc updating branch\n", "f = plt.figure(figsize=(6, 6))\n", "\n", "ax_locs = dict(\n", @@ -151,36 +414,28 @@ "\n", "ax = f.add_axes([0, 0, 1, 1])\n", "ax.set_axis_off()\n", - "ax.arrow(.15, .05, .4, 0, width=.002, head_width=.015, color=\"k\")\n", - "ax.arrow(.05, .15, 0, .4, width=.002, head_width=.015, color=\"k\");" + "ax.arrow(.15, .05, .4, 0, width=.002, head_width=.015, color=\".15\")\n", + "ax.arrow(.05, .15, 0, .4, width=.002, head_width=.015, color=\".15\")\n", + "ax.set(xlim=(0, 1), ylim=(0, 1))" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Using circular color systems\n", - "~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", - "\n", - "When you have an arbitrary number of categories to distinguish without emphasizing any one, the easiest approach is to draw evenly-spaced colors in a circular color space (one where the hue changes while keeping the brightness and saturation constant). This is what most seaborn functions default to when they need to use more colors than are currently set in the default color cycle.\n", - "\n", - "The most common way to do this uses the ``hls`` color space, which is a simple transformation of RGB values." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "sns.palplot(sns.color_palette(\"hls\", 8))" + "Many people find the moderated hues of the default ``\"deep\"`` palette to be aesthetically pleasing, but they are also less distinct. As a result, they may be more difficult to discriminate in some contexts, which is something to keep in mind when making publication graphics. `This comparison `_ can be helpful for estimating how the the seaborn color palettes perform when simulating different forms of colorblindess." ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "There is also the :func:`hls_palette` function that lets you control the lightness and saturation of the colors." + "Using circular color systems\n", + "~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", + "\n", + "When you have an arbitrary number of categories, the easiest approach to finding unique hues is to draw evenly-spaced colors in a circular color space (one where the hue changes while keeping the brightness and saturation constant). This is what most seaborn functions default to when they need to use more colors than are currently set in the default color cycle.\n", + "\n", + "The most common way to do this uses the ``hls`` color space, which is a simple transformation of RGB values. We saw this color palette before as a counterexample for how to plot a histogram:" ] }, { @@ -189,7 +444,7 @@ "metadata": {}, "outputs": [], "source": [ - "sns.palplot(sns.hls_palette(8, l=.3, s=.8))" + "sns.color_palette(\"hls\", 8)" ] }, { @@ -198,9 +453,7 @@ "raw_mimetype": "text/restructuredtext" }, "source": [ - "However, because of the way the human visual system works, colors that are even \"intensity\" in terms of their RGB levels won't necessarily look equally intense. `We perceive `_ yellows and greens as relatively brighter and blues as relatively darker, which can be a problem when aiming for uniformity with the ``hls`` system.\n", - "\n", - "To remedy this, seaborn provides an interface to the `husl `_ system (since renamed to HSLuv), which also makes it easy to select evenly spaced hues while keeping the apparent brightness and saturation much more uniform." + "Because of the way the human visual system works, colors that have the same luminance and saturation in terms of their RGB values won't necessarily look equally intense To remedy this, seaborn provides an interface to the `husl `_ system (since renamed to HSLuv), which achieves less intensity variation as you rotate around the color wheel:" ] }, { @@ -209,21 +462,19 @@ "metadata": {}, "outputs": [], "source": [ - "sns.palplot(sns.color_palette(\"husl\", 8))" + "sns.color_palette(\"husl\", 8)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "There is similarly a function called :func:`husl_palette` that provides a more flexible interface to this system.\n", + "When seaborn needs a categorical palette with more colors than are available in the current default, it will use this approach.\n", "\n", "Using categorical Color Brewer palettes\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", - "Another source of visually pleasing categorical palettes comes from the `Color Brewer `_ tool (which also has sequential and diverging palettes, as we'll see below). These also exist as matplotlib colormaps, but they are not handled properly. In seaborn, when you ask for a qualitative Color Brewer palette, you'll always get the discrete colors, but this means that at a certain point they will begin to cycle.\n", - "\n", - "A nice feature of the Color Brewer website is that it provides some guidance on which palettes are color blind safe. There is a variety of `kinds `_ of color blindness, but the most common variant leads to difficulty distinguishing reds and greens. It's generally a good idea to avoid using red and green for plot elements that need to be discriminated based on color. [This comparison](https://gist.github.com/mwaskom/b35f6ebc2d4b340b4f64a4e28e778486) can be helpful to understand how the the seaborn color palettes perform for different type of colorblindess." + "Another source of visually pleasing categorical palettes comes from the `Color Brewer `_ tool (which also has sequential and diverging palettes, as we'll see below)." ] }, { @@ -232,7 +483,14 @@ "metadata": {}, "outputs": [], "source": [ - "sns.palplot(sns.color_palette(\"Paired\"))" + "sns.color_palette(\"Set2\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Be aware that the qualitative Color Brewer palettes have different lengths, and the default behavior of :func:`color_palette` is to give you the full list:" ] }, { @@ -241,16 +499,24 @@ "metadata": {}, "outputs": [], "source": [ - "sns.palplot(sns.color_palette(\"Set2\"))" + "sns.color_palette(\"Paired\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "To help you choose palettes from the Color Brewer library, there is the :func:`choose_colorbrewer_palette` function. This function, which must be used in a Jupyter notebook, will launch an interactive widget that lets you browse the various options and tweak their parameters.\n", + ".. _sequential_palettes:\n", + "\n", + "Sequential color palettes\n", + "-------------------------\n", "\n", - "Of course, you might just want to use a set of colors you particularly like together. Because :func:`color_palette` accepts a list of colors, this is easy to do." + "The second major class of color palettes is called \"sequential\". This kind of mapping is appropriate when data range from relatively low or uninteresting values to relatively high or interesting values (or vice versa). As we saw above, the primary dimension of variation in a sequential palette is luminance. Some seaborn functions will default to a sequential palette when you are mapping numeric data. (For historical reasons, both categorical and numeric mappings are specified with the ``hue`` parameter in functions like :func:`relplot` or :func:`displot`, even though numeric mappings use color palettes with relatively little hue variation).\n", + "\n", + "Perceptually uniform palettes\n", + "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", + "\n", + "Because they are intended to represent numeric values, the best sequential palettes will be *perceptually uniform*, meaning that the relative discriminability of two colors is proportional to the difference between the corresponding data values. Seaborn includes four perceptually uniform sequential colormaps: ``\"rocket\"``, ``\"mako\"``, ``\"flare\"``, and ``\"crest\"``. The first two have a very wide luminance range and are well suited for applications such as heatmaps, where colors fill the space they are plotted into:" ] }, { @@ -259,38 +525,32 @@ "metadata": {}, "outputs": [], "source": [ - "flatui = [\"#9b59b6\", \"#3498db\", \"#95a5a6\", \"#e74c3c\", \"#34495e\", \"#2ecc71\"]\n", - "sns.palplot(sns.color_palette(flatui))" + "sns.color_palette(\"rocket\", as_cmap=True)" ] }, { - "cell_type": "raw", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - ".. _using_xkcd_palettes:\n", - " \n", - "Using named colors from the xkcd color survey\n", - "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", - "\n", - "A while back, `xkcd `_ ran a `crowdsourced effort `_ to name random RGB colors. This produced a set of `954 named colors `_, which you can now reference in seaborn using the ``xkcd_rgb`` dictionary:" + "sns.color_palette(\"mako\", as_cmap=True)" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "raw", "metadata": {}, - "outputs": [], "source": [ - "plt.plot([0, 1], [0, 1], sns.xkcd_rgb[\"pale red\"], lw=3)\n", - "plt.plot([0, 1], [0, 2], sns.xkcd_rgb[\"medium green\"], lw=3)\n", - "plt.plot([0, 1], [0, 3], sns.xkcd_rgb[\"denim blue\"], lw=3);" + "Because the extreme values of these colormaps approach white, they are not well-suited for coloring elements such as lines or points: it will be difficult to discriminate important values against a white or gray background. The \"flare\" and \"crest\" colormaps are a better choice for such plots. They have a more restricted range of luminance variations, which they compensate for with a slightly more pronounced variation in hue. The default direction of the luminance ramp is also reversed, so that smaller values have lighter colors:" ] }, { - "cell_type": "raw", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "In addition to pulling out single colors from the ``xkcd_rgb`` dictionary, you can also pass a list of names to the :func:`xkcd_palette` function." + "sns.color_palette(\"flare\", as_cmap=True)" ] }, { @@ -299,26 +559,23 @@ "metadata": {}, "outputs": [], "source": [ - "colors = [\"windows blue\", \"amber\", \"greyish\", \"faded green\", \"dusty purple\"]\n", - "sns.palplot(sns.xkcd_palette(colors))" + "sns.color_palette(\"crest\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - ".. _sequential_palettes:\n", - "\n", - "Sequential color palettes\n", - "-------------------------\n", - "\n", - "The second major class of color palettes is called \"sequential\". This kind of color mapping is appropriate when data range from relatively low or uninteresting values to relatively high or interesting values. Although there are cases where you will want discrete colors in a sequential palette, it's more common to use them as a colormap in functions like :func:`kdeplot` and :func:`heatmap` (along with similar matplotlib functions).\n", - "\n", - "It's common to see colormaps like ``jet`` (or other rainbow palettes) used in this case, because the range of hues gives the impression of providing additional information about the data. However, colormaps with large hue shifts tend to introduce discontinuities that don't exist in the data, and our visual system isn't able to naturally map the rainbow to quantitative distinctions like \"high\" or \"low\". The result is that these visualizations end up being more like a puzzle, and they obscure patterns in the data rather than revealing them. The jet colormap is misleading because the brightest colors, yellow and cyan, are used for intermediate data values. This has the effect of emphasizing uninteresting (and arbitrary) values while deemphasizing the extremes.\n", - "\n", - "For sequential data, it's better to use palettes that have at most a relatively subtle shift in hue accompanied by a large shift in brightness and saturation. This approach will naturally draw the eye to the relatively important parts of the data.\n", - "\n", - "The Color Brewer library has a great set of these palettes. They're named after the dominant color (or colors) in the palette." + "It is also possible to use the perceptually uniform colormaps provided by matplotlib, such as ``\"magma\"`` and ``\"viridis\"``:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.color_palette(\"magma\", as_cmap=True)" ] }, { @@ -327,14 +584,14 @@ "metadata": {}, "outputs": [], "source": [ - "sns.palplot(sns.color_palette(\"Blues\"))" + "sns.color_palette(\"viridis\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Like in matplotlib, if you want the lightness ramp to be reversed, you can add a ``_r`` suffix to the palette name." + "As with the convention in matplotlib, every continuous colormap has a reversed version, which has the suffix ``\"_r\"``:" ] }, { @@ -343,14 +600,17 @@ "metadata": {}, "outputs": [], "source": [ - "sns.palplot(sns.color_palette(\"BuGn_r\"))" + "sns.color_palette(\"rocket_r\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Seaborn also adds a trick that allows you to create \"dark\" palettes, which do not have as wide a dynamic range. This can be useful if you want to map lines or points sequentially, as brighter-colored lines might otherwise be hard to distinguish." + "Discrete vs. continuous mapping\n", + "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", + "\n", + "One thing to be aware of is that seaborn can generate discrete values from sequential colormaps and, when doing so, it will not use the most extreme values. Compare the discrete version of ``\"rocket\"`` against the continuous version shown above:" ] }, { @@ -359,14 +619,14 @@ "metadata": {}, "outputs": [], "source": [ - "sns.palplot(sns.color_palette(\"GnBu_d\"))" + "sns.color_palette(\"rocket\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Remember that you may want to use the :func:`choose_colorbrewer_palette` function to play with the various options, and you can set the ``as_cmap`` argument to ``True`` if you want the return value to be a colormap object that you can pass to seaborn or matplotlib functions." + "Interally, seaborn uses the discrete version for categorical data and the continuous version when in numeric mapping mode. Discrete sequential colormaps can be well-suited for visualizing categorical data with an intrinsic ordering, especially if there is some hue variation." ] }, { @@ -380,7 +640,7 @@ "Sequential \"cubehelix\" palettes\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", - "The `cubehelix `_ color palette system makes sequential palettes with a linear increase or decrease in brightness and some variation in hue. This means that the information in your colormap will be preserved when converted to black and white (for printing) or when viewed by a colorblind individual.\n", + "The perceptually uniform colormaps are difficult to programmatically generate, because they are not based on the RGB color space. The `cubehelix `_ system offers an RGB-based compromise: it generates sequential palettes with a linear increase or decrease in brightness and some continuous variation in hue. While not perfectly perceptually uniform, the resulting colormaps have many good properties. Importantly, many aspects of the design process are parameterizable.\n", "\n", "Matplotlib has the default cubehelix version built into it:" ] @@ -391,16 +651,14 @@ "metadata": {}, "outputs": [], "source": [ - "sns.palplot(sns.color_palette(\"cubehelix\", 8))" + "sns.color_palette(\"cubehelix\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Seaborn adds an interface to the cubehelix *system* so that you can make a variety of palettes that all have a well-behaved linear brightness ramp.\n", - "\n", - "The default palette returned by the seaborn :func:`cubehelix_palette` function is a bit different from the matplotlib default in that it does not rotate as far around the hue wheel or cover as wide a range of intensities. It also reverses the order so that more important values are darker:" + "The default palette returned by the seaborn :func:`cubehelix_palette` function is a bit different from the matplotlib default in that it does not rotate as far around the hue wheel or cover as wide a range of intensities. It also reverses the luminance ramp:" ] }, { @@ -409,14 +667,14 @@ "metadata": {}, "outputs": [], "source": [ - "sns.palplot(sns.cubehelix_palette(8))" + "sns.cubehelix_palette(as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Other arguments to :func:`cubehelix_palette` control how the palette looks. The two main things you'll change are the ``start`` (a value between 0 and 3) and ``rot``, or number of rotations (an arbitrary value, but probably within -1 and 1)," + "Other arguments to :func:`cubehelix_palette` control how the palette looks. The two main things you'll change are the ``start`` (a value between 0 and 3) and ``rot``, or number of rotations (an arbitrary value, but usually between -1 and 1)" ] }, { @@ -425,14 +683,14 @@ "metadata": {}, "outputs": [], "source": [ - "sns.palplot(sns.cubehelix_palette(8, start=.5, rot=-.75))" + "sns.cubehelix_palette(start=.5, rot=-.5, as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "You can also control how dark and light the endpoints are and even reverse the ramp:" + "The more you rotate, the more hue variation you will see:" ] }, { @@ -441,14 +699,14 @@ "metadata": {}, "outputs": [], "source": [ - "sns.palplot(sns.cubehelix_palette(8, start=2, rot=0, dark=0, light=.95, reverse=True))" + "sns.cubehelix_palette(start=.5, rot=-.75, as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "By default you just get a list of colors, like any other seaborn palette, but you can also return the palette as a colormap object that can be passed to seaborn or matplotlib functions using ``as_cmap=True``." + "You can control both how dark and light the endpoints are and their order:" ] }, { @@ -457,16 +715,39 @@ "metadata": {}, "outputs": [], "source": [ - "x, y = np.random.multivariate_normal([0, 0], [[1, -.5], [-.5, 1]], size=300).T\n", - "cmap = sns.cubehelix_palette(light=1, as_cmap=True)\n", - "sns.kdeplot(x, y, cmap=cmap, shade=True);" + "sns.cubehelix_palette(start=2, rot=0, dark=0, light=.95, reverse=True, as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "To help select good palettes or colormaps using this system, you can use the :func:`choose_cubehelix_palette` function in a notebook to launch an interactive app that will let you play with the different parameters. Pass ``as_cmap=True`` if you want the function to return a colormap (rather than a list) for use in function like ``hexbin``." + "The :func:`color_palette` accepts a string code, starting with ``\"ch:\"``, for generating an arbitrary cubehelix palette. You can passs the names of parameters in the string:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.color_palette(\"ch:start=.2,rot=-.3\", as_cmap=True)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "And for compactness, each parameter can be specified with its first letter:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.color_palette(\"ch:s=-.2,r=.6\", as_cmap=True)" ] }, { @@ -476,7 +757,7 @@ "Custom sequential palettes\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", - "For a simpler interface to custom sequential palettes, you can use :func:`light_palette` or :func:`dark_palette`, which are both seeded with a single color and produce a palette that ramps either from light or dark desaturated values to that color. These functions are also accompanied by the :func:`choose_light_palette` and :func:`choose_dark_palette` functions that launch interactive widgets to create these palettes." + "For a simpler interface to custom sequential palettes, you can use :func:`light_palette` or :func:`dark_palette`, which are both seeded with a single color and produce a palette that ramps either from light or dark desaturated values to that color:" ] }, { @@ -485,7 +766,7 @@ "metadata": {}, "outputs": [], "source": [ - "sns.palplot(sns.light_palette(\"green\"))" + "sns.light_palette(\"seagreen\", as_cmap=True)" ] }, { @@ -494,14 +775,14 @@ "metadata": {}, "outputs": [], "source": [ - "sns.palplot(sns.dark_palette(\"purple\"))" + "sns.dark_palette(\"#69d\", reverse=True, as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "These palettes can also be reversed." + "As with cubehelix palettes, you can also specify light or dark palettes through :func:`color_palette` or anywhere ``palette`` is accepted:" ] }, { @@ -510,14 +791,14 @@ "metadata": {}, "outputs": [], "source": [ - "sns.palplot(sns.light_palette(\"navy\", reverse=True))" + "sns.color_palette(\"light:b\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "They can also be used to create colormap objects rather than lists of colors." + "Reverse the colormap by adding ``\"_r\"``:" ] }, { @@ -526,15 +807,17 @@ "metadata": {}, "outputs": [], "source": [ - "pal = sns.dark_palette(\"palegreen\", as_cmap=True)\n", - "sns.kdeplot(x, y, cmap=pal);" + "sns.color_palette(\"dark:salmon_r\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "By default, the input can be any valid matplotlib color. Alternate interpretations are controlled by the ``input`` argument. Currently you can provide tuples in ``hls`` or ``husl`` space along with the default ``rgb``, and you can also seed the palette with any valid ``xkcd`` color." + "Sequential Color Brewer palettes\n", + "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", + "\n", + "The Color Brewer library also has some good options for sequential palettes. They include palettes with one primary hue:" ] }, { @@ -543,23 +826,23 @@ "metadata": {}, "outputs": [], "source": [ - "sns.palplot(sns.light_palette((210, 90, 60), input=\"husl\"))" + "sns.color_palette(\"Blues\", as_cmap=True)" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "raw", "metadata": {}, - "outputs": [], "source": [ - "sns.palplot(sns.dark_palette(\"muted purple\", input=\"xkcd\"))" + "Along with multi-hue options:" ] }, { - "cell_type": "raw", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "Note that the default input space for the interactive palette widgets is ``husl``, which is different from the default for the function itself, but much more useful in this context." + "sns.color_palette(\"YlOrBr\", as_cmap=True)" ] }, { @@ -571,13 +854,12 @@ "Diverging color palettes\n", "------------------------\n", "\n", - "The third class of color palettes is called \"diverging\". These are used for data where both large low and high values are interesting. There is also usually a well-defined midpoint in the data. For instance, if you are plotting changes in temperature from some baseline timepoint, it is best to use a diverging colormap to show areas with relative decreases and areas with relative increases.\n", + "The third class of color palettes is called \"diverging\". These are used for data where both large low and high values are interesting and span a midpoint value (often 0) that should be demphasized. The rules for choosing good diverging palettes are similar to good sequential palettes, except now there should be two dominant hues in the colormap, one at (or near) each pole. It's also important that the starting values are of similar brightness and saturation.\n", "\n", - "The rules for choosing good diverging palettes are similar to good sequential palettes, except now you want to have two relatively subtle hue shifts from distinct starting hues that meet in an under-emphasized color at the midpoint. It's also important that the starting values are of similar brightness and saturation.\n", - "\n", - "It's also important to emphasize here that using red and green should be avoided, as a substantial population of potential viewers will be `unable to distinguish them `_.\n", + "Perceptually uniform diverging palettes\n", + "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", - "It should not surprise you that the Color Brewer library comes with a set of well-chosen diverging colormaps." + "Seaborn includes two perceptually uniform diverging palettes: ``\"vlag\"`` and ``\"icefire\"``. They both use blue and red at their poles, which many intuitively processes as \"cold\" and \"hot\":" ] }, { @@ -586,7 +868,7 @@ "metadata": {}, "outputs": [], "source": [ - "sns.palplot(sns.color_palette(\"BrBG\", 7))" + "sns.color_palette(\"vlag\", as_cmap=True)" ] }, { @@ -595,14 +877,17 @@ "metadata": {}, "outputs": [], "source": [ - "sns.palplot(sns.color_palette(\"RdBu_r\", 7))" + "sns.color_palette(\"icefire\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Another good choice that is built into matplotlib is the ``coolwarm`` palette. Note that this colormap has less contrast between the middle values and the extremes." + "Custom diverging palettes\n", + "~~~~~~~~~~~~~~~~~~~~~~~~~\n", + "\n", + "You can also use the seaborn function :func:`diverging_palette` to create a custom colormap for diverging data. This function makes diverging palettes using the ``husl`` color system. You pass it two hues (in degrees) and, optionally, the lightness and saturation values for the extremes. Using ``husl`` means that the extreme values, and the resulting ramps to the midpoint, while not perfectly perceptually uniform, will be well-balanced:" ] }, { @@ -611,26 +896,14 @@ "metadata": {}, "outputs": [], "source": [ - "sns.palplot(sns.color_palette(\"coolwarm\", 7))" + "sns.diverging_palette(220, 20, as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Custom diverging palettes\n", - "~~~~~~~~~~~~~~~~~~~~~~~~~\n", - "\n", - "You can also use the seaborn function :func:`diverging_palette` to create a custom colormap for diverging data. (Naturally there is also a companion interactive widget, :func:`choose_diverging_palette`). This function makes diverging palettes using the ``husl`` color system. You pass it two hues (in degrees) and, optionally, the lightness and saturation values for the extremes. Using ``husl`` means that the extreme values, and the resulting ramps to the midpoint, will be well-balanced." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "sns.palplot(sns.diverging_palette(220, 20, n=7))" + "This is convenient when you want to stray from the boring confines of cold-hot approaches:" ] }, { @@ -639,14 +912,14 @@ "metadata": {}, "outputs": [], "source": [ - "sns.palplot(sns.diverging_palette(145, 280, s=85, l=25, n=7))" + "sns.diverging_palette(145, 300, s=60, as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "The ``sep`` argument controls the width of the separation between the two ramps in the middle region of the palette." + "It's also possible to make a palette where the midpoint is dark rather than light:" ] }, { @@ -655,14 +928,19 @@ "metadata": {}, "outputs": [], "source": [ - "sns.palplot(sns.diverging_palette(10, 220, sep=80, n=7))" + "sns.diverging_palette(250, 30, l=65, center=\"dark\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "It's also possible to make a palette with the midpoint is dark rather than light." + "It's important to emphasize here that using red and green, while intuitive, `should be avoided `_.\n", + "\n", + "Other diverging palettes\n", + "~~~~~~~~~~~~~~~~~~~~~~~~\n", + "\n", + "There are a few other good diverging palettes built into matplotlib, including Color Brewer palettes:" ] }, { @@ -671,19 +949,14 @@ "metadata": {}, "outputs": [], "source": [ - "sns.palplot(sns.diverging_palette(255, 133, l=60, n=7, center=\"dark\"))" + "sns.color_palette(\"Spectral\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - ".. _palette_contexts:\n", - "\n", - "Setting the default color palette\n", - "---------------------------------\n", - "\n", - "The :func:`color_palette` function has a companion called :func:`set_palette`. The relationship between them is similar to the pairs covered in the :ref:`aesthetics tutorial `. :func:`set_palette` accepts the same arguments as :func:`color_palette`, but it changes the default matplotlib parameters so that the palette is used for all plots." + "And the ``coolwarm`` palette, which has less contrast between the middle values and the extremes:" ] }, { @@ -692,27 +965,25 @@ "metadata": {}, "outputs": [], "source": [ - "def sinplot(flip=1):\n", - " x = np.linspace(0, 14, 100)\n", - " for i in range(1, 7):\n", - " plt.plot(x, np.sin(x + i * .5) * (7 - i) * flip)" + "sns.color_palette(\"coolwarm\", as_cmap=True)" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "raw", "metadata": {}, - "outputs": [], "source": [ - "sns.set_palette(\"husl\")\n", - "sinplot()" + "As you can see, there are many options for using color in your visualizations. Seaborn tries both to use good defaults and to offer a lot of flexibility.\n", + "\n", + "This discussion is only the beginning, and there are a number of good resources for learning more about techniques for using color in visualizations. One great example is this `series of blog posts `_ from the NASA Earth Observatory. The matplotlib docs also have a `nice tutorial `_ that illustrates some of the perceptual properties of their colormaps." ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "The :func:`color_palette` function can also be used in a ``with`` statement to temporarily change the color palette." + ".. raw:: html\n", + "\n", + "
" ] }, { @@ -720,27 +991,15 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "with sns.color_palette(\"PuBuGn_d\"):\n", - " sinplot()" - ] - }, - { - "cell_type": "raw", - "metadata": {}, - "source": [ - ".. raw:: html\n", - "\n", - "
" - ] + "source": [] } ], "metadata": { "celltoolbar": "Tags", "kernelspec": { - "display_name": "Python 3.6 (seaborn-dev)", + "display_name": "seaborn-py38-latest", "language": "python", - "name": "seaborn-dev" + "name": "seaborn-py38-latest" }, "language_info": { "codemirror_mode": { @@ -752,9 +1011,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.3" + "version": "3.8.5" } }, "nbformat": 4, - "nbformat_minor": 1 + "nbformat_minor": 4 } diff --git a/doc/tutorial/data_structure.ipynb b/doc/tutorial/data_structure.ipynb new file mode 100644 index 0000000000..69bd2b19e4 --- /dev/null +++ b/doc/tutorial/data_structure.ipynb @@ -0,0 +1,515 @@ +{ + "cells": [ + { + "cell_type": "raw", + "metadata": {}, + "source": [ + ".. _data_tutorial:\n", + "\n", + ".. currentmodule:: seaborn" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Data structures accepted by seaborn\n", + "===================================\n", + "\n", + ".. raw:: html\n", + "\n", + "
\n", + "\n", + "As a data visualization library, seaborn requires that you provide it with data. This chapter explains the various ways to accomplish that task. Seaborn supports several different dataset formats, and most functions accept data represented with objects from the `pandas `_ or `numpy `_ libraries as well as built-in Python types like lists and dictionaries. Understanding the usage patterns associated with these different options will help you quickly create useful visualizations for nearly any dataset.\n", + "\n", + ".. note::\n", + " As of current writing (v0.11.0), the full breadth of options covered here are supported by only a subset of the modules in seaborn (namely, the :ref:`relational ` and :ref:`distribution ` modules). The other modules offer much of the same flexibility, but have some exceptions (e.g., :func:`catplot` and :func:`lmplot` are limited to long-form data with named variables). The data-ingest code will be standardized over the next few release cycles, but until that point, be mindful of the specific documentation for each function if it is not doing what you expect with your dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "sns.set_theme()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Long-form vs. wide-form data\n", + "----------------------------\n", + "\n", + "Most plotting functions in seaborn are oriented towards *vectors* of data. When plotting ``x`` against ``y``, each variable should be a vector. Seaborn accepts data *sets* that have more than one vector organized in some tabular fashion. There is a fundamental distinction between \"long-form\" and \"wide-form\" data tables, and seaborn will treat each differently.\n", + "\n", + "Long-form data\n", + "~~~~~~~~~~~~~~\n", + "\n", + "A long-form data table has the following characteristics:\n", + "\n", + "- Each variable is a column\n", + "- Each observation is a row" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "As a simple example, consider the \"flights\" dataset, which records the number of airline passengers who flew in each month from 1949 to 1960. This dataset has three variables (*year*, *month*, and number of *passengers*):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "flights = sns.load_dataset(\"flights\")\n", + "flights.head()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "With long-form data, columns in the table are given roles in the plot by explicitly assigning them to one of the variables. For example, making a monthly plot of the number of passengers per year looks like this:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.relplot(data=flights, x=\"year\", y=\"passengers\", hue=\"month\", kind=\"line\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The advantage of long-form data is that it lends itself well to this explicit specification of the plot. It can accomodate datasets of arbitrary complexity, so long as the variables and observations can be clearly defined. But this format takes some getting used to, because it is often not the model of the data that one has in their head.\n", + "\n", + "Wide-form data\n", + "~~~~~~~~~~~~~~\n", + "\n", + "For simple datasets, it is often more intuitive to think about data the way it might be viewed in a spreadsheet, where the columns and rows contain *levels* of different variables. For example, we can convert the flights dataset into a wide-form organization by \"pivoting\" it so that each column has each month's time series over years:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "flights_wide = flights.pivot(index=\"year\", columns=\"month\", values=\"passengers\")\n", + "flights_wide.head()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Here we have the same three variables, but they are organized differently. The variables in this dataset are linked to the *dimensions* of the table, rather than to named fields. Each observation is defined by both the value at a cell in the table and the coordinates of that cell with respect to the row and column indices." + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "With long-form data, we can access variables in the dataset by their name. That is not the case with wide-form data. Nevertheless, because there is a clear association between the dimensions of the table and the variable in the dataset, seaborn is able to assign those variables roles in the plot.\n", + "\n", + ".. note::\n", + " Seaborn treats the argument to ``data`` as wide form when neither ``x`` nor ``y`` are assigned." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.relplot(data=flights_wide, kind=\"line\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "This plot looks very similar to the one before. Seaborn has assigned the index of the dataframe to ``x``, the values of the dataframe to ``y``, and it has drawn a separate line for each month. There is a notable difference between the two plots, however. When the dataset went through the \"pivot\" operation that converted it from long-form to wide-form, the information about what the values mean was lost. As a result, there is no y axis label. (The lines also have dashes here, becuase :func:`relplot` has mapped the column variable to both the ``hue`` and ``style`` semantic so that the plot is more accessible. We didn't do that in the long-form case, but we could have by setting ``style=\"month\"``).\n", + "\n", + "Thus far, we did much less typing while using wide-form data and made nearly the same plot. This seems easier! But a big advantage of long-form data is that, once you have the data in the correct format, you no longer need to think about its *structure*. You can design your plots by thinking only about the variables contained within it. For example, to draw lines that represent the monthly time series for each year, simply reassign the variables:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.relplot(data=flights, x=\"month\", y=\"passengers\", hue=\"year\", kind=\"line\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "To achieve the same remapping with the wide-form dataset, we would need to transpose the table:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.relplot(data=flights_wide.transpose(), kind=\"line\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "(This example also illustrates another wrinkle, which is that seaborn currently considers the column variable in a wide-form dataset to be categorical regardless of its datatype, whereas, because the long-form variable is numeric, it is assigned a quantitative color palette and legend. This may change in the future).\n", + "\n", + "The absence of explicit variable assignments also means that each plot type needs to define a fixed mapping between the dimensions of the wide-form data and the roles in the plot. Because ths natural mapping may vary across plot types, the results are less predictable when using wide-form data. For example, the :ref:`categorical ` plots assign the *column* dimension of the table to ``x`` and then aggregate across the rows (ignoring the index):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.catplot(data=flights_wide, kind=\"box\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "When using pandas to represent wide-form data, you are limited to just a few variables (no more than three). This is because seaborn does not make use of multi-index information, which is how pandas represents additional variables in a tabular format. The `xarray `_ project offers labeled N-dimensional array objects, which can be considered a generalization of wide-form data to higher dimensions. At present, seaborn does not directly support objects from ``xarray``, but they can be transformed into a long-form :class:`pandas.DataFrame` using the ``to_pandas`` method and then plotted in seaborn like any other long-form data set.\n", + "\n", + "In summary, we can think of long-form and wide-form datasets as looking something like this:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "f = plt.figure(figsize=(7, 5))\n", + "\n", + "gs = plt.GridSpec(\n", + " ncols=6, nrows=2, figure=f,\n", + " left=0, right=.35, bottom=0, top=.9,\n", + " height_ratios=(1, 20),\n", + " wspace=.1, hspace=.01\n", + ")\n", + "\n", + "colors = [c + (.5,) for c in sns.color_palette()]\n", + "\n", + "f.add_subplot(gs[0, :], facecolor=\".8\")\n", + "[\n", + " f.add_subplot(gs[1:, i], facecolor=colors[i])\n", + " for i in range(gs.ncols)\n", + "]\n", + "\n", + "gs = plt.GridSpec(\n", + " ncols=2, nrows=2, figure=f,\n", + " left=.4, right=1, bottom=.2, top=.8,\n", + " height_ratios=(1, 8), width_ratios=(1, 11),\n", + " wspace=.015, hspace=.02\n", + ")\n", + "\n", + "f.add_subplot(gs[0, 1:], facecolor=colors[2])\n", + "f.add_subplot(gs[1:, 0], facecolor=colors[1])\n", + "f.add_subplot(gs[1, 1], facecolor=colors[0])\n", + "\n", + "for ax in f.axes:\n", + " ax.set(xticks=[], yticks=[])\n", + "\n", + "f.text(.35 / 2, .91, \"Long-form\", ha=\"center\", va=\"bottom\", size=15)\n", + "f.text(.7, .81, \"Wide-form\", ha=\"center\", va=\"bottom\", size=15)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Messy data\n", + "~~~~~~~~~~\n", + "\n", + "Many datasets cannot be clearly interpreted using either long-form or wide-form rules. If datasets that are clearly long-form or wide-form are `\"tidy\" `_, we might say that these more ambiguous datasets are \"messy\". In a messy dataset, the variables are neither uniquely defined by the keys nor by the dimensions of the table. This often occurs with *repeated-measures* data, where it is natural to organize a table such that each row corresponds to the *unit* of data collection. Consider this simple dataset from a psychology experiment in which twenty subjects performed a memory task where they studied anagrams while their attention was either divided or focused:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "anagrams = sns.load_dataset(\"anagrams\")\n", + "anagrams" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The attention variable is *between-subjects*, but there is also a *within-subjects* variable: the number of possible solutions to the anagrams, which varied from 1 to 3. The dependent measure is a score of memory performance. These two variables (number and score) are jointly encoded across several columns. As a result, the whole dataset is neither clearly long-form nor clearly wide-form.\n", + "\n", + "How might we tell seaborn to plot the average score as a function of attention and number of solutions? We'd first need to coerce the data into one of our two structures. Let's transform it to a tidy long-form table, such that each variable is a column and each row is an observation. We can use the method :meth:`pandas.DataFrame.melt` to accomplish this task:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "anagrams_long = anagrams.melt(id_vars=[\"subidr\", \"attnr\"], var_name=\"solutions\", value_name=\"score\")\n", + "anagrams_long.head()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Now we can make the plot that we want:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.catplot(data=anagrams_long, x=\"solutions\", y=\"score\", hue=\"attnr\", kind=\"point\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Further reading and take-home points\n", + "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", + "\n", + "For a longer discussion about tabular data structures, you could read the `\"Tidy Data\" `_ paper by Hadley Whickham. Note that seaborn uses a slightly different set of concepts than are defined in the paper. While the paper associates tidyness with long-form structure, we have drawn a distinction between \"tidy wide-form\" data, where there is a clear mapping between variables in the dataset and the dimensions of the table, and \"messy data\", where no such mapping exists.\n", + "\n", + "The long-form structure has clear advantages. It allows you to create figures by explicitly assigning variables in the dataset to roles in plot, and you can do so with more than three variables. When possible, try to represent your data with a long-form structure when embarking on serious analysis. Most of the examples in the seaborn documentation will use long-form data. But in cases where it is more natural to keep the dataset wide, remember that seaborn can remain useful." + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Options for visualizing long-form data\n", + "--------------------------------------\n", + "\n", + "While long-form data has a precise definition, seaborn is fairly flexible in terms of how it is actually organized across the data structures in memory. The examples in the rest of the documentation will typically use :class:`pandas.DataFrame` objects and reference variables in them by assigning names of their columns to the variables in the plot. But it is also possible to store vectors in a Python dictionary or a class that implements that interface:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "flights_dict = flights.to_dict()\n", + "sns.relplot(data=flights_dict, x=\"year\", y=\"passengers\", hue=\"month\", kind=\"line\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Many pandas operations, such as a the split-apply-combine operations of a group-by, will produce a dataframe where information has moved from the columns of the input dataframe to the index of the output. So long as the name is retained, you can still reference the data as normal:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "flights_avg = flights.groupby(\"year\").mean()\n", + "sns.relplot(data=flights_avg, x=\"year\", y=\"passengers\", kind=\"line\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Additionally, it's possible to pass vectors of data directly as arguments to ``x``, ``y``, and other plotting variables. If these vectors are pandas objects, the ``name`` attribute will be used to label the plot:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "year = flights_avg.index\n", + "passengers = flights_avg[\"passengers\"]\n", + "sns.relplot(x=year, y=passengers, kind=\"line\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Numpy arrays and other objects that implement the Python sequence interface work too, but if they don't have names, the plot will not be as informative without further tweaking:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.relplot(x=year.to_numpy(), y=passengers.to_list(), kind=\"line\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Options for visualizing wide-form data\n", + "--------------------------------------\n", + "\n", + "The options for passing wide-form data are even more flexible. As with long-form data, pandas objects are preferable because the name (and, in some cases, index) information can be used. But in essence, any format that can be viewed as a single vector or a collection of vectors can be passed to ``data``, and a valid plot can usually be constructed.\n", + "\n", + "The example we saw above used a rectangular :class:`pandas.DataFrame`, which can be thought of as a collection of its columns. A dict or list of pandas objects will also work, but we'll lose the axis labels:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "flights_wide_list = [col for _, col in flights_wide.items()]\n", + "sns.relplot(data=flights_wide_list, kind=\"line\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The vectors in a collection do not need to have the same length. If they have an ``index``, it will be used to align them:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "two_series = [flights_wide.loc[:1955, \"Jan\"], flights_wide.loc[1952:, \"Aug\"]]\n", + "sns.relplot(data=two_series, kind=\"line\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Whereas an ordinal index will be used for numpy arrays or simple Python sequences:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "two_arrays = [s.to_numpy() for s in two_series]\n", + "sns.relplot(data=two_arrays, kind=\"line\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "But a dictionary of such vectors will at least use the keys:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "two_arrays_dict = {s.name: s.to_numpy() for s in two_series}\n", + "sns.relplot(data=two_arrays_dict, kind=\"line\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Rectangular numpy arrays are treated just like a dataframe without index information, so they are viewed as a collection of column vectors. Note that this is different from how numpy indexing operations work, where a single indexer will access a row. But it is consistent with how pandas would turn the array into a dataframe or how matplotlib would plot it:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "flights_array = flights_wide.to_numpy()\n", + "sns.relplot(data=flights_array, kind=\"line\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "# TODO once the categorical module is refactored, its single vectors will get special treatment\n", + "# (they'll look like collection of singletons, rather than a single collection). That should be noted." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "seaborn-refactor (py38)", + "language": "python", + "name": "seaborn-refactor" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/tutorial/distributions.ipynb b/doc/tutorial/distributions.ipynb index b69f0f398e..59452c5a13 100644 --- a/doc/tutorial/distributions.ipynb +++ b/doc/tutorial/distributions.ipynb @@ -13,19 +13,44 @@ "cell_type": "raw", "metadata": {}, "source": [ - "Visualizing the distribution of a dataset\n", - "=========================================\n", + "Visualizing distributions of data\n", + "==================================\n", "\n", ".. raw:: html\n", "\n", - "
\n" + "
\n", + "\n", + "An early step in any effort to analyze or model data should be to understand how the variables are distributed. Techniques for distribution visualization can provide quick answers to many important questions. What range do the observations cover? What is their central tendency? Are they heavily skewed in one direction? Is there evidence for bimodality? Are there significant outliers? Do the answers to these questions vary across subsets defined by other variables?\n", + "\n", + "The :ref:`distributions module ` contains several functions designed to answer questions such as these. The axes-level functions are :func:`histplot`, :func:`kdeplot`, :func:`ecdfplot`, and :func:`rugplot`. They are grouped together within the figure-level :func:`displot`, :func:`jointplot`, and :func:`pairplot` functions.\n", + "\n", + "There are several different approaches to visualizing a distribution, and each has its relative advantages and drawbacks. It is important to understand theses factors so that you can choose the best approach for your particular aim." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "import seaborn as sns; sns.set_theme()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "When dealing with a set of data, often the first thing you'll want to do is get a sense for how the variables are distributed. This chapter of the tutorial will give a brief introduction to some of the tools in seaborn for examining univariate and bivariate distributions. You may also want to look at the :ref:`categorical plots ` chapter for examples of functions that make it easy to compare the distribution of a variable across levels of other variables." + ".. _tutorial_hist:\n", + "\n", + "Plotting univariate histograms\n", + "------------------------------\n", + "\n", + "Perhaps the most common approach to visualizing a distribution is the *histogram*. This is the default approach in :func:`displot`, which uses the same underlying code as :func:`histplot`. A histogram is a bar plot where the axis representing the data variable is divided into a set of discrete bins and the count of observations falling within each bin is shown using the height of the corresponding bar:" ] }, { @@ -34,11 +59,20 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", - "import pandas as pd\n", - "import seaborn as sns\n", - "import matplotlib.pyplot as plt\n", - "from scipy import stats" + "penguins = sns.load_dataset(\"penguins\")\n", + "sns.displot(penguins, x=\"flipper_length_mm\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "This plot immediately affords a few insights about the ``flipper_length_mm`` variable. For instance, we can see that the most common flipper length is about 195 mm, but the distribution appears bimodal, so this one number does not represent the data well.\n", + "\n", + "Choosing the bin size\n", + "^^^^^^^^^^^^^^^^^^^^^\n", + "\n", + "The size of the bins is an important parameter, and using the wrong bin size can mislead by obscuring important features of the data or by creating apparent features out of random variability. By default, :func:`displot`/:func:`histplot` choose a default bin size based on the variance of the data and the number of observations. But you should not be over-reliant on such automatic approaches, because they depend on particular assumptions about the structure of your data. It is always advisable to check that your impressions of the distribution are consistent across different bin sizes. To choose the size directly, set the `binwidth` parameter:" ] }, { @@ -47,31 +81,162 @@ "metadata": {}, "outputs": [], "source": [ - "sns.set(color_codes=True)" + "sns.displot(penguins, x=\"flipper_length_mm\", binwidth=3)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "In other circumstances, it may make more sense to specify the *number* of bins, rather than their size:" ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [ - "hide" - ] - }, + "metadata": {}, "outputs": [], "source": [ - "%matplotlib inline\n", - "np.random.seed(sum(map(ord, \"distributions\")))" + "sns.displot(penguins, x=\"flipper_length_mm\", bins=20)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "One example of a situation where defaults fail is when the variable takes a relatively small number of integer values. In that case, the default bin width may be too small, creating awkward gaps in the distribution:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tips = sns.load_dataset(\"tips\")\n", + "sns.displot(tips, x=\"size\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "One approach would be to specify the precise bin breaks by passing an array to ``bins``:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(tips, x=\"size\", bins=[1, 2, 3, 4, 5, 6, 7])" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "This can also be accomplished by setting ``discrete=True``, which chooses bin breaks that represent the unique values in a dataset with bars that are centered on their corresponding value." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(tips, x=\"size\", discrete=True)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "It's also possible to visualize the distribution of a categorical variable using the logic of a histogram. Discrete bins are automatically set for categorical variables, but it may also be helpful to \"shrink\" the bars slightly to emphasize the categorical nature of the axis:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(tips, x=\"day\", shrink=.8)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Plotting univariate distributions\n", - "---------------------------------\n", + "Conditioning on other variables\n", + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", "\n", - "The most convenient way to take a quick look at a univariate distribution in seaborn is the :func:`distplot` function. By default, this will draw a `histogram `_ and fit a `kernel density estimate `_ (KDE). " + "Once you understand the distribution of a variable, the next step is often to ask whether features of that distribution differ across other variables in the dataset. For example, what accounts for the bimodal distribution of flipper lengths that we saw above? :func:`displot` and :func:`histplot` provide support for conditional subsetting via the ``hue`` semantic. Assigning a variable to ``hue`` will draw a separate histogram for each of its unique values and distinguish them by color:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(penguins, x=\"flipper_length_mm\", hue=\"species\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "By default, the different histograms are \"layered\" on top of each other and, in some cases, they may be difficult to distinguish. One option is to change the visual representation of the histogram from a bar plot to a \"step\" plot:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(penguins, x=\"flipper_length_mm\", hue=\"species\", element=\"step\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Alternatively, instead of layering each bar, they can be \"stacked\", or moved vertically. In this plot, the outline of the full histogram will match the plot with only a single variable:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(penguins, x=\"flipper_length_mm\", hue=\"species\", multiple=\"stack\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The stacked histogram emphasizes the part-whole relationship between the variables, but it can obscure other features (for example, it is difficult to determine the mode of the Adelie distribution. Another option is \"dodge\" the bars, which moves them horizontally and reduces their width. This ensures that there are no overlaps and that the bars remain comparable in terms of height. But it only works well when the categorical variable has a small number of levels:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(penguins, x=\"flipper_length_mm\", hue=\"sex\", multiple=\"dodge\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Because :func:`displot` is a figure-level function and is drawn onto a :class:`FacetGrid`, it is also possible to draw each individual distribution in a separate subplot by assigning the second variable to ``col`` or ``row`` rather than (or in addition to) ``hue``. This represents the distribution of each subset well, but it makes it more difficult to draw direct comparisons:" ] }, { @@ -80,20 +245,35 @@ "metadata": {}, "outputs": [], "source": [ - "x = np.random.normal(size=100)\n", - "sns.distplot(x);" + "sns.displot(penguins, x=\"flipper_length_mm\", col=\"sex\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Histograms\n", - "^^^^^^^^^^\n", + "None of these approaches are perfect, and we will soon see some alternatives to a histogram that are better-suited to the task of comparison.\n", "\n", - "Histograms are likely familiar, and a ``hist`` function already exists in matplotlib. A histogram represents the distribution of data by forming bins along the range of the data and then drawing bars to show the number of observations that fall in each bin.\n", + "Normalized histogram statistics\n", + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", "\n", - "To illustrate this, let's remove the density curve and add a rug plot, which draws a small vertical tick at each observation. You can make the rug plot itself with the :func:`rugplot` function, but it is also available in :func:`distplot`:" + "Before we do, another point to note is that, when the subsets have unequal numbers of observations, comparing their distributions in terms of counts may not be ideal. One solution is to *normalize* the counts using the ``stat`` parameter:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(penguins, x=\"flipper_length_mm\", hue=\"species\", stat=\"density\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "By default, however, the normalization is applied to the entire distribution, so this simply rescales the height of the bars. By setting ``common_norm=False``, each subset will be normalized independently:" ] }, { @@ -102,14 +282,14 @@ "metadata": {}, "outputs": [], "source": [ - "sns.distplot(x, kde=False, rug=True);" + "sns.displot(penguins, x=\"flipper_length_mm\", hue=\"species\", stat=\"density\", common_norm=False)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "When drawing histograms, the main choice you have is the number of bins to use and where to place them. :func:`distplot` uses a simple rule to make a good guess for what the right number is by default, but trying more or fewer bins might reveal other features in the data:" + "Density normalization scales the bars so that their *areas* sum to 1. As a result, the density axis is not directly interpretable. Another option is to normalize the bars to that their *heights* sum to 1. This makes most sense when the variable is discrete, but it is an option for all histograms:" ] }, { @@ -118,17 +298,19 @@ "metadata": {}, "outputs": [], "source": [ - "sns.distplot(x, bins=20, kde=False, rug=True);" + "sns.displot(penguins, x=\"flipper_length_mm\", hue=\"species\", stat=\"probability\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ + ".. _tutorial_kde:\n", + "\n", "Kernel density estimation\n", - "^^^^^^^^^^^^^^^^^^^^^^^^^\n", + "-------------------------\n", "\n", - "The kernel density estimate may be less familiar, but it can be a useful tool for plotting the shape of a distribution. Like the histogram, the KDE plots encode the density of observations on one axis with height along the other axis:" + "A histogram aims to approximate the underlying probability density function that generated the data by binning and counting observations. Kernel density estimation (KDE) presents a different solution to the same problem. Rather than using discrete bins, a KDE plot smooths the observations with a Gaussian kernel, producing a continuous density estimate:" ] }, { @@ -137,14 +319,17 @@ "metadata": {}, "outputs": [], "source": [ - "sns.distplot(x, hist=False, rug=True);" + "sns.displot(penguins, x=\"flipper_length_mm\", kind=\"kde\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Drawing a KDE is more computationally involved than drawing a histogram. What happens is that each observation is first replaced with a normal (Gaussian) curve centered at that value:" + "Choosing the smoothing bandwidth\n", + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + "\n", + "Much like with the bin size in the histogram, the ability of the KDE to accurately represent the data depends on the choice of smoothing bandwidth. An over-smoothed estimate might erase meaningful features, but an under-smoothed estimate can obscure the true shape within random noise. The easiest way to check the robustness of the estimate is to adjust the default bandwidth:" ] }, { @@ -153,25 +338,100 @@ "metadata": {}, "outputs": [], "source": [ - "x = np.random.normal(0, 1, size=30)\n", - "bandwidth = 1.06 * x.std() * x.size ** (-1 / 5.)\n", - "support = np.linspace(-4, 4, 200)\n", - "\n", - "kernels = []\n", - "for x_i in x:\n", + "sns.displot(penguins, x=\"flipper_length_mm\", kind=\"kde\", bw_adjust=.25)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Note how the narrow bandwidth makes the bimodality much more apparent, but the curve is much less smooth. In contrast, a larger bandwidth obscures the bimodality almost completely:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(penguins, x=\"flipper_length_mm\", kind=\"kde\", bw_adjust=2)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Conditioning on other variables\n", + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", "\n", - " kernel = stats.norm(x_i, bandwidth).pdf(support)\n", - " kernels.append(kernel)\n", - " plt.plot(support, kernel, color=\"r\")\n", + "As with histograms, if you assign a ``hue`` variable, a separate density estimate will be computed for each level of that variable:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(penguins, x=\"flipper_length_mm\", hue=\"species\", kind=\"kde\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "In many cases, the layered KDE is easier to interpret than the layered histogram, so it is often a good choice for the task of comparison. Many of the same options for resolving multiple distributions apply to the KDE as well, however:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(penguins, x=\"flipper_length_mm\", hue=\"species\", kind=\"kde\", multiple=\"stack\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Note how the stacked plot filled in the area between each curve by default. It is also possible to fill in the curves for single or layered densities, although the default alpha value (opacity) will be different, so that the individual densities are easier to resolve." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(penguins, x=\"flipper_length_mm\", hue=\"species\", kind=\"kde\", fill=True)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Kernel density estimation pitfalls\n", + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", "\n", - "sns.rugplot(x, color=\".2\", linewidth=3);" + "KDE plots have many advantages. Important features of the data are easy to discern (central tendency, bimodality, skew), and they afford easy comparisons between subsets. But there are also situations where KDE poorly represents the underlying data. This is because the logic of KDE assumes that the underlying distribution is smooth and unbounded. One way this assumption can fail is when a varible reflects a quantity that is naturally bounded. If there are observations lying close to the bound (for example, small values of a variable that cannot be negative), the KDE curve may extend to unrealistic values:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(tips, x=\"total_bill\", kind=\"kde\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Next, these curves are summed to compute the value of the density at each point in the support grid. The resulting curve is then normalized so that the area under it is equal to 1:" + "This can be partially avoided with the ``cut`` parameter, which specifies how far the curve should extend beyond the extreme datapoints. But this influences only where the curve is drawn; the density estimate will still smooth over the range where no data can exist, causing it to be artifically low at the extremes of the distribution:" ] }, { @@ -180,17 +440,14 @@ "metadata": {}, "outputs": [], "source": [ - "from scipy.integrate import trapz\n", - "density = np.sum(kernels, axis=0)\n", - "density /= trapz(density, support)\n", - "plt.plot(support, density);" + "sns.displot(tips, x=\"total_bill\", kind=\"kde\", cut=0)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "We can see that if we use the :func:`kdeplot` function in seaborn, we get the same curve. This function is used by :func:`distplot`, but it provides a more direct interface with easier access to other options when you just want the density estimate:" + "The KDE approach also fails for discrete data or when data are naturally continuous but specific values are over-represented. The important thing to keep in mind is that the KDE will *always show you a smooth curve*, even when the data themselves are not smooth. For example, consider this distribution of diamond weights:" ] }, { @@ -199,14 +456,15 @@ "metadata": {}, "outputs": [], "source": [ - "sns.kdeplot(x, shade=True);" + "diamonds = sns.load_dataset(\"diamonds\")\n", + "sns.displot(diamonds, x=\"carat\", kind=\"kde\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "The bandwidth (``bw``) parameter of the KDE controls how tightly the estimation is fit to the data, much like the bin size in a histogram. It corresponds to the width of the kernels we plotted above. The default behavior tries to guess a good value using a common reference rule, but it may be helpful to try larger or smaller values:" + "While the KDE suggests that there are peaks around specific values, the histogram reveals a much more jagged distribution:" ] }, { @@ -215,17 +473,14 @@ "metadata": {}, "outputs": [], "source": [ - "sns.kdeplot(x)\n", - "sns.kdeplot(x, bw=.2, label=\"bw: 0.2\")\n", - "sns.kdeplot(x, bw=2, label=\"bw: 2\")\n", - "plt.legend();" + "sns.displot(diamonds, x=\"carat\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "As you can see above, the nature of the Gaussian KDE process means that estimation extends past the largest and smallest values in the dataset. It's possible to control how far past the extreme values the curve is drawn with the ``cut`` parameter; however, this only influences how the curve is drawn and not how it is fit:" + "As a compromise, it is possible to combine these two approaches. While in histogram mode, :func:`displot` (as with :func:`histplot`) has the option of including the smoothed KDE curve (note ``kde=True``, not ``kind=\"kde\"``):" ] }, { @@ -234,18 +489,19 @@ "metadata": {}, "outputs": [], "source": [ - "sns.kdeplot(x, shade=True, cut=0)\n", - "sns.rugplot(x);" + "sns.displot(diamonds, x=\"carat\", kde=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Fitting parametric distributions\n", - "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + ".. _tutorial_ecdf:\n", + "\n", + "Empirical cumulative distributions\n", + "----------------------------------\n", "\n", - "You can also use :func:`distplot` to fit a parametric distribution to a dataset and visually evaluate how closely it corresponds to the observed data:" + "A third option for visualizing distributions computes the \"empirical cumulative distribution function\" (ECDF). This plot draws a monotonically-increasing curve through each datapoint such that the height of the curve reflects the proportion of observations with a smaller value:" ] }, { @@ -254,18 +510,14 @@ "metadata": {}, "outputs": [], "source": [ - "x = np.random.gamma(6, size=200)\n", - "sns.distplot(x, kde=False, fit=stats.gamma);" + "sns.displot(penguins, x=\"flipper_length_mm\", kind=\"ecdf\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Plotting bivariate distributions\n", - "--------------------------------\n", - "\n", - "It can also be useful to visualize a bivariate distribution of two variables. The easiest way to do this in seaborn is to just use the :func:`jointplot` function, which creates a multi-panel figure that shows both the bivariate (or joint) relationship between two variables along with the univariate (or marginal) distribution of each on separate axes." + "The ECDF plot has two key advantages. Unlike the histogram or KDE, it directly represents each datapoint. That means there is no bin size or smoothing parameter to consider. Additionally, because the curve is monotonically increasing, it is well-suited for comparing multiple distributions:" ] }, { @@ -274,19 +526,24 @@ "metadata": {}, "outputs": [], "source": [ - "mean, cov = [0, 1], [(1, .5), (.5, 1)]\n", - "data = np.random.multivariate_normal(mean, cov, 200)\n", - "df = pd.DataFrame(data, columns=[\"x\", \"y\"])" + "sns.displot(penguins, x=\"flipper_length_mm\", hue=\"species\", kind=\"ecdf\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Scatterplots\n", - "^^^^^^^^^^^^\n", + "The major downside to the ECDF plot is that it represents the shape of the distribution less intuitively than a histogram or density curve. Consider how the bimodality of flipper lengths is immediately apparent in the histogram, but to see it in the ECDF plot, you must look for varying slopes. Nevertheless, with practice, you can learn to answer all of the important questions about a distribution by examining the ECDF, and doing so can be a powerful approach." + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Visualizing bivariate distributions\n", + "-----------------------------------\n", "\n", - "The most familiar way to visualize a bivariate distribution is a scatterplot, where each observation is shown with point at the *x* and *y* values. This is analogous to a rug plot on two dimensions. You can draw a scatterplot with the matplotlib ``plt.scatter`` function, and it is also the default kind of plot shown by the :func:`jointplot` function:" + "All of the examples so far have considered *univariate* distributions: distributions of a single variable, perhaps conditional on a second variable assigned to ``hue``. Assigning a second variable to ``y``, however, will plot a *bivariate* distribution:" ] }, { @@ -295,17 +552,14 @@ "metadata": {}, "outputs": [], "source": [ - "sns.jointplot(x=\"x\", y=\"y\", data=df);" + "sns.displot(penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Hexbin plots\n", - "^^^^^^^^^^^^\n", - "\n", - "The bivariate analogue of a histogram is known as a \"hexbin\" plot, because it shows the counts of observations that fall within hexagonal bins. This plot works best with relatively large datasets. It's available through the matplotlib ``plt.hexbin`` function and as a style in :func:`jointplot`. It looks best with a white background:" + "A bivariate histogram bins the data within rectangles that tile the plot and then shows the count of observations within each rectangle with the fill color (analagous to a :func:`heatmap`). Similarly, a bivariate KDE plot smoothes the (x, y) observations with a 2D Gaussian. The default representation then shows the *contours* of the 2D density:" ] }, { @@ -314,19 +568,30 @@ "metadata": {}, "outputs": [], "source": [ - "x, y = np.random.multivariate_normal(mean, cov, 1000).T\n", - "with sns.axes_style(\"white\"):\n", - " sns.jointplot(x=x, y=y, kind=\"hex\", color=\"k\");" + "sns.displot(penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", kind=\"kde\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Kernel density estimation\n", - "^^^^^^^^^^^^^^^^^^^^^^^^^\n", - "\n", - "It is also possible to use the kernel density estimation procedure described above to visualize a bivariate distribution. In seaborn, this kind of plot is shown with a contour plot and is available as a style in :func:`jointplot`:" + "Assigning a ``hue`` variable will plot multiple heatmaps or contour sets using different colors. For bivariate histograms, this will only work well if there is minimal overlap between the conditional distributions:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", hue=\"species\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The contour approach of the bivariate KDE plot lends itself better to evaluating overlap, although a plot with too many contours can get busy:" ] }, { @@ -335,14 +600,14 @@ "metadata": {}, "outputs": [], "source": [ - "sns.jointplot(x=\"x\", y=\"y\", data=df, kind=\"kde\");" + "sns.displot(penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", hue=\"species\", kind=\"kde\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "You can also draw a two-dimensional kernel density plot with the :func:`kdeplot` function. This allows you to draw this kind of plot onto a specific (and possibly already existing) matplotlib axes, whereas the :func:`jointplot` function manages its own figure:" + "Just as with univariate plots, the choice of bin size or smoothing bandwidth will determine how well the plot represents the underlying bivariate distribution. The same parameters apply, but they can be tuned for each variable by passing a pair of values:" ] }, { @@ -351,17 +616,14 @@ "metadata": {}, "outputs": [], "source": [ - "f, ax = plt.subplots(figsize=(6, 6))\n", - "sns.kdeplot(df.x, df.y, ax=ax)\n", - "sns.rugplot(df.x, color=\"g\", ax=ax)\n", - "sns.rugplot(df.y, vertical=True, ax=ax);" + "sns.displot(penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", binwidth=(2, .5))" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "If you wish to show the bivariate density more continuously, you can simply increase the number of contour levels:" + "To aid interpretation of the heatmap, add a colorbar to show the mapping between counts and color intensity:" ] }, { @@ -370,16 +632,14 @@ "metadata": {}, "outputs": [], "source": [ - "f, ax = plt.subplots(figsize=(6, 6))\n", - "cmap = sns.cubehelix_palette(as_cmap=True, dark=0, light=1, reverse=True)\n", - "sns.kdeplot(df.x, df.y, cmap=cmap, n_levels=60, shade=True);" + "sns.displot(penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", binwidth=(2, .5), cbar=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "The :func:`jointplot` function uses a :class:`JointGrid` to manage the figure. For more flexibility, you may want to draw your figure by using :class:`JointGrid` directly. :func:`jointplot` returns the :class:`JointGrid` object after plotting, which you can use to add more layers or to tweak other aspects of the visualization:" + "The meaning of the bivariate density contours is less straightforward. Because the density is not directly interpretable, the contours are drawn at *iso-proportions* of the density, meaning that each curve shows a level set such that some proportion *p* of the density lies below it. The *p* values are evenly spaced, with the lowest level contolled by the ``thresh`` parameter and the number controlled by ``levels``:" ] }, { @@ -388,20 +648,71 @@ "metadata": {}, "outputs": [], "source": [ - "g = sns.jointplot(x=\"x\", y=\"y\", data=df, kind=\"kde\", color=\"m\")\n", - "g.plot_joint(plt.scatter, c=\"w\", s=30, linewidth=1, marker=\"+\")\n", - "g.ax_joint.collections[0].set_alpha(0)\n", - "g.set_axis_labels(\"$X$\", \"$Y$\");" + "sns.displot(penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", kind=\"kde\", thresh=.2, levels=4)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Visualizing pairwise relationships in a dataset\n", - "-----------------------------------------------\n", + "The ``levels`` parameter also accepts a list of values, for more control:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", kind=\"kde\", levels=[.01, .05, .1, .8])" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The bivariate histogram allows one or both variables to be discrete. Plotting one discrete and one continuous variable offers another way to compare conditional univariate distributions:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(diamonds, x=\"price\", y=\"clarity\", log_scale=(True, False))" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "In contrast, plotting two discrete variables is an easy to way show the cross-tabulation of the observations:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(diamonds, x=\"color\", y=\"clarity\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Distribution visualization in other settings\n", + "--------------------------------------------\n", + "\n", + "Several other figure-level plotting functions in seaborn make use of the :func:`histplot` and :func:`kdeplot` functions.\n", "\n", - "To plot multiple pairwise bivariate distributions in a dataset, you can use the :func:`pairplot` function. This creates a matrix of axes and shows the relationship for each pair of columns in a DataFrame. By default, it also draws the univariate distribution of each variable on the diagonal Axes:" + "\n", + "Plotting joint and marginal distributions\n", + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + "\n", + "The first is :func:`jointplot`, which augments a bivariate relatonal or distribution plot with the marginal distributions of the two variables. By default, :func:`jointplot` represents the bivariate distribution using :func:`scatterplot` and the marginal distributions using :func:`histplot`:" ] }, { @@ -410,15 +721,91 @@ "metadata": {}, "outputs": [], "source": [ - "iris = sns.load_dataset(\"iris\")\n", - "sns.pairplot(iris);" + "sns.jointplot(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Specifying the ``hue`` parameter automatically changes the histograms to KDE plots to facilitate comparisons between multiple distributions." + "Similar to :func:`displot`, setting a different ``kind=\"kde\"`` in :func:`jointplot` will change both the joint and marginal plots the use :func:`kdeplot`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.jointplot(\n", + " data=penguins,\n", + " x=\"bill_length_mm\", y=\"bill_depth_mm\", hue=\"species\",\n", + " kind=\"kde\"\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + ":func:`jointplot` is a convenient interface to the :class:`JointGrid` class, which offeres more flexibility when used directly:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.JointGrid(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")\n", + "g.plot_joint(sns.histplot)\n", + "g.plot_marginals(sns.boxplot)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "A less-obtrusive way to show marginal distributions uses a \"rug\" plot, which adds a small tick on the edge of the plot to represent each individual observation. This is built into :func:`displot`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(\n", + " penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\",\n", + " kind=\"kde\", rug=True\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "And the axes-level :func:`rugplot` function can be used to add rugs on the side of any other kind of plot:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.relplot(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")\n", + "sns.rugplot(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Plotting many distributions\n", + "^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + "\n", + "The :func:`pairplot` function offers a similar blend of joint and marginal distributions. Rather than focusing on a single relationship, however, :func:`pairplot` uses a \"small-multiple\" approach to visualize the univariate distribution of all variables in a dataset along with all of their pairwise relationships:" ] }, { @@ -427,14 +814,14 @@ "metadata": {}, "outputs": [], "source": [ - "sns.pairplot(iris, hue=\"species\");" + "sns.pairplot(penguins)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ - "Much like the relationship between :func:`jointplot` and :class:`JointGrid`, the :func:`pairplot` function is built on top of a :class:`PairGrid` object, which can be used directly for more flexibility:" + "As with :func:`jointplot`/:class:`JointGrid`, using the underlying :class:`PairGrid` directly will afford more flexibility with only a bit more typing:" ] }, { @@ -443,9 +830,10 @@ "metadata": {}, "outputs": [], "source": [ - "g = sns.PairGrid(iris)\n", - "g.map_diag(sns.kdeplot)\n", - "g.map_offdiag(sns.kdeplot, n_levels=6);" + "g = sns.PairGrid(penguins)\n", + "g.map_upper(sns.histplot)\n", + "g.map_lower(sns.kdeplot, fill=True)\n", + "g.map_diag(sns.histplot, kde=True)" ] }, { @@ -461,9 +849,9 @@ "metadata": { "celltoolbar": "Tags", "kernelspec": { - "display_name": "Python 3.6 (seaborn-dev)", + "display_name": "seaborn-py38-latest", "language": "python", - "name": "seaborn-dev" + "name": "seaborn-py38-latest" }, "language_info": { "codemirror_mode": { @@ -475,9 +863,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.3" + "version": "3.8.5" } }, "nbformat": 4, - "nbformat_minor": 1 + "nbformat_minor": 4 } diff --git a/doc/tutorial/error_bars.ipynb b/doc/tutorial/error_bars.ipynb new file mode 100644 index 0000000000..c8db4600f4 --- /dev/null +++ b/doc/tutorial/error_bars.ipynb @@ -0,0 +1,358 @@ +{ + "cells": [ + { + "cell_type": "raw", + "metadata": {}, + "source": [ + ".. _errorbar_tutorial:\n", + "\n", + ".. currentmodule:: seaborn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "import matplotlib as mpl\n", + "import matplotlib.pyplot as plt\n", + "sns.set_theme(style=\"darkgrid\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "np.random.seed(sum(map(ord, \"relational\")))" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Statistical estimation and error bars\n", + "=====================================\n", + "\n", + ".. raw:: html\n", + "\n", + "
\n", + "\n", + "Data visualization sometimes involves a step of aggregation or estimation, where multiple data points are reduced to a summary statistic such as the mean or median. When showing a summary statistic, it is usually appropriate to add *error bars*, which provide a visual cue about how well the summary represents the underlying data points.\n", + "\n", + "Several seaborn functions will automatically calculate both summary statistics and the error bars when \n", + "given a full dataset. This chapter explains how you can control what the error bars show and why you might choose each of the options that seaborn affords." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO example plot pointing out what the error bar is" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The error bars around an estimate of central tendency can show one of two things: either the range of uncertainty about the estimate or the spread of the underlying data around it. These measures are related: given the same sample size, estimates will be more uncertain when data has a broader spread. But uncertainty will decrease as sample sizes grow, whereas spread will not.\n", + "\n", + "In seaborn, there are two approaches for constructing each kind of error bar. One approach is parametric, using a formula that relies on assumptions about the shape of the distribution. The other approach is nonparametric, using only the data that you provide.\n", + "\n", + "Your choice is made with the `errorbar` parameter, which exists for each function that does estimation as part of plotting. This parameter accepts the name of the method to use and, optionally, a parameter that controls the size of the interval. The choices can be defined in a 2D taxonomy that depends on what is shown and how it is constructed:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "f, axs = plt.subplots(2, 2, figsize=(6, 4), sharex=True, sharey=True)\n", + "\n", + "plt.setp(axs, xlim=(-1, 1), ylim=(-1, 1), xticks=[], yticks=[])\n", + "for ax, color in zip(axs.flat, [\"C0\", \"C1\", \"C3\", \"C2\"]):\n", + " ax.set_facecolor(mpl.colors.to_rgba(color, .25))\n", + "\n", + "kws = dict(x=0, y=.2, ha=\"center\", va=\"center\", size=20)\n", + "axs[0, 0].text(s=\"Standard deviation\", **kws)\n", + "axs[0, 1].text(s=\"Standard error\", **kws)\n", + "axs[1, 0].text(s=\"Percentile interval\", **kws)\n", + "axs[1, 1].text(s=\"Confidence interval\", **kws)\n", + "\n", + "kws = dict(x=0, y=-.2, ha=\"center\", va=\"center\", size=20, font=\"Courier New\")\n", + "axs[0, 0].text(s='errorbar=(\"sd\", scale)', **kws)\n", + "axs[0, 1].text(s='errorbar=(\"se\", scale)', **kws)\n", + "axs[1, 0].text(s='errorbar=(\"pi\", width)', **kws)\n", + "axs[1, 1].text(s='errorbar=(\"ci\", width)', **kws)\n", + "\n", + "kws = dict(size=16)\n", + "axs[1, 0].set_xlabel(\"Spread\", **kws)\n", + "axs[1, 1].set_xlabel(\"Uncertainty\", **kws)\n", + "axs[0, 0].set_ylabel(\"Parametric\", **kws)\n", + "axs[1, 0].set_ylabel(\"Nonparametric\", **kws)\n", + "\n", + "f.tight_layout()\n", + "f.subplots_adjust(hspace=.05, wspace=.05 * (4 / 6))" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "You will note that the size parameter is defined differently for the parametric and nonparametric approaches. For parametric error bars, it is a scalar factor that is multiplied by the statistic defining the error (standard error or standard deviation). For nonparametric error bars, it is a percentile width. This is explained further for each specific approach below." + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Measures of data spread\n", + "-----------------------\n", + "\n", + "Error bars that represent data spread present a compact display of the distribution, using three numbers where :func:`boxplot` would use 5 or more and :func:`violinplot` would use a complicated algorithm.\n", + "\n", + "Standard deviation error bars\n", + "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", + "\n", + "Standard deviation error bars are the simplest to explain, because the standard deviation is a familiar statistic. It is the average distance from each data point to the sample mean. By default, `errorbar=\"sd\"` will draw error bars at +/- 1 sd around the estimate, but the range can be increased by passing a scaling size parameter. Note that, assuming normally-distributed data, ~68% of the data will lie within one standard deviation, ~95% will lie within two, and ~99.7% will lie within three:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO demonstrate with figure" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Percentile interval error bars\n", + "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", + "\n", + "Percentile intervals also represent the range where some amount of the data fall, but they do so by \n", + "computing those percentiles directly from your sample. By default, `errorbar=\"pi\"` will show a 95% interval, ranging from the 2.5 to the 97.5 percentiles. You can chose a different range by passing a size parameter, e.g., to show the inter-quartile range:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO demonstrate with figure" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The standard deviation error bars will always be symmetrical around the estimate. This can be a problem when the data are skewed, especially if there are natural bounds (e.g., if the data represent a quantity that can only be positive). In some cases, standard deviation error bars may extend to \"impossible\" values. The nonparametric approach does not have this problem, because it can account for asymmetrical spread and will never extend beyond the range of the data:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO demonstrate with figure" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Measures of estimate uncertainty\n", + "--------------------------------\n", + "\n", + "If your data are a random sample from a larger population, then the mean (or other estimate) will be an imperfect measure of the true population average. Error bars that show estimate uncertainty try to represent the range of likely values for the true parameter.\n", + "\n", + "Standard error bars\n", + "~~~~~~~~~~~~~~~~~~~\n", + "\n", + "The standard error statistic is related to the standard deviation: in fact it is just the standard deviation divided by the square root of the sample size. The default, with `errorbar=\"se\"`, draws an interval +/- 1 standard error from the mean, but you can draw a different interval by scaling:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO demonstrate with figure" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Confidence interval error bars\n", + "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", + "\n", + "The nonparametric approach to representing uncertainty uses *bootstrapping*: a procedure where the dataset is randomly resampled with replacement a number of times, and the estimate is recalculated from each resample. This procedure creates a distribution of statistics approximating the distribution of values that you could have gotten for your estimate if you had a different sample.\n", + "\n", + "The confidence interval is constructed by taking a percentile interval of the *bootstrap distribution*. By default `errorbar=\"ci\"` draws a 95% confidence interval, but you can choose a smaller or larger one by setting a different width:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO demonstrate with figure" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The seaborn terminology is somewhat specific, because a confidence interval in statistics can be parametric or nonparametric. To draw a parametric confidence interval, you scale the standard error, using a formula similar to the one mentioned above. For example, an approximate 95% confidence interval can be constructed by taking the mean +/- two standard errors:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO demonstrate with figure" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The nonparametric bootstrap has advantages similar to those of the percentile interval: it will naturally adapt to skwewed and bounded data in a way that a standard error interval cannot. It is also more general. While the standard error formula is specific to the mean, error bars can be computed using the boootstrap for any estimator:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO demonstrate with figure" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "On the other hand, the bootstrap procedure is much more computationally-intensive. For large datasets, bootstrap intervals can be expensive to compute. But because uncertainty decreases with sample size, it may be more informative in that case to use an error bar that represents data spread." + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Custom error bars\n", + "-----------------\n", + "\n", + "If these recipes are not sufficient, it is also possible to pass a generic function to the `errorbar` parameter. This function should take a vector and produce a pair of values representing the minimum and maximum points of the interval:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO demonstrate with figure" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Note that seaborn functions cannot currently draw error bars from values that have been calculated externally, although matplotlib functions can be used to add such error bars to seaborn plots." + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Error bars on regression models\n", + "-------------------------------\n", + "\n", + "The preceding discussion has focused on error bars shown around parameter estimates for aggregate data. Error bars also arise in seaborn when estimating regression models to visualize relationships. Here, the error bars will be represented by a \"band\" around the regression line:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO demonstrate with figure" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Most of the same options apply. A regression line represents a conditional mean: the estimated average value of the *y* variable given a specific *x* value. So a standard error or confidence interval represents the uncertainty around the estimate of that conditional mean, whereas a standard deviation interval represents a prediction about the range of *y* would would see if given that *x*. Because the regression model extrapolates beyond the data at hand, it is not possible to draw a percentile interval around it." + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Are error bars enough?\n", + "----------------------\n", + "\n", + "You should always ask yourself whether it's best to use a plot that displays only a summary statistic and error bars. In many cases, it isn't.\n", + "\n", + "If you are interested in questions about summaries (such as whether the mean value differs between groups or increases over time), aggregation reduces the complexity of the plot and makes those inferences easier. But in doing so, it obscures valuable information about the underlying data points, such as the shape of the distributions and the presence of outliers.\n", + "\n", + "When analyzing your own data, don't be satisfied with summary statistics. Always look at the underlying distributions too. Sometimes, it can be helpful to combine both perspectives into the same figure. Many seaborn functions can help with this task, especially those discussed in the :ref:`categorical tutorial `." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "seaborn-py38-latest", + "language": "python", + "name": "seaborn-py38-latest" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/tutorial/function_overview.ipynb b/doc/tutorial/function_overview.ipynb new file mode 100644 index 0000000000..ba5184830a --- /dev/null +++ b/doc/tutorial/function_overview.ipynb @@ -0,0 +1,525 @@ +{ + "cells": [ + { + "cell_type": "raw", + "metadata": {}, + "source": [ + ".. _function_tutorial:\n", + "\n", + ".. currentmodule:: seaborn" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Overview of seaborn plotting functions\n", + "======================================\n", + "\n", + ".. raw:: html\n", + "\n", + "
\n", + "\n", + "Most of your interactions with seaborn will happen through a set of plotting functions. Later chapters in the tutorial will explore the specific features offered by each function. This chapter will introduce, at a high-level, the different kinds of functions that you will encounter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "import matplotlib.pyplot as plt\n", + "from IPython.display import HTML\n", + "sns.set_theme()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Similar functions for similar tasks\n", + "-----------------------------------\n", + "\n", + "The seaborn namespace is flat; all of the functionality is accessible at the top level. But the code itself is hierarchically structured, with modules of functions that achieve similar visualization goals through different means. Most of the docs are structured around these modules: you'll encounter names like \"relational\", \"distributional\", and \"categorical\".\n", + "\n", + "For example, the :ref:`distributions module ` defines functions that specialize in representing the distribution of datapoints. This includes familiar methods like the histogram:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "penguins = sns.load_dataset(\"penguins\")\n", + "sns.histplot(data=penguins, x=\"flipper_length_mm\", hue=\"species\", multiple=\"stack\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Along with similar, but perhaps less familiar, options such as kernel density estimation:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.kdeplot(data=penguins, x=\"flipper_length_mm\", hue=\"species\", multiple=\"stack\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Functions within a module share a lot of underlying code and offer similar features that may not be present in other components of the library (such as ``multiple=\"stack\"`` in the examples above). They are designed to facilitate switching between different visual representations as you explore a dataset, because different representations often have complementary strengths and weaknesses." + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Figure-level vs. axes-level functions\n", + "-------------------------------------\n", + "\n", + "In addition to the different modules, there is a cross-cutting classification of seaborn functions as \"axes-level\" or \"figure-level\". The examples above are axes-level functions. They plot data onto a single :class:`matplotlib.pyplot.Axes` object, which is the return value of the function.\n", + "\n", + "In contrast, figure-level functions interface with matplotlib through a seaborn object, usually a :class:`FacetGrid`, that manages the figure. Each module has a single figure-level function, which offers a unitary interface to its various axes-level functions. The organization looks a bit like this:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "from matplotlib.patches import FancyBboxPatch\n", + "\n", + "f, ax = plt.subplots(figsize=(7, 5))\n", + "f.subplots_adjust(0, 0, 1, 1)\n", + "ax.set_axis_off()\n", + "ax.set(xlim=(0, 1), ylim=(0, 1))\n", + "\n", + "\n", + "modules = \"relational\", \"distributions\", \"categorical\"\n", + "\n", + "pal = sns.color_palette(\"deep\")\n", + "colors = dict(relational=pal[0], distributions=pal[1], categorical=pal[2])\n", + "\n", + "pal = sns.color_palette(\"dark\")\n", + "text_colors = dict(relational=pal[0], distributions=pal[1], categorical=pal[2])\n", + "\n", + "\n", + "functions = dict(\n", + " relational=[\"scatterplot\", \"lineplot\"],\n", + " distributions=[\"histplot\", \"kdeplot\", \"ecdfplot\", \"rugplot\"],\n", + " categorical=[\"stripplot\", \"swarmplot\", \"boxplot\", \"violinplot\", \"pointplot\", \"barplot\"],\n", + ")\n", + "\n", + "pad = .06\n", + "\n", + "w = .2\n", + "h = .15\n", + "\n", + "xs = np.arange(0, 1, 1 / 3) + pad * 1.05\n", + "y = .7\n", + "\n", + "for x, mod in zip(xs, modules):\n", + " color = colors[mod] + (.2,)\n", + " text_color = text_colors[mod]\n", + " box = FancyBboxPatch((x, y), w, h, f\"round,pad={pad}\", color=\"white\")\n", + " ax.add_artist(box)\n", + " box = FancyBboxPatch((x, y), w, h, f\"round,pad={pad}\", linewidth=1, edgecolor=text_color, facecolor=color)\n", + " ax.add_artist(box)\n", + " ax.text(x + w / 2, y + h / 2, f\"{mod[:3]}plot\\n({mod})\", ha=\"center\", va=\"center\", size=22, color=text_color)\n", + "\n", + " for i, func in enumerate(functions[mod]):\n", + " x_i = x + w / 2\n", + " y_i = y - i * .1 - h / 2 - pad\n", + " box = FancyBboxPatch((x_i - w / 2, y_i - pad / 3), w, h / 4, f\"round,pad={pad / 3}\",\n", + " color=\"white\")\n", + " ax.add_artist(box)\n", + " box = FancyBboxPatch((x_i - w / 2, y_i - pad / 3), w, h / 4, f\"round,pad={pad / 3}\",\n", + " linewidth=1, edgecolor=text_color, facecolor=color)\n", + " ax.add_artist(box)\n", + " ax.text(x_i, y_i, func, ha=\"center\", va=\"center\", size=18, color=text_color)\n", + "\n", + " ax.plot([x_i, x_i], [y, y_i], zorder=-100, color=text_color, lw=1)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "For example, :func:`displot` is the figure-level function for the distributions module. Its default behavior is to draw a histogram, using the same code as :func:`histplot` behind the scenes:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(data=penguins, x=\"flipper_length_mm\", hue=\"species\", multiple=\"stack\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "To draw a kernel density plot instead, using the same code as :func:`kdeplot`, select it using the ``kind`` parameter:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(data=penguins, x=\"flipper_length_mm\", hue=\"species\", multiple=\"stack\", kind=\"kde\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "You'll notice that the figure-level plots look mostly like their axes-level counterparts, but there are a few differences. Notably, the legend is placed ouside the plot. They also have a slightly different shape (more on that shortly).\n", + "\n", + "The most useful feature offered by the figure-level functions is that they can easily create figures with multiple subplots. For example, instead of stacking the three distributions for each species of penguins in the same axes, we can \"facet\" them by plotting each distribution across the columns of the figure:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.displot(data=penguins, x=\"flipper_length_mm\", hue=\"species\", col=\"species\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The figure-level functions wrap their axes-level counterparts and pass the kind-specific keyword arguments (such as the bin size for a histogram) down to the underlying function. That means they are no less flexible, but there is a downside: the kind-specific parameters don't appear in the function signature or docstrings. Some of their features might be less discoverable, and you may need to look at two different pages of the documentation before understanding how to achieve a specific goal." + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Axes-level functions make self-contained plots\n", + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + "\n", + "The axes-level functions are written to act like drop-in replacements for matplotlib functions. While they add axis labels and legends automatically, they don't modify anything beyond the axes that they are drawn into. That means they can be composed into arbitrarily-complex matplotlib figures with predictable results.\n", + "\n", + "The axes-level functions call :func:`matplotlib.pyplot.gca` internally, which hooks into the matplotlib state-machine interface so that they draw their plots on the \"currently-active\" axes. But they additionally accept an ``ax=`` argument, which integrates with the object-oriented interface and lets you specify exactly where each plot should go:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "f, axs = plt.subplots(1, 2, figsize=(8, 4), gridspec_kw=dict(width_ratios=[4, 3]))\n", + "sns.scatterplot(data=penguins, x=\"flipper_length_mm\", y=\"bill_length_mm\", hue=\"species\", ax=axs[0])\n", + "sns.histplot(data=penguins, x=\"species\", hue=\"species\", shrink=.8, alpha=.8, legend=False, ax=axs[1])\n", + "f.tight_layout()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Figure-level functions own their figure\n", + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + "\n", + "In contrast, figure-level functions cannot (easily) be composed with other plots. By design, they \"own\" their own figure, including its initialization, so there's no notion of using a figure-level function to draw a plot onto an existing axes. This constraint allows the figure-level functions to implement features such as putting the legend outside of the plot.\n", + "\n", + "Nevertheless, it is possible to go beyond what the figure-level functions offer by accessing the matplotlib axes on the object that they return and adding other elements to the plot that way:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tips = sns.load_dataset(\"tips\")\n", + "g = sns.relplot(data=tips, x=\"total_bill\", y=\"tip\")\n", + "g.ax.axline(xy1=(10, 2), slope=.2, color=\"b\", dashes=(5, 2))" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Customizing plots from a figure-level function\n", + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + "\n", + "The figure-level functions return a :class:`FacetGrid` instance, which has a few methods for customizing attributes of the plot in a way that is \"smart\" about the subplot organization. For example, you can change the labels on the external axes using a single line of code:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.relplot(data=penguins, x=\"flipper_length_mm\", y=\"bill_length_mm\", col=\"sex\")\n", + "g.set_axis_labels(\"Flipper length (mm)\", \"Bill length (mm)\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "While convenient, this does add a bit of extra complexity, as you need to remember that this method is not part of the matplotlib API and exists only when using a figure-level function." + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Specifying figure sizes\n", + "^^^^^^^^^^^^^^^^^^^^^^^\n", + "\n", + "To increase or decrease the size of a matplotlib plot, you set the width and height of the entire figure, either in the `global rcParams `_, while setting up the plot (e.g. with the ``figsize`` parameter of :func:`matplotlib.pyplot.subplots`), or by calling a method on the figure object (e.g. :meth:`matplotlib.Figure.set_size_inches`). When using an axes-level function in seaborn, the same rules apply: the size of the plot is determined by the size of the figure it is part of and the axes layout in that figure.\n", + "\n", + "When using a figure-level function, there are several key differences. First, the functions themselves have parameters to control the figure size (although these are actually parameters of the underlying :class:`FacetGrid` that manages the figure). Second, these parameters, ``height`` and ``aspect``, parameterize the size slightly differently than the ``width``, ``height`` parameterization in matplotlib (using the seaborn parameters, ``width = height * apsect``). Most importantly, the parameters correspond to the size of each *subplot*, rather than the size of the overall figure.\n", + "\n", + "To illustrate the difference between these approaches, here is the default output of :func:`matplotlib.pyplot.subplots` with one subplot:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "f, ax = plt.subplots()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "A figure with multiple columns will have the same overall size, but the axes will be squeezed horizontally to fit in the space:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "f, ax = plt.subplots(1, 2, sharey=True)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "In contrast, a plot created by a figure-level function will be square. To demonstrate that, let's set up an empty plot by using :class:`FacetGrid` directly. This happens behind the scenes in functions like :func:`relplot`, :func:`displot`, or :func:`catplot`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.FacetGrid(penguins)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "When additional columns are added, the figure itself will become wider, so that its subplots have the same size and shape:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.FacetGrid(penguins, col=\"sex\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "And you can adjust the size and shape of each subplot without accounting for the total number of rows and columns in the figure:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.FacetGrid(penguins, col=\"sex\", height=3.5, aspect=.75)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The upshot is that you can assign faceting variables without stopping to think about how you'll need to adjust the total figure size. A downside is that, when you do want to change the figure size, you'll need to remember that things work a bit differently than they do in matplotlib." + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Relative merits of figure-level functions\n", + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + "\n", + "Here is a summary of the pros and cons that we have discussed above:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "HTML(\"\"\"\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
AdvantagesDrawbacks
Easy faceting by data variablesMany parameters not in function signature
Legend outside of plot by defaultCannot be part of a larger matplotlib figure
Easy figure-level customizationDifferent API from matplotlib
Different figure size parameterizationDifferent figure size parameterization
\n", + "\"\"\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "On balance, the figure-level functions add some additional complexity that can make things more confusing for beginners, but their distinct features give them additional power. The tutorial documentaion mostly uses the figure-level functions, because they produce slightly cleaner plots, and we generally recommend their use for most applications. The one situation where they are not a good choice is when you need to make a complex, standalone figure that composes multiple different plot kinds. At this point, it's recommended to set up the figure using matplotlib directly and to fill in the individual components using axes-level functions." + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Combining multiple views on the data\n", + "------------------------------------\n", + "\n", + "Two important plotting functions in seaborn don't fit cleanly into the classification scheme discussed above. These functions, :func:`jointplot` and :func:`pairplot`, employ multiple kinds of plots from different modules to represent mulitple aspects of a dataset in a single figure. Both plots are figure-level functions and create figures with multiple subplots by default. But they use different objects to manage the figure: :class:`JointGrid` and :class:`PairGrid`, respectively.\n", + "\n", + ":func:`jointplot` plots the relationship or joint distribution of two variables while adding marginal axes that show the univariate distribution of each one separately:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.jointplot(data=penguins, x=\"flipper_length_mm\", y=\"bill_length_mm\", hue=\"species\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + ":func:`pairplot` is similar — it combines joint and marginal views — but rather than focusing on a single relationship, it visualizes every pairwise combination of variables simultaneously:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.pairplot(data=penguins, hue=\"species\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Behind the scenes, these functions are using axes-level functions that you have already met (:func:`scatterplot` and :func:`kdeplot`), and they also have a ``kind`` parameter that lets you quickly swap in a different representation:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.jointplot(data=penguins, x=\"flipper_length_mm\", y=\"bill_length_mm\", hue=\"species\", kind=\"hist\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "seaborn-py38-latest", + "language": "python", + "name": "seaborn-py38-latest" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/tutorial/regression.ipynb b/doc/tutorial/regression.ipynb index 375948bbc7..9b6bf6ef4c 100644 --- a/doc/tutorial/regression.ipynb +++ b/doc/tutorial/regression.ipynb @@ -13,8 +13,8 @@ "cell_type": "raw", "metadata": {}, "source": [ - "Visualizing linear relationships\n", - "================================\n", + "Visualizing regression models\n", + "=============================\n", "\n", ".. raw:: html\n", "\n", @@ -47,7 +47,7 @@ "metadata": {}, "outputs": [], "source": [ - "sns.set(color_codes=True)" + "sns.set_theme(color_codes=True)" ] }, { @@ -201,9 +201,7 @@ }, { "cell_type": "raw", - "metadata": { - "collapsed": true - }, + "metadata": {}, "source": [ "In the presence of these kind of higher-order relationships, :func:`lmplot` and :func:`regplot` can fit a polynomial regression model to explore simple kinds of nonlinear trends in the dataset:" ] @@ -517,9 +515,9 @@ "metadata": { "celltoolbar": "Tags", "kernelspec": { - "display_name": "Python 3.6 (seaborn-dev)", + "display_name": "seaborn-py38-latest", "language": "python", - "name": "seaborn-dev" + "name": "seaborn-py38-latest" }, "language_info": { "codemirror_mode": { @@ -531,9 +529,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.3" + "version": "3.8.5" } }, "nbformat": 4, - "nbformat_minor": 1 + "nbformat_minor": 4 } diff --git a/doc/tutorial/relational.ipynb b/doc/tutorial/relational.ipynb index 76caa9c17a..fd01a16c77 100644 --- a/doc/tutorial/relational.ipynb +++ b/doc/tutorial/relational.ipynb @@ -22,7 +22,7 @@ "\n", "Statistical analysis is a process of understanding how variables in a dataset relate to each other and how those relationships depend on other variables. Visualization can be a core component of this process because, when data are visualized properly, the human visual system can see trends and patterns that indicate a relationship.\n", "\n", - "We will discuss three seaborn functions in this tutorial. The one we will use most is :func:`relplot`. This is a :ref:`figure-level function ` for visualizing statistical relationships using two common approaches: scatter plots and line plots. :func:`relplot` combines a :class:`FacetGrid` with one of two axes-level functions:\n", + "We will discuss three seaborn functions in this tutorial. The one we will use most is :func:`relplot`. This is a :doc:`figure-level function ` for visualizing statistical relationships using two common approaches: scatter plots and line plots. :func:`relplot` combines a :class:`FacetGrid` with one of two axes-level functions:\n", "\n", "- :func:`scatterplot` (with ``kind=\"scatter\"``; the default)\n", "- :func:`lineplot` (with ``kind=\"line\"``)\n", @@ -40,7 +40,7 @@ "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", - "sns.set(style=\"darkgrid\")" + "sns.set_theme(style=\"darkgrid\")" ] }, { @@ -277,7 +277,7 @@ "metadata": {}, "outputs": [], "source": [ - "sns.relplot(x=\"timepoint\", y=\"signal\", ci=None, kind=\"line\", data=fmri);" + "sns.relplot(x=\"timepoint\", y=\"signal\", errorbar=None, kind=\"line\", data=fmri);" ] }, { @@ -293,7 +293,7 @@ "metadata": {}, "outputs": [], "source": [ - "sns.relplot(x=\"timepoint\", y=\"signal\", kind=\"line\", ci=\"sd\", data=fmri);" + "sns.relplot(x=\"timepoint\", y=\"signal\", kind=\"line\", errorbar=\"sd\", data=fmri);" ] }, { @@ -404,9 +404,7 @@ }, { "cell_type": "raw", - "metadata": { - "collapsed": true - }, + "metadata": {}, "source": [ "The default colormap and handling of the legend in :func:`lineplot` also depends on whether the hue semantic is categorical or numeric:" ] @@ -461,7 +459,8 @@ "sns.relplot(x=\"time\", y=\"firing_rate\",\n", " hue=\"coherence\", style=\"choice\",\n", " hue_norm=LogNorm(),\n", - " kind=\"line\", data=dots);" + " kind=\"line\",\n", + " data=dots.query(\"coherence > 0\"));" ] }, { @@ -616,9 +615,9 @@ "metadata": { "celltoolbar": "Tags", "kernelspec": { - "display_name": "Python 3.6 (seaborn-dev)", + "display_name": "seaborn-py38-latest", "language": "python", - "name": "seaborn-dev" + "name": "seaborn-py38-latest" }, "language_info": { "codemirror_mode": { @@ -630,9 +629,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.3" + "version": "3.8.5" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/doc/whatsnew.rst b/doc/whatsnew.rst index 97f3e3fbba..01cbc0bb57 100644 --- a/doc/whatsnew.rst +++ b/doc/whatsnew.rst @@ -2,13 +2,40 @@ .. currentmodule:: seaborn +.. role:: raw-html(raw) + :format: html + +.. role:: raw-latex(raw) + :format: latex + +.. |API| replace:: :raw-html:`API` :raw-latex:`{\small\sc [API]}` +.. |Defaults| replace:: :raw-html:`Defaults` :raw-latex:`{\small\sc [Defaults]}` +.. |Docs| replace:: :raw-html:`Docs` :raw-latex:`{\small\sc [Docs]}` +.. |Feature| replace:: :raw-html:`Feature` :raw-latex:`{\small\sc [Feature]}` +.. |Enhancement| replace:: :raw-html:`Enhancement` :raw-latex:`{\small\sc [Enhancement]}` +.. |Fix| replace:: :raw-html:`Fix` :raw-latex:`{\small\sc [Fix]}` + What's new in each version ========================== +This page contains information about what has changed in each new version of ``seaborn``. + .. raw:: html
+.. include:: releases/v0.12.0.txt + +.. include:: releases/v0.11.1.txt + +.. include:: releases/v0.11.0.txt + +.. include:: releases/v0.10.1.txt + +.. include:: releases/v0.10.0.txt + +.. include:: releases/v0.9.1.txt + .. include:: releases/v0.9.0.txt .. include:: releases/v0.8.1.txt diff --git a/examples/anscombes_quartet.py b/examples/anscombes_quartet.py index 1b8efb08bf..6ba7170c3f 100644 --- a/examples/anscombes_quartet.py +++ b/examples/anscombes_quartet.py @@ -5,7 +5,7 @@ _thumb: .4, .4 """ import seaborn as sns -sns.set(style="ticks") +sns.set_theme(style="ticks") # Load the example dataset for Anscombe's quartet df = sns.load_dataset("anscombe") diff --git a/examples/different_scatter_variables.py b/examples/different_scatter_variables.py index 10e8d564f9..710d005808 100644 --- a/examples/different_scatter_variables.py +++ b/examples/different_scatter_variables.py @@ -1,13 +1,13 @@ """ -Scatterplot with categorical and numerical semantics -==================================================== +Scatterplot with multiple semantics +=================================== _thumb: .45, .5 """ import seaborn as sns import matplotlib.pyplot as plt -sns.set(style="whitegrid") +sns.set_theme(style="whitegrid") # Load the example diamonds dataset diamonds = sns.load_dataset("diamonds") diff --git a/examples/distplot_options.py b/examples/distplot_options.py deleted file mode 100644 index 147c8e11bc..0000000000 --- a/examples/distplot_options.py +++ /dev/null @@ -1,33 +0,0 @@ -""" -Distribution plot options -========================= - -""" -import numpy as np -import seaborn as sns -import matplotlib.pyplot as plt - -sns.set(style="white", palette="muted", color_codes=True) -rs = np.random.RandomState(10) - -# Set up the matplotlib figure -f, axes = plt.subplots(2, 2, figsize=(7, 7), sharex=True) -sns.despine(left=True) - -# Generate a random univariate dataset -d = rs.normal(size=100) - -# Plot a simple histogram with binsize determined automatically -sns.distplot(d, kde=False, color="b", ax=axes[0, 0]) - -# Plot a kernel density estimate and rug plot -sns.distplot(d, hist=False, rug=True, color="r", ax=axes[0, 1]) - -# Plot a filled kernel density estimate -sns.distplot(d, hist=False, color="g", kde_kws={"shade": True}, ax=axes[1, 0]) - -# Plot a histogram and kernel density estimate -sns.distplot(d, color="m", ax=axes[1, 1]) - -plt.setp(axes, yticks=[]) -plt.tight_layout() diff --git a/examples/errorband_lineplots.py b/examples/errorband_lineplots.py index 3656706abe..13a8ab3f85 100644 --- a/examples/errorband_lineplots.py +++ b/examples/errorband_lineplots.py @@ -6,7 +6,7 @@ """ import seaborn as sns -sns.set(style="darkgrid") +sns.set_theme(style="darkgrid") # Load an example dataset with long-form data fmri = sns.load_dataset("fmri") diff --git a/examples/faceted_histogram.py b/examples/faceted_histogram.py index 995dbe47a7..1c84b4ba10 100644 --- a/examples/faceted_histogram.py +++ b/examples/faceted_histogram.py @@ -2,14 +2,13 @@ Facetting histograms by subsets of data ======================================= -_thumb: .42, .57 +_thumb: .33, .57 """ -import numpy as np import seaborn as sns -import matplotlib.pyplot as plt -sns.set(style="darkgrid") -tips = sns.load_dataset("tips") -g = sns.FacetGrid(tips, row="sex", col="time", margin_titles=True) -bins = np.linspace(0, 60, 13) -g.map(plt.hist, "total_bill", color="steelblue", bins=bins) +sns.set_theme(style="darkgrid") +df = sns.load_dataset("penguins") +sns.displot( + df, x="flipper_length_mm", col="species", row="sex", + binwidth=3, height=3, facet_kws=dict(margin_titles=True), +) diff --git a/examples/faceted_lineplot.py b/examples/faceted_lineplot.py index 3b87d452b5..4bb4cd61f2 100644 --- a/examples/faceted_lineplot.py +++ b/examples/faceted_lineplot.py @@ -2,22 +2,22 @@ Line plots on multiple facets ============================= -_thumb: .45, .42 +_thumb: .48, .42 """ import seaborn as sns -sns.set(style="ticks") +sns.set_theme(style="ticks") dots = sns.load_dataset("dots") -# Define a palette to ensure that colors will be -# shared across the facets -palette = dict(zip(dots.coherence.unique(), - sns.color_palette("rocket_r", 6))) +# Define the palette as a list to specify exact values +palette = sns.color_palette("rocket_r") # Plot the lines on two facets -sns.relplot(x="time", y="firing_rate", - hue="coherence", size="choice", col="align", - size_order=["T1", "T2"], palette=palette, - height=5, aspect=.75, facet_kws=dict(sharex=False), - kind="line", legend="full", data=dots) +sns.relplot( + data=dots, + x="time", y="firing_rate", + hue="coherence", size="choice", col="align", + kind="line", size_order=["T1", "T2"], palette=palette, + height=5, aspect=.75, facet_kws=dict(sharex=False), +) diff --git a/examples/grouped_barplot.py b/examples/grouped_barplot.py index 908dd28e38..e3142ef8bb 100644 --- a/examples/grouped_barplot.py +++ b/examples/grouped_barplot.py @@ -2,16 +2,19 @@ Grouped barplots ================ -_thumb: .45, .5 +_thumb: .36, .5 """ import seaborn as sns -sns.set(style="whitegrid") +sns.set_theme(style="whitegrid") -# Load the example Titanic dataset -titanic = sns.load_dataset("titanic") +penguins = sns.load_dataset("penguins") -# Draw a nested barplot to show survival for class and sex -g = sns.catplot(x="class", y="survived", hue="sex", data=titanic, - height=6, kind="bar", palette="muted") +# Draw a nested barplot by species and sex +g = sns.catplot( + data=penguins, kind="bar", + x="species", y="body_mass_g", hue="sex", + ci="sd", palette="dark", alpha=.6, height=6 +) g.despine(left=True) -g.set_ylabels("survival probability") +g.set_axis_labels("", "Body mass (g)") +g.legend.set_title("") diff --git a/examples/grouped_boxplot.py b/examples/grouped_boxplot.py index eb4fb604c6..d10a9bbd84 100644 --- a/examples/grouped_boxplot.py +++ b/examples/grouped_boxplot.py @@ -6,7 +6,7 @@ """ import seaborn as sns -sns.set(style="ticks", palette="pastel") +sns.set_theme(style="ticks", palette="pastel") # Load the example tips dataset tips = sns.load_dataset("tips") diff --git a/examples/grouped_violinplots.py b/examples/grouped_violinplots.py index 5c86172d7b..788885863c 100644 --- a/examples/grouped_violinplots.py +++ b/examples/grouped_violinplots.py @@ -2,17 +2,16 @@ Grouped violinplots with split violins ====================================== -_thumb: .43, .47 +_thumb: .44, .47 """ import seaborn as sns -sns.set(style="whitegrid", palette="pastel", color_codes=True) +sns.set_theme(style="whitegrid") # Load the example tips dataset tips = sns.load_dataset("tips") # Draw a nested violinplot and split the violins for easier comparison -sns.violinplot(x="day", y="total_bill", hue="smoker", - split=True, inner="quart", - palette={"Yes": "y", "No": "b"}, - data=tips) +sns.violinplot(data=tips, x="day", y="total_bill", hue="smoker", + split=True, inner="quart", linewidth=1, + palette={"Yes": "b", "No": ".85"}) sns.despine(left=True) diff --git a/examples/heat_scatter.py b/examples/heat_scatter.py new file mode 100644 index 0000000000..228e91c402 --- /dev/null +++ b/examples/heat_scatter.py @@ -0,0 +1,41 @@ +""" +Scatterplot heatmap +------------------- + +_thumb: .5, .5 + +""" +import seaborn as sns +sns.set_theme(style="whitegrid") + +# Load the brain networks dataset, select subset, and collapse the multi-index +df = sns.load_dataset("brain_networks", header=[0, 1, 2], index_col=0) + +used_networks = [1, 5, 6, 7, 8, 12, 13, 17] +used_columns = (df.columns + .get_level_values("network") + .astype(int) + .isin(used_networks)) +df = df.loc[:, used_columns] + +df.columns = df.columns.map("-".join) + +# Compute a correlation matrix and convert to long-form +corr_mat = df.corr().stack().reset_index(name="correlation") + +# Draw each cell as a scatter point with varying size and color +g = sns.relplot( + data=corr_mat, + x="level_0", y="level_1", hue="correlation", size="correlation", + palette="vlag", hue_norm=(-1, 1), edgecolor=".7", + height=10, sizes=(50, 250), size_norm=(-.2, .8), +) + +# Tweak the figure to finalize +g.set(xlabel="", ylabel="", aspect="equal") +g.despine(left=True, bottom=True) +g.ax.margins(.02) +for label in g.ax.get_xticklabels(): + label.set_rotation(90) +for artist in g.legend.legendHandles: + artist.set_edgecolor(".7") diff --git a/examples/hexbin_marginals.py b/examples/hexbin_marginals.py index 5dff7885cc..e59b65fe69 100644 --- a/examples/hexbin_marginals.py +++ b/examples/hexbin_marginals.py @@ -6,10 +6,10 @@ """ import numpy as np import seaborn as sns -sns.set(style="ticks") +sns.set_theme(style="ticks") rs = np.random.RandomState(11) x = rs.gamma(2, size=1000) y = -.5 * x + rs.normal(size=1000) -sns.jointplot(x, y, kind="hex", color="#4CB391") +sns.jointplot(x=x, y=y, kind="hex", color="#4CB391") diff --git a/examples/histogram_stacked.py b/examples/histogram_stacked.py new file mode 100644 index 0000000000..9efd80406c --- /dev/null +++ b/examples/histogram_stacked.py @@ -0,0 +1,29 @@ +""" +Stacked histogram on a log scale +================================ + +_thumb: .5, .45 + +""" +import seaborn as sns +import matplotlib as mpl +import matplotlib.pyplot as plt + +sns.set_theme(style="ticks") + +diamonds = sns.load_dataset("diamonds") + +f, ax = plt.subplots(figsize=(7, 5)) +sns.despine(f) + +sns.histplot( + diamonds, + x="price", hue="cut", + multiple="stack", + palette="light:m_r", + edgecolor=".3", + linewidth=.5, + log_scale=True, +) +ax.xaxis.set_major_formatter(mpl.ticker.ScalarFormatter()) +ax.set_xticks([500, 1000, 2000, 5000, 10000]) diff --git a/examples/horizontal_boxplot.py b/examples/horizontal_boxplot.py index a29b914aa6..48e4991fac 100644 --- a/examples/horizontal_boxplot.py +++ b/examples/horizontal_boxplot.py @@ -7,7 +7,7 @@ import seaborn as sns import matplotlib.pyplot as plt -sns.set(style="ticks") +sns.set_theme(style="ticks") # Initialize the figure with a logarithmic x axis f, ax = plt.subplots(figsize=(7, 6)) @@ -18,11 +18,11 @@ # Plot the orbital period with horizontal boxes sns.boxplot(x="distance", y="method", data=planets, - whis="range", palette="vlag") + whis=[0, 100], width=.6, palette="vlag") # Add in points to show each observation -sns.swarmplot(x="distance", y="method", data=planets, - size=2, color=".3", linewidth=0) +sns.stripplot(x="distance", y="method", data=planets, + size=4, color=".3", linewidth=0) # Tweak the visual presentation ax.xaxis.grid(True) diff --git a/examples/jitter_stripplot.py b/examples/jitter_stripplot.py index c2cc0af620..e5abf78ce3 100644 --- a/examples/jitter_stripplot.py +++ b/examples/jitter_stripplot.py @@ -7,7 +7,7 @@ import seaborn as sns import matplotlib.pyplot as plt -sns.set(style="whitegrid") +sns.set_theme(style="whitegrid") iris = sns.load_dataset("iris") # "Melt" the dataset to "long-form" or "tidy" representation @@ -21,12 +21,15 @@ sns.stripplot(x="value", y="measurement", hue="species", data=iris, dodge=True, alpha=.25, zorder=1) -# Show the conditional means +# Show the conditional means, aligning each pointplot in the +# center of the strips by adjusting the width allotted to each +# category (.8 by default) by the number of hue levels sns.pointplot(x="value", y="measurement", hue="species", - data=iris, dodge=.532, join=False, palette="dark", + data=iris, dodge=.8 - .8 / 3, + join=False, palette="dark", markers="d", scale=.75, ci=None) -# Improve the legend +# Improve the legend handles, labels = ax.get_legend_handles_labels() ax.legend(handles[3:], labels[3:], title="species", handletextpad=0, columnspacing=1, diff --git a/examples/joint_histogram.py b/examples/joint_histogram.py new file mode 100644 index 0000000000..5ad350a687 --- /dev/null +++ b/examples/joint_histogram.py @@ -0,0 +1,26 @@ +""" +Joint and marginal histograms +============================= + +_thumb: .52, .505 + +""" +import seaborn as sns +sns.set_theme(style="ticks") + +# Load the planets dataset and initialize the figure +planets = sns.load_dataset("planets") +g = sns.JointGrid(data=planets, x="year", y="distance", marginal_ticks=True) + +# Set a log scaling on the y axis +g.ax_joint.set(yscale="log") + +# Create an inset legend for the histogram colorbar +cax = g.fig.add_axes([.15, .55, .02, .2]) + +# Add the joint and marginal histogram plots +g.plot_joint( + sns.histplot, discrete=(True, False), + cmap="light:#03012d", pmax=.8, cbar=True, cbar_ax=cax +) +g.plot_marginals(sns.histplot, element="step", color="#03012d") diff --git a/examples/joint_kde.py b/examples/joint_kde.py index 70edebc156..2358228ba5 100644 --- a/examples/joint_kde.py +++ b/examples/joint_kde.py @@ -4,18 +4,15 @@ _thumb: .6, .4 """ -import numpy as np -import pandas as pd import seaborn as sns -sns.set(style="white") +sns.set_theme(style="ticks") -# Generate a random correlated bivariate dataset -rs = np.random.RandomState(5) -mean = [0, 0] -cov = [(1, .5), (.5, 1)] -x1, x2 = rs.multivariate_normal(mean, cov, 500).T -x1 = pd.Series(x1, name="$X_1$") -x2 = pd.Series(x2, name="$X_2$") +# Load the penguins dataset +penguins = sns.load_dataset("penguins") # Show the joint distribution using kernel density estimation -g = sns.jointplot(x1, x2, kind="kde", height=7, space=0) +g = sns.jointplot( + data=penguins, + x="bill_length_mm", y="bill_depth_mm", hue="species", + kind="kde", +) diff --git a/examples/kde_ridgeplot.py b/examples/kde_ridgeplot.py index 75e3193799..cd30e99547 100644 --- a/examples/kde_ridgeplot.py +++ b/examples/kde_ridgeplot.py @@ -8,7 +8,7 @@ import pandas as pd import seaborn as sns import matplotlib.pyplot as plt -sns.set(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}) +sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}) # Create the data rs = np.random.RandomState(1979) @@ -23,8 +23,10 @@ g = sns.FacetGrid(df, row="g", hue="g", aspect=15, height=.5, palette=pal) # Draw the densities in a few steps -g.map(sns.kdeplot, "x", clip_on=False, shade=True, alpha=1, lw=1.5, bw=.2) -g.map(sns.kdeplot, "x", clip_on=False, color="w", lw=2, bw=.2) +g.map(sns.kdeplot, "x", + bw_adjust=.5, clip_on=False, + fill=True, alpha=1, linewidth=1.5) +g.map(sns.kdeplot, "x", clip_on=False, color="w", lw=2, bw_adjust=.5) g.map(plt.axhline, y=0, lw=2, clip_on=False) diff --git a/examples/large_distributions.py b/examples/large_distributions.py index 6133a4b323..6dbfe63aae 100644 --- a/examples/large_distributions.py +++ b/examples/large_distributions.py @@ -4,7 +4,7 @@ """ import seaborn as sns -sns.set(style="whitegrid") +sns.set_theme(style="whitegrid") diamonds = sns.load_dataset("diamonds") clarity_ranking = ["I1", "SI2", "SI1", "VS2", "VS1", "VVS2", "VVS1", "IF"] diff --git a/examples/layered_bivariate_plot.py b/examples/layered_bivariate_plot.py new file mode 100644 index 0000000000..40c63e35f4 --- /dev/null +++ b/examples/layered_bivariate_plot.py @@ -0,0 +1,23 @@ +""" +Bivariate plot with multiple elements +===================================== + + +""" +import numpy as np +import seaborn as sns +import matplotlib.pyplot as plt +sns.set_theme(style="dark") + +# Simulate data from a bivariate Gaussian +n = 10000 +mean = [0, 0] +cov = [(2, .4), (.4, .2)] +rng = np.random.RandomState(0) +x, y = rng.multivariate_normal(mean, cov, n).T + +# Draw a combo histogram and scatterplot with density contours +f, ax = plt.subplots(figsize=(6, 6)) +sns.scatterplot(x=x, y=y, s=5, color=".15") +sns.histplot(x=x, y=y, bins=50, pthresh=.1, cmap="mako") +sns.kdeplot(x=x, y=y, levels=5, color="w", linewidths=1) diff --git a/examples/logistic_regression.py b/examples/logistic_regression.py index de6aafae38..8e636956e0 100644 --- a/examples/logistic_regression.py +++ b/examples/logistic_regression.py @@ -5,7 +5,7 @@ _thumb: .58, .5 """ import seaborn as sns -sns.set(style="darkgrid") +sns.set_theme(style="darkgrid") # Load the example Titanic dataset df = sns.load_dataset("titanic") @@ -15,5 +15,5 @@ # Show the survival probability as a function of age and sex g = sns.lmplot(x="age", y="survived", col="sex", hue="sex", data=df, - palette=pal, y_jitter=.02, logistic=True) + palette=pal, y_jitter=.02, logistic=True, truncate=False) g.set(xlim=(0, 80), ylim=(-.05, 1.05)) diff --git a/examples/many_facets.py b/examples/many_facets.py index 1375a3849f..4c03ac8904 100644 --- a/examples/many_facets.py +++ b/examples/many_facets.py @@ -10,7 +10,7 @@ import seaborn as sns import matplotlib.pyplot as plt -sns.set(style="ticks") +sns.set_theme(style="ticks") # Create a dataset with many short random walks rs = np.random.RandomState(4) diff --git a/examples/many_pairwise_correlations.py b/examples/many_pairwise_correlations.py index 4d93f8ace7..2ae2315412 100644 --- a/examples/many_pairwise_correlations.py +++ b/examples/many_pairwise_correlations.py @@ -10,7 +10,7 @@ import seaborn as sns import matplotlib.pyplot as plt -sns.set(style="white") +sns.set_theme(style="white") # Generate a large random dataset rs = np.random.RandomState(33) @@ -21,13 +21,13 @@ corr = d.corr() # Generate a mask for the upper triangle -mask = np.triu(np.ones_like(corr, dtype=np.bool)) +mask = np.triu(np.ones_like(corr, dtype=bool)) # Set up the matplotlib figure f, ax = plt.subplots(figsize=(11, 9)) # Generate a custom diverging colormap -cmap = sns.diverging_palette(220, 10, as_cmap=True) +cmap = sns.diverging_palette(230, 20, as_cmap=True) # Draw the heatmap with the mask and correct aspect ratio sns.heatmap(corr, mask=mask, cmap=cmap, vmax=.3, center=0, diff --git a/examples/marginal_ticks.py b/examples/marginal_ticks.py index e8389ac2a2..e6a8d61518 100644 --- a/examples/marginal_ticks.py +++ b/examples/marginal_ticks.py @@ -2,20 +2,14 @@ Scatterplot with marginal ticks =============================== -_thumb: .68, .32 +_thumb: .66, .34 """ -import numpy as np import seaborn as sns -import matplotlib.pyplot as plt -sns.set(style="white", color_codes=True) - -# Generate a random bivariate dataset -rs = np.random.RandomState(9) -mean = [0, 0] -cov = [(1, 0), (0, 2)] -x, y = rs.multivariate_normal(mean, cov, 100).T +sns.set_theme(style="white", color_codes=True) +mpg = sns.load_dataset("mpg") # Use JointGrid directly to draw a custom plot -grid = sns.JointGrid(x, y, space=0, height=6, ratio=50) -grid.plot_joint(plt.scatter, color="g") -grid.plot_marginals(sns.rugplot, height=1, color="g") +g = sns.JointGrid(data=mpg, x="mpg", y="acceleration", space=0, ratio=17) +g.plot_joint(sns.scatterplot, size=mpg["horsepower"], sizes=(30, 120), + color="g", alpha=.6, legend=False) +g.plot_marginals(sns.rugplot, height=1, color="g", alpha=.6) diff --git a/examples/multiple_bivariate_kde.py b/examples/multiple_bivariate_kde.py new file mode 100644 index 0000000000..217c8afa5d --- /dev/null +++ b/examples/multiple_bivariate_kde.py @@ -0,0 +1,24 @@ +""" +Multiple bivariate KDE plots +============================ + +_thumb: .6, .45 +""" +import seaborn as sns +import matplotlib.pyplot as plt + +sns.set_theme(style="darkgrid") +iris = sns.load_dataset("iris") + +# Set up the figure +f, ax = plt.subplots(figsize=(8, 8)) +ax.set_aspect("equal") + +# Draw a contour plot to represent each bivariate density +sns.kdeplot( + data=iris.query("species != 'versicolor'"), + x="sepal_width", + y="sepal_length", + hue="species", + thresh=.1, +) diff --git a/examples/multiple_conditional_kde.py b/examples/multiple_conditional_kde.py new file mode 100644 index 0000000000..0c6dfb1079 --- /dev/null +++ b/examples/multiple_conditional_kde.py @@ -0,0 +1,20 @@ +""" +Conditional kernel density estimate +=================================== + +_thumb: .4, .5 +""" +import seaborn as sns +sns.set_theme(style="whitegrid") + +# Load the diamonds dataset +diamonds = sns.load_dataset("diamonds") + +# Plot the distribution of clarity ratings, conditional on carat +sns.displot( + data=diamonds, + x="carat", hue="cut", + kind="kde", height=6, + multiple="fill", clip=(0, None), + palette="ch:rot=-.25,hue=1,light=.75", +) diff --git a/examples/multiple_ecdf.py b/examples/multiple_ecdf.py new file mode 100644 index 0000000000..4ae904590b --- /dev/null +++ b/examples/multiple_ecdf.py @@ -0,0 +1,17 @@ +""" +Facetted ECDF plots +=================== + +_thumb: .30, .49 +""" +import seaborn as sns +sns.set_theme(style="ticks") +mpg = sns.load_dataset("mpg") + +colors = (250, 70, 50), (350, 70, 50) +cmap = sns.blend_palette(colors, input="husl", as_cmap=True) +sns.displot( + mpg, + x="displacement", col="origin", hue="model_year", + kind="ecdf", aspect=.75, linewidth=2, palette=cmap, +) diff --git a/examples/multiple_joint_kde.py b/examples/multiple_joint_kde.py deleted file mode 100644 index 521144d5ec..0000000000 --- a/examples/multiple_joint_kde.py +++ /dev/null @@ -1,31 +0,0 @@ -""" -Multiple bivariate KDE plots -============================ - -_thumb: .6, .45 -""" -import seaborn as sns -import matplotlib.pyplot as plt - -sns.set(style="darkgrid") -iris = sns.load_dataset("iris") - -# Subset the iris dataset by species -setosa = iris.query("species == 'setosa'") -virginica = iris.query("species == 'virginica'") - -# Set up the figure -f, ax = plt.subplots(figsize=(8, 8)) -ax.set_aspect("equal") - -# Draw the two density plots -ax = sns.kdeplot(setosa.sepal_width, setosa.sepal_length, - cmap="Reds", shade=True, shade_lowest=False) -ax = sns.kdeplot(virginica.sepal_width, virginica.sepal_length, - cmap="Blues", shade=True, shade_lowest=False) - -# Add labels to the plot -red = sns.color_palette("Reds")[-2] -blue = sns.color_palette("Blues")[-2] -ax.text(2.5, 8.2, "virginica", size=16, color=blue) -ax.text(3.8, 4.5, "setosa", size=16, color=red) diff --git a/examples/multiple_regression.py b/examples/multiple_regression.py index 93c0f291ad..cb72204fca 100644 --- a/examples/multiple_regression.py +++ b/examples/multiple_regression.py @@ -5,14 +5,17 @@ _thumb: .45, .45 """ import seaborn as sns -sns.set() +sns.set_theme() -# Load the iris dataset -iris = sns.load_dataset("iris") +# Load the penguins dataset +penguins = sns.load_dataset("penguins") # Plot sepal width as a function of sepal_length across days -g = sns.lmplot(x="sepal_length", y="sepal_width", hue="species", - truncate=True, height=5, data=iris) +g = sns.lmplot( + data=penguins, + x="bill_length_mm", y="bill_depth_mm", hue="species", + height=5 +) # Use more informative axis labels than are provided by default -g.set_axis_labels("Sepal length (mm)", "Sepal width (mm)") +g.set_axis_labels("Snoot length (mm)", "Snoot depth (mm)") diff --git a/examples/pair_grid_with_kde.py b/examples/pair_grid_with_kde.py index 7b3d27bfce..3ad74429cc 100644 --- a/examples/pair_grid_with_kde.py +++ b/examples/pair_grid_with_kde.py @@ -5,11 +5,11 @@ _thumb: .5, .5 """ import seaborn as sns -sns.set(style="white") +sns.set_theme(style="white") -df = sns.load_dataset("iris") +df = sns.load_dataset("penguins") g = sns.PairGrid(df, diag_sharey=False) +g.map_upper(sns.scatterplot, s=15) g.map_lower(sns.kdeplot) -g.map_upper(sns.scatterplot) -g.map_diag(sns.kdeplot, lw=3) +g.map_diag(sns.kdeplot, lw=2) diff --git a/examples/paired_pointplots.py b/examples/paired_pointplots.py index e8f111610d..0ab79655b8 100644 --- a/examples/paired_pointplots.py +++ b/examples/paired_pointplots.py @@ -4,7 +4,7 @@ """ import seaborn as sns -sns.set(style="whitegrid") +sns.set_theme(style="whitegrid") # Load the example Titanic dataset titanic = sns.load_dataset("titanic") diff --git a/examples/pairgrid_dotplot.py b/examples/pairgrid_dotplot.py index c4c8b6ac1b..8509812d4a 100644 --- a/examples/pairgrid_dotplot.py +++ b/examples/pairgrid_dotplot.py @@ -5,7 +5,7 @@ _thumb: .3, .3 """ import seaborn as sns -sns.set(style="whitegrid") +sns.set_theme(style="whitegrid") # Load the dataset crashes = sns.load_dataset("car_crashes") @@ -16,8 +16,8 @@ height=10, aspect=.25) # Draw a dot plot using the stripplot function -g.map(sns.stripplot, size=10, orient="h", - palette="ch:s=1,r=-.1,h=1_r", linewidth=1, edgecolor="w") +g.map(sns.stripplot, size=10, orient="h", jitter=False, + palette="flare_r", linewidth=1, edgecolor="w") # Use the same x axis limits on all columns and add better labels g.set(xlim=(0, 25), xlabel="Crashes", ylabel="") diff --git a/examples/color_palettes.py b/examples/palette_choices.py similarity index 95% rename from examples/color_palettes.py rename to examples/palette_choices.py index 364f458b8b..141a024854 100644 --- a/examples/color_palettes.py +++ b/examples/palette_choices.py @@ -6,7 +6,7 @@ import numpy as np import seaborn as sns import matplotlib.pyplot as plt -sns.set(style="white", context="talk") +sns.set_theme(style="white", context="talk") rs = np.random.RandomState(8) # Set up the matplotlib figure diff --git a/examples/cubehelix_palette.py b/examples/palette_generation.py similarity index 64% rename from examples/cubehelix_palette.py rename to examples/palette_generation.py index 753f20d47f..82ef6842f5 100644 --- a/examples/cubehelix_palette.py +++ b/examples/palette_generation.py @@ -8,7 +8,7 @@ import seaborn as sns import matplotlib.pyplot as plt -sns.set(style="dark") +sns.set_theme(style="white") rs = np.random.RandomState(50) # Set up the matplotlib figure @@ -21,8 +21,15 @@ cmap = sns.cubehelix_palette(start=s, light=1, as_cmap=True) # Generate and plot a random bivariate dataset - x, y = rs.randn(2, 50) - sns.kdeplot(x, y, cmap=cmap, shade=True, cut=5, ax=ax) - ax.set(xlim=(-3, 3), ylim=(-3, 3)) + x, y = rs.normal(size=(2, 50)) + sns.kdeplot( + x=x, y=y, + cmap=cmap, fill=True, + clip=(-5, 5), cut=10, + thresh=0, levels=15, + ax=ax, + ) + ax.set_axis_off() -f.tight_layout() +ax.set(xlim=(-3.5, 3.5), ylim=(-3.5, 3.5)) +f.subplots_adjust(0, 0, 1, 1, .08, .08) diff --git a/examples/horizontal_barplot.py b/examples/part_whole_bars.py similarity index 96% rename from examples/horizontal_barplot.py rename to examples/part_whole_bars.py index 0e022f8b54..d64ffbf74d 100644 --- a/examples/horizontal_barplot.py +++ b/examples/part_whole_bars.py @@ -5,7 +5,7 @@ """ import seaborn as sns import matplotlib.pyplot as plt -sns.set(style="whitegrid") +sns.set_theme(style="whitegrid") # Initialize the matplotlib figure f, ax = plt.subplots(figsize=(6, 15)) diff --git a/examples/pointplot_anova.py b/examples/pointplot_anova.py index 5289aee463..fda1b387d8 100644 --- a/examples/pointplot_anova.py +++ b/examples/pointplot_anova.py @@ -5,7 +5,7 @@ _thumb: .42, .5 """ import seaborn as sns -sns.set(style="whitegrid") +sns.set_theme(style="whitegrid") # Load the example exercise dataset df = sns.load_dataset("exercise") diff --git a/examples/facet_projections.py b/examples/radial_facets.py similarity index 97% rename from examples/facet_projections.py rename to examples/radial_facets.py index 37771e4057..3d95717239 100644 --- a/examples/facet_projections.py +++ b/examples/radial_facets.py @@ -9,7 +9,7 @@ import pandas as pd import seaborn as sns -sns.set() +sns.set_theme() # Generate an example radial datast r = np.linspace(0, 10, num=100) diff --git a/examples/regression_marginals.py b/examples/regression_marginals.py index 2f0f5e1c9c..a64bb91c25 100644 --- a/examples/regression_marginals.py +++ b/examples/regression_marginals.py @@ -5,8 +5,10 @@ _thumb: .65, .65 """ import seaborn as sns -sns.set(style="darkgrid") +sns.set_theme(style="darkgrid") tips = sns.load_dataset("tips") -g = sns.jointplot("total_bill", "tip", data=tips, kind="reg", - xlim=(0, 60), ylim=(0, 12), color="m", height=7) +g = sns.jointplot(x="total_bill", y="tip", data=tips, + kind="reg", truncate=False, + xlim=(0, 60), ylim=(0, 12), + color="m", height=7) diff --git a/examples/residplot.py b/examples/residplot.py index 1a557947ac..cfd9518df1 100644 --- a/examples/residplot.py +++ b/examples/residplot.py @@ -5,7 +5,7 @@ """ import numpy as np import seaborn as sns -sns.set(style="whitegrid") +sns.set_theme(style="whitegrid") # Make an example dataset with y ~ x rs = np.random.RandomState(7) @@ -13,4 +13,4 @@ y = 2 + 1.5 * x + rs.normal(0, 2, 75) # Plot the residuals after fitting a linear model -sns.residplot(x, y, lowess=True, color="g") +sns.residplot(x=x, y=y, lowess=True, color="g") diff --git a/examples/scatter_bubbles.py b/examples/scatter_bubbles.py index 8ee59d9d63..4b9c6577c1 100644 --- a/examples/scatter_bubbles.py +++ b/examples/scatter_bubbles.py @@ -6,7 +6,7 @@ """ import seaborn as sns -sns.set(style="white") +sns.set_theme(style="white") # Load the example mpg dataset mpg = sns.load_dataset("mpg") diff --git a/examples/scatterplot_categorical.py b/examples/scatterplot_categorical.py index 79be715e18..cc22a611e1 100644 --- a/examples/scatterplot_categorical.py +++ b/examples/scatterplot_categorical.py @@ -2,17 +2,15 @@ Scatterplot with categorical variables ====================================== +_thumb: .45, .45 + """ -import pandas as pd import seaborn as sns -sns.set(style="whitegrid", palette="muted") - -# Load the example iris dataset -iris = sns.load_dataset("iris") +sns.set_theme(style="whitegrid", palette="muted") -# "Melt" the dataset to "long-form" or "tidy" representation -iris = pd.melt(iris, "species", var_name="measurement") +# Load the penguins dataset +df = sns.load_dataset("penguins") # Draw a categorical scatterplot to show each observation -sns.swarmplot(x="measurement", y="value", hue="species", - palette=["r", "c", "y"], data=iris) +ax = sns.swarmplot(data=df, x="body_mass_g", y="sex", hue="species") +ax.set(ylabel="") diff --git a/examples/scatterplot_matrix.py b/examples/scatterplot_matrix.py index 399072f701..61829ea1b1 100644 --- a/examples/scatterplot_matrix.py +++ b/examples/scatterplot_matrix.py @@ -2,10 +2,10 @@ Scatterplot Matrix ================== -_thumb: .5, .43 +_thumb: .3, .2 """ import seaborn as sns -sns.set(style="ticks") +sns.set_theme(style="ticks") -df = sns.load_dataset("iris") +df = sns.load_dataset("penguins") sns.pairplot(df, hue="species") diff --git a/examples/scatterplot_sizes.py b/examples/scatterplot_sizes.py index e2bb6444bf..492f6a496a 100644 --- a/examples/scatterplot_sizes.py +++ b/examples/scatterplot_sizes.py @@ -2,18 +2,23 @@ Scatterplot with continuous hues and sizes ========================================== -_thumb: .45, .45 +_thumb: .51, .44 """ - import seaborn as sns -sns.set() +sns.set_theme(style="whitegrid") # Load the example planets dataset planets = sns.load_dataset("planets") cmap = sns.cubehelix_palette(rot=-.2, as_cmap=True) -ax = sns.scatterplot(x="distance", y="orbital_period", - hue="year", size="mass", - palette=cmap, sizes=(10, 200), - data=planets) +g = sns.relplot( + data=planets, + x="distance", y="orbital_period", + hue="year", size="mass", + palette=cmap, sizes=(10, 200), +) +g.set(xscale="log", yscale="log") +g.ax.xaxis.grid(True, "minor", linewidth=.25) +g.ax.yaxis.grid(True, "minor", linewidth=.25) +g.despine(left=True, bottom=True) diff --git a/examples/simple_violinplots.py b/examples/simple_violinplots.py index d04a9d5225..2d53ccbaf8 100644 --- a/examples/simple_violinplots.py +++ b/examples/simple_violinplots.py @@ -6,16 +6,13 @@ import numpy as np import seaborn as sns -sns.set() +sns.set_theme() # Create a random dataset across several variables -rs = np.random.RandomState(0) +rs = np.random.default_rng(0) n, p = 40, 8 d = rs.normal(0, 2, (n, p)) d += np.log(np.arange(1, p + 1)) * -5 + 10 -# Use cubehelix to get a custom sequential palette -pal = sns.cubehelix_palette(p, rot=-.5, dark=.3) - # Show each distribution with both violins and points -sns.violinplot(data=d, palette=pal, inner="points") +sns.violinplot(data=d, palette="light:g", inner="points", orient="h") diff --git a/examples/smooth_bivariate_kde.py b/examples/smooth_bivariate_kde.py new file mode 100644 index 0000000000..a654e399f1 --- /dev/null +++ b/examples/smooth_bivariate_kde.py @@ -0,0 +1,16 @@ +""" +Smooth kernel density with marginal histograms +============================================== + +_thumb: .48, .41 +""" +import seaborn as sns +sns.set_theme(style="white") + +df = sns.load_dataset("penguins") + +g = sns.JointGrid(data=df, x="body_mass_g", y="bill_depth_mm", space=0) +g.plot_joint(sns.kdeplot, + fill=True, clip=((2200, 6800), (10, 25)), + thresh=0, levels=100, cmap="rocket") +g.plot_marginals(sns.histplot, color="#03051A", alpha=1, bins=25) diff --git a/examples/heatmap_annotation.py b/examples/spreadsheet_heatmap.py similarity index 96% rename from examples/heatmap_annotation.py rename to examples/spreadsheet_heatmap.py index 84a586ecf7..6eeecd0d05 100644 --- a/examples/heatmap_annotation.py +++ b/examples/spreadsheet_heatmap.py @@ -5,7 +5,7 @@ """ import matplotlib.pyplot as plt import seaborn as sns -sns.set() +sns.set_theme() # Load the example flights dataset and convert to long-form flights_long = sns.load_dataset("flights") diff --git a/examples/structured_heatmap.py b/examples/structured_heatmap.py index 7986eae77d..af675caa34 100644 --- a/examples/structured_heatmap.py +++ b/examples/structured_heatmap.py @@ -2,11 +2,11 @@ Discovering structure in heatmap data ===================================== -_thumb: .4, .25 +_thumb: .3, .25 """ import pandas as pd import seaborn as sns -sns.set() +sns.set_theme() # Load the brain networks example dataset df = sns.load_dataset("brain_networks", header=[0, 1, 2], index_col=0) @@ -27,6 +27,10 @@ network_colors = pd.Series(networks, index=df.columns).map(network_lut) # Draw the full plot -sns.clustermap(df.corr(), center=0, cmap="vlag", - row_colors=network_colors, col_colors=network_colors, - linewidths=.75, figsize=(13, 13)) +g = sns.clustermap(df.corr(), center=0, cmap="vlag", + row_colors=network_colors, col_colors=network_colors, + dendrogram_ratio=(.1, .2), + cbar_pos=(.02, .32, .03, .2), + linewidths=.75, figsize=(12, 13)) + +g.ax_row_dendrogram.remove() diff --git a/examples/three_variable_histogram.py b/examples/three_variable_histogram.py new file mode 100644 index 0000000000..d642b7d40a --- /dev/null +++ b/examples/three_variable_histogram.py @@ -0,0 +1,15 @@ +""" +Trivariate histogram with two categorical variables +=================================================== + +_thumb: .32, .55 + +""" +import seaborn as sns +sns.set_theme(style="dark") + +diamonds = sns.load_dataset("diamonds") +sns.displot( + data=diamonds, x="price", y="color", col="clarity", + log_scale=(True, False), col_wrap=4, height=4, aspect=.7, +) diff --git a/examples/timeseries_facets.py b/examples/timeseries_facets.py new file mode 100644 index 0000000000..b757c93c12 --- /dev/null +++ b/examples/timeseries_facets.py @@ -0,0 +1,39 @@ +""" +Small multiple time series +-------------------------- + +_thumb: .42, .58 + +""" +import seaborn as sns + +sns.set_theme(style="dark") +flights = sns.load_dataset("flights") + +# Plot each year's time series in its own facet +g = sns.relplot( + data=flights, + x="month", y="passengers", col="year", hue="year", + kind="line", palette="crest", linewidth=4, zorder=5, + col_wrap=3, height=2, aspect=1.5, legend=False, +) + +# Iterate over each subplot to customize further +for year, ax in g.axes_dict.items(): + + # Add the title as an annotation within the plot + ax.text(.8, .85, year, transform=ax.transAxes, fontweight="bold") + + # Plot every year's time series in the background + sns.lineplot( + data=flights, x="month", y="passengers", units="year", + estimator=None, color=".7", linewidth=1, ax=ax, + ) + +# Reduce the frequency of the x axis ticks +ax.set_xticks(ax.get_xticks()[::2]) + +# Tweak the supporting aspects of the plot +g.set_titles("") +g.set_axis_labels("", "Passengers") +g.tight_layout() diff --git a/examples/wide_data_lineplot.py b/examples/wide_data_lineplot.py index 580217eba0..cfdf24960a 100644 --- a/examples/wide_data_lineplot.py +++ b/examples/wide_data_lineplot.py @@ -8,7 +8,7 @@ import numpy as np import pandas as pd import seaborn as sns -sns.set(style="whitegrid") +sns.set_theme(style="whitegrid") rs = np.random.RandomState(365) values = rs.randn(365, 4).cumsum(axis=0) diff --git a/examples/wide_form_violinplot.py b/examples/wide_form_violinplot.py index 3ec9705ff8..77a90193f7 100644 --- a/examples/wide_form_violinplot.py +++ b/examples/wide_form_violinplot.py @@ -6,7 +6,7 @@ """ import seaborn as sns import matplotlib.pyplot as plt -sns.set(style="whitegrid") +sns.set_theme(style="whitegrid") # Load the example dataset of brain network correlations df = sns.load_dataset("brain_networks", header=[0, 1, 2], index_col=0) diff --git a/licences/SIX_LICENSE b/licences/SIX_LICENSE deleted file mode 100644 index d76e024263..0000000000 --- a/licences/SIX_LICENSE +++ /dev/null @@ -1,18 +0,0 @@ -Copyright (c) 2010-2014 Benjamin Peterson - -Permission is hereby granted, free of charge, to any person obtaining a copy of -this software and associated documentation files (the "Software"), to deal in -the Software without restriction, including without limitation the rights to -use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of -the Software, and to permit persons to whom the Software is furnished to do so, -subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS -FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR -COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER -IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN -CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000000..7e61aea2ad --- /dev/null +++ b/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +filterwarnings = +; Warnings raised from within patsy imports + ignore:Using or importing the ABCs:DeprecationWarning +junit_family=xunit1 diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 8fed041814..0000000000 --- a/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -numpy -scipy -matplotlib -pandas diff --git a/seaborn/__init__.py b/seaborn/__init__.py index edae840a77..a2a60bb316 100644 --- a/seaborn/__init__.py +++ b/seaborn/__init__.py @@ -1,21 +1,21 @@ +# Import seaborn objects +from .rcmod import * # noqa: F401,F403 +from .utils import * # noqa: F401,F403 +from .palettes import * # noqa: F401,F403 +from .relational import * # noqa: F401,F403 +from .regression import * # noqa: F401,F403 +from .categorical import * # noqa: F401,F403 +from .distributions import * # noqa: F401,F403 +from .matrix import * # noqa: F401,F403 +from .miscplot import * # noqa: F401,F403 +from .axisgrid import * # noqa: F401,F403 +from .widgets import * # noqa: F401,F403 +from .colors import xkcd_rgb, crayons # noqa: F401 +from . import cm # noqa: F401 + # Capture the original matplotlib rcParams import matplotlib as mpl _orig_rc_params = mpl.rcParams.copy() -# Import seaborn objects -from .rcmod import * -from .utils import * -from .palettes import * -from .relational import * -from .regression import * -from .categorical import * -from .distributions import * -from .timeseries import * -from .matrix import * -from .miscplot import * -from .axisgrid import * -from .widgets import * -from .colors import xkcd_rgb, crayons -from . import cm - -__version__ = "0.9.1.dev0" +# Define the seaborn version +__version__ = "0.12.0.dev0" diff --git a/seaborn/_core.py b/seaborn/_core.py new file mode 100644 index 0000000000..c901998f22 --- /dev/null +++ b/seaborn/_core.py @@ -0,0 +1,1739 @@ +import warnings +import itertools +from copy import copy +from functools import partial +from collections import UserString +from collections.abc import Iterable, Sequence, Mapping +from numbers import Number +from datetime import datetime +from distutils.version import LooseVersion + +import numpy as np +import pandas as pd +import matplotlib as mpl + +from ._decorators import ( + share_init_params_with_map, +) +from .palettes import ( + QUAL_PALETTES, + color_palette, +) +from .utils import ( + _check_argument, + get_color_cycle, + remove_na, +) + + +class SemanticMapping: + """Base class for mapping data values to plot attributes.""" + + # -- Default attributes that all SemanticMapping subclasses must set + + # Whether the mapping is numeric, categorical, or datetime + map_type = None + + # Ordered list of unique values in the input data + levels = None + + # A mapping from the data values to corresponding plot attributes + lookup_table = None + + def __init__(self, plotter): + + # TODO Putting this here so we can continue to use a lot of the + # logic that's built into the library, but the idea of this class + # is to move towards semantic mappings that are agnostic about the + # kind of plot they're going to be used to draw. + # Fully achieving that is going to take some thinking. + self.plotter = plotter + + def map(cls, plotter, *args, **kwargs): + # This method is assigned the __init__ docstring + method_name = "_{}_map".format(cls.__name__[:-7].lower()) + setattr(plotter, method_name, cls(plotter, *args, **kwargs)) + return plotter + + def _lookup_single(self, key): + """Apply the mapping to a single data value.""" + return self.lookup_table[key] + + def __call__(self, key, *args, **kwargs): + """Get the attribute(s) values for the data key.""" + if isinstance(key, (list, np.ndarray, pd.Series)): + return [self._lookup_single(k, *args, **kwargs) for k in key] + else: + return self._lookup_single(key, *args, **kwargs) + + +@share_init_params_with_map +class HueMapping(SemanticMapping): + """Mapping that sets artist colors according to data values.""" + # A specification of the colors that should appear in the plot + palette = None + + # An object that normalizes data values to [0, 1] range for color mapping + norm = None + + # A continuous colormap object for interpolating in a numeric context + cmap = None + + def __init__( + self, plotter, palette=None, order=None, norm=None, + ): + """Map the levels of the `hue` variable to distinct colors. + + Parameters + ---------- + # TODO add generic parameters + + """ + super().__init__(plotter) + + data = plotter.plot_data.get("hue", pd.Series(dtype=float)) + + if data.notna().any(): + + map_type = self.infer_map_type( + palette, norm, plotter.input_format, plotter.var_types["hue"] + ) + + # Our goal is to end up with a dictionary mapping every unique + # value in `data` to a color. We will also keep track of the + # metadata about this mapping we will need for, e.g., a legend + + # --- Option 1: numeric mapping with a matplotlib colormap + + if map_type == "numeric": + + data = pd.to_numeric(data) + levels, lookup_table, norm, cmap = self.numeric_mapping( + data, palette, norm, + ) + + # --- Option 2: categorical mapping using seaborn palette + + elif map_type == "categorical": + + cmap = norm = None + levels, lookup_table = self.categorical_mapping( + data, palette, order, + ) + + # --- Option 3: datetime mapping + + else: + # TODO this needs actual implementation + cmap = norm = None + levels, lookup_table = self.categorical_mapping( + # Casting data to list to handle differences in the way + # pandas and numpy represent datetime64 data + list(data), palette, order, + ) + + self.map_type = map_type + self.lookup_table = lookup_table + self.palette = palette + self.levels = levels + self.norm = norm + self.cmap = cmap + + def _lookup_single(self, key): + """Get the color for a single value, using colormap to interpolate.""" + try: + # Use a value that's in the original data vector + value = self.lookup_table[key] + except KeyError: + # Use the colormap to interpolate between existing datapoints + # (e.g. in the context of making a continuous legend) + try: + normed = self.norm(key) + except TypeError as err: + if np.isnan(key): + value = (0, 0, 0, 0) + else: + raise err + else: + if np.ma.is_masked(normed): + normed = np.nan + value = self.cmap(normed) + return value + + def infer_map_type(self, palette, norm, input_format, var_type): + """Determine how to implement the mapping.""" + if palette in QUAL_PALETTES: + map_type = "categorical" + elif norm is not None: + map_type = "numeric" + elif isinstance(palette, (dict, list)): + map_type = "categorical" + elif input_format == "wide": + map_type = "categorical" + else: + map_type = var_type + + return map_type + + def categorical_mapping(self, data, palette, order): + """Determine colors when the hue mapping is categorical.""" + # -- Identify the order and name of the levels + + levels = categorical_order(data, order) + n_colors = len(levels) + + # -- Identify the set of colors to use + + if isinstance(palette, dict): + + missing = set(levels) - set(palette) + if any(missing): + err = "The palette dictionary is missing keys: {}" + raise ValueError(err.format(missing)) + + lookup_table = palette + + else: + + if palette is None: + if n_colors <= len(get_color_cycle()): + colors = color_palette(None, n_colors) + else: + colors = color_palette("husl", n_colors) + elif isinstance(palette, list): + if len(palette) != n_colors: + err = "The palette list has the wrong number of colors." + raise ValueError(err) + colors = palette + else: + colors = color_palette(palette, n_colors) + + lookup_table = dict(zip(levels, colors)) + + return levels, lookup_table + + def numeric_mapping(self, data, palette, norm): + """Determine colors when the hue variable is quantitative.""" + if isinstance(palette, dict): + + # The presence of a norm object overrides a dictionary of hues + # in specifying a numeric mapping, so we need to process it here. + levels = list(sorted(palette)) + colors = [palette[k] for k in sorted(palette)] + cmap = mpl.colors.ListedColormap(colors) + lookup_table = palette.copy() + + else: + + # The levels are the sorted unique values in the data + levels = list(np.sort(remove_na(data.unique()))) + + # --- Sort out the colormap to use from the palette argument + + # Default numeric palette is our default cubehelix palette + # TODO do we want to do something complicated to ensure contrast? + palette = "ch:" if palette is None else palette + + if isinstance(palette, mpl.colors.Colormap): + cmap = palette + else: + cmap = color_palette(palette, as_cmap=True) + + # Now sort out the data normalization + if norm is None: + norm = mpl.colors.Normalize() + elif isinstance(norm, tuple): + norm = mpl.colors.Normalize(*norm) + elif not isinstance(norm, mpl.colors.Normalize): + err = "``hue_norm`` must be None, tuple, or Normalize object." + raise ValueError(err) + + if not norm.scaled(): + norm(np.asarray(data.dropna())) + + lookup_table = dict(zip(levels, cmap(norm(levels)))) + + return levels, lookup_table, norm, cmap + + +@share_init_params_with_map +class SizeMapping(SemanticMapping): + """Mapping that sets artist sizes according to data values.""" + # An object that normalizes data values to [0, 1] range + norm = None + + def __init__( + self, plotter, sizes=None, order=None, norm=None, + ): + """Map the levels of the `size` variable to distinct values. + + Parameters + ---------- + # TODO add generic parameters + + """ + super().__init__(plotter) + + data = plotter.plot_data.get("size", pd.Series(dtype=float)) + + if data.notna().any(): + + map_type = self.infer_map_type( + norm, sizes, plotter.var_types["size"] + ) + + # --- Option 1: numeric mapping + + if map_type == "numeric": + + levels, lookup_table, norm = self.numeric_mapping( + data, sizes, norm, + ) + + # --- Option 2: categorical mapping + + elif map_type == "categorical": + + levels, lookup_table = self.categorical_mapping( + data, sizes, order, + ) + + # --- Option 3: datetime mapping + + # TODO this needs an actual implementation + else: + + levels, lookup_table = self.categorical_mapping( + # Casting data to list to handle differences in the way + # pandas and numpy represent datetime64 data + list(data), sizes, order, + ) + + self.map_type = map_type + self.levels = levels + self.norm = norm + self.sizes = sizes + self.lookup_table = lookup_table + + def infer_map_type(self, norm, sizes, var_type): + + if norm is not None: + map_type = "numeric" + elif isinstance(sizes, (dict, list)): + map_type = "categorical" + else: + map_type = var_type + + return map_type + + def _lookup_single(self, key): + + try: + value = self.lookup_table[key] + except KeyError: + normed = self.norm(key) + if np.ma.is_masked(normed): + normed = np.nan + size_values = self.lookup_table.values() + size_range = min(size_values), max(size_values) + value = size_range[0] + normed * np.ptp(size_range) + return value + + def categorical_mapping(self, data, sizes, order): + + levels = categorical_order(data, order) + + if isinstance(sizes, dict): + + # Dict inputs map existing data values to the size attribute + missing = set(levels) - set(sizes) + if any(missing): + err = f"Missing sizes for the following levels: {missing}" + raise ValueError(err) + lookup_table = sizes.copy() + + elif isinstance(sizes, list): + + # List inputs give size values in the same order as the levels + if len(sizes) != len(levels): + err = "The `sizes` list has the wrong number of values." + raise ValueError(err) + + lookup_table = dict(zip(levels, sizes)) + + else: + + if isinstance(sizes, tuple): + + # Tuple input sets the min, max size values + if len(sizes) != 2: + err = "A `sizes` tuple must have only 2 values" + raise ValueError(err) + + elif sizes is not None: + + err = f"Value for `sizes` not understood: {sizes}" + raise ValueError(err) + + else: + + # Otherwise, we need to get the min, max size values from + # the plotter object we are attached to. + + # TODO this is going to cause us trouble later, because we + # want to restructure things so that the plotter is generic + # across the visual representation of the data. But at this + # point, we don't know the visual representation. Likely we + # want to change the logic of this Mapping so that it gives + # points on a nornalized range that then gets unnormalized + # when we know what we're drawing. But given the way the + # package works now, this way is cleanest. + sizes = self.plotter._default_size_range + + # For categorical sizes, use regularly-spaced linear steps + # between the minimum and maximum sizes. Then reverse the + # ramp so that the largest value is used for the first entry + # in size_order, etc. This is because "ordered" categoricals + # are often though to go in decreasing priority. + sizes = np.linspace(*sizes, len(levels))[::-1] + lookup_table = dict(zip(levels, sizes)) + + return levels, lookup_table + + def numeric_mapping(self, data, sizes, norm): + + if isinstance(sizes, dict): + # The presence of a norm object overrides a dictionary of sizes + # in specifying a numeric mapping, so we need to process it + # dictionary here + levels = list(np.sort(list(sizes))) + size_values = sizes.values() + size_range = min(size_values), max(size_values) + + else: + + # The levels here will be the unique values in the data + levels = list(np.sort(remove_na(data.unique()))) + + if isinstance(sizes, tuple): + + # For numeric inputs, the size can be parametrized by + # the minimum and maximum artist values to map to. The + # norm object that gets set up next specifies how to + # do the mapping. + + if len(sizes) != 2: + err = "A `sizes` tuple must have only 2 values" + raise ValueError(err) + + size_range = sizes + + elif sizes is not None: + + err = f"Value for `sizes` not understood: {sizes}" + raise ValueError(err) + + else: + + # When not provided, we get the size range from the plotter + # object we are attached to. See the note in the categorical + # method about how this is suboptimal for future development.: + size_range = self.plotter._default_size_range + + # Now that we know the minimum and maximum sizes that will get drawn, + # we need to map the data values that we have into that range. We will + # use a matplotlib Normalize class, which is typically used for numeric + # color mapping but works fine here too. It takes data values and maps + # them into a [0, 1] interval, potentially nonlinear-ly. + + if norm is None: + # Default is a linear function between the min and max data values + norm = mpl.colors.Normalize() + elif isinstance(norm, tuple): + # It is also possible to give different limits in data space + norm = mpl.colors.Normalize(*norm) + elif not isinstance(norm, mpl.colors.Normalize): + err = f"Value for size `norm` parameter not understood: {norm}" + raise ValueError(err) + else: + # If provided with Normalize object, copy it so we can modify + norm = copy(norm) + + # Set the mapping so all output values are in [0, 1] + norm.clip = True + + # If the input range is not set, use the full range of the data + if not norm.scaled(): + norm(levels) + + # Map from data values to [0, 1] range + sizes_scaled = norm(levels) + + # Now map from the scaled range into the artist units + if isinstance(sizes, dict): + lookup_table = sizes + else: + lo, hi = size_range + sizes = lo + sizes_scaled * (hi - lo) + lookup_table = dict(zip(levels, sizes)) + + return levels, lookup_table, norm + + +@share_init_params_with_map +class StyleMapping(SemanticMapping): + """Mapping that sets artist style according to data values.""" + + # Style mapping is always treated as categorical + map_type = "categorical" + + def __init__( + self, plotter, markers=None, dashes=None, order=None, + ): + """Map the levels of the `style` variable to distinct values. + + Parameters + ---------- + # TODO add generic parameters + + """ + super().__init__(plotter) + + data = plotter.plot_data.get("style", pd.Series(dtype=float)) + + if data.notna().any(): + + # Cast to list to handle numpy/pandas datetime quirks + if variable_type(data) == "datetime": + data = list(data) + + # Find ordered unique values + levels = categorical_order(data, order) + + markers = self._map_attributes( + markers, levels, unique_markers(len(levels)), "markers", + ) + dashes = self._map_attributes( + dashes, levels, unique_dashes(len(levels)), "dashes", + ) + + # Build the paths matplotlib will use to draw the markers + paths = {} + filled_markers = [] + for k, m in markers.items(): + if not isinstance(m, mpl.markers.MarkerStyle): + m = mpl.markers.MarkerStyle(m) + paths[k] = m.get_path().transformed(m.get_transform()) + filled_markers.append(m.is_filled()) + + # Mixture of filled and unfilled markers will show line art markers + # in the edge color, which defaults to white. This can be handled, + # but there would be additional complexity with specifying the + # weight of the line art markers without overwhelming the filled + # ones with the edges. So for now, we will disallow mixtures. + if any(filled_markers) and not all(filled_markers): + err = "Filled and line art markers cannot be mixed" + raise ValueError(err) + + lookup_table = {} + for key in levels: + lookup_table[key] = {} + if markers: + lookup_table[key]["marker"] = markers[key] + lookup_table[key]["path"] = paths[key] + if dashes: + lookup_table[key]["dashes"] = dashes[key] + + self.levels = levels + self.lookup_table = lookup_table + + def _lookup_single(self, key, attr=None): + """Get attribute(s) for a given data point.""" + if attr is None: + value = self.lookup_table[key] + else: + value = self.lookup_table[key][attr] + return value + + def _map_attributes(self, arg, levels, defaults, attr): + """Handle the specification for a given style attribute.""" + if arg is True: + lookup_table = dict(zip(levels, defaults)) + elif isinstance(arg, dict): + missing = set(levels) - set(arg) + if missing: + err = f"These `{attr}` levels are missing values: {missing}" + raise ValueError(err) + lookup_table = arg + elif isinstance(arg, Sequence): + if len(levels) != len(arg): + err = f"The `{attr}` argument has the wrong number of values" + raise ValueError(err) + lookup_table = dict(zip(levels, arg)) + elif arg: + err = f"This `{attr}` argument was not understood: {arg}" + raise ValueError(err) + else: + lookup_table = {} + + return lookup_table + + +# =========================================================================== # + + +class VectorPlotter: + """Base class for objects underlying *plot functions.""" + + _semantic_mappings = { + "hue": HueMapping, + "size": SizeMapping, + "style": StyleMapping, + } + + # TODO units is another example of a non-mapping "semantic" + # we need a general name for this and separate handling + semantics = "x", "y", "hue", "size", "style", "units" + wide_structure = { + "x": "@index", "y": "@values", "hue": "@columns", "style": "@columns", + } + flat_structure = {"x": "@index", "y": "@values"} + + _default_size_range = 1, 2 # Unused but needed in tests, ugh + + def __init__(self, data=None, variables={}): + + self._var_levels = {} + # var_ordered is relevant only for categorical axis variables, and may + # be better handled by an internal axis information object that tracks + # such information and is set up by the scale_* methods. The analogous + # information for numeric axes would be information about log scales. + self._var_ordered = {"x": False, "y": False} # alt., used DefaultDict + self.assign_variables(data, variables) + + for var, cls in self._semantic_mappings.items(): + + # Create the mapping function + map_func = partial(cls.map, plotter=self) + setattr(self, f"map_{var}", map_func) + + # Call the mapping function to initialize with default values + getattr(self, f"map_{var}")() + + @classmethod + def get_semantics(cls, kwargs, semantics=None): + """Subset a dictionary` arguments with known semantic variables.""" + # TODO this should be get_variables since we have included x and y + if semantics is None: + semantics = cls.semantics + variables = {} + for key, val in kwargs.items(): + if key in semantics and val is not None: + variables[key] = val + return variables + + @property + def has_xy_data(self): + """Return True at least one of x or y is defined.""" + return bool({"x", "y"} & set(self.variables)) + + @property + def var_levels(self): + """Property interface to ordered list of variables levels. + + Each time it's accessed, it updates the var_levels dictionary with the + list of levels in the current semantic mappers. But it also allows the + dictionary to persist, so it can be used to set levels by a key. This is + used to track the list of col/row levels using an attached FacetGrid + object, but it's kind of messy and ideally fixed by improving the + faceting logic so it interfaces better with the modern approach to + tracking plot variables. + + """ + for var in self.variables: + try: + map_obj = getattr(self, f"_{var}_map") + self._var_levels[var] = map_obj.levels + except AttributeError: + pass + return self._var_levels + + def assign_variables(self, data=None, variables={}): + """Define plot variables, optionally using lookup from `data`.""" + x = variables.get("x", None) + y = variables.get("y", None) + + if x is None and y is None: + self.input_format = "wide" + plot_data, variables = self._assign_variables_wideform( + data, **variables, + ) + else: + self.input_format = "long" + plot_data, variables = self._assign_variables_longform( + data, **variables, + ) + + self.plot_data = plot_data + self.variables = variables + self.var_types = { + v: variable_type( + plot_data[v], + boolean_type="numeric" if v in "xy" else "categorical" + ) + for v in variables + } + + # XXX does this make sense here? + for axis in "xy": + if axis not in variables: + continue + self.var_levels[axis] = categorical_order(self.plot_data[axis]) + + return self + + def _assign_variables_wideform(self, data=None, **kwargs): + """Define plot variables given wide-form data. + + Parameters + ---------- + data : flat vector or collection of vectors + Data can be a vector or mapping that is coerceable to a Series + or a sequence- or mapping-based collection of such vectors, or a + rectangular numpy array, or a Pandas DataFrame. + kwargs : variable -> data mappings + Behavior with keyword arguments is currently undefined. + + Returns + ------- + plot_data : :class:`pandas.DataFrame` + Long-form data object mapping seaborn variables (x, y, hue, ...) + to data vectors. + variables : dict + Keys are defined seaborn variables; values are names inferred from + the inputs (or None when no name can be determined). + + """ + # Raise if semantic or other variables are assigned in wide-form mode + assigned = [k for k, v in kwargs.items() if v is not None] + if any(assigned): + s = "s" if len(assigned) > 1 else "" + err = f"The following variable{s} cannot be assigned with wide-form data: " + err += ", ".join(f"`{v}`" for v in assigned) + raise ValueError(err) + + # Determine if the data object actually has any data in it + empty = data is None or not len(data) + + # Then, determine if we have "flat" data (a single vector) + if isinstance(data, dict): + values = data.values() + else: + values = np.atleast_1d(np.asarray(data, dtype=object)) + flat = not any( + isinstance(v, Iterable) and not isinstance(v, (str, bytes)) + for v in values + ) + + if empty: + + # Make an object with the structure of plot_data, but empty + plot_data = pd.DataFrame() + variables = {} + + elif flat: + + # Handle flat data by converting to pandas Series and using the + # index and/or values to define x and/or y + # (Could be accomplished with a more general to_series() interface) + flat_data = pd.Series(data).copy() + names = { + "@values": flat_data.name, + "@index": flat_data.index.name + } + + plot_data = {} + variables = {} + + for var in ["x", "y"]: + if var in self.flat_structure: + attr = self.flat_structure[var] + plot_data[var] = getattr(flat_data, attr[1:]) + variables[var] = names[self.flat_structure[var]] + + plot_data = pd.DataFrame(plot_data) + + else: + + # Otherwise assume we have some collection of vectors. + + # Handle Python sequences such that entries end up in the columns, + # not in the rows, of the intermediate wide DataFrame. + # One way to accomplish this is to convert to a dict of Series. + if isinstance(data, Sequence): + data_dict = {} + for i, var in enumerate(data): + key = getattr(var, "name", i) + # TODO is there a safer/more generic way to ensure Series? + # sort of like np.asarray, but for pandas? + data_dict[key] = pd.Series(var) + + data = data_dict + + # Pandas requires that dict values either be Series objects + # or all have the same length, but we want to allow "ragged" inputs + if isinstance(data, Mapping): + data = {key: pd.Series(val) for key, val in data.items()} + + # Otherwise, delegate to the pandas DataFrame constructor + # This is where we'd prefer to use a general interface that says + # "give me this data as a pandas DataFrame", so we can accept + # DataFrame objects from other libraries + wide_data = pd.DataFrame(data, copy=True) + + # At this point we should reduce the dataframe to numeric cols + numeric_cols = [ + k for k, v in wide_data.items() if variable_type(v) == "numeric" + ] + wide_data = wide_data[numeric_cols] + + # Now melt the data to long form + melt_kws = {"var_name": "@columns", "value_name": "@values"} + use_index = "@index" in self.wide_structure.values() + if use_index: + melt_kws["id_vars"] = "@index" + try: + orig_categories = wide_data.columns.categories + orig_ordered = wide_data.columns.ordered + wide_data.columns = wide_data.columns.add_categories("@index") + except AttributeError: + category_columns = False + else: + category_columns = True + wide_data["@index"] = wide_data.index.to_series() + + plot_data = wide_data.melt(**melt_kws) + + if use_index and category_columns: + plot_data["@columns"] = pd.Categorical(plot_data["@columns"], + orig_categories, + orig_ordered) + + # Assign names corresponding to plot semantics + for var, attr in self.wide_structure.items(): + plot_data[var] = plot_data[attr] + + # Define the variable names + variables = {} + for var, attr in self.wide_structure.items(): + obj = getattr(wide_data, attr[1:]) + variables[var] = getattr(obj, "name", None) + + # Remove redundant columns from plot_data + plot_data = plot_data[list(variables)] + + return plot_data, variables + + def _assign_variables_longform(self, data=None, **kwargs): + """Define plot variables given long-form data and/or vector inputs. + + Parameters + ---------- + data : dict-like collection of vectors + Input data where variable names map to vector values. + kwargs : variable -> data mappings + Keys are seaborn variables (x, y, hue, ...) and values are vectors + in any format that can construct a :class:`pandas.DataFrame` or + names of columns or index levels in ``data``. + + Returns + ------- + plot_data : :class:`pandas.DataFrame` + Long-form data object mapping seaborn variables (x, y, hue, ...) + to data vectors. + variables : dict + Keys are defined seaborn variables; values are names inferred from + the inputs (or None when no name can be determined). + + Raises + ------ + ValueError + When variables are strings that don't appear in ``data``. + + """ + plot_data = {} + variables = {} + + # Data is optional; all variables can be defined as vectors + if data is None: + data = {} + + # TODO should we try a data.to_dict() or similar here to more + # generally accept objects with that interface? + # Note that dict(df) also works for pandas, and gives us what we + # want, whereas DataFrame.to_dict() gives a nested dict instead of + # a dict of series. + + # Variables can also be extraced from the index attribute + # TODO is this the most general way to enable it? + # There is no index.to_dict on multiindex, unfortunately + try: + index = data.index.to_frame() + except AttributeError: + index = {} + + # The caller will determine the order of variables in plot_data + for key, val in kwargs.items(): + + # First try to treat the argument as a key for the data collection. + # But be flexible about what can be used as a key. + # Usually it will be a string, but allow numbers or tuples too when + # taking from the main data object. Only allow strings to reference + # fields in the index, because otherwise there is too much ambiguity. + try: + val_as_data_key = ( + val in data + or (isinstance(val, (str, bytes)) and val in index) + ) + except (KeyError, TypeError): + val_as_data_key = False + + if val_as_data_key: + + # We know that __getitem__ will work + + if val in data: + plot_data[key] = data[val] + elif val in index: + plot_data[key] = index[val] + variables[key] = val + + elif isinstance(val, (str, bytes)): + + # This looks like a column name but we don't know what it means! + + err = f"Could not interpret value `{val}` for parameter `{key}`" + raise ValueError(err) + + else: + + # Otherwise, assume the value is itself data + + # Raise when data object is present and a vector can't matched + if isinstance(data, pd.DataFrame) and not isinstance(val, pd.Series): + if np.ndim(val) and len(data) != len(val): + val_cls = val.__class__.__name__ + err = ( + f"Length of {val_cls} vectors must match length of `data`" + f" when both are used, but `data` has length {len(data)}" + f" and the vector passed to `{key}` has length {len(val)}." + ) + raise ValueError(err) + + plot_data[key] = val + + # Try to infer the name of the variable + variables[key] = getattr(val, "name", None) + + # Construct a tidy plot DataFrame. This will convert a number of + # types automatically, aligning on index in case of pandas objects + plot_data = pd.DataFrame(plot_data) + + # Reduce the variables dictionary to fields with valid data + variables = { + var: name + for var, name in variables.items() + if plot_data[var].notnull().any() + } + + return plot_data, variables + + def iter_data( + self, grouping_vars=None, *, + reverse=False, from_comp_data=False, + by_facet=True, allow_empty=False, dropna=True, + ): + """Generator for getting subsets of data defined by semantic variables. + + Also injects "col" and "row" into grouping semantics. + + Parameters + ---------- + grouping_vars : string or list of strings + Semantic variables that define the subsets of data. + reverse : bool + If True, reverse the order of iteration. + from_comp_data : bool + If True, use self.comp_data rather than self.plot_data + by_facet : bool + If True, add faceting variables to the set of grouping variables. + allow_empty : bool + If True, yield an empty dataframe when no observations exist for + combinations of grouping variables. + dropna : bool + If True, remove rows with missing data. + + Yields + ------ + sub_vars : dict + Keys are semantic names, values are the level of that semantic. + sub_data : :class:`pandas.DataFrame` + Subset of ``plot_data`` for this combination of semantic values. + + """ + # TODO should this default to using all (non x/y?) semantics? + # or define groupping vars somewhere? + if grouping_vars is None: + grouping_vars = [] + elif isinstance(grouping_vars, str): + grouping_vars = [grouping_vars] + elif isinstance(grouping_vars, tuple): + grouping_vars = list(grouping_vars) + + # Always insert faceting variables + if by_facet: + facet_vars = {"col", "row"} + grouping_vars.extend( + facet_vars & set(self.variables) - set(grouping_vars) + ) + + # Reduce to the semantics used in this plot + grouping_vars = [ + var for var in grouping_vars if var in self.variables + ] + + if from_comp_data: + data = self.comp_data + else: + data = self.plot_data + + if dropna: + data = data.dropna() + + levels = self.var_levels.copy() + if from_comp_data: + for axis in {"x", "y"} & set(grouping_vars): + if self.var_types[axis] == "categorical": + if self._var_ordered[axis]: + # If the axis is ordered, then the axes in a possible + # facet grid are by definition "shared", or there is a + # single axis with a unique cat -> idx mapping. + # So we can just take the first converter object. + converter = self.converters[axis].iloc[0] + levels[axis] = converter.convert_units(levels[axis]) + else: + # Otherwise, the mappings may not be unique, but we can + # use the unique set of index values in comp_data. + levels[axis] = np.sort(data[axis].unique()) + elif self.var_types[axis] == "datetime": + levels[axis] = mpl.dates.date2num(levels[axis]) + elif self.var_types[axis] == "numeric" and self._log_scaled(axis): + levels[axis] = np.log10(levels[axis]) + + if grouping_vars: + + grouped_data = data.groupby( + grouping_vars, sort=False, as_index=False + ) + + grouping_keys = [] + for var in grouping_vars: + grouping_keys.append(levels.get(var, [])) + + iter_keys = itertools.product(*grouping_keys) + if reverse: + iter_keys = reversed(list(iter_keys)) + + for key in iter_keys: + + # Pandas fails with singleton tuple inputs + pd_key = key[0] if len(key) == 1 else key + + try: + data_subset = grouped_data.get_group(pd_key) + except KeyError: + # XXX we are adding this to allow backwards compatability + # with the empty artists that old categorical plots would + # add (before 0.12), which we may decide to break, in which + # case this option could be removed + data_subset = data.loc[[]] + + if data_subset.empty and not allow_empty: + continue + + sub_vars = dict(zip(grouping_vars, key)) + + yield sub_vars, data_subset.copy() + + else: + + yield {}, data.copy() + + @property + def comp_data(self): + """Dataframe with numeric x and y, after unit conversion and log scaling.""" + if not hasattr(self, "ax"): + # Probably a good idea, but will need a bunch of tests updated + # Most of these tests should just use the external interface + # Then this can be re-enabled. + # raise AttributeError("No Axes attached to plotter") + return self.plot_data + + if not hasattr(self, "_comp_data"): + + comp_data = ( + self.plot_data + .copy(deep=False) + .drop(["x", "y"], axis=1, errors="ignore") + ) + for var in "yx": + if var not in self.variables: + continue + + comp_col = pd.Series(index=self.plot_data.index, dtype=float, name=var) + grouped = self.plot_data[var].groupby(self.converters[var], sort=False) + for converter, orig in grouped: + with pd.option_context('mode.use_inf_as_null', True): + orig = orig.dropna() + comp = pd.to_numeric(converter.convert_units(orig)) + if converter.get_scale() == "log": + comp = np.log10(comp) + comp_col.loc[orig.index] = comp + + comp_data.insert(0, var, comp_col) + + self._comp_data = comp_data + + return self._comp_data + + def _get_axes(self, sub_vars): + """Return an Axes object based on existence of row/col variables.""" + row = sub_vars.get("row", None) + col = sub_vars.get("col", None) + if row is not None and col is not None: + return self.facets.axes_dict[(row, col)] + elif row is not None: + return self.facets.axes_dict[row] + elif col is not None: + return self.facets.axes_dict[col] + elif self.ax is None: + return self.facets.ax + else: + return self.ax + + def _attach( + self, + obj, + allowed_types=None, + log_scale=None, + ): + """Associate the plotter with an Axes manager and initialize its units. + + Parameters + ---------- + obj : :class:`matplotlib.axes.Axes` or :class:'FacetGrid` + Structural object that we will eventually plot onto. + allowed_types : str or list of str + If provided, raise when either the x or y variable does not have + one of the declared seaborn types. + log_scale : bool, number, or pair of bools or numbers + If not False, set the axes to use log scaling, with the given + base or defaulting to 10. If a tuple, interpreted as separate + arguments for the x and y axes. + + """ + from .axisgrid import FacetGrid + if isinstance(obj, FacetGrid): + self.ax = None + self.facets = obj + ax_list = obj.axes.flatten() + if obj.col_names is not None: + self.var_levels["col"] = obj.col_names + if obj.row_names is not None: + self.var_levels["row"] = obj.row_names + else: + self.ax = obj + self.facets = None + ax_list = [obj] + + # Identify which "axis" variables we have defined + axis_variables = set("xy").intersection(self.variables) + + # -- Verify the types of our x and y variables here. + # This doesn't really make complete sense being here here, but it's a fine + # place for it, given the current sytstem. + # (Note that for some plots, there might be more complicated restrictions) + # e.g. the categorical plots have their own check that as specific to the + # non-categorical axis. + if allowed_types is None: + allowed_types = ["numeric", "datetime", "categorical"] + elif isinstance(allowed_types, str): + allowed_types = [allowed_types] + + for var in axis_variables: + var_type = self.var_types[var] + if var_type not in allowed_types: + err = ( + f"The {var} variable is {var_type}, but one of " + f"{allowed_types} is required" + ) + raise TypeError(err) + + # -- Get axis objects for each row in plot_data for type conversions and scaling + + facet_dim = {"x": "col", "y": "row"} + + self.converters = {} + for var in axis_variables: + other_var = {"x": "y", "y": "x"}[var] + + converter = pd.Series(index=self.plot_data.index, name=var, dtype=object) + share_state = getattr(self.facets, f"_share{var}", True) + + # Simplest cases are that we have a single axes, all axes are shared, + # or sharing is only on the orthogonal facet dimension. In these cases, + # all datapoints get converted the same way, so use the first axis + if share_state is True or share_state == facet_dim[other_var]: + converter.loc[:] = getattr(ax_list[0], f"{var}axis") + + else: + + # Next simplest case is when no axes are shared, and we can + # use the axis objects within each facet + if share_state is False: + for axes_vars, axes_data in self.iter_data(): + ax = self._get_axes(axes_vars) + converter.loc[axes_data.index] = getattr(ax, f"{var}axis") + + # In the more complicated case, the axes are shared within each + # "file" of the facetgrid. In that case, we need to subset the data + # for that file and assign it the first axis in the slice of the grid + else: + + names = getattr(self.facets, f"{share_state}_names") + for i, level in enumerate(names): + idx = (i, 0) if share_state == "row" else (0, i) + axis = getattr(self.facets.axes[idx], f"{var}axis") + converter.loc[self.plot_data[share_state] == level] = axis + + # Store the converter vector, which we use elsewhere (e.g comp_data) + self.converters[var] = converter + + # Now actually update the matplotlib objects to do the conversion we want + grouped = self.plot_data[var].groupby(self.converters[var], sort=False) + for converter, seed_data in grouped: + if self.var_types[var] == "categorical": + if self._var_ordered[var]: + order = self.var_levels[var] + else: + order = None + seed_data = categorical_order(seed_data, order) + converter.update_units(seed_data) + + # -- Set numerical axis scales + + # First unpack the log_scale argument + if log_scale is None: + scalex = scaley = False + else: + # Allow single value or x, y tuple + try: + scalex, scaley = log_scale + except TypeError: + scalex = log_scale if "x" in self.variables else False + scaley = log_scale if "y" in self.variables else False + + # Now use it + for axis, scale in zip("xy", (scalex, scaley)): + if scale: + for ax in ax_list: + set_scale = getattr(ax, f"set_{axis}scale") + if scale is True: + set_scale("log") + else: + if LooseVersion(mpl.__version__) >= "3.3": + set_scale("log", base=scale) + else: + set_scale("log", **{f"base{axis}": scale}) + + # For categorical y, we want the "first" level to be at the top of the axis + if self.var_types.get("y", None) == "categorical": + for ax in ax_list: + try: + ax.yaxis.set_inverted(True) + except AttributeError: # mpl < 3.1 + if not ax.yaxis_inverted(): + ax.invert_yaxis() + + # TODO -- Add axes labels + + def _log_scaled(self, axis): + """Return True if specified axis is log scaled on all attached axes.""" + if not hasattr(self, "ax"): + return False + + if self.ax is None: + axes_list = self.facets.axes.flatten() + else: + axes_list = [self.ax] + + log_scaled = [] + for ax in axes_list: + data_axis = getattr(ax, f"{axis}axis") + log_scaled.append(data_axis.get_scale() == "log") + + if any(log_scaled) and not all(log_scaled): + raise RuntimeError("Axis scaling is not consistent") + + return any(log_scaled) + + def _add_axis_labels(self, ax, default_x="", default_y=""): + """Add axis labels if not present, set visibility to match ticklabels.""" + # TODO ax could default to None and use attached axes if present + # but what to do about the case of facets? Currently using FacetGrid's + # set_axis_labels method, which doesn't add labels to the interior even + # when the axes are not shared. Maybe that makes sense? + if not ax.get_xlabel(): + x_visible = any(t.get_visible() for t in ax.get_xticklabels()) + ax.set_xlabel(self.variables.get("x", default_x), visible=x_visible) + if not ax.get_ylabel(): + y_visible = any(t.get_visible() for t in ax.get_yticklabels()) + ax.set_ylabel(self.variables.get("y", default_y), visible=y_visible) + + # XXX If the scale_* methods are going to modify the plot_data structure, they + # can't be called twice. That means that if they are called twice, they should + # raise. Alternatively, we could store an original version of plot_data and each + # time they are called they operate on the store, not the current state. + + def scale_native(self, axis, *args, **kwargs): + + # Default, defer to matplotlib + + raise NotImplementedError + + def scale_numeric(self, axis, *args, **kwargs): + + # Feels needed to completeness, what should it do? + # Perhaps handle log scaling? Set the ticker/formatter/limits? + + raise NotImplementedError + + def scale_datetime(self, axis, *args, **kwargs): + + # Use pd.to_datetime to convert strings or numbers to datetime objects + # Note, use day-resolution for numeric->datetime to match matplotlib + + raise NotImplementedError + + def scale_categorical(self, axis, order=None, formatter=None): + """ + Enforce categorical (fixed-scale) rules for the data on given axis. + + Parameters + ---------- + axis : "x" or "y" + Axis of the plot to operate on. + order : list + Order that unique values should appear in. + formatter : callable + Function mapping values to a string representation. + + Returns + ------- + self + + """ + # This method both modifies the internal representation of the data + # (converting it to string) and sets some attributes on self. It might be + # a good idea to have a separate object attached to self that contains the + # information in those attributes (i.e. whether to enforce variable order + # across facets, the order to use) similar to the SemanticMapping objects + # we have for semantic variables. That object could also hold the converter + # objects that get used, if we can decouple those from an existing axis + # (cf. https://github.com/matplotlib/matplotlib/issues/19229). + # There are some interactions with faceting information that would need + # to be thought through, since the converts to use depend on facets. + # If we go that route, these methods could become "borrowed" methods similar + # to what happens with the alternate semantic mapper constructors, although + # that approach is kind of fussy and confusing. + + # TODO this method could also set the grid state? Since we like to have no + # grid on the categorical axis by default. Again, a case where we'll need to + # store information until we use it, so best to have a way to collect the + # attributes that this method sets. + + # TODO if we are going to set visual properties of the axes with these methods, + # then we could do the steps currently in CategoricalPlotter._adjust_cat_axis + + # TODO another, and distinct idea, is to expose a cut= param here + + _check_argument("axis", ["x", "y"], axis) + + # Categorical plots can be "univariate" in which case they get an anonymous + # category label on the opposite axis. + if axis not in self.variables: + self.variables[axis] = None + self.var_types[axis] = "categorical" + self.plot_data[axis] = "" + + # If the "categorical" variable has a numeric type, sort the rows so that + # the default result from categorical_order has those values sorted after + # they have been coerced to strings. The reason for this is so that later + # we can get facet-wise orders that are correct. + # XXX Should this also sort datetimes? + # It feels more consistent, but technically will be a default change + # If so, should also change categorical_order to behave that way + if self.var_types[axis] == "numeric": + self.plot_data = self.plot_data.sort_values(axis, kind="mergesort") + + # Now get a reference to the categorical data vector + cat_data = self.plot_data[axis] + + # Get the initial categorical order, which we do before string + # conversion to respect the original types of the order list. + # Track whether the order is given explicitly so that we can know + # whether or not to use the order constructed here downstream + self._var_ordered[axis] = order is not None or cat_data.dtype.name == "category" + order = pd.Index(categorical_order(cat_data, order)) + + # Then convert data to strings. This is because in matplotlib, + # "categorical" data really mean "string" data, so doing this artists + # will be drawn on the categorical axis with a fixed scale. + # TODO implement formatter here; check that it returns strings? + if formatter is not None: + cat_data = cat_data.map(formatter) + order = order.map(formatter) + else: + cat_data = cat_data.astype(str) + order = order.astype(str) + + # Update the levels list with the type-converted order variable + self.var_levels[axis] = order + + # Now ensure that seaborn will use categorical rules internally + self.var_types[axis] = "categorical" + + # Put the string-typed categorical vector back into the plot_data structure + self.plot_data[axis] = cat_data + + return self + + +class VariableType(UserString): + """ + Prevent comparisons elsewhere in the library from using the wrong name. + + Errors are simple assertions because users should not be able to trigger + them. If that changes, they should be more verbose. + + """ + allowed = "numeric", "datetime", "categorical" + + def __init__(self, data): + assert data in self.allowed, data + super().__init__(data) + + def __eq__(self, other): + assert other in self.allowed, other + return self.data == other + + +def variable_type(vector, boolean_type="numeric"): + """ + Determine whether a vector contains numeric, categorical, or datetime data. + + This function differs from the pandas typing API in two ways: + + - Python sequences or object-typed PyData objects are considered numeric if + all of their entries are numeric. + - String or mixed-type data are considered categorical even if not + explicitly represented as a :class:`pandas.api.types.CategoricalDtype`. + + Parameters + ---------- + vector : :func:`pandas.Series`, :func:`numpy.ndarray`, or Python sequence + Input data to test. + boolean_type : 'numeric' or 'categorical' + Type to use for vectors containing only 0s and 1s (and NAs). + + Returns + ------- + var_type : 'numeric', 'categorical', or 'datetime' + Name identifying the type of data in the vector. + """ + + # If a categorical dtype is set, infer categorical + if pd.api.types.is_categorical_dtype(vector): + return VariableType("categorical") + + # Special-case all-na data, which is always "numeric" + if pd.isna(vector).all(): + return VariableType("numeric") + + # Special-case binary/boolean data, allow caller to determine + # This triggers a numpy warning when vector has strings/objects + # https://github.com/numpy/numpy/issues/6784 + # Because we reduce with .all(), we are agnostic about whether the + # comparison returns a scalar or vector, so we will ignore the warning. + # It triggers a separate DeprecationWarning when the vector has datetimes: + # https://github.com/numpy/numpy/issues/13548 + # This is considered a bug by numpy and will likely go away. + with warnings.catch_warnings(): + warnings.simplefilter( + action='ignore', category=(FutureWarning, DeprecationWarning) + ) + if np.isin(vector, [0, 1, np.nan]).all(): + return VariableType(boolean_type) + + # Defer to positive pandas tests + if pd.api.types.is_numeric_dtype(vector): + return VariableType("numeric") + + if pd.api.types.is_datetime64_dtype(vector): + return VariableType("datetime") + + # --- If we get to here, we need to check the entries + + # Check for a collection where everything is a number + + def all_numeric(x): + for x_i in x: + if not isinstance(x_i, Number): + return False + return True + + if all_numeric(vector): + return VariableType("numeric") + + # Check for a collection where everything is a datetime + + def all_datetime(x): + for x_i in x: + if not isinstance(x_i, (datetime, np.datetime64)): + return False + return True + + if all_datetime(vector): + return VariableType("datetime") + + # Otherwise, our final fallback is to consider things categorical + + return VariableType("categorical") + + +def infer_orient(x=None, y=None, orient=None, require_numeric=True): + """Determine how the plot should be oriented based on the data. + + For historical reasons, the convention is to call a plot "horizontally" + or "vertically" oriented based on the axis representing its dependent + variable. Practically, this is used when determining the axis for + numerical aggregation. + + Parameters + ---------- + x, y : Vector data or None + Positional data vectors for the plot. + orient : string or None + Specified orientation, which must start with "v" or "h" if not None. + require_numeric : bool + If set, raise when the implied dependent variable is not numeric. + + Returns + ------- + orient : "v" or "h" + + Raises + ------ + ValueError: When `orient` is not None and does not start with "h" or "v" + TypeError: When dependant variable is not numeric, with `require_numeric` + + """ + + x_type = None if x is None else variable_type(x) + y_type = None if y is None else variable_type(y) + + nonnumeric_dv_error = "{} orientation requires numeric `{}` variable." + single_var_warning = "{} orientation ignored with only `{}` specified." + + if x is None: + if str(orient).startswith("h"): + warnings.warn(single_var_warning.format("Horizontal", "y")) + if require_numeric and y_type != "numeric": + raise TypeError(nonnumeric_dv_error.format("Vertical", "y")) + return "v" + + elif y is None: + if str(orient).startswith("v"): + warnings.warn(single_var_warning.format("Vertical", "x")) + if require_numeric and x_type != "numeric": + raise TypeError(nonnumeric_dv_error.format("Horizontal", "x")) + return "h" + + elif str(orient).startswith("v"): + if require_numeric and y_type != "numeric": + raise TypeError(nonnumeric_dv_error.format("Vertical", "y")) + return "v" + + elif str(orient).startswith("h"): + if require_numeric and x_type != "numeric": + raise TypeError(nonnumeric_dv_error.format("Horizontal", "x")) + return "h" + + elif orient is not None: + err = ( + "`orient` must start with 'v' or 'h' or be None, " + f"but `{repr(orient)}` was passed." + ) + raise ValueError(err) + + elif x_type != "categorical" and y_type == "categorical": + return "h" + + elif x_type != "numeric" and y_type == "numeric": + return "v" + + elif x_type == "numeric" and y_type != "numeric": + return "h" + + elif require_numeric and "numeric" not in (x_type, y_type): + err = "Neither the `x` nor `y` variable appears to be numeric." + raise TypeError(err) + + else: + return "v" + + +def unique_dashes(n): + """Build an arbitrarily long list of unique dash styles for lines. + + Parameters + ---------- + n : int + Number of unique dash specs to generate. + + Returns + ------- + dashes : list of strings or tuples + Valid arguments for the ``dashes`` parameter on + :class:`matplotlib.lines.Line2D`. The first spec is a solid + line (``""``), the remainder are sequences of long and short + dashes. + + """ + # Start with dash specs that are well distinguishable + dashes = [ + "", + (4, 1.5), + (1, 1), + (3, 1.25, 1.5, 1.25), + (5, 1, 1, 1), + ] + + # Now programatically build as many as we need + p = 3 + while len(dashes) < n: + + # Take combinations of long and short dashes + a = itertools.combinations_with_replacement([3, 1.25], p) + b = itertools.combinations_with_replacement([4, 1], p) + + # Interleave the combinations, reversing one of the streams + segment_list = itertools.chain(*zip( + list(a)[1:-1][::-1], + list(b)[1:-1] + )) + + # Now insert the gaps + for segments in segment_list: + gap = min(segments) + spec = tuple(itertools.chain(*((seg, gap) for seg in segments))) + dashes.append(spec) + + p += 1 + + return dashes[:n] + + +def unique_markers(n): + """Build an arbitrarily long list of unique marker styles for points. + + Parameters + ---------- + n : int + Number of unique marker specs to generate. + + Returns + ------- + markers : list of string or tuples + Values for defining :class:`matplotlib.markers.MarkerStyle` objects. + All markers will be filled. + + """ + # Start with marker specs that are well distinguishable + markers = [ + "o", + "X", + (4, 0, 45), + "P", + (4, 0, 0), + (4, 1, 0), + "^", + (4, 1, 45), + "v", + ] + + # Now generate more from regular polygons of increasing order + s = 5 + while len(markers) < n: + a = 360 / (s + 1) / 2 + markers.extend([ + (s + 1, 1, a), + (s + 1, 0, a), + (s, 1, 0), + (s, 0, 0), + ]) + s += 1 + + # Convert to MarkerStyle object, using only exactly what we need + # markers = [mpl.markers.MarkerStyle(m) for m in markers[:n]] + + return markers[:n] + + +def categorical_order(vector, order=None): + """Return a list of unique data values. + + Determine an ordered list of levels in ``values``. + + Parameters + ---------- + vector : list, array, Categorical, or Series + Vector of "categorical" values + order : list-like, optional + Desired order of category levels to override the order determined + from the ``values`` object. + + Returns + ------- + order : list + Ordered list of category levels not including null values. + + """ + if order is None: + if hasattr(vector, "categories"): + order = vector.categories + else: + try: + order = vector.cat.categories + except (TypeError, AttributeError): + + try: + order = vector.unique() + except AttributeError: + order = pd.unique(vector) + + if variable_type(vector) == "numeric": + order = np.sort(order) + + order = filter(pd.notnull, order) + return list(order) diff --git a/seaborn/_decorators.py b/seaborn/_decorators.py new file mode 100644 index 0000000000..d1c24b870b --- /dev/null +++ b/seaborn/_decorators.py @@ -0,0 +1,62 @@ +from inspect import signature, Parameter +from functools import wraps +import warnings + + +# This function was adapted from scikit-learn +# github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/validation.py +def _deprecate_positional_args(f): + """Decorator for methods that issues warnings for positional arguments. + + Using the keyword-only argument syntax in pep 3102, arguments after the + * will issue a warning when passed as a positional argument. + + Parameters + ---------- + f : function + function to check arguments on + + """ + sig = signature(f) + kwonly_args = [] + all_args = [] + + for name, param in sig.parameters.items(): + if param.kind == Parameter.POSITIONAL_OR_KEYWORD: + all_args.append(name) + elif param.kind == Parameter.KEYWORD_ONLY: + kwonly_args.append(name) + + @wraps(f) + def inner_f(*args, **kwargs): + extra_args = len(args) - len(all_args) + if extra_args > 0: + plural = "s" if extra_args > 1 else "" + article = "" if plural else "a " + warnings.warn( + "Pass the following variable{} as {}keyword arg{}: {}. " + "From version 0.12, the only valid positional argument " + "will be `data`, and passing other arguments without an " + "explicit keyword will result in an error or misinterpretation." + .format(plural, article, plural, + ", ".join(kwonly_args[:extra_args])), + FutureWarning + ) + kwargs.update({k: arg for k, arg in zip(sig.parameters, args)}) + return f(**kwargs) + return inner_f + + +def share_init_params_with_map(cls): + """Make cls.map a classmethod with same signature as cls.__init__.""" + map_sig = signature(cls.map) + init_sig = signature(cls.__init__) + + new = [v for k, v in init_sig.parameters.items() if k != "self"] + new.insert(0, map_sig.parameters["cls"]) + cls.map.__signature__ = map_sig.replace(parameters=new) + cls.map.__doc__ = cls.__init__.__doc__ + + cls.map = classmethod(cls.map) + + return cls diff --git a/seaborn/_docstrings.py b/seaborn/_docstrings.py new file mode 100644 index 0000000000..2ab210b6ff --- /dev/null +++ b/seaborn/_docstrings.py @@ -0,0 +1,198 @@ +import re +import pydoc +from .external.docscrape import NumpyDocString + + +class DocstringComponents: + + regexp = re.compile(r"\n((\n|.)+)\n\s*", re.MULTILINE) + + def __init__(self, comp_dict, strip_whitespace=True): + """Read entries from a dict, optionally stripping outer whitespace.""" + if strip_whitespace: + entries = {} + for key, val in comp_dict.items(): + m = re.match(self.regexp, val) + if m is None: + entries[key] = val + else: + entries[key] = m.group(1) + else: + entries = comp_dict.copy() + + self.entries = entries + + def __getattr__(self, attr): + """Provide dot access to entries for clean raw docstrings.""" + if attr in self.entries: + return self.entries[attr] + else: + try: + return self.__getattribute__(attr) + except AttributeError as err: + # If Python is run with -OO, it will strip docstrings and our lookup + # from self.entries will fail. We check for __debug__, which is actually + # set to False by -O (it is True for normal execution). + # But we only want to see an error when building the docs; + # not something users should see, so this slight inconsistency is fine. + if __debug__: + raise err + else: + pass + + @classmethod + def from_nested_components(cls, **kwargs): + """Add multiple sub-sets of components.""" + return cls(kwargs, strip_whitespace=False) + + @classmethod + def from_function_params(cls, func): + """Use the numpydoc parser to extract components from existing func.""" + params = NumpyDocString(pydoc.getdoc(func))["Parameters"] + comp_dict = {} + for p in params: + name = p.name + type = p.type + desc = "\n ".join(p.desc) + comp_dict[name] = f"{name} : {type}\n {desc}" + + return cls(comp_dict) + + +# TODO is "vector" the best term here? We mean to imply 1D data with a variety +# of types? + +# TODO now that we can parse numpydoc style strings, do we need to define dicts +# of docstring components, or just write out a docstring? + + +_core_params = dict( + data=""" +data : :class:`pandas.DataFrame`, :class:`numpy.ndarray`, mapping, or sequence + Input data structure. Either a long-form collection of vectors that can be + assigned to named variables or a wide-form dataset that will be internally + reshaped. + """, # TODO add link to user guide narrative when exists + xy=""" +x, y : vectors or keys in ``data`` + Variables that specify positions on the x and y axes. + """, + hue=""" +hue : vector or key in ``data`` + Semantic variable that is mapped to determine the color of plot elements. + """, + palette=""" +palette : string, list, dict, or :class:`matplotlib.colors.Colormap` + Method for choosing the colors to use when mapping the ``hue`` semantic. + String values are passed to :func:`color_palette`. List or dict values + imply categorical mapping, while a colormap object implies numeric mapping. + """, # noqa: E501 + hue_order=""" +hue_order : vector of strings + Specify the order of processing and plotting for categorical levels of the + ``hue`` semantic. + """, + hue_norm=""" +hue_norm : tuple or :class:`matplotlib.colors.Normalize` + Either a pair of values that set the normalization range in data units + or an object that will map from data units into a [0, 1] interval. Usage + implies numeric mapping. + """, + color=""" +color : :mod:`matplotlib color ` + Single color specification for when hue mapping is not used. Otherwise, the + plot will try to hook into the matplotlib property cycle. + """, + ax=""" +ax : :class:`matplotlib.axes.Axes` + Pre-existing axes for the plot. Otherwise, call :func:`matplotlib.pyplot.gca` + internally. + """, # noqa: E501 +) + + +_core_returns = dict( + ax=""" +:class:`matplotlib.axes.Axes` + The matplotlib axes containing the plot. + """, + facetgrid=""" +:class:`FacetGrid` + An object managing one or more subplots that correspond to conditional data + subsets with convenient methods for batch-setting of axes attributes. + """, + jointgrid=""" +:class:`JointGrid` + An object managing multiple subplots that correspond to joint and marginal axes + for plotting a bivariate relationship or distribution. + """, + pairgrid=""" +:class:`PairGrid` + An object managing multiple subplots that correspond to joint and marginal axes + for pairwise combinations of multiple variables in a dataset. + """, +) + + +_seealso_blurbs = dict( + + # Relational plots + scatterplot=""" +scatterplot : Plot data using points. + """, + lineplot=""" +lineplot : Plot data using lines. + """, + + # Distribution plots + displot=""" +displot : Figure-level interface to distribution plot functions. + """, + histplot=""" +histplot : Plot a histogram of binned counts with optional normalization or smoothing. + """, + kdeplot=""" +kdeplot : Plot univariate or bivariate distributions using kernel density estimation. + """, + ecdfplot=""" +ecdfplot : Plot empirical cumulative distribution functions. + """, + rugplot=""" +rugplot : Plot a tick at each observation value along the x and/or y axes. + """, + + # Categorical plots + stripplot=""" +stripplot : Plot a categorical scatter with jitter. + """, + swarmplot=""" +swarmplot : Plot a categorical scatter with non-overlapping points. + """, + violinplot=""" +violinplot : Draw an enhanced boxplot using kernel density estimation. + """, + pointplot=""" +pointplot : Plot point estimates and CIs using markers and lines. + """, + + # Multiples + jointplot=""" +jointplot : Draw a bivariate plot with univariate marginal distributions. + """, + pairplot=""" +jointplot : Draw multiple bivariate plots with univariate marginal distributions. + """, + jointgrid=""" +JointGrid : Set up a figure with joint and marginal views on bivariate data. + """, + pairgrid=""" +PairGrid : Set up a figure with joint and marginal views on multiple variables. + """, +) + + +_core_docs = dict( + params=DocstringComponents(_core_params), + returns=DocstringComponents(_core_returns), + seealso=DocstringComponents(_seealso_blurbs), +) diff --git a/seaborn/_statistics.py b/seaborn/_statistics.py new file mode 100644 index 0000000000..5815aa94ae --- /dev/null +++ b/seaborn/_statistics.py @@ -0,0 +1,535 @@ +"""Statistical transformations for visualization. + +This module is currently private, but is being written to eventually form part +of the public API. + +The classes should behave roughly in the style of scikit-learn. + +- All data-independent parameters should be passed to the class constructor. +- Each class should implement a default transformation that is exposed through + __call__. These are currently written for vector arguments, but I think + consuming a whole `plot_data` DataFrame and return it with transformed + variables would make more sense. +- Some class have data-dependent preprocessing that should be cached and used + multiple times (think defining histogram bins off all data and then counting + observations within each bin multiple times per data subsets). These currently + have unique names, but it would be good to have a common name. Not quite + `fit`, but something similar. +- Alternatively, the transform interface could take some information about grouping + variables and do a groupby internally. +- Some classes should define alternate transforms that might make the most sense + with a different function. For example, KDE usually evaluates the distribution + on a regular grid, but it would be useful for it to transform at the actual + datapoints. Then again, this could be controlled by a parameter at the time of + class instantiation. + +""" +from numbers import Number +import numpy as np +import pandas as pd +try: + from scipy.stats import gaussian_kde + _no_scipy = False +except ImportError: + from .external.kde import gaussian_kde + _no_scipy = True + +from .algorithms import bootstrap +from .utils import _check_argument + + +class KDE: + """Univariate and bivariate kernel density estimator.""" + def __init__( + self, *, + bw_method=None, + bw_adjust=1, + gridsize=200, + cut=3, + clip=None, + cumulative=False, + ): + """Initialize the estimator with its parameters. + + Parameters + ---------- + bw_method : string, scalar, or callable, optional + Method for determining the smoothing bandwidth to use; passed to + :class:`scipy.stats.gaussian_kde`. + bw_adjust : number, optional + Factor that multiplicatively scales the value chosen using + ``bw_method``. Increasing will make the curve smoother. See Notes. + gridsize : int, optional + Number of points on each dimension of the evaluation grid. + cut : number, optional + Factor, multiplied by the smoothing bandwidth, that determines how + far the evaluation grid extends past the extreme datapoints. When + set to 0, truncate the curve at the data limits. + clip : pair of numbers or None, or a pair of such pairs + Do not evaluate the density outside of these limits. + cumulative : bool, optional + If True, estimate a cumulative distribution function. Requires scipy. + + """ + if clip is None: + clip = None, None + + self.bw_method = bw_method + self.bw_adjust = bw_adjust + self.gridsize = gridsize + self.cut = cut + self.clip = clip + self.cumulative = cumulative + + if cumulative and _no_scipy: + raise RuntimeError("Cumulative KDE evaluation requires scipy") + + self.support = None + + def _define_support_grid(self, x, bw, cut, clip, gridsize): + """Create the grid of evaluation points depending for vector x.""" + clip_lo = -np.inf if clip[0] is None else clip[0] + clip_hi = +np.inf if clip[1] is None else clip[1] + gridmin = max(x.min() - bw * cut, clip_lo) + gridmax = min(x.max() + bw * cut, clip_hi) + return np.linspace(gridmin, gridmax, gridsize) + + def _define_support_univariate(self, x, weights): + """Create a 1D grid of evaluation points.""" + kde = self._fit(x, weights) + bw = np.sqrt(kde.covariance.squeeze()) + grid = self._define_support_grid( + x, bw, self.cut, self.clip, self.gridsize + ) + return grid + + def _define_support_bivariate(self, x1, x2, weights): + """Create a 2D grid of evaluation points.""" + clip = self.clip + if clip[0] is None or np.isscalar(clip[0]): + clip = (clip, clip) + + kde = self._fit([x1, x2], weights) + bw = np.sqrt(np.diag(kde.covariance).squeeze()) + + grid1 = self._define_support_grid( + x1, bw[0], self.cut, clip[0], self.gridsize + ) + grid2 = self._define_support_grid( + x2, bw[1], self.cut, clip[1], self.gridsize + ) + + return grid1, grid2 + + def define_support(self, x1, x2=None, weights=None, cache=True): + """Create the evaluation grid for a given data set.""" + if x2 is None: + support = self._define_support_univariate(x1, weights) + else: + support = self._define_support_bivariate(x1, x2, weights) + + if cache: + self.support = support + + return support + + def _fit(self, fit_data, weights=None): + """Fit the scipy kde while adding bw_adjust logic and version check.""" + fit_kws = {"bw_method": self.bw_method} + if weights is not None: + fit_kws["weights"] = weights + + kde = gaussian_kde(fit_data, **fit_kws) + kde.set_bandwidth(kde.factor * self.bw_adjust) + + return kde + + def _eval_univariate(self, x, weights=None): + """Fit and evaluate a univariate on univariate data.""" + support = self.support + if support is None: + support = self.define_support(x, cache=False) + + kde = self._fit(x, weights) + + if self.cumulative: + s_0 = support[0] + density = np.array([ + kde.integrate_box_1d(s_0, s_i) for s_i in support + ]) + else: + density = kde(support) + + return density, support + + def _eval_bivariate(self, x1, x2, weights=None): + """Fit and evaluate a univariate on bivariate data.""" + support = self.support + if support is None: + support = self.define_support(x1, x2, cache=False) + + kde = self._fit([x1, x2], weights) + + if self.cumulative: + + grid1, grid2 = support + density = np.zeros((grid1.size, grid2.size)) + p0 = grid1.min(), grid2.min() + for i, xi in enumerate(grid1): + for j, xj in enumerate(grid2): + density[i, j] = kde.integrate_box(p0, (xi, xj)) + + else: + + xx1, xx2 = np.meshgrid(*support) + density = kde([xx1.ravel(), xx2.ravel()]).reshape(xx1.shape) + + return density, support + + def __call__(self, x1, x2=None, weights=None): + """Fit and evaluate on univariate or bivariate data.""" + if x2 is None: + return self._eval_univariate(x1, weights) + else: + return self._eval_bivariate(x1, x2, weights) + + +class Histogram: + """Univariate and bivariate histogram estimator.""" + def __init__( + self, + stat="count", + bins="auto", + binwidth=None, + binrange=None, + discrete=False, + cumulative=False, + ): + """Initialize the estimator with its parameters. + + Parameters + ---------- + stat : {"count", "frequency", "density", "probability", "percent"} + Aggregate statistic to compute in each bin. + + - ``count`` shows the number of observations + - ``frequency`` shows the number of observations divided by the bin width + - ``density`` normalizes counts so that the area of the histogram is 1 + - ``probability`` normalizes counts so that the sum of the bar heights is 1 + + bins : str, number, vector, or a pair of such values + Generic bin parameter that can be the name of a reference rule, + the number of bins, or the breaks of the bins. + Passed to :func:`numpy.histogram_bin_edges`. + binwidth : number or pair of numbers + Width of each bin, overrides ``bins`` but can be used with + ``binrange``. + binrange : pair of numbers or a pair of pairs + Lowest and highest value for bin edges; can be used either + with ``bins`` or ``binwidth``. Defaults to data extremes. + discrete : bool or pair of bools + If True, set ``binwidth`` and ``binrange`` such that bin + edges cover integer values in the dataset. + cumulative : bool + If True, return the cumulative statistic. + + """ + stat_choices = ["count", "frequency", "density", "probability", "percent"] + _check_argument("stat", stat_choices, stat) + + self.stat = stat + self.bins = bins + self.binwidth = binwidth + self.binrange = binrange + self.discrete = discrete + self.cumulative = cumulative + + self.bin_edges = None + + def _define_bin_edges(self, x, weights, bins, binwidth, binrange, discrete): + """Inner function that takes bin parameters as arguments.""" + if binrange is None: + start, stop = x.min(), x.max() + else: + start, stop = binrange + + if discrete: + bin_edges = np.arange(start - .5, stop + 1.5) + elif binwidth is not None: + step = binwidth + bin_edges = np.arange(start, stop + step, step) + else: + bin_edges = np.histogram_bin_edges( + x, bins, binrange, weights, + ) + return bin_edges + + def define_bin_edges(self, x1, x2=None, weights=None, cache=True): + """Given data, return the edges of the histogram bins.""" + if x2 is None: + + bin_edges = self._define_bin_edges( + x1, weights, self.bins, self.binwidth, self.binrange, self.discrete, + ) + + else: + + bin_edges = [] + for i, x in enumerate([x1, x2]): + + # Resolve out whether bin parameters are shared + # or specific to each variable + + bins = self.bins + if not bins or isinstance(bins, (str, Number)): + pass + elif isinstance(bins[i], str): + bins = bins[i] + elif len(bins) == 2: + bins = bins[i] + + binwidth = self.binwidth + if binwidth is None: + pass + elif not isinstance(binwidth, Number): + binwidth = binwidth[i] + + binrange = self.binrange + if binrange is None: + pass + elif not isinstance(binrange[0], Number): + binrange = binrange[i] + + discrete = self.discrete + if not isinstance(discrete, bool): + discrete = discrete[i] + + # Define the bins for this variable + + bin_edges.append(self._define_bin_edges( + x, weights, bins, binwidth, binrange, discrete, + )) + + bin_edges = tuple(bin_edges) + + if cache: + self.bin_edges = bin_edges + + return bin_edges + + def _eval_bivariate(self, x1, x2, weights): + """Inner function for histogram of two variables.""" + bin_edges = self.bin_edges + if bin_edges is None: + bin_edges = self.define_bin_edges(x1, x2, cache=False) + + density = self.stat == "density" + + hist, _, _ = np.histogram2d( + x1, x2, bin_edges, weights=weights, density=density + ) + + area = np.outer( + np.diff(bin_edges[0]), + np.diff(bin_edges[1]), + ) + + if self.stat == "probability": + hist = hist.astype(float) / hist.sum() + elif self.stat == "percent": + hist = hist.astype(float) / hist.sum() * 100 + elif self.stat == "frequency": + hist = hist.astype(float) / area + + if self.cumulative: + if self.stat in ["density", "frequency"]: + hist = (hist * area).cumsum(axis=0).cumsum(axis=1) + else: + hist = hist.cumsum(axis=0).cumsum(axis=1) + + return hist, bin_edges + + def _eval_univariate(self, x, weights): + """Inner function for histogram of one variable.""" + bin_edges = self.bin_edges + if bin_edges is None: + bin_edges = self.define_bin_edges(x, weights=weights, cache=False) + + density = self.stat == "density" + hist, _ = np.histogram( + x, bin_edges, weights=weights, density=density, + ) + + if self.stat == "probability": + hist = hist.astype(float) / hist.sum() + elif self.stat == "percent": + hist = hist.astype(float) / hist.sum() * 100 + elif self.stat == "frequency": + hist = hist.astype(float) / np.diff(bin_edges) + + if self.cumulative: + if self.stat in ["density", "frequency"]: + hist = (hist * np.diff(bin_edges)).cumsum() + else: + hist = hist.cumsum() + + return hist, bin_edges + + def __call__(self, x1, x2=None, weights=None): + """Count the occurrances in each bin, maybe normalize.""" + if x2 is None: + return self._eval_univariate(x1, weights) + else: + return self._eval_bivariate(x1, x2, weights) + + +class ECDF: + """Univariate empirical cumulative distribution estimator.""" + def __init__(self, stat="proportion", complementary=False): + """Initialize the class with its paramters + + Parameters + ---------- + stat : {{"proportion", "count"}} + Distribution statistic to compute. + complementary : bool + If True, use the complementary CDF (1 - CDF) + + """ + _check_argument("stat", ["count", "proportion"], stat) + self.stat = stat + self.complementary = complementary + + def _eval_bivariate(self, x1, x2, weights): + """Inner function for ECDF of two variables.""" + raise NotImplementedError("Bivariate ECDF is not implemented") + + def _eval_univariate(self, x, weights): + """Inner function for ECDF of one variable.""" + sorter = x.argsort() + x = x[sorter] + weights = weights[sorter] + y = weights.cumsum() + + if self.stat == "proportion": + y = y / y.max() + + x = np.r_[-np.inf, x] + y = np.r_[0, y] + + if self.complementary: + y = y.max() - y + + return y, x + + def __call__(self, x1, x2=None, weights=None): + """Return proportion or count of observations below each sorted datapoint.""" + x1 = np.asarray(x1) + if weights is None: + weights = np.ones_like(x1) + else: + weights = np.asarray(weights) + + if x2 is None: + return self._eval_univariate(x1, weights) + else: + return self._eval_bivariate(x1, x2, weights) + + +class EstimateAggregator: + + def __init__(self, estimator, errorbar=None, **boot_kws): + """ + Data aggregator that produces an estimate and error bar interval. + + Parameters + ---------- + estimator : callable or string + Function (or method name) that maps a vector to a scalar. + errorbar : string, (string, number) tuple, or callable + Name of errorbar method (either "ci", "pi", "se", or "sd"), or a tuple + with a method name and a level parameter, or a function that maps from a + vector to a (min, max) interval. See the :ref:`tutorial ` + for more information. + boot_kws + Additional keywords are passed to bootstrap when error_method is "ci". + + """ + self.estimator = estimator + + method, level = _validate_errorbar_arg(errorbar) + self.error_method = method + self.error_level = level + + self.boot_kws = boot_kws + + def __call__(self, data, var): + """Aggregate over `var` column of `data` with estimate and error interval.""" + vals = data[var] + estimate = vals.agg(self.estimator) + + # Options that produce no error bars + if self.error_method is None: + err_min = err_max = np.nan + elif len(data) <= 1: + err_min = err_max = np.nan + + # Generic errorbars from use-supplied function + elif callable(self.error_method): + err_min, err_max = self.error_method(vals) + + # Parametric options + elif self.error_method == "sd": + half_interval = vals.std() * self.error_level + err_min, err_max = estimate - half_interval, estimate + half_interval + elif self.error_method == "se": + half_interval = vals.sem() * self.error_level + err_min, err_max = estimate - half_interval, estimate + half_interval + + # Nonparametric options + elif self.error_method == "pi": + err_min, err_max = _percentile_interval(vals, self.error_level) + elif self.error_method == "ci": + units = data.get("units", None) + boots = bootstrap(vals, units=units, func=self.estimator, **self.boot_kws) + err_min, err_max = _percentile_interval(boots, self.error_level) + + return pd.Series({var: estimate, f"{var}min": err_min, f"{var}max": err_max}) + + +def _percentile_interval(data, width): + """Return a percentile interval from data of a given width.""" + edge = (100 - width) / 2 + percentiles = edge, 100 - edge + return np.percentile(data, percentiles) + + +def _validate_errorbar_arg(arg): + """Check type and value of errorbar argument and assign default level.""" + DEFAULT_LEVELS = { + "ci": 95, + "pi": 95, + "se": 1, + "sd": 1, + } + + usage = "`errorbar` must be a callable, string, or (string, number) tuple" + + if arg is None: + return None, None + elif callable(arg): + return arg, None + elif isinstance(arg, str): + method = arg + level = DEFAULT_LEVELS.get(method, None) + else: + try: + method, level = arg + except (ValueError, TypeError) as err: + raise err.__class__(usage) from err + + _check_argument("errorbar", list(DEFAULT_LEVELS), method) + if level is not None and not isinstance(level, Number): + raise TypeError(usage) + + return method, level diff --git a/seaborn/_testing.py b/seaborn/_testing.py new file mode 100644 index 0000000000..c6f821cbe2 --- /dev/null +++ b/seaborn/_testing.py @@ -0,0 +1,90 @@ +import numpy as np +import matplotlib as mpl +from matplotlib.colors import to_rgb, to_rgba +from numpy.testing import assert_array_equal + + +USE_PROPS = [ + "alpha", + "edgecolor", + "facecolor", + "fill", + "hatch", + "height", + "linestyle", + "linewidth", + "paths", + "xy", + "xydata", + "sizes", + "zorder", +] + + +def assert_artists_equal(list1, list2): + + assert len(list1) == len(list2) + for a1, a2 in zip(list1, list2): + assert a1.__class__ == a2.__class__ + prop1 = a1.properties() + prop2 = a2.properties() + for key in USE_PROPS: + if key not in prop1: + continue + v1 = prop1[key] + v2 = prop2[key] + if key == "paths": + for p1, p2 in zip(v1, v2): + assert_array_equal(p1.vertices, p2.vertices) + assert_array_equal(p1.codes, p2.codes) + elif key == "color": + v1 = mpl.colors.to_rgba(v1) + v2 = mpl.colors.to_rgba(v2) + assert v1 == v2 + elif isinstance(v1, np.ndarray): + assert_array_equal(v1, v2) + else: + assert v1 == v2 + + +def assert_legends_equal(leg1, leg2): + + assert leg1.get_title().get_text() == leg2.get_title().get_text() + for t1, t2 in zip(leg1.get_texts(), leg2.get_texts()): + assert t1.get_text() == t2.get_text() + + assert_artists_equal( + leg1.get_patches(), leg2.get_patches(), + ) + assert_artists_equal( + leg1.get_lines(), leg2.get_lines(), + ) + + +def assert_plots_equal(ax1, ax2, labels=True): + + assert_artists_equal(ax1.patches, ax2.patches) + assert_artists_equal(ax1.lines, ax2.lines) + assert_artists_equal(ax1.collections, ax2.collections) + + if labels: + assert ax1.get_xlabel() == ax2.get_xlabel() + assert ax1.get_ylabel() == ax2.get_ylabel() + + +def assert_colors_equal(a, b, check_alpha=True): + + def handle_array(x): + + if isinstance(x, np.ndarray): + if x.ndim > 1: + x = np.unique(x, axis=0).squeeze() + if x.ndim > 1: + raise ValueError("Color arrays must be 1 dimensional") + return x + + a = handle_array(a) + b = handle_array(b) + + f = to_rgba if check_alpha else to_rgb + assert f(a) == f(b) diff --git a/seaborn/algorithms.py b/seaborn/algorithms.py index 5b937755a6..0bcc6c1e38 100644 --- a/seaborn/algorithms.py +++ b/seaborn/algorithms.py @@ -1,10 +1,7 @@ """Algorithms to support fitting routines in seaborn plotting functions.""" -from __future__ import division +import numbers import numpy as np -from scipy import stats import warnings -from .external.six import string_types -from .external.six.moves import range def bootstrap(*args, **kwargs): @@ -14,22 +11,19 @@ def bootstrap(*args, **kwargs): axis and pass to a summary function. Keyword arguments: - n_boot : int, default 10000 + n_boot : int, default=10000 Number of iterations - axis : int, default None + axis : int, default=None Will pass axis to ``func`` as a keyword argument. - units : array, default None + units : array, default=None Array of sampling unit IDs. When used the bootstrap resamples units and then observations within units instead of individual datapoints. - smooth : bool, default False - If True, performs a smoothed bootstrap (draws samples from a kernel - density estimate); only works for one-dimensional inputs and cannot - be used `units` is present. - func : string or callable, default np.mean - Function to call on the args that are passed in. If string, tries - to use as named method on numpy array. - random_seed : int | None, default None + func : string or callable, default="mean" + Function to call on the args that are passed in. If string, uses as + name of function in the numpy namespace. If nans are present in the + data, will try to use nan-aware version of named function. + seed : Generator | SeedSequence | RandomState | int | None Seed for the random number generator; useful if you want reproducible resamples. @@ -46,50 +40,66 @@ def bootstrap(*args, **kwargs): # Default keyword arguments n_boot = kwargs.get("n_boot", 10000) - func = kwargs.get("func", np.mean) + func = kwargs.get("func", "mean") axis = kwargs.get("axis", None) units = kwargs.get("units", None) - smooth = kwargs.get("smooth", False) random_seed = kwargs.get("random_seed", None) + if random_seed is not None: + msg = "`random_seed` has been renamed to `seed` and will be removed" + warnings.warn(msg) + seed = kwargs.get("seed", random_seed) if axis is None: func_kwargs = dict() else: func_kwargs = dict(axis=axis) # Initialize the resampler - rs = np.random.RandomState(random_seed) + rng = _handle_random_seed(seed) # Coerce to arrays args = list(map(np.asarray, args)) if units is not None: units = np.asarray(units) - # Allow for a function that is the name of a method on an array - if isinstance(func, string_types): - def f(x): - return getattr(x, func)() + if isinstance(func, str): + + # Allow named numpy functions + f = getattr(np, func) + + # Try to use nan-aware version of function if necessary + missing_data = np.isnan(np.sum(np.column_stack(args))) + + if missing_data and not func.startswith("nan"): + nanf = getattr(np, f"nan{func}", None) + if nanf is None: + msg = f"Data contain nans but no nan-aware version of `{func}` found" + warnings.warn(msg, UserWarning) + else: + f = nanf + else: f = func - # Do the bootstrap - if smooth: - msg = "Smooth bootstraps are deprecated and will be removed." - warnings.warn(msg) - return _smooth_bootstrap(args, n_boot, f, func_kwargs) + # Handle numpy changes + try: + integers = rng.integers + except AttributeError: + integers = rng.randint + # Do the bootstrap if units is not None: return _structured_bootstrap(args, n_boot, units, f, - func_kwargs, rs) + func_kwargs, integers) boot_dist = [] for i in range(int(n_boot)): - resampler = rs.randint(0, n, n) + resampler = integers(0, n, n, dtype=np.intp) # intp is indexing dtype sample = [a.take(resampler, axis=0) for a in args] boot_dist.append(f(*sample, **func_kwargs)) return np.array(boot_dist) -def _structured_bootstrap(args, n_boot, units, func, func_kwargs, rs): +def _structured_bootstrap(args, n_boot, units, func, func_kwargs, integers): """Resample units instead of datapoints.""" unique_units = np.unique(units) n_units = len(unique_units) @@ -98,23 +108,35 @@ def _structured_bootstrap(args, n_boot, units, func, func_kwargs, rs): boot_dist = [] for i in range(int(n_boot)): - resampler = rs.randint(0, n_units, n_units) - sample = [np.take(a, resampler, axis=0) for a in args] + resampler = integers(0, n_units, n_units, dtype=np.intp) + sample = [[a[i] for i in resampler] for a in args] lengths = map(len, sample[0]) - resampler = [rs.randint(0, n, n) for n in lengths] - sample = [[c.take(r, axis=0) for c, r in zip(a, resampler)] - for a in sample] + resampler = [integers(0, n, n, dtype=np.intp) for n in lengths] + sample = [[c.take(r, axis=0) for c, r in zip(a, resampler)] for a in sample] sample = list(map(np.concatenate, sample)) boot_dist.append(func(*sample, **func_kwargs)) return np.array(boot_dist) -def _smooth_bootstrap(args, n_boot, func, func_kwargs): - """Bootstrap by resampling from a kernel density estimate.""" - n = len(args[0]) - boot_dist = [] - kde = [stats.gaussian_kde(np.transpose(a)) for a in args] - for i in range(int(n_boot)): - sample = [a.resample(n).T for a in kde] - boot_dist.append(func(*sample, **func_kwargs)) - return np.array(boot_dist) +def _handle_random_seed(seed=None): + """Given a seed in one of many formats, return a random number generator. + + Generalizes across the numpy 1.17 changes, preferring newer functionality. + + """ + if isinstance(seed, np.random.RandomState): + rng = seed + else: + try: + # General interface for seeding on numpy >= 1.17 + rng = np.random.default_rng(seed) + except AttributeError: + # We are on numpy < 1.17, handle options ourselves + if isinstance(seed, (numbers.Integral, np.integer)): + rng = np.random.RandomState(seed) + elif seed is None: + rng = np.random.RandomState() + else: + err = "{} cannot be used to seed the randomn number generator" + raise ValueError(err.format(seed)) + return rng diff --git a/seaborn/apionly.py b/seaborn/apionly.py deleted file mode 100644 index 1a27045b1c..0000000000 --- a/seaborn/apionly.py +++ /dev/null @@ -1,9 +0,0 @@ -import warnings -from seaborn import * # noqa -reset_orig() # noqa - -msg = ( - "As seaborn no longer sets a default style on import, the seaborn.apionly " - "module is deprecated. It will be removed in a future version." -) -warnings.warn(msg, UserWarning) diff --git a/seaborn/axisgrid.py b/seaborn/axisgrid.py index 4f2925bb99..89dc980117 100644 --- a/seaborn/axisgrid.py +++ b/seaborn/axisgrid.py @@ -1,33 +1,51 @@ -from __future__ import division from itertools import product -from distutils.version import LooseVersion +from inspect import signature import warnings from textwrap import dedent import numpy as np import pandas as pd -from scipy import stats import matplotlib as mpl import matplotlib.pyplot as plt +from ._core import VectorPlotter, variable_type, categorical_order from . import utils +from .utils import _check_argument, adjust_legend_subtitles, _draw_figure from .palettes import color_palette, blend_palette -from .external.six import string_types -from .distributions import distplot, kdeplot, _freedman_diaconis_bins +from ._decorators import _deprecate_positional_args +from ._docstrings import ( + DocstringComponents, + _core_docs, +) __all__ = ["FacetGrid", "PairGrid", "JointGrid", "pairplot", "jointplot"] -class Grid(object): +_param_docs = DocstringComponents.from_nested_components( + core=_core_docs["params"], +) + + +class Grid: """Base class for grids of subplots.""" _margin_titles = False _legend_out = True + def __init__(self): + + self._tight_layout_rect = [0, 0, 1, 1] + self._tight_layout_pad = None + + # This attribute is set externally and is a hack to handle newer functions that + # don't add proxy artists onto the Axes. We need an overall cleaner approach. + self._extract_legend_handles = False + def set(self, **kwargs): """Set attributes on each subplot Axes.""" for ax in self.axes.flat: - ax.set(**kwargs) + if ax is not None: # Handle removed axes + ax.set(**kwargs) return self def savefig(self, *args, **kwargs): @@ -36,20 +54,32 @@ def savefig(self, *args, **kwargs): kwargs.setdefault("bbox_inches", "tight") self.fig.savefig(*args, **kwargs) + def tight_layout(self, *args, **kwargs): + """Call fig.tight_layout within rect that exclude the legend.""" + kwargs = kwargs.copy() + kwargs.setdefault("rect", self._tight_layout_rect) + if self._tight_layout_pad is not None: + kwargs.setdefault("pad", self._tight_layout_pad) + self.fig.tight_layout(*args, **kwargs) + def add_legend(self, legend_data=None, title=None, label_order=None, - **kwargs): + adjust_subtitles=False, **kwargs): """Draw a legend, maybe placing it outside axes and resizing the figure. Parameters ---------- - legend_data : dict, optional - Dictionary mapping label names to matplotlib artist handles. The + legend_data : dict + Dictionary mapping label names (or two-element tuples where the + second element is a label name) to matplotlib artist handles. The default reads from ``self._legend_data``. - title : string, optional + title : string Title for the legend. The default reads from ``self._hue_var``. - label_order : list of labels, optional + label_order : list of labels The order that the legend entries should appear in. The default reads from ``self.hue_names``. + adjust_subtitles : bool + If True, modify entries with invisible artists to left-align + the labels and set the font size to that of a title. kwargs : key, value pairings Other keyword arguments are passed to the underlying legend methods on the Figure or Axes object. @@ -61,7 +91,8 @@ def add_legend(self, legend_data=None, title=None, label_order=None, """ # Find the data for the legend - legend_data = self._legend_data if legend_data is None else legend_data + if legend_data is None: + legend_data = self._legend_data if label_order is None: if self.hue_names is None: label_order = list(legend_data.keys()) @@ -71,10 +102,16 @@ def add_legend(self, legend_data=None, title=None, label_order=None, blank_handle = mpl.patches.Patch(alpha=0, linewidth=0) handles = [legend_data.get(l, blank_handle) for l in label_order] title = self._hue_var if title is None else title - try: - title_size = mpl.rcParams["axes.labelsize"] * .85 - except TypeError: # labelsize is something like "large" - title_size = mpl.rcParams["axes.labelsize"] + title_size = mpl.rcParams["legend.title_fontsize"] + + # Unpack nested labels from a hierarchical legend + labels = [] + for entry in label_order: + if isinstance(entry, tuple): + _, label = entry + else: + label = entry + labels.append(label) # Set default legend kwargs kwargs.setdefault("scatterpoints", 1) @@ -82,16 +119,19 @@ def add_legend(self, legend_data=None, title=None, label_order=None, if self._legend_out: kwargs.setdefault("frameon", False) + kwargs.setdefault("loc", "center right") # Draw a full-figure legend outside the grid - figlegend = self.fig.legend(handles, label_order, "center right", - **kwargs) + figlegend = self.fig.legend(handles, labels, **kwargs) + self._legend = figlegend figlegend.set_title(title, prop={"size": title_size}) + if adjust_subtitles: + adjust_legend_subtitles(figlegend) + # Draw the plot to set the bounding boxes correctly - if hasattr(self.fig.canvas, "get_renderer"): - self.fig.draw(self.fig.canvas.get_renderer()) + _draw_figure(self.fig) # Calculate and set the new width of the figure so the legend fits legend_width = figlegend.get_window_extent().width / self.fig.dpi @@ -99,8 +139,7 @@ def add_legend(self, legend_data=None, title=None, label_order=None, self.fig.set_size_inches(fig_width + legend_width, fig_height) # Draw the plot again to get the new transformations - if hasattr(self.fig.canvas, "get_renderer"): - self.fig.draw(self.fig.canvas.get_renderer()) + _draw_figure(self.fig) # Now calculate how much space we need on the right side legend_width = figlegend.get_window_extent().width / self.fig.dpi @@ -111,12 +150,19 @@ def add_legend(self, legend_data=None, title=None, label_order=None, # Place the subplot axes to give space for the legend self.fig.subplots_adjust(right=right) + self._tight_layout_rect[2] = right else: # Draw a legend in the first axis ax = self.axes.flat[0] - leg = ax.legend(handles, label_order, loc="best", **kwargs) + kwargs.setdefault("loc", "best") + + leg = ax.legend(handles, labels, **kwargs) leg.set_title(title, prop={"size": title_size}) + self._legend = leg + + if adjust_subtitles: + adjust_legend_subtitles(leg) return self @@ -129,8 +175,15 @@ def _clean_axis(self, ax): def _update_legend_data(self, ax): """Extract the legend data from an axes object and save it.""" + data = {} + if ax.legend_ is not None and self._extract_legend_handles: + handles = ax.legend_.legendHandles + labels = [t.get_text() for t in ax.legend_.texts] + data.update({l: h for h, l in zip(handles, labels)}) + handles, labels = ax.get_legend_handles_labels() - data = {l: h for h, l in zip(handles, labels)} + data.update({l: h for h, l in zip(handles, labels)}) + self._legend_data.update(data) def _get_palette(self, data, hue, hue_order, palette): @@ -139,7 +192,7 @@ def _get_palette(self, data, hue, hue_order, palette): palette = color_palette(n_colors=1) else: - hue_names = utils.categorical_order(data[hue], hue_order) + hue_names = categorical_order(data[hue], hue_order) n_colors = len(hue_names) # By default use either the current color palette or HUSL @@ -163,6 +216,14 @@ def _get_palette(self, data, hue, hue_order, palette): return palette + @property + def legend(self): + """The :class:`matplotlib.legend.Legend` object, if present.""" + try: + return self._legend + except AttributeError: + return None + _facet_docs = dict( @@ -171,8 +232,17 @@ def _get_palette(self, data, hue, hue_order, palette): Tidy ("long-form") dataframe where each column is a variable and each row is an observation.\ """), + rowcol=dedent("""\ + row, col : vectors or keys in ``data`` + Variables that define subsets to plot on different facets.\ + """), + rowcol_order=dedent("""\ + {row,col}_order : vector of strings + Specify the order in which levels of the ``row`` and/or ``col`` variables + appear in the grid of subplots.\ + """), col_wrap=dedent("""\ - col_wrap : int, optional + col_wrap : int "Wrap" the column variable at this width, so that the column facets span multiple rows. Incompatible with a ``row`` facet.\ """), @@ -182,50 +252,57 @@ def _get_palette(self, data, hue, hue_order, palette): across rows.\ """), height=dedent("""\ - height : scalar, optional + height : scalar Height (in inches) of each facet. See also: ``aspect``.\ """), aspect=dedent("""\ - aspect : scalar, optional + aspect : scalar Aspect ratio of each facet, so that ``aspect * height`` gives the width of each facet in inches.\ """), palette=dedent("""\ - palette : palette name, list, or dict, optional + palette : palette name, list, or dict Colors to use for the different levels of the ``hue`` variable. Should be something that can be interpreted by :func:`color_palette`, or a dictionary mapping hue levels to matplotlib colors.\ """), legend_out=dedent("""\ - legend_out : bool, optional + legend_out : bool If ``True``, the figure size will be extended, and the legend will be drawn outside the plot on the center right.\ """), margin_titles=dedent("""\ - margin_titles : bool, optional + margin_titles : bool If ``True``, the titles for the row variable are drawn to the right of the last column. This option is experimental and may not work in all cases.\ """), - ) + facet_kws=dedent("""\ + facet_kws : dict + Additional parameters passed to :class:`FacetGrid`. + """), +) class FacetGrid(Grid): """Multi-plot grid for plotting conditional relationships.""" - def __init__(self, data, row=None, col=None, hue=None, col_wrap=None, - sharex=True, sharey=True, height=3, aspect=1, palette=None, - row_order=None, col_order=None, hue_order=None, hue_kws=None, - dropna=True, legend_out=True, despine=True, - margin_titles=False, xlim=None, ylim=None, subplot_kws=None, - gridspec_kws=None, size=None): - - MPL_GRIDSPEC_VERSION = LooseVersion('1.4') - OLD_MPL = LooseVersion(mpl.__version__) < MPL_GRIDSPEC_VERSION + @_deprecate_positional_args + def __init__( + self, data, *, + row=None, col=None, hue=None, col_wrap=None, + sharex=True, sharey=True, height=3, aspect=1, palette=None, + row_order=None, col_order=None, hue_order=None, hue_kws=None, + dropna=False, legend_out=True, despine=True, + margin_titles=False, xlim=None, ylim=None, subplot_kws=None, + gridspec_kws=None, size=None + ): + + super(FacetGrid, self).__init__() # Handle deprecations if size is not None: height = size - msg = ("The `size` paramter has been renamed to `height`; " + msg = ("The `size` parameter has been renamed to `height`; " "please update your code.") warnings.warn(msg, UserWarning) @@ -234,7 +311,7 @@ def __init__(self, data, row=None, col=None, hue=None, col_wrap=None, if hue is None: hue_names = None else: - hue_names = utils.categorical_order(data[hue], hue_order) + hue_names = categorical_order(data[hue], hue_order) colors = self._get_palette(data, hue, hue_order, palette) @@ -242,19 +319,19 @@ def __init__(self, data, row=None, col=None, hue=None, col_wrap=None, if row is None: row_names = [] else: - row_names = utils.categorical_order(data[row], row_order) + row_names = categorical_order(data[row], row_order) if col is None: col_names = [] else: - col_names = utils.categorical_order(data[col], col_order) + col_names = categorical_order(data[col], col_order) # Additional dict of kwarg -> list of values for mapping the hue var hue_kws = hue_kws if hue_kws is not None else {} # Make a boolean mask that is True anywhere there is an NA # value in one of the faceting variables, but only if dropna is True - none_na = np.zeros(len(data), np.bool) + none_na = np.zeros(len(data), bool) if dropna: row_na = none_na if row is None else data[row].isnull() col_na = none_na if col is None else data[col].isnull() @@ -295,23 +372,28 @@ def __init__(self, data, row=None, col=None, hue=None, col_wrap=None, if ylim is not None: subplot_kws["ylim"] = ylim - # Initialize the subplot grid + # --- Initialize the subplot grid if col_wrap is None: + kwargs = dict(figsize=figsize, squeeze=False, sharex=sharex, sharey=sharey, subplot_kw=subplot_kws, gridspec_kw=gridspec_kws) - if OLD_MPL: - kwargs.pop('gridspec_kw', None) - if gridspec_kws: - msg = "gridspec module only available in mpl >= {}" - warnings.warn(msg.format(MPL_GRIDSPEC_VERSION)) - fig, axes = plt.subplots(nrow, ncol, **kwargs) - self.axes = axes + + if col is None and row is None: + axes_dict = {} + elif col is None: + axes_dict = dict(zip(row_names, axes.flat)) + elif row is None: + axes_dict = dict(zip(col_names, axes.flat)) + else: + facet_product = product(row_names, col_names) + axes_dict = dict(zip(facet_product, axes.flat)) else: + # If wrapping the col variable we need to make the grid ourselves if gridspec_kws: warnings.warn("`gridspec_kws` ignored when using `col_wrap`") @@ -326,28 +408,21 @@ def __init__(self, data, row=None, col=None, hue=None, col_wrap=None, subplot_kws["sharey"] = axes[0] for i in range(1, n_axes): axes[i] = fig.add_subplot(nrow, ncol, i + 1, **subplot_kws) - self.axes = axes - # Now we turn off labels on the inner axes - if sharex: - for ax in self._not_bottom_axes: - for label in ax.get_xticklabels(): - label.set_visible(False) - ax.xaxis.offsetText.set_visible(False) - if sharey: - for ax in self._not_left_axes: - for label in ax.get_yticklabels(): - label.set_visible(False) - ax.yaxis.offsetText.set_visible(False) + axes_dict = dict(zip(col_names, axes)) - # Set up the class attributes - # --------------------------- + # --- Set up the class attributes - # First the public API - self.data = data - self.fig = fig - self.axes = axes + # Attributes that are part of the public API but accessed through + # a property so that Sphinx adds them to the auto class doc + self._fig = fig + self._axes = axes + self._axes_dict = axes_dict + self._legend = None + # Public attributes that aren't explicitly documented + # (It's not obvious that having them be public was a good idea) + self.data = data self.row_names = row_names self.col_names = col_names self.hue_names = hue_names @@ -360,22 +435,37 @@ def __init__(self, data, row=None, col=None, hue=None, col_wrap=None, self._col_var = col self._margin_titles = margin_titles + self._margin_titles_texts = [] self._col_wrap = col_wrap self._hue_var = hue_var self._colors = colors self._legend_out = legend_out - self._legend = None self._legend_data = {} self._x_var = None self._y_var = None + self._sharex = sharex + self._sharey = sharey self._dropna = dropna self._not_na = not_na - # Make the axes look good - fig.tight_layout() + # --- Make the axes look good + + self.tight_layout() if despine: self.despine() + if sharex in [True, 'col']: + for ax in self._not_bottom_axes: + for label in ax.get_xticklabels(): + label.set_visible(False) + ax.xaxis.offsetText.set_visible(False) + + if sharey in [True, 'row']: + for ax in self._not_left_axes: + for label in ax.get_yticklabels(): + label.set_visible(False) + ax.yaxis.offsetText.set_visible(False) + __init__.__doc__ = dedent("""\ Initialize the matplotlib figure and FacetGrid object. @@ -384,19 +474,13 @@ def __init__(self, data, row=None, col=None, hue=None, col_wrap=None, The plots it produces are often called "lattice", "trellis", or "small-multiple" graphics. - It can also represent levels of a third varaible with the ``hue`` - parameter, which plots different subets of data in different colors. + It can also represent levels of a third variable with the ``hue`` + parameter, which plots different subsets of data in different colors. This uses color to resolve elements on a third dimension, but only draws subsets on top of each other and will not tailor the ``hue`` parameter for the specific visualization the way that axes-level functions that accept ``hue`` will. - When using seaborn functions that infer semantic mappings from a - dataset, care must be taken to synchronize those mappings across - facets. In most cases, it will be better to use a figure-level function - (e.g. :func:`relplot` or :func:`catplot`) than to use - :class:`FacetGrid` directly. - The basic workflow is to initialize the :class:`FacetGrid` object with the dataset and the variables that are used to structure the grid. Then one or more plotting functions can be applied to each subset by calling @@ -405,6 +489,15 @@ def __init__(self, data, row=None, col=None, hue=None, col_wrap=None, axis labels, use different ticks, or add a legend. See the detailed code examples below for more information. + .. warning:: + + When using seaborn functions that infer semantic mappings from a + dataset, care must be taken to synchronize those mappings across + facets (e.g., by defing the ``hue`` mapping with a palette dict or + setting the data type of the variables to ``category``). In most cases, + it will be better to use a figure-level function (e.g. :func:`relplot` + or :func:`catplot`) than to use :class:`FacetGrid` directly. + See the :ref:`tutorial ` for more information. Parameters @@ -412,14 +505,14 @@ def __init__(self, data, row=None, col=None, hue=None, col_wrap=None, {data} row, col, hue : strings Variables that define subsets of the data, which will be drawn on - separate facets in the grid. See the ``*_order`` parameters to + separate facets in the grid. See the ``{{var}}_order`` parameters to control the order of levels of this variable. {col_wrap} {share_xy} {height} {aspect} {palette} - {{row,col,hue}}_order : lists, optional + {{row,col,hue}}_order : lists Order for the levels of the faceting variables. By default, this will be the order that the levels appear in ``data`` or, if the variables are pandas categoricals, the category order. @@ -428,210 +521,40 @@ def __init__(self, data, row=None, col=None, hue=None, col_wrap=None, other plot attributes vary across levels of the hue variable (e.g. the markers in a scatterplot). {legend_out} - despine : boolean, optional + despine : boolean Remove the top and right spines from the plots. {margin_titles} - {{x, y}}lim: tuples, optional + {{x, y}}lim: tuples Limits for each of the axes on each facet (only relevant when - share{{x, y}} is True. - subplot_kws : dict, optional + share{{x, y}} is True). + subplot_kws : dict Dictionary of keyword arguments passed to matplotlib subplot(s) methods. - gridspec_kws : dict, optional - Dictionary of keyword arguments passed to matplotlib's ``gridspec`` - module (via ``plt.subplots``). Requires matplotlib >= 1.4 and is - ignored if ``col_wrap`` is not ``None``. + gridspec_kws : dict + Dictionary of keyword arguments passed to + :class:`matplotlib.gridspec.GridSpec` + (via :func:`matplotlib.pyplot.subplots`). + Ignored if ``col_wrap`` is not ``None``. See Also -------- - PairGrid : Subplot grid for plotting pairwise relationships. - relplot : Combine a relational plot and a :class:`FacetGrid`. - catplot : Combine a categorical plot and a :class:`FacetGrid`. - lmplot : Combine a regression plot and a :class:`FacetGrid`. + PairGrid : Subplot grid for plotting pairwise relationships + relplot : Combine a relational plot and a :class:`FacetGrid` + displot : Combine a distribution plot and a :class:`FacetGrid` + catplot : Combine a categorical plot and a :class:`FacetGrid` + lmplot : Combine a regression plot and a :class:`FacetGrid` Examples -------- - Initialize a 2x2 grid of facets using the tips dataset: - - .. plot:: - :context: close-figs - - >>> import seaborn as sns; sns.set(style="ticks", color_codes=True) - >>> tips = sns.load_dataset("tips") - >>> g = sns.FacetGrid(tips, col="time", row="smoker") - - Draw a univariate plot on each facet: - - .. plot:: - :context: close-figs - - >>> import matplotlib.pyplot as plt - >>> g = sns.FacetGrid(tips, col="time", row="smoker") - >>> g = g.map(plt.hist, "total_bill") + .. note:: - (Note that it's not necessary to re-catch the returned variable; it's - the same object, but doing so in the examples makes dealing with the - doctests somewhat less annoying). + These examples use seaborn functions to demonstrate some of the + advanced features of the class, but in most cases you will want + to use figue-level functions (e.g. :func:`displot`, :func:`relplot`) + to make the plots shown here. - Pass additional keyword arguments to the mapped function: - - .. plot:: - :context: close-figs - - >>> import numpy as np - >>> bins = np.arange(0, 65, 5) - >>> g = sns.FacetGrid(tips, col="time", row="smoker") - >>> g = g.map(plt.hist, "total_bill", bins=bins, color="r") - - Plot a bivariate function on each facet: - - .. plot:: - :context: close-figs - - >>> g = sns.FacetGrid(tips, col="time", row="smoker") - >>> g = g.map(plt.scatter, "total_bill", "tip", edgecolor="w") - - Assign one of the variables to the color of the plot elements: - - .. plot:: - :context: close-figs - - >>> g = sns.FacetGrid(tips, col="time", hue="smoker") - >>> g = (g.map(plt.scatter, "total_bill", "tip", edgecolor="w") - ... .add_legend()) - - Change the height and aspect ratio of each facet: - - .. plot:: - :context: close-figs - - >>> g = sns.FacetGrid(tips, col="day", height=4, aspect=.5) - >>> g = g.map(plt.hist, "total_bill", bins=bins) - - Specify the order for plot elements: - - .. plot:: - :context: close-figs - - >>> g = sns.FacetGrid(tips, col="smoker", col_order=["Yes", "No"]) - >>> g = g.map(plt.hist, "total_bill", bins=bins, color="m") - - Use a different color palette: - - .. plot:: - :context: close-figs - - >>> kws = dict(s=50, linewidth=.5, edgecolor="w") - >>> g = sns.FacetGrid(tips, col="sex", hue="time", palette="Set1", - ... hue_order=["Dinner", "Lunch"]) - >>> g = (g.map(plt.scatter, "total_bill", "tip", **kws) - ... .add_legend()) - - Use a dictionary mapping hue levels to colors: - - .. plot:: - :context: close-figs - - >>> pal = dict(Lunch="seagreen", Dinner="gray") - >>> g = sns.FacetGrid(tips, col="sex", hue="time", palette=pal, - ... hue_order=["Dinner", "Lunch"]) - >>> g = (g.map(plt.scatter, "total_bill", "tip", **kws) - ... .add_legend()) - - Additionally use a different marker for the hue levels: - - .. plot:: - :context: close-figs - - >>> g = sns.FacetGrid(tips, col="sex", hue="time", palette=pal, - ... hue_order=["Dinner", "Lunch"], - ... hue_kws=dict(marker=["^", "v"])) - >>> g = (g.map(plt.scatter, "total_bill", "tip", **kws) - ... .add_legend()) - - "Wrap" a column variable with many levels into the rows: - - .. plot:: - :context: close-figs - - >>> att = sns.load_dataset("attention") - >>> g = sns.FacetGrid(att, col="subject", col_wrap=5, height=1.5) - >>> g = g.map(plt.plot, "solutions", "score", marker=".") - - Define a custom bivariate function to map onto the grid: - - .. plot:: - :context: close-figs - - >>> from scipy import stats - >>> def qqplot(x, y, **kwargs): - ... _, xr = stats.probplot(x, fit=False) - ... _, yr = stats.probplot(y, fit=False) - ... plt.scatter(xr, yr, **kwargs) - >>> g = sns.FacetGrid(tips, col="smoker", hue="sex") - >>> g = (g.map(qqplot, "total_bill", "tip", **kws) - ... .add_legend()) - - Define a custom function that uses a ``DataFrame`` object and accepts - column names as positional variables: - - .. plot:: - :context: close-figs - - >>> import pandas as pd - >>> df = pd.DataFrame( - ... data=np.random.randn(90, 4), - ... columns=pd.Series(list("ABCD"), name="walk"), - ... index=pd.date_range("2015-01-01", "2015-03-31", - ... name="date")) - >>> df = df.cumsum(axis=0).stack().reset_index(name="val") - >>> def dateplot(x, y, **kwargs): - ... ax = plt.gca() - ... data = kwargs.pop("data") - ... data.plot(x=x, y=y, ax=ax, grid=False, **kwargs) - >>> g = sns.FacetGrid(df, col="walk", col_wrap=2, height=3.5) - >>> g = g.map_dataframe(dateplot, "date", "val") - - Use different axes labels after plotting: - - .. plot:: - :context: close-figs - - >>> g = sns.FacetGrid(tips, col="smoker", row="sex") - >>> g = (g.map(plt.scatter, "total_bill", "tip", color="g", **kws) - ... .set_axis_labels("Total bill (US Dollars)", "Tip")) - - Set other attributes that are shared across the facetes: - - .. plot:: - :context: close-figs - - >>> g = sns.FacetGrid(tips, col="smoker", row="sex") - >>> g = (g.map(plt.scatter, "total_bill", "tip", color="r", **kws) - ... .set(xlim=(0, 60), ylim=(0, 12), - ... xticks=[10, 30, 50], yticks=[2, 6, 10])) - - Use a different template for the facet titles: - - .. plot:: - :context: close-figs - - >>> g = sns.FacetGrid(tips, col="size", col_wrap=3) - >>> g = (g.map(plt.hist, "tip", bins=np.arange(0, 13), color="c") - ... .set_titles("{{col_name}} diners")) - - Tighten the facets: - - .. plot:: - :context: close-figs - - >>> g = sns.FacetGrid(tips, col="smoker", row="sex", - ... margin_titles=True) - >>> g = (g.map(plt.scatter, "total_bill", "tip", color="m", **kws) - ... .set(xlim=(0, 60), ylim=(0, 12), - ... xticks=[10, 30, 50], yticks=[2, 6, 10]) - ... .fig.subplots_adjust(wspace=.05, hspace=.05)) + .. include:: ../docstrings/FacetGrid.rst """).format(**_facet_docs) @@ -701,10 +624,8 @@ def map(self, func, *args, **kwargs): # If color was a keyword argument, grab it here kw_color = kwargs.pop("color", None) - if hasattr(func, "__module__"): - func_module = str(func.__module__) - else: - func_module = "" + # How we use the function depends on where it comes from + func_module = str(getattr(func, "__module__", "")) # Check for categorical plots without order information if func_module == "seaborn.categorical": @@ -727,7 +648,8 @@ def map(self, func, *args, **kwargs): continue # Get the current axis - ax = self.facet_axis(row_i, col_j) + modify_state = not func_module.startswith("seaborn") + ax = self.facet_axis(row_i, col_j, modify_state) # Decide what color to plot with kwargs["color"] = self._facet_color(hue_k, kw_color) @@ -798,7 +720,8 @@ def map_dataframe(self, func, *args, **kwargs): continue # Get the current axis - ax = self.facet_axis(row_i, col_j) + modify_state = not str(func.__module__).startswith("seaborn") + ax = self.facet_axis(row_i, col_j, modify_state) # Decide what color to plot with kwargs["color"] = self._facet_color(hue_k, kw_color) @@ -835,6 +758,13 @@ def _facet_color(self, hue_index, kw_color): def _facet_plot(self, func, ax, plot_args, plot_kwargs): # Draw the plot + if str(func.__module__).startswith("seaborn"): + plot_kwargs = plot_kwargs.copy() + semantics = ["x", "y", "hue", "size", "style"] + for key, val in zip(semantics, plot_args): + plot_kwargs[key] = val + plot_args = [] + plot_kwargs["ax"] = ax func(*plot_args, **plot_kwargs) # Sort out the supporting information @@ -845,9 +775,9 @@ def _finalize_grid(self, axlabels): """Finalize the annotations and layout.""" self.set_axis_labels(*axlabels) self.set_titles() - self.fig.tight_layout() + self.tight_layout() - def facet_axis(self, row_i, col_j): + def facet_axis(self, row_i, col_j, modify_state=True): """Make the axis identified by these indices active and return it.""" # Calculate the actual indices of the axes to plot on @@ -857,7 +787,8 @@ def facet_axis(self, row_i, col_j): ax = self.axes[row_i, col_j] # Get a reference to the axes object we want, and make it active - plt.sca(ax) + if modify_state: + plt.sca(ax) return ax def despine(self, **kwargs): @@ -865,35 +796,44 @@ def despine(self, **kwargs): utils.despine(self.fig, **kwargs) return self - def set_axis_labels(self, x_var=None, y_var=None): + def set_axis_labels(self, x_var=None, y_var=None, clear_inner=True, **kwargs): """Set axis labels on the left column and bottom row of the grid.""" if x_var is not None: self._x_var = x_var - self.set_xlabels(x_var) + self.set_xlabels(x_var, clear_inner=clear_inner, **kwargs) if y_var is not None: self._y_var = y_var - self.set_ylabels(y_var) + self.set_ylabels(y_var, clear_inner=clear_inner, **kwargs) + return self - def set_xlabels(self, label=None, **kwargs): + def set_xlabels(self, label=None, clear_inner=True, **kwargs): """Label the x axis on the bottom row of the grid.""" if label is None: label = self._x_var for ax in self._bottom_axes: ax.set_xlabel(label, **kwargs) + if clear_inner: + for ax in self._not_bottom_axes: + ax.set_xlabel("") return self - def set_ylabels(self, label=None, **kwargs): + def set_ylabels(self, label=None, clear_inner=True, **kwargs): """Label the y axis on the left column of the grid.""" if label is None: label = self._y_var for ax in self._left_axes: ax.set_ylabel(label, **kwargs) + if clear_inner: + for ax in self._not_left_axes: + ax.set_ylabel("") return self def set_xticklabels(self, labels=None, step=None, **kwargs): """Set x axis tick labels of the grid.""" for ax in self.axes.flat: + curr_ticks = ax.get_xticks() + ax.set_xticks(curr_ticks) if labels is None: curr_labels = [l.get_text() for l in ax.get_xticklabels()] if step is not None: @@ -908,6 +848,8 @@ def set_xticklabels(self, labels=None, step=None, **kwargs): def set_yticklabels(self, labels=None, **kwargs): """Set y axis tick labels on the left column of the grid.""" for ax in self.axes.flat: + curr_ticks = ax.get_yticks() + ax.set_yticks(curr_ticks) if labels is None: curr_labels = [l.get_text() for l in ax.get_yticklabels()] ax.set_yticklabels(curr_labels, **kwargs) @@ -915,7 +857,7 @@ def set_yticklabels(self, labels=None, **kwargs): ax.set_yticklabels(labels, **kwargs) return self - def set_titles(self, template=None, row_template=None, col_template=None, + def set_titles(self, template=None, row_template=None, col_template=None, **kwargs): """Draw titles either above each facet or on the grid margins. @@ -959,16 +901,24 @@ def set_titles(self, template=None, row_template=None, col_template=None, template = utils.to_utf8(template) if self._margin_titles: + + # Remove any existing title texts + for text in self._margin_titles_texts: + text.remove() + self._margin_titles_texts = [] + if self.row_names is not None: # Draw the row titles on the right edge of the grid for i, row_name in enumerate(self.row_names): ax = self.axes[i, -1] args.update(dict(row_name=row_name)) title = row_template.format(**args) - bgcolor = self.fig.get_facecolor() - ax.annotate(title, xy=(1.02, .5), xycoords="axes fraction", - rotation=270, ha="left", va="center", - backgroundcolor=bgcolor, **kwargs) + text = ax.annotate( + title, xy=(1.02, .5), xycoords="axes fraction", + rotation=270, ha="left", va="center", + **kwargs + ) + self._margin_titles_texts.append(text) if self.col_names is not None: # Draw the column titles as normal titles @@ -999,16 +949,42 @@ def set_titles(self, template=None, row_template=None, col_template=None, self.axes.flat[i].set_title(title, **kwargs) return self + # ------ Properties that are part of the public API and documented by Sphinx + + @property + def fig(self): + """The :class:`matplotlib.figure.Figure` with the plot.""" + return self._fig + + @property + def axes(self): + """An array of the :class:`matplotlib.axes.Axes` objects in the grid.""" + return self._axes + @property def ax(self): - """Easy access to single axes.""" + """The :class:`matplotlib.axes.Axes` when no faceting variables are assigned.""" if self.axes.shape == (1, 1): return self.axes[0, 0] else: - err = ("You must use the `.axes` attribute (an array) when " - "there is more than one plot.") + err = ( + "Use the `.axes` attribute when facet variables are assigned." + ) raise AttributeError(err) + @property + def axes_dict(self): + """A mapping of facet names to corresponding :class:`matplotlib.axes.Axes`. + + If only one of ``row`` or ``col`` is assigned, each key is a string + representing a level of that variable. If both facet dimensions are + assigned, each key is a ``({row_level}, {col_level})`` tuple. + + """ + return self._axes_dict + + # ------ Private properties, that require some computation to get + @property def _inner_axes(self): """Return a flat array of the inner axes.""" @@ -1018,9 +994,11 @@ def _inner_axes(self): axes = [] n_empty = self._nrow * self._ncol - self._n_facets for i, ax in enumerate(self.axes): - append = (i % self._ncol and - i < (self._ncol * (self._nrow - 1)) and - i < (self._ncol * (self._nrow - 1) - n_empty)) + append = ( + i % self._ncol + and i < (self._ncol * (self._nrow - 1)) + and i < (self._ncol * (self._nrow - 1) - n_empty) + ) if append: axes.append(ax) return np.array(axes, object).flat @@ -1058,8 +1036,10 @@ def _bottom_axes(self): axes = [] n_empty = self._nrow * self._ncol - self._n_facets for i, ax in enumerate(self.axes): - append = (i >= (self._ncol * (self._nrow - 1)) or - i >= (self._ncol * (self._nrow - 1) - n_empty)) + append = ( + i >= (self._ncol * (self._nrow - 1)) + or i >= (self._ncol * (self._nrow - 1) - n_empty) + ) if append: axes.append(ax) return np.array(axes, object).flat @@ -1073,8 +1053,10 @@ def _not_bottom_axes(self): axes = [] n_empty = self._nrow * self._ncol - self._n_facets for i, ax in enumerate(self.axes): - append = (i < (self._ncol * (self._nrow - 1)) and - i < (self._ncol * (self._nrow - 1) - n_empty)) + append = ( + i < (self._ncol * (self._nrow - 1)) + and i < (self._ncol * (self._nrow - 1) - n_empty) + ) if append: axes.append(ax) return np.array(axes, object).flat @@ -1083,26 +1065,25 @@ def _not_bottom_axes(self): class PairGrid(Grid): """Subplot grid for plotting pairwise relationships in a dataset. - This class maps each variable in a dataset onto a column and row in a + This object maps each variable in a dataset onto a column and row in a grid of multiple axes. Different axes-level plotting functions can be used to draw bivariate plots in the upper and lower triangles, and the the marginal distribution of each variable can be shown on the diagonal. - It can also represent an additional level of conditionalization with the - ``hue`` parameter, which plots different subets of data in different - colors. This uses color to resolve elements on a third dimension, but - only draws subsets on top of each other and will not tailor the ``hue`` - parameter for the specific visualization the way that axes-level functions - that accept ``hue`` will. + Several different common plots can be generated in a single line using + :func:`pairplot`. Use :class:`PairGrid` when you need more flexibility. See the :ref:`tutorial ` for more information. """ - - def __init__(self, data, hue=None, hue_order=None, palette=None, - hue_kws=None, vars=None, x_vars=None, y_vars=None, - corner=False, diag_sharey=True, height=2.5, aspect=1, - layout_pad=0, despine=True, dropna=True, size=None): + @_deprecate_positional_args + def __init__( + self, data, *, + hue=None, hue_order=None, palette=None, + hue_kws=None, vars=None, x_vars=None, y_vars=None, + corner=False, diag_sharey=True, height=2.5, aspect=1, + layout_pad=.5, despine=True, dropna=False, size=None + ): """Initialize the plot figure and PairGrid object. Parameters @@ -1110,7 +1091,7 @@ def __init__(self, data, hue=None, hue_order=None, palette=None, data : DataFrame Tidy (long-form) dataframe where each column is a variable and each row is an observation. - hue : string (variable name), optional + hue : string (variable name) Variable in ``data`` to map plot aspects to different colors. This variable will be excluded from the default x and y variables. hue_order : list of strings @@ -1122,24 +1103,24 @@ def __init__(self, data, hue=None, hue_order=None, palette=None, Other keyword arguments to insert into the plotting call to let other plot attributes vary across levels of the hue variable (e.g. the markers in a scatterplot). - vars : list of variable names, optional + vars : list of variable names Variables within ``data`` to use, otherwise use every column with a numeric datatype. - {x, y}_vars : lists of variable names, optional + {x, y}_vars : lists of variable names Variables within ``data`` to use separately for the rows and columns of the figure; i.e. to make a non-square plot. - corner : bool, optional + corner : bool If True, don't add axes to the upper (off-diagonal) triangle of the grid, making this a "corner" plot. - height : scalar, optional + height : scalar Height (in inches) of each facet. - aspect : scalar, optional + aspect : scalar Aspect * height gives the width (in inches) of each facet. - layout_pad : scalar, optional + layout_pad : scalar Padding between axes; passed to ``fig.tight_layout``. - despine : boolean, optional + despine : boolean Remove the top and right spines from the plots. - dropna : boolean, optional + dropna : boolean Drop missing values from the data before plotting. See Also @@ -1150,118 +1131,29 @@ def __init__(self, data, hue=None, hue_order=None, palette=None, Examples -------- - Draw a scatterplot for each pairwise relationship: - - .. plot:: - :context: close-figs - - >>> import matplotlib.pyplot as plt - >>> import seaborn as sns; sns.set() - >>> iris = sns.load_dataset("iris") - >>> g = sns.PairGrid(iris) - >>> g = g.map(plt.scatter) - - Show a univariate distribution on the diagonal: - - .. plot:: - :context: close-figs - - >>> g = sns.PairGrid(iris) - >>> g = g.map_diag(plt.hist) - >>> g = g.map_offdiag(plt.scatter) - - (It's not actually necessary to catch the return value every time, - as it is the same object, but it makes it easier to deal with the - doctests). - - Color the points using a categorical variable: - - .. plot:: - :context: close-figs - - >>> g = sns.PairGrid(iris, hue="species") - >>> g = g.map_diag(plt.hist) - >>> g = g.map_offdiag(plt.scatter) - >>> g = g.add_legend() - - Use a different style to show multiple histograms: - - .. plot:: - :context: close-figs - - >>> g = sns.PairGrid(iris, hue="species") - >>> g = g.map_diag(plt.hist, histtype="step", linewidth=3) - >>> g = g.map_offdiag(plt.scatter) - >>> g = g.add_legend() - - Plot a subset of variables - - .. plot:: - :context: close-figs - - >>> g = sns.PairGrid(iris, vars=["sepal_length", "sepal_width"]) - >>> g = g.map(plt.scatter) - - Pass additional keyword arguments to the functions - - .. plot:: - :context: close-figs - - >>> g = sns.PairGrid(iris) - >>> g = g.map_diag(plt.hist, edgecolor="w") - >>> g = g.map_offdiag(plt.scatter, edgecolor="w", s=40) - - Use different variables for the rows and columns: - - .. plot:: - :context: close-figs - - >>> g = sns.PairGrid(iris, - ... x_vars=["sepal_length", "sepal_width"], - ... y_vars=["petal_length", "petal_width"]) - >>> g = g.map(plt.scatter) - - Use different functions on the upper and lower triangles: - - .. plot:: - :context: close-figs - - >>> g = sns.PairGrid(iris) - >>> g = g.map_upper(plt.scatter) - >>> g = g.map_lower(sns.kdeplot, cmap="Blues_d") - >>> g = g.map_diag(sns.kdeplot, lw=3, legend=False) - - Use different colors and markers for each categorical level: - - .. plot:: - :context: close-figs - - >>> g = sns.PairGrid(iris, hue="species", palette="Set2", - ... hue_kws={"marker": ["o", "s", "D"]}) - >>> g = g.map(plt.scatter, linewidths=1, edgecolor="w", s=40) - >>> g = g.add_legend() + .. include:: ../docstrings/PairGrid.rst """ + super(PairGrid, self).__init__() + # Handle deprecations if size is not None: height = size - msg = ("The `size` paramter has been renamed to `height`; " + msg = ("The `size` parameter has been renamed to `height`; " "please update your code.") warnings.warn(UserWarning(msg)) # Sort out the variables that define the grid + numeric_cols = self._find_numeric_cols(data) + if hue in numeric_cols: + numeric_cols.remove(hue) if vars is not None: x_vars = list(vars) y_vars = list(vars) - elif (x_vars is not None) or (y_vars is not None): - if (x_vars is None) or (y_vars is None): - raise ValueError("Must specify `x_vars` and `y_vars`") - else: - numeric_cols = self._find_numeric_cols(data) - if hue in numeric_cols: - numeric_cols.remove(hue) + if x_vars is None: x_vars = numeric_cols + if y_vars is None: y_vars = numeric_cols if np.isscalar(x_vars): @@ -1269,10 +1161,15 @@ def __init__(self, data, hue=None, hue_order=None, palette=None, if np.isscalar(y_vars): y_vars = [y_vars] - self.x_vars = list(x_vars) - self.y_vars = list(y_vars) + self.x_vars = x_vars = list(x_vars) + self.y_vars = y_vars = list(y_vars) self.square_grid = self.x_vars == self.y_vars + if not x_vars: + raise ValueError("No variables found for grid columns.") + if not y_vars: + raise ValueError("No variables found for grid rows.") + # Create the figure and the array of subplots figsize = len(x_vars) * height * aspect, len(y_vars) * height @@ -1290,10 +1187,7 @@ def __init__(self, data, hue=None, hue_order=None, palette=None, if corner: hide_indices = np.triu_indices_from(axes, 1) for i, j in zip(*hide_indices): - try: - axes[i, j].remove() - except NotImplementedError: # Problem on old matplotlibs? - axes[i, j].set_axis_off() + axes[i, j].remove() axes[i, j] = None self.fig = fig @@ -1313,11 +1207,19 @@ def __init__(self, data, hue=None, hue_order=None, palette=None, # Sort out the hue variable self._hue_var = hue if hue is None: - self.hue_names = ["_nolegend_"] + self.hue_names = hue_order = ["_nolegend_"] self.hue_vals = pd.Series(["_nolegend_"] * len(data), index=data.index) else: - hue_names = utils.categorical_order(data[hue], hue_order) + # We need hue_order and hue_names because the former is used to control + # the order of drawing and the latter is used to control the order of + # the legend. hue_names can become string-typed while hue_order must + # retain the type of the input data. This is messy but results from + # the fact that PairGrid can implement the hue-mapping logic itself + # (and was originally written exclusively that way) but now can delegate + # to the axes-level functions, while always handling legend creation. + # See GH2307 + hue_names = hue_order = categorical_order(data[hue], hue_order) if dropna: # Filter NA from the list of unique hue names hue_names = list(filter(pd.notnull, hue_names)) @@ -1327,14 +1229,18 @@ def __init__(self, data, hue=None, hue_order=None, palette=None, # Additional dict of kwarg -> list of values for mapping the hue var self.hue_kws = hue_kws if hue_kws is not None else {} + self._orig_palette = palette + self._hue_order = hue_order self.palette = self._get_palette(data, hue, hue_order, palette) self._legend_data = {} # Make the plot look nice + self._tight_layout_rect = [.01, .01, .99, .99] + self._tight_layout_pad = layout_pad + self._despine = despine if despine: - self._despine = True utils.despine(fig=fig) - fig.tight_layout(pad=layout_pad) + self.tight_layout(pad=layout_pad) def map(self, func, **kwargs): """Plot with the same function in every subplot. @@ -1350,6 +1256,7 @@ def map(self, func, **kwargs): row_indices, col_indices = np.indices(self.axes.shape) indices = zip(row_indices.flat, col_indices.flat) self._map_bivariate(func, indices, **kwargs) + return self def map_lower(self, func, **kwargs): @@ -1393,10 +1300,17 @@ def map_offdiag(self, func, **kwargs): called ``color`` and ``label``. """ - - self.map_lower(func, **kwargs) - if not self._corner: - self.map_upper(func, **kwargs) + if self.square_grid: + self.map_lower(func, **kwargs) + if not self._corner: + self.map_upper(func, **kwargs) + else: + indices = [] + for i, (y_var) in enumerate(self.y_vars): + for j, (x_var) in enumerate(self.x_vars): + if x_var != y_var: + indices.append((i, j)) + self._map_bivariate(func, indices, **kwargs) return self def map_diag(self, func, **kwargs): @@ -1439,32 +1353,72 @@ def map_diag(self, func, **kwargs): # TODO add optional density ticks (on the right) # when drawing a corner plot? - if self.diag_sharey: + if self.diag_sharey and diag_axes: # This may change in future matplotlibs # See https://github.com/matplotlib/matplotlib/pull/9923 group = diag_axes[0].get_shared_y_axes() for ax in diag_axes[1:]: group.join(ax, diag_axes[0]) - self.diag_vars = np.array(diag_vars, np.object) - self.diag_axes = np.array(diag_axes, np.object) + self.diag_vars = np.array(diag_vars, np.object_) + self.diag_axes = np.array(diag_axes, np.object_) + + if "hue" not in signature(func).parameters: + return self._map_diag_iter_hue(func, **kwargs) + # Loop over diagonal variables and axes, making one plot in each + for var, ax in zip(self.diag_vars, self.diag_axes): + + plot_kwargs = kwargs.copy() + if str(func.__module__).startswith("seaborn"): + plot_kwargs["ax"] = ax + else: + plt.sca(ax) + + vector = self.data[var] + if self._hue_var is not None: + hue = self.data[self._hue_var] + else: + hue = None + + if self._dropna: + not_na = vector.notna() + if hue is not None: + not_na &= hue.notna() + vector = vector[not_na] + if hue is not None: + hue = hue[not_na] + + plot_kwargs.setdefault("hue", hue) + plot_kwargs.setdefault("hue_order", self._hue_order) + plot_kwargs.setdefault("palette", self._orig_palette) + func(x=vector, **plot_kwargs) + self._clean_axis(ax) + + self._add_axis_labels() + return self + + def _map_diag_iter_hue(self, func, **kwargs): + """Put marginal plot on each diagonal axes, iterating over hue.""" # Plot on each of the diagonal axes fixed_color = kwargs.pop("color", None) for var, ax in zip(self.diag_vars, self.diag_axes): hue_grouped = self.data[var].groupby(self.hue_vals) - plt.sca(ax) + plot_kwargs = kwargs.copy() + if str(func.__module__).startswith("seaborn"): + plot_kwargs["ax"] = ax + else: + plt.sca(ax) - for k, label_k in enumerate(self.hue_names): + for k, label_k in enumerate(self._hue_order): # Attempt to get data for this level, allowing for empty try: - # TODO newer matplotlib(?) doesn't need array for hist - data_k = np.asarray(hue_grouped.get_group(label_k)) + data_k = hue_grouped.get_group(label_k) except KeyError: - data_k = np.array([]) + data_k = pd.Series([], dtype=float) if fixed_color is None: color = self.palette[k] @@ -1474,7 +1428,10 @@ def map_diag(self, func, **kwargs): if self._dropna: data_k = utils.remove_na(data_k) - func(data_k, label=label_k, color=color, **kwargs) + if str(func.__module__).startswith("seaborn"): + func(x=data_k, label=label_k, color=color, **plot_kwargs) + else: + func(data_k, label=label_k, color=color, **plot_kwargs) self._clean_axis(ax) @@ -1484,31 +1441,89 @@ def map_diag(self, func, **kwargs): def _map_bivariate(self, func, indices, **kwargs): """Draw a bivariate plot on the indicated axes.""" + # This is a hack to handle the fact that new distribution plots don't add + # their artists onto the axes. This is probably superior in general, but + # we'll need a better way to handle it in the axisgrid functions. + from .distributions import histplot, kdeplot + if func is histplot or func is kdeplot: + self._extract_legend_handles = True + kws = kwargs.copy() # Use copy as we insert other kwargs - kw_color = kws.pop("color", None) for i, j in indices: x_var = self.x_vars[j] y_var = self.y_vars[i] ax = self.axes[i, j] - self._plot_bivariate(x_var, y_var, ax, func, kw_color, **kws) + if ax is None: # i.e. we are in corner mode + continue + self._plot_bivariate(x_var, y_var, ax, func, **kws) self._add_axis_labels() - def _plot_bivariate(self, x_var, y_var, ax, func, kw_color, **kwargs): + if "hue" in signature(func).parameters: + self.hue_names = list(self._legend_data) + + def _plot_bivariate(self, x_var, y_var, ax, func, **kwargs): """Draw a bivariate plot on the specified axes.""" - plt.sca(ax) + if "hue" not in signature(func).parameters: + self._plot_bivariate_iter_hue(x_var, y_var, ax, func, **kwargs) + return + + kwargs = kwargs.copy() + if str(func.__module__).startswith("seaborn"): + kwargs["ax"] = ax + else: + plt.sca(ax) + if x_var == y_var: axes_vars = [x_var] else: axes_vars = [x_var, y_var] + + if self._hue_var is not None and self._hue_var not in axes_vars: + axes_vars.append(self._hue_var) + + data = self.data[axes_vars] + if self._dropna: + data = data.dropna() + + x = data[x_var] + y = data[y_var] + if self._hue_var is None: + hue = None + else: + hue = data.get(self._hue_var) + + kwargs.setdefault("hue", hue) + kwargs.setdefault("hue_order", self._hue_order) + kwargs.setdefault("palette", self._orig_palette) + func(x=x, y=y, **kwargs) + + self._update_legend_data(ax) + self._clean_axis(ax) + + def _plot_bivariate_iter_hue(self, x_var, y_var, ax, func, **kwargs): + """Draw a bivariate plot while iterating over hue subsets.""" + kwargs = kwargs.copy() + if str(func.__module__).startswith("seaborn"): + kwargs["ax"] = ax + else: + plt.sca(ax) + + if x_var == y_var: + axes_vars = [x_var] + else: + axes_vars = [x_var, y_var] + hue_grouped = self.data.groupby(self.hue_vals) - for k, label_k in enumerate(self.hue_names): + for k, label_k in enumerate(self._hue_order): + + kws = kwargs.copy() # Attempt to get data for this level, allowing for empty try: data_k = hue_grouped.get_group(label_k) except KeyError: data_k = pd.DataFrame(columns=axes_vars, - dtype=np.float) + dtype=float) if self._dropna: data_k = data_k[axes_vars].dropna() @@ -1517,13 +1532,18 @@ def _plot_bivariate(self, x_var, y_var, ax, func, kw_color, **kwargs): y = data_k[y_var] for kw, val_list in self.hue_kws.items(): - kwargs[kw] = val_list[k] - color = self.palette[k] if kw_color is None else kw_color + kws[kw] = val_list[k] + kws.setdefault("color", self.palette[k]) + if self._hue_var is not None: + kws["label"] = label_k - func(x, y, label=label_k, color=color, **kwargs) + if str(func.__module__).startswith("seaborn"): + func(x=x, y=y, **kws) + else: + func(x, y, **kws) - self._clean_axis(ax) self._update_legend_data(ax) + self._clean_axis(ax) def _add_axis_labels(self): """Add labels to the left and bottom Axes.""" @@ -1536,148 +1556,30 @@ def _add_axis_labels(self): def _find_numeric_cols(self, data): """Find which variables in a DataFrame are numeric.""" - # This can't be the best way to do this, but I do not - # know what the best way might be, so this seems ok numeric_cols = [] for col in data: - try: - data[col].astype(np.float) + if variable_type(data[col]) == "numeric": numeric_cols.append(col) - except (ValueError, TypeError): - pass return numeric_cols class JointGrid(object): - """Grid for drawing a bivariate plot with marginal univariate plots.""" - - def __init__(self, x, y, data=None, height=6, ratio=5, space=.2, - dropna=True, xlim=None, ylim=None, size=None): - """Set up the grid of subplots. - - Parameters - ---------- - x, y : strings or vectors - Data or names of variables in ``data``. - data : DataFrame, optional - DataFrame when ``x`` and ``y`` are variable names. - height : numeric - Size of each side of the figure in inches (it will be square). - ratio : numeric - Ratio of joint axes size to marginal axes height. - space : numeric, optional - Space between the joint and marginal axes - dropna : bool, optional - If True, remove observations that are missing from `x` and `y`. - {x, y}lim : two-tuples, optional - Axis limits to set before plotting. + """Grid for drawing a bivariate plot with marginal univariate plots. - See Also - -------- - jointplot : High-level interface for drawing bivariate plots with - several different default plot kinds. - - Examples - -------- - - Initialize the figure but don't draw any plots onto it: - - .. plot:: - :context: close-figs - - >>> import seaborn as sns; sns.set(style="ticks", color_codes=True) - >>> tips = sns.load_dataset("tips") - >>> g = sns.JointGrid(x="total_bill", y="tip", data=tips) - - Add plots using default parameters: - - .. plot:: - :context: close-figs - - >>> g = sns.JointGrid(x="total_bill", y="tip", data=tips) - >>> g = g.plot(sns.regplot, sns.distplot) - - Draw the join and marginal plots separately, which allows finer-level - control other parameters: - - .. plot:: - :context: close-figs - - >>> import matplotlib.pyplot as plt - >>> g = sns.JointGrid(x="total_bill", y="tip", data=tips) - >>> g = g.plot_joint(plt.scatter, color=".5", edgecolor="white") - >>> g = g.plot_marginals(sns.distplot, kde=False, color=".5") - - Draw the two marginal plots separately: - - .. plot:: - :context: close-figs - - >>> import numpy as np - >>> g = sns.JointGrid(x="total_bill", y="tip", data=tips) - >>> g = g.plot_joint(plt.scatter, color="m", edgecolor="white") - >>> _ = g.ax_marg_x.hist(tips["total_bill"], color="b", alpha=.6, - ... bins=np.arange(0, 60, 5)) - >>> _ = g.ax_marg_y.hist(tips["tip"], color="r", alpha=.6, - ... orientation="horizontal", - ... bins=np.arange(0, 12, 1)) - - Add an annotation with a statistic summarizing the bivariate - relationship: - - .. plot:: - :context: close-figs - - >>> from scipy import stats - >>> g = sns.JointGrid(x="total_bill", y="tip", data=tips) - >>> g = g.plot_joint(plt.scatter, - ... color="g", s=40, edgecolor="white") - >>> g = g.plot_marginals(sns.distplot, kde=False, color="g") - >>> g = g.annotate(stats.pearsonr) - - Use a custom function and formatting for the annotation - - .. plot:: - :context: close-figs - - >>> g = sns.JointGrid(x="total_bill", y="tip", data=tips) - >>> g = g.plot_joint(plt.scatter, - ... color="g", s=40, edgecolor="white") - >>> g = g.plot_marginals(sns.distplot, kde=False, color="g") - >>> rsquare = lambda a, b: stats.pearsonr(a, b)[0] ** 2 - >>> g = g.annotate(rsquare, template="{stat}: {val:.2f}", - ... stat="$R^2$", loc="upper left", fontsize=12) - - Remove the space between the joint and marginal axes: - - .. plot:: - :context: close-figs + Many plots can be drawn by using the figure-level interface :func:`jointplot`. + Use this class directly when you need more flexibility. - >>> g = sns.JointGrid(x="total_bill", y="tip", data=tips, space=0) - >>> g = g.plot_joint(sns.kdeplot, cmap="Blues_d") - >>> g = g.plot_marginals(sns.kdeplot, shade=True) - - Draw a smaller plot with relatively larger marginal axes: - - .. plot:: - :context: close-figs - - >>> g = sns.JointGrid(x="total_bill", y="tip", data=tips, - ... height=5, ratio=2) - >>> g = g.plot_joint(sns.kdeplot, cmap="Reds_d") - >>> g = g.plot_marginals(sns.kdeplot, color="r", shade=True) - - Set limits on the axes: - - .. plot:: - :context: close-figs - - >>> g = sns.JointGrid(x="total_bill", y="tip", data=tips, - ... xlim=(0, 50), ylim=(0, 8)) - >>> g = g.plot_joint(sns.kdeplot, cmap="Purples_d") - >>> g = g.plot_marginals(sns.kdeplot, color="m", shade=True) + """ - """ + @_deprecate_positional_args + def __init__( + self, *, + x=None, y=None, + data=None, + height=6, ratio=5, space=.2, + dropna=False, xlim=None, ylim=None, size=None, marginal_ticks=False, + hue=None, palette=None, hue_order=None, hue_norm=None, + ): # Handle deprecations if size is not None: height = size @@ -1701,213 +1603,211 @@ def __init__(self, x, y, data=None, height=6, ratio=5, space=.2, # Turn off tick visibility for the measure axis on the marginal plots plt.setp(ax_marg_x.get_xticklabels(), visible=False) plt.setp(ax_marg_y.get_yticklabels(), visible=False) + plt.setp(ax_marg_x.get_xticklabels(minor=True), visible=False) + plt.setp(ax_marg_y.get_yticklabels(minor=True), visible=False) # Turn off the ticks on the density axis for the marginal plots - plt.setp(ax_marg_x.yaxis.get_majorticklines(), visible=False) - plt.setp(ax_marg_x.yaxis.get_minorticklines(), visible=False) - plt.setp(ax_marg_y.xaxis.get_majorticklines(), visible=False) - plt.setp(ax_marg_y.xaxis.get_minorticklines(), visible=False) - plt.setp(ax_marg_x.get_yticklabels(), visible=False) - plt.setp(ax_marg_y.get_xticklabels(), visible=False) - ax_marg_x.yaxis.grid(False) - ax_marg_y.xaxis.grid(False) - - # Possibly extract the variables from a DataFrame - if data is not None: - x = data.get(x, x) - y = data.get(y, y) - - for var in [x, y]: - if isinstance(var, string_types): - err = "Could not interpret input '{}'".format(var) - raise ValueError(err) - - # Find the names of the variables - if hasattr(x, "name"): - xlabel = x.name - ax_joint.set_xlabel(xlabel) - if hasattr(y, "name"): - ylabel = y.name - ax_joint.set_ylabel(ylabel) - - # Convert the x and y data to arrays for indexing and plotting - x_array = np.asarray(x) - y_array = np.asarray(y) + if not marginal_ticks: + plt.setp(ax_marg_x.yaxis.get_majorticklines(), visible=False) + plt.setp(ax_marg_x.yaxis.get_minorticklines(), visible=False) + plt.setp(ax_marg_y.xaxis.get_majorticklines(), visible=False) + plt.setp(ax_marg_y.xaxis.get_minorticklines(), visible=False) + plt.setp(ax_marg_x.get_yticklabels(), visible=False) + plt.setp(ax_marg_y.get_xticklabels(), visible=False) + plt.setp(ax_marg_x.get_yticklabels(minor=True), visible=False) + plt.setp(ax_marg_y.get_xticklabels(minor=True), visible=False) + ax_marg_x.yaxis.grid(False) + ax_marg_y.xaxis.grid(False) + + # Process the input variables + p = VectorPlotter(data=data, variables=dict(x=x, y=y, hue=hue)) + plot_data = p.plot_data.loc[:, p.plot_data.notna().any()] # Possibly drop NA if dropna: - not_na = pd.notnull(x_array) & pd.notnull(y_array) - x_array = x_array[not_na] - y_array = y_array[not_na] + plot_data = plot_data.dropna() + + def get_var(var): + vector = plot_data.get(var, None) + if vector is not None: + vector = vector.rename(p.variables.get(var, None)) + return vector + + self.x = get_var("x") + self.y = get_var("y") + self.hue = get_var("hue") - self.x = x_array - self.y = y_array + for axis in "xy": + name = p.variables.get(axis, None) + if name is not None: + getattr(ax_joint, f"set_{axis}label")(name) if xlim is not None: ax_joint.set_xlim(xlim) if ylim is not None: ax_joint.set_ylim(ylim) + # Store the semantic mapping parameters for axes-level functions + self._hue_params = dict(palette=palette, hue_order=hue_order, hue_norm=hue_norm) + # Make the grid look nice utils.despine(f) - utils.despine(ax=ax_marg_x, left=True) - utils.despine(ax=ax_marg_y, bottom=True) + if not marginal_ticks: + utils.despine(ax=ax_marg_x, left=True) + utils.despine(ax=ax_marg_y, bottom=True) + for axes in [ax_marg_x, ax_marg_y]: + for axis in [axes.xaxis, axes.yaxis]: + axis.label.set_visible(False) f.tight_layout() f.subplots_adjust(hspace=space, wspace=space) - def plot(self, joint_func, marginal_func, annot_func=None): - """Shortcut to draw the full plot. + def _inject_kwargs(self, func, kws, params): + """Add params to kws if they are accepted by func.""" + func_params = signature(func).parameters + for key, val in params.items(): + if key in func_params: + kws.setdefault(key, val) + + def plot(self, joint_func, marginal_func, **kwargs): + """Draw the plot by passing functions for joint and marginal axes. - Use `plot_joint` and `plot_marginals` directly for more control. + This method passes the ``kwargs`` dictionary to both functions. If you + need more control, call :meth:`JointGrid.plot_joint` and + :meth:`JointGrid.plot_marginals` directly with specific parameters. Parameters ---------- - joint_func, marginal_func: callables - Functions to draw the bivariate and univariate plots. + joint_func, marginal_func : callables + Functions to draw the bivariate and univariate plots. See methods + referenced above for information about the required characteristics + of these functions. + kwargs + Additional keyword arguments are passed to both functions. Returns ------- - self : JointGrid instance - Returns `self`. + :class:`JointGrid` instance + Returns ``self`` for easy method chaining. """ - self.plot_marginals(marginal_func) - self.plot_joint(joint_func) - if annot_func is not None: - self.annotate(annot_func) + self.plot_marginals(marginal_func, **kwargs) + self.plot_joint(joint_func, **kwargs) return self def plot_joint(self, func, **kwargs): - """Draw a bivariate plot of `x` and `y`. + """Draw a bivariate plot on the joint axes of the grid. Parameters ---------- func : plotting callable - This must take two 1d arrays of data as the first two + If a seaborn function, it should accept ``x`` and ``y``. Otherwise, + it must accept ``x`` and ``y`` vectors of data as the first two positional arguments, and it must plot on the "current" axes. - kwargs : key, value mappings + If ``hue`` was defined in the class constructor, the function must + accept ``hue`` as a parameter. + kwargs Keyword argument are passed to the plotting function. Returns ------- - self : JointGrid instance - Returns `self`. + :class:`JointGrid` instance + Returns ``self`` for easy method chaining. """ - plt.sca(self.ax_joint) - func(self.x, self.y, **kwargs) + kwargs = kwargs.copy() + if str(func.__module__).startswith("seaborn"): + kwargs["ax"] = self.ax_joint + else: + plt.sca(self.ax_joint) + if self.hue is not None: + kwargs["hue"] = self.hue + self._inject_kwargs(func, kwargs, self._hue_params) + + if str(func.__module__).startswith("seaborn"): + func(x=self.x, y=self.y, **kwargs) + else: + func(self.x, self.y, **kwargs) return self def plot_marginals(self, func, **kwargs): - """Draw univariate plots for `x` and `y` separately. + """Draw univariate plots on each marginal axes. Parameters ---------- func : plotting callable - This must take a 1d array of data as the first positional - argument, it must plot on the "current" axes, and it must - accept a "vertical" keyword argument to orient the measure - dimension of the plot vertically. - kwargs : key, value mappings + If a seaborn function, it should accept ``x`` and ``y`` and plot + when only one of them is defined. Otherwise, it must accept a vector + of data as the first positional argument and determine its orientation + using the ``vertical`` parameter, and it must plot on the "current" axes. + If ``hue`` was defined in the class constructor, it must accept ``hue`` + as a parameter. + kwargs Keyword argument are passed to the plotting function. Returns ------- - self : JointGrid instance - Returns `self`. - - """ - kwargs["vertical"] = False - plt.sca(self.ax_marg_x) - func(self.x, **kwargs) - - kwargs["vertical"] = True - plt.sca(self.ax_marg_y) - func(self.y, **kwargs) - - return self - - def annotate(self, func, template=None, stat=None, loc="best", **kwargs): - """Annotate the plot with a statistic about the relationship. - - *Deprecated and will be removed in a future version*. - - Parameters - ---------- - func : callable - Statistical function that maps the x, y vectors either to (val, p) - or to val. - template : string format template, optional - The template must have the format keys "stat" and "val"; - if `func` returns a p value, it should also have the key "p". - stat : string, optional - Name to use for the statistic in the annotation, by default it - uses the name of `func`. - loc : string or int, optional - Matplotlib legend location code; used to place the annotation. - kwargs : key, value mappings - Other keyword arguments are passed to `ax.legend`, which formats - the annotation. - - Returns - ------- - self : JointGrid instance. - Returns `self`. + :class:`JointGrid` instance + Returns ``self`` for easy method chaining. """ - msg = ("JointGrid annotation is deprecated and will be removed " - "in a future release.") - warnings.warn(UserWarning(msg)) - - default_template = "{stat} = {val:.2g}; p = {p:.2g}" - - # Call the function and determine the form of the return value(s) - out = func(self.x, self.y) - try: - val, p = out - except TypeError: - val, p = out, None - default_template, _ = default_template.split(";") - - # Set the default template - if template is None: - template = default_template - - # Default to name of the function - if stat is None: - stat = func.__name__ + seaborn_func = ( + str(func.__module__).startswith("seaborn") + # deprecated distplot has a legacy API, special case it + and not func.__name__ == "distplot" + ) + func_params = signature(func).parameters + kwargs = kwargs.copy() + if self.hue is not None: + kwargs["hue"] = self.hue + self._inject_kwargs(func, kwargs, self._hue_params) + + if "legend" in func_params: + kwargs.setdefault("legend", False) + + if "orientation" in func_params: + # e.g. plt.hist + orient_kw_x = {"orientation": "vertical"} + orient_kw_y = {"orientation": "horizontal"} + elif "vertical" in func_params: + # e.g. sns.distplot (also how did this get backwards?) + orient_kw_x = {"vertical": False} + orient_kw_y = {"vertical": True} + + if seaborn_func: + func(x=self.x, ax=self.ax_marg_x, **kwargs) + else: + plt.sca(self.ax_marg_x) + func(self.x, **orient_kw_x, **kwargs) - # Format the annotation - if p is None: - annotation = template.format(stat=stat, val=val) + if seaborn_func: + func(y=self.y, ax=self.ax_marg_y, **kwargs) else: - annotation = template.format(stat=stat, val=val, p=p) + plt.sca(self.ax_marg_y) + func(self.y, **orient_kw_y, **kwargs) - # Draw an invisible plot and use the legend to draw the annotation - # This is a bit of a hack, but `loc=best` works nicely and is not - # easily abstracted. - phantom, = self.ax_joint.plot(self.x, self.y, linestyle="", alpha=0) - self.ax_joint.legend([phantom], [annotation], loc=loc, **kwargs) - phantom.remove() + self.ax_marg_x.yaxis.get_label().set_visible(False) + self.ax_marg_y.xaxis.get_label().set_visible(False) return self def set_axis_labels(self, xlabel="", ylabel="", **kwargs): - """Set the axis labels on the bivariate axes. + """Set axis labels on the bivariate axes. Parameters ---------- xlabel, ylabel : strings Label names for the x and y variables. kwargs : key, value mappings - Other keyword arguments are passed to the set_xlabel or - set_ylabel. + Other keyword arguments are passed to the following functions: + + - :meth:`matplotlib.axes.Axes.set_xlabel` + - :meth:`matplotlib.axes.Axes.set_ylabel` Returns ------- - self : JointGrid instance - returns `self` + :class:`JointGrid` instance + Returns ``self`` for easy method chaining. """ self.ax_joint.set_xlabel(xlabel, **kwargs) @@ -1915,23 +1815,75 @@ def set_axis_labels(self, xlabel="", ylabel="", **kwargs): return self def savefig(self, *args, **kwargs): - """Wrap figure.savefig defaulting to tight bounding box.""" + """Save the figure using a "tight" bounding box by default. + + Wraps :meth:`matplotlib.figure.Figure.savefig`. + + """ kwargs.setdefault("bbox_inches", "tight") self.fig.savefig(*args, **kwargs) -def pairplot(data, hue=None, hue_order=None, palette=None, - vars=None, x_vars=None, y_vars=None, - kind="scatter", diag_kind="auto", markers=None, - height=2.5, aspect=1, corner=False, dropna=True, - plot_kws=None, diag_kws=None, grid_kws=None, size=None): +JointGrid.__init__.__doc__ = """\ +Set up the grid of subplots and store data internally for easy plotting. + +Parameters +---------- +{params.core.xy} +{params.core.data} +height : number + Size of each side of the figure in inches (it will be square). +ratio : number + Ratio of joint axes height to marginal axes height. +space : number + Space between the joint and marginal axes +dropna : bool + If True, remove missing observations before plotting. +{{x, y}}lim : pairs of numbers + Set axis limits to these values before plotting. +marginal_ticks : bool + If False, suppress ticks on the count/density axis of the marginal plots. +{params.core.hue} + Note: unlike in :class:`FacetGrid` or :class:`PairGrid`, the axes-level + functions must support ``hue`` to use it in :class:`JointGrid`. +{params.core.palette} +{params.core.hue_order} +{params.core.hue_norm} + +See Also +-------- +{seealso.jointplot} +{seealso.pairgrid} +{seealso.pairplot} + +Examples +-------- + +.. include:: ../docstrings/JointGrid.rst + +""".format( + params=_param_docs, + returns=_core_docs["returns"], + seealso=_core_docs["seealso"], +) + + +@_deprecate_positional_args +def pairplot( + data, *, + hue=None, hue_order=None, palette=None, + vars=None, x_vars=None, y_vars=None, + kind="scatter", diag_kind="auto", markers=None, + height=2.5, aspect=1, corner=False, dropna=False, + plot_kws=None, diag_kws=None, grid_kws=None, size=None, +): """Plot pairwise relationships in a dataset. By default, this function will create a grid of Axes such that each numeric - variable in ``data`` will by shared in the y-axis across a single row and - in the x-axis across a single column. The diagonal Axes are treated - differently, drawing a plot to show the univariate distribution of the data - for the variable in that column. + variable in ``data`` will by shared across the y-axes across a single row and + the x-axes across a single column. The diagonal plots are treated + differently: a univariate distribution plot is drawn to show the marginal + distribution of the data in each column. It is also possible to show a subset of variables or plot different variables on the rows and columns. @@ -1942,42 +1894,42 @@ def pairplot(data, hue=None, hue_order=None, palette=None, Parameters ---------- - data : DataFrame + data : `pandas.DataFrame` Tidy (long-form) dataframe where each column is a variable and each row is an observation. - hue : string (variable name), optional + hue : name of variable in ``data`` Variable in ``data`` to map plot aspects to different colors. hue_order : list of strings Order for the levels of the hue variable in the palette palette : dict or seaborn color palette Set of colors for mapping the ``hue`` variable. If a dict, keys should be values in the ``hue`` variable. - vars : list of variable names, optional + vars : list of variable names Variables within ``data`` to use, otherwise use every column with a numeric datatype. - {x, y}_vars : lists of variable names, optional + {x, y}_vars : lists of variable names Variables within ``data`` to use separately for the rows and columns of the figure; i.e. to make a non-square plot. - kind : {'scatter', 'reg'}, optional - Kind of plot for the non-identity relationships. - diag_kind : {'auto', 'hist', 'kde', None}, optional - Kind of plot for the diagonal subplots. The default depends on whether - ``"hue"`` is used or not. - markers : single matplotlib marker code or list, optional - Either the marker to use for all datapoints or a list of markers with - a length the same as the number of levels in the hue variable so that + kind : {'scatter', 'kde', 'hist', 'reg'} + Kind of plot to make. + diag_kind : {'auto', 'hist', 'kde', None} + Kind of plot for the diagonal subplots. If 'auto', choose based on + whether or not ``hue`` is used. + markers : single matplotlib marker code or list + Either the marker to use for all scatterplot points or a list of markers + with a length the same as the number of levels in the hue variable so that differently colored points will also have different scatterplot markers. - height : scalar, optional + height : scalar Height (in inches) of each facet. - aspect : scalar, optional + aspect : scalar Aspect * height gives the width (in inches) of each facet. - corner : bool, optional + corner : bool If True, don't add axes to the upper (off-diagonal) triangle of the grid, making this a "corner" plot. - dropna : boolean, optional + dropna : boolean Drop missing values from the data before plotting. - {plot, diag, grid}_kws : dicts, optional + {plot, diag, grid}_kws : dicts Dictionaries of keyword arguments. ``plot_kws`` are passed to the bivariate plotting function, ``diag_kws`` are passed to the univariate plotting function, and ``grid_kws`` are passed to the :class:`PairGrid` @@ -1990,100 +1942,18 @@ def pairplot(data, hue=None, hue_order=None, palette=None, See Also -------- - PairGrid : Subplot grid for more flexible plotting of pairwise - relationships. + PairGrid : Subplot grid for more flexible plotting of pairwise relationships. + JointGrid : Grid for plotting joint and marginal distributions of two variables. Examples -------- - Draw scatterplots for joint relationships and histograms for univariate - distributions: - - .. plot:: - :context: close-figs - - >>> import seaborn as sns; sns.set(style="ticks", color_codes=True) - >>> iris = sns.load_dataset("iris") - >>> g = sns.pairplot(iris) - - Show different levels of a categorical variable by the color of plot - elements: - - .. plot:: - :context: close-figs - - >>> g = sns.pairplot(iris, hue="species") - - Use a different color palette: - - .. plot:: - :context: close-figs - - >>> g = sns.pairplot(iris, hue="species", palette="husl") - - Use different markers for each level of the hue variable: - - .. plot:: - :context: close-figs - - >>> g = sns.pairplot(iris, hue="species", markers=["o", "s", "D"]) - - Plot a subset of variables: - - .. plot:: - :context: close-figs - - >>> g = sns.pairplot(iris, vars=["sepal_width", "sepal_length"]) - - Draw larger plots: - - .. plot:: - :context: close-figs - - >>> g = sns.pairplot(iris, height=3, - ... vars=["sepal_width", "sepal_length"]) - - Plot different variables in the rows and columns: - - .. plot:: - :context: close-figs - - >>> g = sns.pairplot(iris, - ... x_vars=["sepal_width", "sepal_length"], - ... y_vars=["petal_width", "petal_length"]) - - Plot only the lower triangle of bivariate axes: - - .. plot:: - :context: close-figs - - >>> g = sns.pairplot(iris, corner=True) - - Use kernel density estimates for univariate plots: - - .. plot:: - :context: close-figs - - >>> g = sns.pairplot(iris, diag_kind="kde") - - Fit linear regression models to the scatter plots: - - .. plot:: - :context: close-figs - - >>> g = sns.pairplot(iris, kind="reg") - - Pass keyword arguments down to the underlying functions (it may be easier - to use :class:`PairGrid` directly): - - .. plot:: - :context: close-figs - - >>> g = sns.pairplot(iris, diag_kind="kde", markers="+", - ... plot_kws=dict(s=50, edgecolor="b", linewidth=1), - ... diag_kws=dict(shade=True)) + .. include:: ../docstrings/pairplot.rst """ + # Avoid circular import + from .distributions import histplot, kdeplot + # Handle deprecations if size is not None: height = size @@ -2096,12 +1966,16 @@ def pairplot(data, hue=None, hue_order=None, palette=None, "'data' must be pandas DataFrame object, not: {typefound}".format( typefound=type(data))) - if plot_kws is None: - plot_kws = {} - if diag_kws is None: - diag_kws = {} - if grid_kws is None: - grid_kws = {} + plot_kws = {} if plot_kws is None else plot_kws.copy() + diag_kws = {} if diag_kws is None else diag_kws.copy() + grid_kws = {} if grid_kws is None else grid_kws.copy() + + # Resolve "auto" diag kind + if diag_kind == "auto": + if hue is None: + diag_kind = "kde" if kind == "kde" else "hist" + else: + diag_kind = "hist" if kind == "hist" else "kde" # Set up the PairGrid grid_kws.setdefault("diag_sharey", diag_kind == "hist") @@ -2112,32 +1986,36 @@ def pairplot(data, hue=None, hue_order=None, palette=None, # Add the markers here as PairGrid has figured out how many levels of the # hue variable are needed and we don't want to duplicate that process if markers is not None: - if grid.hue_names is None: - n_markers = 1 - else: - n_markers = len(grid.hue_names) - if not isinstance(markers, list): - markers = [markers] * n_markers - if len(markers) != n_markers: - raise ValueError(("markers must be a singleton or a list of " - "markers for each level of the hue variable")) - grid.hue_kws = {"marker": markers} - - # Maybe plot on the diagonal - if diag_kind == "auto": - diag_kind = "hist" if hue is None else "kde" - + if kind == "reg": + # Needed until regplot supports style + if grid.hue_names is None: + n_markers = 1 + else: + n_markers = len(grid.hue_names) + if not isinstance(markers, list): + markers = [markers] * n_markers + if len(markers) != n_markers: + raise ValueError(("markers must be a singleton or a list of " + "markers for each level of the hue variable")) + grid.hue_kws = {"marker": markers} + elif kind == "scatter": + if isinstance(markers, str): + plot_kws["marker"] = markers + elif hue is not None: + plot_kws["style"] = data[hue] + plot_kws["markers"] = markers + + # Draw the marginal plots on the diagonal diag_kws = diag_kws.copy() - if grid.square_grid: - if diag_kind == "hist": - grid.map_diag(plt.hist, **diag_kws) - elif diag_kind == "kde": - diag_kws.setdefault("shade", True) - diag_kws["legend"] = False - grid.map_diag(kdeplot, **diag_kws) + diag_kws.setdefault("legend", False) + if diag_kind == "hist": + grid.map_diag(histplot, **diag_kws) + elif diag_kind == "kde": + diag_kws.setdefault("fill", True) + grid.map_diag(kdeplot, **diag_kws) # Maybe plot on the off-diagonals - if grid.square_grid and diag_kind is not None: + if diag_kind is not None: plotter = grid.map_offdiag else: plotter = grid.map @@ -2148,176 +2026,150 @@ def pairplot(data, hue=None, hue_order=None, palette=None, elif kind == "reg": from .regression import regplot # Avoid circular import plotter(regplot, **plot_kws) + elif kind == "kde": + from .distributions import kdeplot # Avoid circular import + plotter(kdeplot, **plot_kws) + elif kind == "hist": + from .distributions import histplot # Avoid circular import + plotter(histplot, **plot_kws) # Add a legend if hue is not None: grid.add_legend() - return grid - - -def jointplot(x, y, data=None, kind="scatter", stat_func=None, - color=None, height=6, ratio=5, space=.2, - dropna=True, xlim=None, ylim=None, - joint_kws=None, marginal_kws=None, annot_kws=None, **kwargs): - """Draw a plot of two variables with bivariate and univariate graphs. - - This function provides a convenient interface to the :class:`JointGrid` - class, with several canned plot kinds. This is intended to be a fairly - lightweight wrapper; if you need more flexibility, you should use - :class:`JointGrid` directly. - - Parameters - ---------- - x, y : strings or vectors - Data or names of variables in ``data``. - data : DataFrame, optional - DataFrame when ``x`` and ``y`` are variable names. - kind : { "scatter" | "reg" | "resid" | "kde" | "hex" }, optional - Kind of plot to draw. - stat_func : callable or None, optional - *Deprecated* - color : matplotlib color, optional - Color used for the plot elements. - height : numeric, optional - Size of the figure (it will be square). - ratio : numeric, optional - Ratio of joint axes height to marginal axes height. - space : numeric, optional - Space between the joint and marginal axes - dropna : bool, optional - If True, remove observations that are missing from ``x`` and ``y``. - {x, y}lim : two-tuples, optional - Axis limits to set before plotting. - {joint, marginal, annot}_kws : dicts, optional - Additional keyword arguments for the plot components. - kwargs : key, value pairings - Additional keyword arguments are passed to the function used to - draw the plot on the joint Axes, superseding items in the - ``joint_kws`` dictionary. - - Returns - ------- - grid : :class:`JointGrid` - :class:`JointGrid` object with the plot on it. - - See Also - -------- - JointGrid : The Grid class used for drawing this plot. Use it directly if - you need more flexibility. - - Examples - -------- - - Draw a scatterplot with marginal histograms: - - .. plot:: - :context: close-figs - - >>> import numpy as np, pandas as pd; np.random.seed(0) - >>> import seaborn as sns; sns.set(style="white", color_codes=True) - >>> tips = sns.load_dataset("tips") - >>> g = sns.jointplot(x="total_bill", y="tip", data=tips) - - Add regression and kernel density fits: - - .. plot:: - :context: close-figs - - >>> g = sns.jointplot("total_bill", "tip", data=tips, kind="reg") - - Replace the scatterplot with a joint histogram using hexagonal bins: - - .. plot:: - :context: close-figs - - >>> g = sns.jointplot("total_bill", "tip", data=tips, kind="hex") - - Replace the scatterplots and histograms with density estimates and align - the marginal Axes tightly with the joint Axes: - - .. plot:: - :context: close-figs - - >>> iris = sns.load_dataset("iris") - >>> g = sns.jointplot("sepal_width", "petal_length", data=iris, - ... kind="kde", space=0, color="g") - - Draw a scatterplot, then add a joint density estimate: - - .. plot:: - :context: close-figs - - >>> g = (sns.jointplot("sepal_length", "sepal_width", - ... data=iris, color="k") - ... .plot_joint(sns.kdeplot, zorder=0, n_levels=6)) - - Pass vectors in directly without using Pandas, then name the axes: - - .. plot:: - :context: close-figs - - >>> x, y = np.random.randn(2, 300) - >>> g = (sns.jointplot(x, y, kind="hex") - ... .set_axis_labels("x", "y")) - - Draw a smaller figure with more space devoted to the marginal plots: + grid.tight_layout() - .. plot:: - :context: close-figs - - >>> g = sns.jointplot("total_bill", "tip", data=tips, - ... height=5, ratio=3, color="g") - - Pass keyword arguments down to the underlying plots: + return grid - .. plot:: - :context: close-figs - >>> g = sns.jointplot("petal_length", "sepal_length", data=iris, - ... marginal_kws=dict(bins=15, rug=True), - ... annot_kws=dict(stat="r"), - ... s=40, edgecolor="w", linewidth=1) +@_deprecate_positional_args +def jointplot( + *, + x=None, y=None, + data=None, + kind="scatter", color=None, height=6, ratio=5, space=.2, + dropna=False, xlim=None, ylim=None, marginal_ticks=False, + joint_kws=None, marginal_kws=None, + hue=None, palette=None, hue_order=None, hue_norm=None, + **kwargs +): + # Avoid circular imports + from .relational import scatterplot + from .regression import regplot, residplot + from .distributions import histplot, kdeplot, _freedman_diaconis_bins - """ # Handle deprecations if "size" in kwargs: height = kwargs.pop("size") - msg = ("The `size` paramter has been renamed to `height`; " + msg = ("The `size` parameter has been renamed to `height`; " "please update your code.") warnings.warn(msg, UserWarning) # Set up empty default kwarg dicts - if joint_kws is None: - joint_kws = {} + joint_kws = {} if joint_kws is None else joint_kws.copy() joint_kws.update(kwargs) - if marginal_kws is None: - marginal_kws = {} - if annot_kws is None: - annot_kws = {} + marginal_kws = {} if marginal_kws is None else marginal_kws.copy() + + # Handle deprecations of distplot-specific kwargs + distplot_keys = [ + "rug", "fit", "hist_kws", "norm_hist" "hist_kws", "rug_kws", + ] + unused_keys = [] + for key in distplot_keys: + if key in marginal_kws: + unused_keys.append(key) + marginal_kws.pop(key) + if unused_keys and kind != "kde": + msg = ( + "The marginal plotting function has changed to `histplot`," + " which does not accept the following argument(s): {}." + ).format(", ".join(unused_keys)) + warnings.warn(msg, UserWarning) + + # Validate the plot kind + plot_kinds = ["scatter", "hist", "hex", "kde", "reg", "resid"] + _check_argument("kind", plot_kinds, kind) + + # Raise early if using `hue` with a kind that does not support it + if hue is not None and kind in ["hex", "reg", "resid"]: + msg = ( + f"Use of `hue` with `kind='{kind}'` is not currently supported." + ) + raise ValueError(msg) # Make a colormap based off the plot color + # (Currently used only for kind="hex") if color is None: - color = color_palette()[0] + color = "C0" color_rgb = mpl.colors.colorConverter.to_rgb(color) colors = [utils.set_hls_values(color_rgb, l=l) # noqa for l in np.linspace(1, 0, 12)] cmap = blend_palette(colors, as_cmap=True) + # Matplotlib's hexbin plot is not na-robust + if kind == "hex": + dropna = True + # Initialize the JointGrid object - grid = JointGrid(x, y, data, dropna=dropna, - height=height, ratio=ratio, space=space, - xlim=xlim, ylim=ylim) + grid = JointGrid( + data=data, x=x, y=y, hue=hue, + palette=palette, hue_order=hue_order, hue_norm=hue_norm, + dropna=dropna, height=height, ratio=ratio, space=space, + xlim=xlim, ylim=ylim, marginal_ticks=marginal_ticks, + ) + + if grid.hue is not None: + marginal_kws.setdefault("legend", False) # Plot the data using the grid - if kind == "scatter": + if kind.startswith("scatter"): joint_kws.setdefault("color", color) - grid.plot_joint(plt.scatter, **joint_kws) + grid.plot_joint(scatterplot, **joint_kws) + + if grid.hue is None: + marg_func = histplot + else: + marg_func = kdeplot + marginal_kws.setdefault("fill", True) + + marginal_kws.setdefault("color", color) + grid.plot_marginals(marg_func, **marginal_kws) + + elif kind.startswith("hist"): + + # TODO process pair parameters for bins, etc. and pass + # to both jount and marginal plots + + joint_kws.setdefault("color", color) + grid.plot_joint(histplot, **joint_kws) marginal_kws.setdefault("kde", False) marginal_kws.setdefault("color", color) - grid.plot_marginals(distplot, **marginal_kws) + + marg_x_kws = marginal_kws.copy() + marg_y_kws = marginal_kws.copy() + + pair_keys = "bins", "binwidth", "binrange" + for key in pair_keys: + if isinstance(joint_kws.get(key), tuple): + x_val, y_val = joint_kws[key] + marg_x_kws.setdefault(key, x_val) + marg_y_kws.setdefault(key, y_val) + + histplot(data=data, x=x, hue=hue, **marg_x_kws, ax=grid.ax_marg_x) + histplot(data=data, y=y, hue=hue, **marg_y_kws, ax=grid.ax_marg_y) + + elif kind.startswith("kde"): + + joint_kws.setdefault("color", color) + grid.plot_joint(kdeplot, **joint_kws) + + marginal_kws.setdefault("color", color) + if "fill" in joint_kws: + marginal_kws.setdefault("fill", joint_kws["fill"]) + + grid.plot_marginals(kdeplot, **marginal_kws) elif kind.startswith("hex"): @@ -2331,47 +2183,86 @@ def jointplot(x, y, data=None, kind="scatter", stat_func=None, marginal_kws.setdefault("kde", False) marginal_kws.setdefault("color", color) - grid.plot_marginals(distplot, **marginal_kws) - - elif kind.startswith("kde"): - - joint_kws.setdefault("shade", True) - joint_kws.setdefault("cmap", cmap) - grid.plot_joint(kdeplot, **joint_kws) - - marginal_kws.setdefault("shade", True) - marginal_kws.setdefault("color", color) - grid.plot_marginals(kdeplot, **marginal_kws) + grid.plot_marginals(histplot, **marginal_kws) elif kind.startswith("reg"): - from .regression import regplot - marginal_kws.setdefault("color", color) - grid.plot_marginals(distplot, **marginal_kws) + marginal_kws.setdefault("kde", True) + grid.plot_marginals(histplot, **marginal_kws) joint_kws.setdefault("color", color) grid.plot_joint(regplot, **joint_kws) elif kind.startswith("resid"): - from .regression import residplot - joint_kws.setdefault("color", color) grid.plot_joint(residplot, **joint_kws) x, y = grid.ax_joint.collections[0].get_offsets().T marginal_kws.setdefault("color", color) - marginal_kws.setdefault("kde", False) - distplot(x, ax=grid.ax_marg_x, **marginal_kws) - distplot(y, vertical=True, fit=stats.norm, ax=grid.ax_marg_y, - **marginal_kws) - stat_func = None - else: - msg = "kind must be either 'scatter', 'reg', 'resid', 'kde', or 'hex'" - raise ValueError(msg) - - if stat_func is not None: - grid.annotate(stat_func, **annot_kws) + histplot(x=x, hue=hue, ax=grid.ax_marg_x, **marginal_kws) + histplot(y=y, hue=hue, ax=grid.ax_marg_y, **marginal_kws) return grid + + +jointplot.__doc__ = """\ +Draw a plot of two variables with bivariate and univariate graphs. + +This function provides a convenient interface to the :class:`JointGrid` +class, with several canned plot kinds. This is intended to be a fairly +lightweight wrapper; if you need more flexibility, you should use +:class:`JointGrid` directly. + +Parameters +---------- +{params.core.xy} +{params.core.data} +kind : {{ "scatter" | "kde" | "hist" | "hex" | "reg" | "resid" }} + Kind of plot to draw. See the examples for references to the underlying functions. +{params.core.color} +height : numeric + Size of the figure (it will be square). +ratio : numeric + Ratio of joint axes height to marginal axes height. +space : numeric + Space between the joint and marginal axes +dropna : bool + If True, remove observations that are missing from ``x`` and ``y``. +{{x, y}}lim : pairs of numbers + Axis limits to set before plotting. +marginal_ticks : bool + If False, suppress ticks on the count/density axis of the marginal plots. +{{joint, marginal}}_kws : dicts + Additional keyword arguments for the plot components. +{params.core.hue} + Semantic variable that is mapped to determine the color of plot elements. +{params.core.palette} +{params.core.hue_order} +{params.core.hue_norm} +kwargs + Additional keyword arguments are passed to the function used to + draw the plot on the joint Axes, superseding items in the + ``joint_kws`` dictionary. + +Returns +------- +{returns.jointgrid} + +See Also +-------- +{seealso.jointgrid} +{seealso.pairgrid} +{seealso.pairplot} + +Examples +-------- + +.. include:: ../docstrings/jointplot.rst + +""".format( + params=_param_docs, + returns=_core_docs["returns"], + seealso=_core_docs["seealso"], +) diff --git a/seaborn/categorical.py b/seaborn/categorical.py index 849d0bc06b..e818098579 100644 --- a/seaborn/categorical.py +++ b/seaborn/categorical.py @@ -1,37 +1,384 @@ -from __future__ import division from textwrap import dedent -import colorsys +from numbers import Number +import warnings +from colorsys import rgb_to_hls +from functools import partial + import numpy as np -from scipy import stats import pandas as pd +try: + from scipy.stats import gaussian_kde + _no_scipy = False +except ImportError: + from .external.kde import gaussian_kde + _no_scipy = True + import matplotlib as mpl from matplotlib.collections import PatchCollection import matplotlib.patches as Patches import matplotlib.pyplot as plt -import warnings - -from .external.six import string_types -from .external.six.moves import range +from ._core import ( + VectorPlotter, + variable_type, + infer_orient, + categorical_order, +) from . import utils -from .utils import iqr, categorical_order, remove_na +from .utils import remove_na, _normal_quantile_func, _draw_figure, _default_color from .algorithms import bootstrap from .palettes import color_palette, husl_palette, light_palette, dark_palette from .axisgrid import FacetGrid, _facet_docs +from ._decorators import _deprecate_positional_args __all__ = [ "catplot", "factorplot", "stripplot", "swarmplot", - "boxplot", "violinplot", "boxenplot", "lvplot", + "boxplot", "violinplot", "boxenplot", "pointplot", "barplot", "countplot", ] +class _CategoricalPlotterNew(VectorPlotter): + + semantics = "x", "y", "hue", "units" + + wide_structure = {"x": "@columns", "y": "@values", "hue": "@columns"} + flat_structure = {"x": "@index", "y": "@values"} + + def __init__( + self, + data=None, + variables={}, + order=None, + orient=None, + require_numeric=False, + fixed_scale=True, + ): + + super().__init__(data=data, variables=variables) + + # This method takes care of some bookkeeping that is necessary because the + # original categorical plots (prior to the 2021 refactor) had some rules that + # don't fit exactly into the logic of _core. It may be wise to have a second + # round of refactoring that moves the logic deeper, but this will keep things + # relatively sensible for now. + + # The concept of an "orientation" is important to the original categorical + # plots, but there's no provision for it in _core, so we need to do it here. + # Note that it could be useful for the other functions in at least two ways + # (orienting a univariate distribution plot from long-form data and selecting + # the aggregation axis in lineplot), so we may want to eventually refactor it. + self.orient = infer_orient( + x=self.plot_data.get("x", None), + y=self.plot_data.get("y", None), + orient=orient, + require_numeric=require_numeric, + ) + + # Short-circuit in the case of an empty plot + if not self.has_xy_data: + return + + # For wide data, orient determines assignment to x/y differently from the + # wide_structure rules in _core. If we do decide to make orient part of the + # _core variable assignment, we'll want to figure out how to express that. + if self.input_format == "wide" and self.orient == "h": + self.plot_data = self.plot_data.rename(columns={"x": "y", "y": "x"}) + orig_x, orig_x_type = self.variables["x"], self.var_types["x"] + orig_y, orig_y_type = self.variables["y"], self.var_types["y"] + self.variables.update({"x": orig_y, "y": orig_x}) + self.var_types.update({"x": orig_y_type, "y": orig_x_type}) + + def _hue_backcompat(self, color, palette, hue_order, force_hue=False): + """Implement backwards compatability for hue parametrization. + + Note: the force_hue parameter is used so that functions can be shown to + pass existing tests during refactoring and then tested for new behavior. + It can be removed after completion of the work. + + """ + # The original categorical functions applied a palette to the categorical axis + # by default. We want to require an explicit hue mapping, to be more consistent + # with how things work elsewhere now. I don't think there's any good way to + # do this gently -- because it's triggered by the default value of hue=None, + # users would always get a warning, unless we introduce some sentinel "default" + # argument for this change. That's possible, but asking users to set `hue=None` + # on every call is annoying. + # We are keeping the logic for implementing the old behavior in with the current + # system so that (a) we can punt on that decision and (b) we can ensure that + # refactored code passes old tests. + default_behavior = color is None or palette is not None + if force_hue and "hue" not in self.variables and default_behavior: + self._redundant_hue = True + self.plot_data["hue"] = self.plot_data[self.cat_axis] + self.variables["hue"] = self.variables[self.cat_axis] + self.var_types["hue"] = "categorical" + hue_order = self.var_levels[self.cat_axis] + + # Because we convert the categorical axis variable to string, + # we need to update a dictionary palette too + if isinstance(palette, dict): + palette = {str(k): v for k, v in palette.items()} + + else: + self._redundant_hue = False + + # Previously, categorical plots had a trick where color= could seed the palette. + # Because that's an explicit parameterization, we are going to give it one + # release cycle with a warning before removing. + if "hue" in self.variables and palette is None and color is not None: + if not isinstance(color, str): + color = mpl.colors.to_hex(color) + palette = f"dark:{color}" + msg = ( + "Setting a gradient palette using color= is deprecated and will be " + f"removed in version 0.13. Set `palette='{palette}'` for same effect." + ) + warnings.warn(msg, FutureWarning) + + return palette, hue_order + + @property + def cat_axis(self): + return {"v": "x", "h": "y"}[self.orient] + + def _get_gray(self, colors): + """Get a grayscale value that looks good with color.""" + if not len(colors): + return None + unique_colors = np.unique(colors, axis=0) + light_vals = [rgb_to_hls(*rgb[:3])[1] for rgb in unique_colors] + lum = min(light_vals) * .6 + return (lum, lum, lum) + + def _adjust_cat_axis(self, ax, axis): + """Set ticks and limits for a categorical variable.""" + # Note: in theory, this could happen in _attach for all categorical axes + # But two reasons not to do that: + # - If it happens before plotting, autoscaling messes up the plot limits + # - It would change existing plots from other seaborn functions + if self.var_types[axis] != "categorical": + return + + data = self.plot_data[axis] + if self.facets is not None: + share_group = getattr(ax, f"get_shared_{axis}_axes")() + shared_axes = [getattr(ax, f"{axis}axis")] + [ + getattr(other_ax, f"{axis}axis") + for other_ax in self.facets.axes.flat + if share_group.joined(ax, other_ax) + ] + data = data[self.converters[axis].isin(shared_axes)] + + if self._var_ordered[axis]: + order = categorical_order(data, self.var_levels[axis]) + else: + order = categorical_order(data) + + n = max(len(order), 1) + + if axis == "x": + ax.xaxis.grid(False) + ax.set_xlim(-.5, n - .5, auto=None) + else: + ax.yaxis.grid(False) + # Note limits that correspond to previously-inverted y axis + ax.set_ylim(n - .5, -.5, auto=None) + + @property + def _native_width(self): + """Return unit of width separating categories on native numeric scale.""" + unique_values = np.unique(self.comp_data[self.cat_axis]) + if len(unique_values) > 1: + native_width = np.nanmin(np.diff(unique_values)) + else: + native_width = 1 + return native_width + + def _nested_offsets(self, width, dodge): + """Return offsets for each hue level for dodged plots.""" + offsets = None + if "hue" in self.variables: + n_levels = len(self._hue_map.levels) + if dodge: + each_width = width / n_levels + offsets = np.linspace(0, width - each_width, n_levels) + offsets -= offsets.mean() + else: + offsets = np.zeros(n_levels) + return offsets + + # Note that the plotting methods here aim (in most cases) to produce the exact same + # artists as the original version of the code, so there is some weirdness that might + # not otherwise be clean or make sense in this context, such as adding empty artists + # for combinations of variables with no observations + + def plot_strips( + self, + jitter, + dodge, + color, + edgecolor, + plot_kws, + ): + + width = .8 * self._native_width + offsets = self._nested_offsets(width, dodge) + + if jitter is True: + jlim = 0.1 + else: + jlim = float(jitter) + if "hue" in self.variables and dodge: + jlim /= len(self._hue_map.levels) + jlim *= self._native_width + jitterer = partial(np.random.uniform, low=-jlim, high=+jlim) + + iter_vars = [self.cat_axis] + if dodge: + iter_vars.append("hue") + + ax = self.ax + dodge_move = jitter_move = 0 + + for sub_vars, sub_data in self.iter_data(iter_vars, + from_comp_data=True, + allow_empty=True): + + if offsets is not None: + dodge_move = offsets[sub_data["hue"].map(self._hue_map.levels.index)] + + jitter_move = jitterer(size=len(sub_data)) if len(sub_data) > 1 else 0 + + adjusted_data = sub_data[self.cat_axis] + dodge_move + jitter_move + sub_data.loc[:, self.cat_axis] = adjusted_data + + for var in "xy": + if self._log_scaled(var): + sub_data[var] = np.power(10, sub_data[var]) + + ax = self._get_axes(sub_vars) + points = ax.scatter(sub_data["x"], sub_data["y"], color=color, **plot_kws) + + if "hue" in self.variables: + points.set_facecolors(self._hue_map(sub_data["hue"])) + + if edgecolor == "gray": # XXX TODO change to "auto" + points.set_edgecolors(self._get_gray(points.get_facecolors())) + else: + points.set_edgecolors(edgecolor) + + # TODO XXX fully impelement legend + show_legend = not self._redundant_hue and self.input_format != "wide" + if "hue" in self.variables and show_legend: + for level in self._hue_map.levels: + color = self._hue_map(level) + ax.scatter([], [], s=60, color=mpl.colors.rgb2hex(color), label=level) + ax.legend(loc="best", title=self.variables["hue"]) + + def plot_swarms( + self, + dodge, + color, + edgecolor, + warn_thresh, + plot_kws, + ): + + width = .8 * self._native_width + offsets = self._nested_offsets(width, dodge) + + iter_vars = [self.cat_axis] + if dodge: + iter_vars.append("hue") + + ax = self.ax + point_collections = {} + dodge_move = 0 + + for sub_vars, sub_data in self.iter_data(iter_vars, + from_comp_data=True, + allow_empty=True): + + if offsets is not None: + dodge_move = offsets[sub_data["hue"].map(self._hue_map.levels.index)] + + if not sub_data.empty: + sub_data.loc[:, self.cat_axis] = sub_data[self.cat_axis] + dodge_move + + for var in "xy": + if self._log_scaled(var): + sub_data[var] = np.power(10, sub_data[var]) + + ax = self._get_axes(sub_vars) + points = ax.scatter(sub_data["x"], sub_data["y"], color=color, **plot_kws) + + if "hue" in self.variables: + points.set_facecolors(self._hue_map(sub_data["hue"])) + + if edgecolor == "gray": # XXX TODO change to "auto" + points.set_edgecolors(self._get_gray(points.get_facecolors())) + else: + points.set_edgecolors(edgecolor) + + if not sub_data.empty: + point_collections[sub_data[self.cat_axis].iloc[0]] = points + + beeswarm = Beeswarm( + width=width, orient=self.orient, warn_thresh=warn_thresh, + ) + for center, points in point_collections.items(): + if points.get_offsets().shape[0] > 1: + + def draw(points, renderer, *, center=center): + + beeswarm(points, center) + + ax = points.axes + if self.orient == "h": + scalex = False + scaley = ax.get_autoscaley_on() + else: + scalex = ax.get_autoscalex_on() + scaley = False + + # This prevents us from undoing the nice categorical axis limits + # set in _adjust_cat_axis, because that method currently leave + # the autoscale flag in its original setting. It may be better + # to disable autoscaling there to avoid needing to do this. + fixed_scale = self.var_types[self.cat_axis] == "categorical" + + ax.update_datalim(points.get_datalim(ax.transData)) + if not fixed_scale and (scalex or scaley): + ax.autoscale_view(scalex=scalex, scaley=scaley) + + super(points.__class__, points).draw(renderer) + + points.draw = draw.__get__(points) + + _draw_figure(ax.figure) + + # TODO XXX fully impelment legend + show_legend = not self._redundant_hue and self.input_format != "wide" + if "hue" in self.variables and show_legend: # TODO and legend: + for level in self._hue_map.levels: + color = self._hue_map(level) + ax.scatter([], [], s=60, color=mpl.colors.rgb2hex(color), label=level) + ax.legend(loc="best", title=self.variables["hue"]) + + +class _CategoricalFacetPlotter(_CategoricalPlotterNew): + + semantics = _CategoricalPlotterNew.semantics + ("col", "row") + + class _CategoricalPlotter(object): width = .8 default_palette = "light" + require_numeric = True def establish_variables(self, x=None, y=None, hue=None, data=None, orient=None, order=None, hue_order=None, @@ -44,7 +391,7 @@ def establish_variables(self, x=None, y=None, hue=None, data=None, # Do a sanity check on the inputs if hue is not None: - error = "Cannot use `hue` without `x` or `y`" + error = "Cannot use `hue` without `x` and `y`" raise ValueError(error) # No hue grouping with wide inputs @@ -70,18 +417,15 @@ def establish_variables(self, x=None, y=None, hue=None, data=None, order = [] # Reduce to just numeric columns for col in data: - try: - data[col].astype(np.float) + if variable_type(data[col]) == "numeric": order.append(col) - except ValueError: - pass plot_data = data[order] group_names = order group_label = data.columns.name # Convert to a list of arrays, the common representation iter_data = plot_data.iteritems() - plot_data = [np.asarray(s, np.float) for k, s in iter_data] + plot_data = [np.asarray(s, float) for k, s in iter_data] # Option 1b: # The input data is an array or list @@ -127,7 +471,7 @@ def establish_variables(self, x=None, y=None, hue=None, data=None, plot_data = data # Convert to a list of arrays, the common representation - plot_data = [np.asarray(d, np.float) for d in plot_data] + plot_data = [np.asarray(d, float) for d in plot_data] # The group names will just be numeric indices group_names = list(range((len(plot_data)))) @@ -149,13 +493,15 @@ def establish_variables(self, x=None, y=None, hue=None, data=None, units = data.get(units, units) # Validate the inputs - for input in [x, y, hue, units]: - if isinstance(input, string_types): - err = "Could not interpret input '{}'".format(input) + for var in [x, y, hue, units]: + if isinstance(var, str): + err = "Could not interpret input '{}'".format(var) raise ValueError(err) # Figure out the plotting orientation - orient = self.infer_orient(x, y, orient) + orient = infer_orient( + x, y, orient, require_numeric=self.require_numeric + ) # Option 2a: # We are plotting a single set of data @@ -243,14 +589,18 @@ def _group_longform(self, vals, grouper, order): """Group a long-form variable by another with correct order.""" # Ensure that the groupby will work if not isinstance(vals, pd.Series): - vals = pd.Series(vals) + if isinstance(grouper, pd.Series): + index = grouper.index + else: + index = None + vals = pd.Series(vals, index=index) # Group the val data grouped_vals = vals.groupby(grouper) out_data = [] for g in order: try: - g_vals = np.asarray(grouped_vals.get_group(g)) + g_vals = grouped_vals.get_group(g) except KeyError: g_vals = np.array([]) out_data.append(g_vals) @@ -307,11 +657,11 @@ def establish_colors(self, color, palette, saturation): if saturation < 1: colors = color_palette(colors, desat=saturation) - # Conver the colors to a common representations + # Convert the colors to a common representations rgb_colors = color_palette(colors) # Determine the gray color to use for the lines framing the plot - light_vals = [colorsys.rgb_to_hls(*c)[1] for c in rgb_colors] + light_vals = [rgb_to_hls(*c)[1] for c in rgb_colors] lum = min(light_vals) * .6 gray = mpl.colors.rgb2hex((lum, lum, lum)) @@ -319,51 +669,6 @@ def establish_colors(self, color, palette, saturation): self.colors = rgb_colors self.gray = gray - def infer_orient(self, x, y, orient=None): - """Determine how the plot should be oriented based on the data.""" - orient = str(orient) - - def is_categorical(s): - try: - # Correct way, but does not exist in older Pandas - try: - return pd.api.types.is_categorical_dtype(s) - except AttributeError: - return pd.core.common.is_categorical_dtype(s) - except AttributeError: - # Also works, but feels hackier - return str(s.dtype) == "categorical" - - def is_not_numeric(s): - try: - np.asarray(s, dtype=np.float) - except ValueError: - return True - return False - - no_numeric = "Neither the `x` nor `y` variable appears to be numeric." - - if orient.startswith("v"): - return "v" - elif orient.startswith("h"): - return "h" - elif x is None: - return "v" - elif y is None: - return "h" - elif is_categorical(y): - if is_categorical(x): - raise ValueError(no_numeric) - else: - return "h" - elif is_not_numeric(y): - if is_not_numeric(x): - raise ValueError(no_numeric) - else: - return "h" - else: - return "v" - @property def hue_offsets(self): """A list of center positions for plots when hue nesting is used.""" @@ -398,12 +703,16 @@ def annotate_axes(self, ax): if ylabel is not None: ax.set_ylabel(ylabel) + group_names = self.group_names + if not group_names: + group_names = ["" for _ in range(len(self.plot_data))] + if self.orient == "v": ax.set_xticks(np.arange(len(self.plot_data))) - ax.set_xticklabels(self.group_names) + ax.set_xticklabels(group_names) else: ax.set_yticks(np.arange(len(self.plot_data))) - ax.set_yticklabels(self.group_names) + ax.set_yticklabels(group_names) if self.orient == "v": ax.xaxis.grid(False) @@ -413,15 +722,7 @@ def annotate_axes(self, ax): ax.set_ylim(-.5, len(self.plot_data) - .5, auto=None) if self.hue_names is not None: - leg = ax.legend(loc="best") - if self.hue_title is not None: - # Matplotlib rcParams does not expose legend title size? - try: - title_size = mpl.rcParams["axes.labelsize"] * .85 - except TypeError: # labelsize is something like "large" - title_size = mpl.rcParams["axes.labelsize"] - prop = mpl.font_manager.FontProperties(size=title_size) - leg.set_title(self.hue_title, prop=prop) + ax.legend(loc="best", title=self.hue_title) def add_legend_data(self, ax, color, label): """Add a dummy patch object so we can get legend data.""" @@ -468,7 +769,7 @@ def draw_boxplot(self, ax, kws): # Draw a single box or a set of boxes # with a single level of grouping - box_data = remove_na(group_data) + box_data = np.asarray(remove_na(group_data)) # Handle case where there is no non-null data if box_data.size == 0: @@ -496,7 +797,7 @@ def draw_boxplot(self, ax, kws): continue hue_mask = self.plot_hues[i] == hue_level - box_data = remove_na(group_data[hue_mask]) + box_data = np.asarray(remove_na(group_data[hue_mask])) # Handle case where there is no non-null data if box_data.size == 0: @@ -708,15 +1009,7 @@ def estimate_densities(self, bw, cut, scale, scale_hue, gridsize): def fit_kde(self, x, bw): """Estimate a KDE for a vector of data with flexible bandwidth.""" - # Allow for the use of old scipy where `bw` is fixed - try: - kde = stats.gaussian_kde(x, bw) - except TypeError: - kde = stats.gaussian_kde(x) - if bw != "scott": # scipy default - msg = ("Ignoring bandwidth choice, " - "please upgrade scipy to use a different bandwidth.") - warnings.warn(msg, UserWarning) + kde = gaussian_kde(x, bw) # Extract the numeric bandwidth from the KDE object bw_used = kde.factor @@ -996,7 +1289,7 @@ def draw_box_lines(self, ax, data, support, density, center): """Draw boxplot information at center of the density.""" # Compute the boxplot statistics q25, q50, q75 = np.percentile(data, [25, 50, 75]) - whisker_lim = 1.5 * iqr(data) + whisker_lim = 1.5 * (q75 - q25) h1 = np.min(data[data >= (q25 - whisker_lim)]) h2 = np.max(data[data <= (q75 + whisker_lim)]) @@ -1090,364 +1383,10 @@ def plot(self, ax): ax.invert_yaxis() -class _CategoricalScatterPlotter(_CategoricalPlotter): - - default_palette = "dark" - - @property - def point_colors(self): - """Return a color for each scatter point based on group and hue.""" - colors = [] - for i, group_data in enumerate(self.plot_data): - - # Initialize the array for this group level - group_colors = np.empty((group_data.size, 3)) - - if self.plot_hues is None: - - # Use the same color for all points at this level - group_color = self.colors[i] - group_colors[:] = group_color - - else: - - # Color the points based on the hue level - for j, level in enumerate(self.hue_names): - hue_color = self.colors[j] - if group_data.size: - group_colors[self.plot_hues[i] == level] = hue_color - - colors.append(group_colors) - - return colors - - def add_legend_data(self, ax): - """Add empty scatterplot artists with labels for the legend.""" - if self.hue_names is not None: - for rgb, label in zip(self.colors, self.hue_names): - ax.scatter([], [], - color=mpl.colors.rgb2hex(rgb), - label=label, - s=60) - - -class _StripPlotter(_CategoricalScatterPlotter): - """1-d scatterplot with categorical organization.""" - def __init__(self, x, y, hue, data, order, hue_order, - jitter, dodge, orient, color, palette): - """Initialize the plotter.""" - self.establish_variables(x, y, hue, data, orient, order, hue_order) - self.establish_colors(color, palette, 1) - - # Set object attributes - self.dodge = dodge - self.width = .8 - - if jitter == 1: # Use a good default for `jitter = True` - jlim = 0.1 - else: - jlim = float(jitter) - if self.hue_names is not None and dodge: - jlim /= len(self.hue_names) - self.jitterer = stats.uniform(-jlim, jlim * 2).rvs - - def draw_stripplot(self, ax, kws): - """Draw the points onto `ax`.""" - # Set the default zorder to 2.1, so that the points - # will be drawn on top of line elements (like in a boxplot) - for i, group_data in enumerate(self.plot_data): - if self.plot_hues is None or not self.dodge: - - if self.hue_names is None: - hue_mask = np.ones(group_data.size, np.bool) - else: - hue_mask = np.array([h in self.hue_names - for h in self.plot_hues[i]], np.bool) - # Broken on older numpys - # hue_mask = np.in1d(self.plot_hues[i], self.hue_names) - - strip_data = group_data[hue_mask] - - # Plot the points in centered positions - cat_pos = np.ones(strip_data.size) * i - cat_pos += self.jitterer(len(strip_data)) - kws.update(c=self.point_colors[i][hue_mask]) - if self.orient == "v": - ax.scatter(cat_pos, strip_data, **kws) - else: - ax.scatter(strip_data, cat_pos, **kws) - - else: - offsets = self.hue_offsets - for j, hue_level in enumerate(self.hue_names): - hue_mask = self.plot_hues[i] == hue_level - strip_data = group_data[hue_mask] - - # Plot the points in centered positions - center = i + offsets[j] - cat_pos = np.ones(strip_data.size) * center - cat_pos += self.jitterer(len(strip_data)) - kws.update(c=self.point_colors[i][hue_mask]) - if self.orient == "v": - ax.scatter(cat_pos, strip_data, **kws) - else: - ax.scatter(strip_data, cat_pos, **kws) - - def plot(self, ax, kws): - """Make the plot.""" - self.draw_stripplot(ax, kws) - self.add_legend_data(ax) - self.annotate_axes(ax) - if self.orient == "h": - ax.invert_yaxis() - - -class _SwarmPlotter(_CategoricalScatterPlotter): - - def __init__(self, x, y, hue, data, order, hue_order, - dodge, orient, color, palette): - """Initialize the plotter.""" - self.establish_variables(x, y, hue, data, orient, order, hue_order) - self.establish_colors(color, palette, 1) - - # Set object attributes - self.dodge = dodge - self.width = .8 - - def could_overlap(self, xy_i, swarm, d): - """Return a list of all swarm points that could overlap with target. - - Assumes that swarm is a sorted list of all points below xy_i. - """ - _, y_i = xy_i - neighbors = [] - for xy_j in reversed(swarm): - _, y_j = xy_j - if (y_i - y_j) < d: - neighbors.append(xy_j) - else: - break - return np.array(list(reversed(neighbors))) - - def position_candidates(self, xy_i, neighbors, d): - """Return a list of (x, y) coordinates that might be valid.""" - candidates = [xy_i] - x_i, y_i = xy_i - left_first = True - for x_j, y_j in neighbors: - dy = y_i - y_j - dx = np.sqrt(d ** 2 - dy ** 2) * 1.05 - cl, cr = (x_j - dx, y_i), (x_j + dx, y_i) - if left_first: - new_candidates = [cl, cr] - else: - new_candidates = [cr, cl] - candidates.extend(new_candidates) - left_first = not left_first - return np.array(candidates) - - def first_non_overlapping_candidate(self, candidates, neighbors, d): - """Remove candidates from the list if they overlap with the swarm.""" - - # IF we have no neighbours, all candidates are good. - if len(neighbors) == 0: - return candidates[0] - - neighbors_x = neighbors[:, 0] - neighbors_y = neighbors[:, 1] - - d_square = d ** 2 - - for xy_i in candidates: - x_i, y_i = xy_i - - dx = neighbors_x - x_i - dy = neighbors_y - y_i - - sq_distances = np.power(dx, 2.0) + np.power(dy, 2.0) - - # good candidate does not overlap any of neighbors - # which means that squared distance between candidate - # and any of the neighbours has to be at least - # square of the diameter - good_candidate = np.all(sq_distances >= d_square) - - if good_candidate: - return xy_i - - # If `position_candidates` works well - # this should never happen - raise Exception('No non-overlapping candidates found. ' - 'This should not happen.') - - def beeswarm(self, orig_xy, d): - """Adjust x position of points to avoid overlaps.""" - # In this method, ``x`` is always the categorical axis - # Center of the swarm, in point coordinates - midline = orig_xy[0, 0] - - # Start the swarm with the first point - swarm = [orig_xy[0]] - - # Loop over the remaining points - for xy_i in orig_xy[1:]: - - # Find the points in the swarm that could possibly - # overlap with the point we are currently placing - neighbors = self.could_overlap(xy_i, swarm, d) - - # Find positions that would be valid individually - # with respect to each of the swarm neighbors - candidates = self.position_candidates(xy_i, neighbors, d) - - # Sort candidates by their centrality - offsets = np.abs(candidates[:, 0] - midline) - candidates = candidates[np.argsort(offsets)] - - # Find the first candidate that does not overlap any neighbours - new_xy_i = self.first_non_overlapping_candidate(candidates, - neighbors, d) - - # Place it into the swarm - swarm.append(new_xy_i) - - return np.array(swarm) - - def add_gutters(self, points, center, width): - """Stop points from extending beyond their territory.""" - half_width = width / 2 - low_gutter = center - half_width - off_low = points < low_gutter - if off_low.any(): - points[off_low] = low_gutter - high_gutter = center + half_width - off_high = points > high_gutter - if off_high.any(): - points[off_high] = high_gutter - return points - - def swarm_points(self, ax, points, center, width, s, **kws): - """Find new positions on the categorical axis for each point.""" - # Convert from point size (area) to diameter - default_lw = mpl.rcParams["patch.linewidth"] - lw = kws.get("linewidth", kws.get("lw", default_lw)) - dpi = ax.figure.dpi - d = (np.sqrt(s) + lw) * (dpi / 72) - - # Transform the data coordinates to point coordinates. - # We'll figure out the swarm positions in the latter - # and then convert back to data coordinates and replot - orig_xy = ax.transData.transform(points.get_offsets()) - - # Order the variables so that x is the categorical axis - if self.orient == "h": - orig_xy = orig_xy[:, [1, 0]] - - # Do the beeswarm in point coordinates - new_xy = self.beeswarm(orig_xy, d) - - # Transform the point coordinates back to data coordinates - if self.orient == "h": - new_xy = new_xy[:, [1, 0]] - new_x, new_y = ax.transData.inverted().transform(new_xy).T - - # Add gutters - if self.orient == "v": - self.add_gutters(new_x, center, width) - else: - self.add_gutters(new_y, center, width) - - # Reposition the points so they do not overlap - points.set_offsets(np.c_[new_x, new_y]) - - def draw_swarmplot(self, ax, kws): - """Plot the data.""" - s = kws.pop("s") - - centers = [] - swarms = [] - - # Set the categorical axes limits here for the swarm math - if self.orient == "v": - ax.set_xlim(-.5, len(self.plot_data) - .5) - else: - ax.set_ylim(-.5, len(self.plot_data) - .5) - - # Plot each swarm - for i, group_data in enumerate(self.plot_data): - - if self.plot_hues is None or not self.dodge: - - width = self.width - - if self.hue_names is None: - hue_mask = np.ones(group_data.size, np.bool) - else: - hue_mask = np.array([h in self.hue_names - for h in self.plot_hues[i]], np.bool) - # Broken on older numpys - # hue_mask = np.in1d(self.plot_hues[i], self.hue_names) - - swarm_data = group_data[hue_mask] - - # Sort the points for the beeswarm algorithm - sorter = np.argsort(swarm_data) - swarm_data = swarm_data[sorter] - point_colors = self.point_colors[i][hue_mask][sorter] - - # Plot the points in centered positions - cat_pos = np.ones(swarm_data.size) * i - kws.update(c=point_colors) - if self.orient == "v": - points = ax.scatter(cat_pos, swarm_data, s=s, **kws) - else: - points = ax.scatter(swarm_data, cat_pos, s=s, **kws) - - centers.append(i) - swarms.append(points) - - else: - offsets = self.hue_offsets - width = self.nested_width - - for j, hue_level in enumerate(self.hue_names): - hue_mask = self.plot_hues[i] == hue_level - swarm_data = group_data[hue_mask] - - # Sort the points for the beeswarm algorithm - sorter = np.argsort(swarm_data) - swarm_data = swarm_data[sorter] - point_colors = self.point_colors[i][hue_mask][sorter] - - # Plot the points in centered positions - center = i + offsets[j] - cat_pos = np.ones(swarm_data.size) * center - kws.update(c=point_colors) - if self.orient == "v": - points = ax.scatter(cat_pos, swarm_data, s=s, **kws) - else: - points = ax.scatter(swarm_data, cat_pos, s=s, **kws) - - centers.append(center) - swarms.append(points) - - # Update the position of each point on the categorical axis - # Do this after plotting so that the numerical axis limits are correct - for center, swarm in zip(centers, swarms): - if swarm.get_offsets().size: - self.swarm_points(ax, swarm, center, width, s, **kws) - - def plot(self, ax, kws): - """Make the full plot.""" - self.draw_swarmplot(ax, kws) - self.add_legend_data(ax) - self.annotate_axes(ax) - if self.orient == "h": - ax.invert_yaxis() - - class _CategoricalStatPlotter(_CategoricalPlotter): + require_numeric = True + @property def nested_width(self): """A float with the width of plot elements when hue nesting is used.""" @@ -1457,7 +1396,7 @@ def nested_width(self): width = self.width return width - def estimate_statistic(self, estimator, ci, n_boot): + def estimate_statistic(self, estimator, ci, n_boot, seed): if self.hue_names is None: statistic = [] @@ -1510,7 +1449,8 @@ def estimate_statistic(self, estimator, ci, n_boot): boots = bootstrap(stat_data, func=estimator, n_boot=n_boot, - units=unit_data) + units=unit_data, + seed=seed) confint.append(utils.ci(boots, ci)) # Option 2: we are grouping by a hue layer @@ -1533,7 +1473,7 @@ def estimate_statistic(self, estimator, ci, n_boot): group_units = self.plot_units[i] have = pd.notnull( np.c_[group_data, group_units] - ).all(axis=1) + ).all(axis=1) stat_data = group_data[hue_mask & have] unit_data = group_units[hue_mask & have] @@ -1565,7 +1505,8 @@ def estimate_statistic(self, estimator, ci, n_boot): boots = bootstrap(stat_data, func=estimator, n_boot=n_boot, - units=unit_data) + units=unit_data, + seed=seed) confint[i].append(utils.ci(boots, ci)) # Save the resulting values for plotting @@ -1605,14 +1546,14 @@ class _BarPlotter(_CategoricalStatPlotter): """Show point estimates and confidence intervals with bars.""" def __init__(self, x, y, hue, data, order, hue_order, - estimator, ci, n_boot, units, + estimator, ci, n_boot, units, seed, orient, color, palette, saturation, errcolor, errwidth, capsize, dodge): """Initialize the plotter.""" self.establish_variables(x, y, hue, data, orient, order, hue_order, units) self.establish_colors(color, palette, saturation) - self.estimate_statistic(estimator, ci, n_boot) + self.estimate_statistic(estimator, ci, n_boot, seed) self.dodge = dodge @@ -1676,14 +1617,14 @@ class _PointPlotter(_CategoricalStatPlotter): """Show point estimates and confidence intervals with (joined) points.""" def __init__(self, x, y, hue, data, order, hue_order, - estimator, ci, n_boot, units, + estimator, ci, n_boot, units, seed, markers, linestyles, dodge, join, scale, orient, color, palette, errwidth=None, capsize=None): """Initialize the plotter.""" self.establish_variables(x, y, hue, data, orient, order, hue_order, units) self.establish_colors(color, palette, 1) - self.estimate_statistic(estimator, ci, n_boot) + self.estimate_statistic(estimator, ci, n_boot, seed) # Override the default palette for single-color plots if hue is None and color is None and palette is None: @@ -1698,12 +1639,12 @@ def __init__(self, x, y, hue, data, order, hue_order, dodge = .025 * len(self.hue_names) # Make sure we have a marker for each hue level - if isinstance(markers, string_types): + if isinstance(markers, str): markers = [markers] * len(self.colors) self.markers = markers # Make sure we have a line style for each hue level - if isinstance(linestyles, string_types): + if isinstance(linestyles, str): linestyles = [linestyles] * len(self.colors) self.linestyles = linestyles @@ -1753,14 +1694,14 @@ def draw_points(self, ax): # Draw the estimate points marker = self.markers[0] - hex_colors = [mpl.colors.rgb2hex(c) for c in self.colors] + colors = [mpl.colors.colorConverter.to_rgb(c) for c in self.colors] if self.orient == "h": x, y = self.statistic, pointpos else: x, y = pointpos, self.statistic ax.scatter(x, y, linewidth=mew, marker=marker, s=markersize, - c=hex_colors, edgecolor=hex_colors) + facecolor=colors, edgecolor=colors) else: @@ -1796,19 +1737,18 @@ def draw_points(self, ax): # Draw the estimate points n_points = len(remove_na(offpos)) marker = self.markers[j] - hex_color = mpl.colors.rgb2hex(self.colors[j]) - if n_points: - point_colors = [hex_color for _ in range(n_points)] - else: - point_colors = hex_color + color = mpl.colors.colorConverter.to_rgb(self.colors[j]) + if self.orient == "h": x, y = statistic, offpos else: x, y = offpos, statistic + if not len(remove_na(statistic)): - x, y = [], [] + x = y = [np.nan] * n_points + ax.scatter(x, y, label=hue_level, - c=point_colors, edgecolor=point_colors, + facecolor=color, edgecolor=color, linewidth=mew, marker=marker, s=markersize, zorder=z) @@ -1820,72 +1760,86 @@ def plot(self, ax): ax.invert_yaxis() +class _CountPlotter(_BarPlotter): + require_numeric = False + + class _LVPlotter(_CategoricalPlotter): def __init__(self, x, y, hue, data, order, hue_order, orient, color, palette, saturation, - width, dodge, k_depth, linewidth, scale, outlier_prop): + width, dodge, k_depth, linewidth, scale, outlier_prop, + trust_alpha, showfliers=True): - # TODO assigning variables for None is unneccesary - if width is None: - width = .8 self.width = width - self.dodge = dodge - - if saturation is None: - saturation = .75 self.saturation = saturation - if k_depth is None: - k_depth = 'proportion' + k_depth_methods = ['proportion', 'tukey', 'trustworthy', 'full'] + if not (k_depth in k_depth_methods or isinstance(k_depth, Number)): + msg = (f'k_depth must be one of {k_depth_methods} or a number, ' + f'but {k_depth} was passed.') + raise ValueError(msg) self.k_depth = k_depth if linewidth is None: linewidth = mpl.rcParams["lines.linewidth"] self.linewidth = linewidth - if scale is None: - scale = 'exponential' + scales = ['linear', 'exponential', 'area'] + if scale not in scales: + msg = f'scale must be one of {scales}, but {scale} was passed.' + raise ValueError(msg) self.scale = scale + if ((outlier_prop > 1) or (outlier_prop <= 0)): + msg = f'outlier_prop {outlier_prop} not in range (0, 1]' + raise ValueError(msg) self.outlier_prop = outlier_prop + if not 0 < trust_alpha < 1: + msg = f'trust_alpha {trust_alpha} not in range (0, 1)' + raise ValueError(msg) + self.trust_alpha = trust_alpha + + self.showfliers = showfliers + self.establish_variables(x, y, hue, data, orient, order, hue_order) self.establish_colors(color, palette, saturation) - def _lv_box_ends(self, vals, k_depth='proportion', outlier_prop=None): + def _lv_box_ends(self, vals): """Get the number of data points and calculate `depth` of letter-value plot.""" vals = np.asarray(vals) - vals = vals[np.isfinite(vals)] + # Remove infinite values while handling a 'object' dtype + # that can come from pd.Float64Dtype() input + with pd.option_context('mode.use_inf_as_null', True): + vals = vals[~pd.isnull(vals)] n = len(vals) - # If p is not set, calculate it so that 8 points are outliers - if not outlier_prop: - # Conventional boxplots assume this proportion of the data are - # outliers. - p = 0.007 - else: - if ((outlier_prop > 1.) or (outlier_prop < 0.)): - raise ValueError('outlier_prop not in range [0, 1]!') - p = outlier_prop + p = self.outlier_prop + # Select the depth, i.e. number of boxes to draw, based on the method - k_dict = {'proportion': (np.log2(n)) - int(np.log2(n*p)) + 1, - 'tukey': (np.log2(n)) - 3, - 'trustworthy': (np.log2(n) - - np.log2(2*stats.norm.ppf((1-p))**2)) + 1} - k = k_dict[k_depth] - try: - k = int(k) - except ValueError: - k = 1 - # If the number happens to be less than 0, set k to 0 - if k < 1.: + if self.k_depth == 'full': + # extend boxes to 100% of the data + k = int(np.log2(n)) + 1 + elif self.k_depth == 'tukey': + # This results with 5-8 points in each tail + k = int(np.log2(n)) - 3 + elif self.k_depth == 'proportion': + k = int(np.log2(n)) - int(np.log2(n * p)) + 1 + elif self.k_depth == 'trustworthy': + point_conf = 2 * _normal_quantile_func((1 - self.trust_alpha / 2)) ** 2 + k = int(np.log2(n / point_conf)) + 1 + else: + k = int(self.k_depth) # allow having k as input + # If the number happens to be less than 1, set k to 1 + if k < 1: k = 1 - # Calculate the upper box ends - upper = [100*(1 - 0.5**(i+2)) for i in range(k, -1, -1)] - # Calculate the lower box ends - lower = [100*(0.5**(i+2)) for i in range(k, -1, -1)] + + # Calculate the upper end for each of the k boxes + upper = [100 * (1 - 0.5 ** (i + 1)) for i in range(k, 0, -1)] + # Calculate the lower end for each of the k boxes + lower = [100 * (0.5 ** (i + 1)) for i in range(k, 0, -1)] # Stitch the box ends together percentile_ends = [(i, j) for i, j in zip(lower, upper)] box_ends = [np.percentile(vals, q) for q in percentile_ends] @@ -1893,7 +1847,8 @@ def _lv_box_ends(self, vals, k_depth='proportion', outlier_prop=None): def _lv_outliers(self, vals, k): """Find the outliers based on the letter value depth.""" - perc_ends = (100*(0.5**(k+2)), 100*(1 - 0.5**(k+2))) + box_edge = 0.5 ** (k + 1) + perc_ends = (100 * box_edge, 100 * (1 - box_edge)) edges = np.percentile(vals, perc_ends) lower_out = vals[np.where(vals < edges[0])[0]] upper_out = vals[np.where(vals > edges[1])[0]] @@ -1902,22 +1857,23 @@ def _lv_outliers(self, vals, k): def _width_functions(self, width_func): # Dictionary of functions for computing the width of the boxes width_functions = {'linear': lambda h, i, k: (i + 1.) / k, - 'exponential': lambda h, i, k: 2**(-k+i-1), - 'area': lambda h, i, k: (1 - 2**(-k+i-2)) / h} + 'exponential': lambda h, i, k: 2**(-k + i - 1), + 'area': lambda h, i, k: (1 - 2**(-k + i - 2)) / h} return width_functions[width_func] def _lvplot(self, box_data, positions, color=[255. / 256., 185. / 256., 0.], - vert=True, widths=1, k_depth='proportion', - ax=None, outlier_prop=None, scale='exponential', - **kws): + widths=1, ax=None, **kws): + vert = self.orient == "v" x = positions[0] box_data = np.asarray(box_data) # If we only have one data point, plot a line if len(box_data) == 1: - kws.update({'color': self.gray, 'linestyle': '-'}) + kws.update({ + 'color': self.gray, 'linestyle': '-', 'linewidth': self.linewidth + }) ys = [box_data[0], box_data[0]] xs = [x - widths / 2, x + widths / 2] if vert: @@ -1928,12 +1884,11 @@ def _lvplot(self, box_data, positions, else: # Get the number of data points and calculate "depth" of # letter-value plot - box_ends, k = self._lv_box_ends(box_data, k_depth=k_depth, - outlier_prop=outlier_prop) + box_ends, k = self._lv_box_ends(box_data) # Anonymous functions for calculating the width and height # of the letter value boxes - width = self._width_functions(scale) + width = self._width_functions(self.scale) # Function to find height of boxes def height(b): @@ -1941,14 +1896,14 @@ def height(b): # Functions to construct the letter value boxes def vert_perc_box(x, b, i, k, w): - rect = Patches.Rectangle((x - widths*w / 2, b[0]), - widths*w, + rect = Patches.Rectangle((x - widths * w / 2, b[0]), + widths * w, height(b), fill=True) return rect def horz_perc_box(x, b, i, k, w): - rect = Patches.Rectangle((b[0], x - widths*w / 2), - height(b), widths*w, + rect = Patches.Rectangle((b[0], x - widths * w / 2), + height(b), widths * w, fill=True) return rect @@ -1960,46 +1915,63 @@ def horz_perc_box(x, b, i, k, w): # Calculate the medians y = np.median(box_data) - # Calculate the outliers and plot - outliers = self._lv_outliers(box_data, k) + # Calculate the outliers and plot (only if showfliers == True) + outliers = [] + if self.showfliers: + outliers = self._lv_outliers(box_data, k) hex_color = mpl.colors.rgb2hex(color) if vert: - boxes = [vert_perc_box(x, b[0], i, k, b[1]) - for i, b in enumerate(zip(box_ends, w_area))] + box_func = vert_perc_box + xs_median = [x - widths / 2, x + widths / 2] + ys_median = [y, y] + xs_outliers = np.full(len(outliers), x) + ys_outliers = outliers - # Plot the medians - ax.plot([x - widths / 2, x + widths / 2], [y, y], - c='.15', alpha=.45, **kws) - - ax.scatter(np.repeat(x, len(outliers)), outliers, - marker='d', c=hex_color, **kws) else: - boxes = [horz_perc_box(x, b[0], i, k, b[1]) - for i, b in enumerate(zip(box_ends, w_area))] - - # Plot the medians - ax.plot([y, y], [x - widths / 2, x + widths / 2], - c='.15', alpha=.45, **kws) + box_func = horz_perc_box + xs_median = [y, y] + ys_median = [x - widths / 2, x + widths / 2] + xs_outliers = outliers + ys_outliers = np.full(len(outliers), x) + + boxes = [box_func(x, b[0], i, k, b[1]) + for i, b in enumerate(zip(box_ends, w_area))] + + # Plot the medians + ax.plot( + xs_median, + ys_median, + c=".15", + alpha=0.45, + solid_capstyle="butt", + linewidth=self.linewidth, + **kws + ) - ax.scatter(outliers, np.repeat(x, len(outliers)), - marker='d', c=hex_color, **kws) + # Plot outliers (if any) + if len(outliers) > 0: + ax.scatter(xs_outliers, ys_outliers, marker='d', + c=self.gray, **kws) # Construct a color map from the input color - rgb = [[1, 1, 1], hex_color] + rgb = [hex_color, (1, 1, 1)] cmap = mpl.colors.LinearSegmentedColormap.from_list('new_map', rgb) - collection = PatchCollection(boxes, cmap=cmap) + # Make sure that the last boxes contain hue and are not pure white + rgb = [hex_color, cmap(.85)] + cmap = mpl.colors.LinearSegmentedColormap.from_list('new_map', rgb) + collection = PatchCollection( + boxes, cmap=cmap, edgecolor=self.gray, linewidth=self.linewidth + ) - # Set the color gradation - collection.set_array(np.array(np.linspace(0, 1, len(boxes)))) + # Set the color gradation, first box will have color=hex_color + collection.set_array(np.array(np.linspace(1, 0, len(boxes)))) # Plot the boxes ax.add_collection(collection) def draw_letter_value_plot(self, ax, kws): """Use matplotlib to draw a letter value plot on an Axes.""" - vert = self.orient == "v" - for i, group_data in enumerate(self.plot_data): if self.plot_hues is None: @@ -2021,12 +1993,8 @@ def draw_letter_value_plot(self, ax, kws): self._lvplot(box_data, positions=[i], color=color, - vert=vert, widths=self.width, - k_depth=self.k_depth, ax=ax, - scale=self.scale, - outlier_prop=self.outlier_prop, **kws) else: @@ -2054,14 +2022,13 @@ def draw_letter_value_plot(self, ax, kws): self._lvplot(box_data, positions=[center], color=color, - vert=vert, widths=self.nested_width, - k_depth=self.k_depth, ax=ax, - scale=self.scale, - outlier_prop=self.outlier_prop, **kws) + # Autoscale the values axis to make sure all patches are visible + ax.autoscale_view(scalex=self.orient == "h", scaley=self.orient == "v") + def plot(self, ax, boxplot_kws): """Make the plot.""" self.draw_letter_value_plot(ax, boxplot_kws) @@ -2134,14 +2101,16 @@ def plot(self, ax, boxplot_kws): intervals. units : name of variable in ``data`` or vector data, optional Identifier of sampling units, which will be used to perform a - multilevel bootstrap and account for repeated measures design.\ + multilevel bootstrap and account for repeated measures design. + seed : int, numpy.random.Generator, or numpy.random.RandomState, optional + Seed or random number generator for reproducible bootstrapping.\ """), orient=dedent("""\ orient : "v" | "h", optional Orientation of the plot (vertical or horizontal). This is usually - inferred from the dtype of the input variables, but can be used to - specify when the "categorical" variable is a numeric or when plotting - wide-form data.\ + inferred based on the type of the input variables, but it can be used + to resolve ambiguity when both `x` and `y` are numeric or when + plotting wide-form data.\ """), color=dedent("""\ color : matplotlib color, optional @@ -2217,21 +2186,28 @@ def plot(self, ax, boxplot_kws): glyphs.\ """), catplot=dedent("""\ - catplot : Combine a categorical plot with a class:`FacetGrid`.\ + catplot : Combine a categorical plot with a :class:`FacetGrid`.\ """), boxenplot=dedent("""\ boxenplot : An enhanced boxplot for larger datasets.\ """), - ) +) _categorical_docs.update(_facet_docs) -def boxplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, - orient=None, color=None, palette=None, saturation=.75, - width=.8, dodge=True, fliersize=5, linewidth=None, - whis=1.5, notch=False, ax=None, **kwargs): +@_deprecate_positional_args +def boxplot( + *, + x=None, y=None, + hue=None, data=None, + order=None, hue_order=None, + orient=None, color=None, palette=None, saturation=.75, + width=.8, dodge=True, fliersize=5, linewidth=None, + whis=1.5, ax=None, + **kwargs +): plotter = _BoxPlotter(x, y, hue, data, order, hue_order, orient, color, palette, saturation, @@ -2239,7 +2215,7 @@ def boxplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, if ax is None: ax = plt.gca() - kwargs.update(dict(whis=whis, notch=notch)) + kwargs.update(dict(whis=whis)) plotter.plot(ax, kwargs) return ax @@ -2274,18 +2250,13 @@ def boxplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, Size of the markers used to indicate outlier observations. {linewidth} whis : float, optional - Proportion of the IQR past the low and high quartiles to extend the - plot whiskers. Points outside this range will be identified as - outliers. - notch : boolean, optional - Whether to "notch" the box to indicate a confidence interval for the - median. There are several other parameters that can control how the - notches are drawn; see the ``plt.boxplot`` help for more information - on them. + Maximum length of the plot whiskers as proportion of the + interquartile range. Whiskers extend to the furthest datapoint + within that range. More extreme points are marked as outliers. {ax_in} kwargs : key, value mappings - Other keyword arguments are passed through to ``plt.boxplot`` at draw - time. + Other keyword arguments are passed through to + :meth:`matplotlib.axes.Axes.boxplot`. Returns ------- @@ -2296,6 +2267,7 @@ def boxplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, {violinplot} {stripplot} {swarmplot} + {catplot} Examples -------- @@ -2306,7 +2278,7 @@ def boxplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, :context: close-figs >>> import seaborn as sns - >>> sns.set(style="whitegrid") + >>> sns.set_theme(style="whitegrid") >>> tips = sns.load_dataset("tips") >>> ax = sns.boxplot(x=tips["total_bill"]) @@ -2366,7 +2338,7 @@ def boxplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, >>> ax = sns.boxplot(x="day", y="total_bill", data=tips) >>> ax = sns.swarmplot(x="day", y="total_bill", data=tips, color=".25") - Use :func:`catplot` to combine a :func:`pointplot` and a + Use :func:`catplot` to combine a :func:`boxplot` and a :class:`FacetGrid`. This allows grouping within additional categorical variables. Using :func:`catplot` is safer than using :class:`FacetGrid` directly, as it ensures synchronization of variable order across facets: @@ -2382,11 +2354,17 @@ def boxplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, """).format(**_categorical_docs) -def violinplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, - bw="scott", cut=2, scale="area", scale_hue=True, gridsize=100, - width=.8, inner="box", split=False, dodge=True, orient=None, - linewidth=None, color=None, palette=None, saturation=.75, - ax=None, **kwargs): +@_deprecate_positional_args +def violinplot( + *, + x=None, y=None, + hue=None, data=None, + order=None, hue_order=None, + bw="scott", cut=2, scale="area", scale_hue=True, gridsize=100, + width=.8, inner="box", split=False, dodge=True, orient=None, + linewidth=None, color=None, palette=None, saturation=.75, + ax=None, **kwargs, +): plotter = _ViolinPlotter(x, y, hue, data, order, hue_order, bw, cut, scale, scale_hue, gridsize, @@ -2474,6 +2452,7 @@ def violinplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, {boxplot} {stripplot} {swarmplot} + {catplot} Examples -------- @@ -2484,7 +2463,7 @@ def violinplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, :context: close-figs >>> import seaborn as sns - >>> sns.set(style="whitegrid") + >>> sns.set_theme(style="whitegrid") >>> tips = sns.load_dataset("tips") >>> ax = sns.violinplot(x=tips["total_bill"]) @@ -2609,27 +2588,22 @@ def violinplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, """).format(**_categorical_docs) -def lvplot(*args, **kwargs): - """Deprecated; please use `boxenplot`.""" - - msg = ( - "The `lvplot` function has been renamed to `boxenplot`. The original " - "name will be removed in a future release. Please update your code. " - ) - warnings.warn(msg) - - return boxenplot(*args, **kwargs) - - -def boxenplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, - orient=None, color=None, palette=None, saturation=.75, - width=.8, dodge=True, k_depth='proportion', linewidth=None, - scale='exponential', outlier_prop=None, ax=None, **kwargs): +@_deprecate_positional_args +def boxenplot( + *, + x=None, y=None, + hue=None, data=None, + order=None, hue_order=None, + orient=None, color=None, palette=None, saturation=.75, + width=.8, dodge=True, k_depth='tukey', linewidth=None, + scale='exponential', outlier_prop=0.007, trust_alpha=0.05, showfliers=True, + ax=None, **kwargs +): plotter = _LVPlotter(x, y, hue, data, order, hue_order, orient, color, palette, saturation, width, dodge, k_depth, linewidth, scale, - outlier_prop) + outlier_prop, trust_alpha, showfliers) if ax is None: ax = plt.gca() @@ -2666,25 +2640,34 @@ def boxenplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, {saturation} {width} {dodge} - k_depth : "proportion" | "tukey" | "trustworthy", optional + k_depth : {{"tukey", "proportion", "trustworthy", "full"}} or scalar,\ + optional The number of boxes, and by extension number of percentiles, to draw. All methods are detailed in Wickham's paper. Each makes different assumptions about the number of outliers and leverages different - statistical properties. + statistical properties. If "proportion", draw no more than + `outlier_prop` extreme observations. If "full", draw `log(n)+1` boxes. {linewidth} - scale : "linear" | "exponential" | "area" + scale : {{"exponential", "linear", "area"}}, optional Method to use for the width of the letter value boxes. All give similar results visually. "linear" reduces the width by a constant linear factor, "exponential" uses the proportion of data not covered, "area" is proportional to the percentage of data covered. outlier_prop : float, optional - Proportion of data believed to be outliers. Used in conjunction with - k_depth to determine the number of percentiles to draw. Defaults to - 0.007 as a proportion of outliers. Should be in range [0, 1]. + Proportion of data believed to be outliers. Must be in the range + (0, 1]. Used to determine the number of boxes to plot when + `k_depth="proportion"`. + trust_alpha : float, optional + Confidence level for a box to be plotted. Used to determine the + number of boxes to plot when `k_depth="trustworthy"`. Must be in the + range (0, 1). + showfliers : bool, optional + If False, suppress the plotting of outliers. {ax_in} kwargs : key, value mappings - Other keyword arguments are passed through to ``plt.plot`` and - ``plt.scatter`` at draw time. + Other keyword arguments are passed through to + :meth:`matplotlib.axes.Axes.plot` and + :meth:`matplotlib.axes.Axes.scatter`. Returns ------- @@ -2694,6 +2677,7 @@ def boxenplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, -------- {violinplot} {boxplot} + {catplot} Examples -------- @@ -2704,7 +2688,7 @@ def boxenplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, :context: close-figs >>> import seaborn as sns - >>> sns.set(style="whitegrid") + >>> sns.set_theme(style="whitegrid") >>> tips = sns.load_dataset("tips") >>> ax = sns.boxenplot(x=tips["total_bill"]) @@ -2752,9 +2736,10 @@ def boxenplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, .. plot:: :context: close-figs - >>> ax = sns.boxenplot(x="day", y="total_bill", data=tips) + >>> ax = sns.boxenplot(x="day", y="total_bill", data=tips, + ... showfliers=False) >>> ax = sns.stripplot(x="day", y="total_bill", data=tips, - ... size=4, color="gray") + ... size=4, color=".26") Use :func:`catplot` to combine :func:`boxenplot` and a :class:`FacetGrid`. This allows grouping within additional categorical variables. Using @@ -2772,31 +2757,67 @@ def boxenplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, """).format(**_categorical_docs) -def stripplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, - jitter=True, dodge=False, orient=None, color=None, palette=None, - size=5, edgecolor="gray", linewidth=0, ax=None, **kwargs): - - if "split" in kwargs: - dodge = kwargs.pop("split") - msg = "The `split` parameter has been renamed to `dodge`." - warnings.warn(msg, UserWarning) +@_deprecate_positional_args +def stripplot( + *, + x=None, y=None, + hue=None, data=None, + order=None, hue_order=None, + jitter=True, dodge=False, orient=None, color=None, palette=None, + size=5, edgecolor="gray", linewidth=0, ax=None, + hue_norm=None, fixed_scale=True, formatter=None, + **kwargs +): + + # XXX we need to add a legend= param!!! + + p = _CategoricalPlotterNew( + data=data, + variables=_CategoricalPlotterNew.get_semantics(locals()), + order=order, + orient=orient, + require_numeric=False, + fixed_scale=fixed_scale, + ) - plotter = _StripPlotter(x, y, hue, data, order, hue_order, - jitter, dodge, orient, color, palette) if ax is None: ax = plt.gca() + if fixed_scale or p.var_types[p.cat_axis] == "categorical": + p.scale_categorical(p.cat_axis, order=order, formatter=formatter) + + p._attach(ax) + + palette, hue_order = p._hue_backcompat(color, palette, hue_order) + + color = _default_color(ax.scatter, hue, color, kwargs) + + p.map_hue(palette=palette, order=hue_order, norm=hue_norm) + + # XXX Copying possibly bad default decisions from original code for now kwargs.setdefault("zorder", 3) size = kwargs.get("s", size) - if linewidth is None: - linewidth = size / 10 - if edgecolor == "gray": - edgecolor = plotter.gray - kwargs.update(dict(s=size ** 2, - edgecolor=edgecolor, - linewidth=linewidth)) - plotter.plot(ax, kwargs) + kwargs.update(dict( + s=size ** 2, + edgecolor=edgecolor, + linewidth=linewidth) + ) + + p.plot_strips( + jitter=jitter, + dodge=dodge, + color=color, + edgecolor=edgecolor, + plot_kws=kwargs, + ) + + # XXX this happens inside a plotting method in the distribution plots + # but maybe it's better out here? Alternatively, we have an open issue + # suggesting that _attach could add default axes labels, which seems smart. + p._add_axis_labels(ax) + p._adjust_cat_axis(ax, axis=p.cat_axis) + return ax @@ -2831,15 +2852,16 @@ def stripplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, {color} {palette} size : float, optional - Diameter of the markers, in points. Although ``plt.scatter`` is used - to draw the points, the ``size`` argument here takes a "normal" - markersize and not size^2 like ``plt.scatter``. + Radius of the markers, in points. edgecolor : matplotlib color, "gray" is special-cased, optional Color of the lines around each point. If you pass ``"gray"``, the brightness is determined by the color palette used for the body of the points. {linewidth} {ax_in} + kwargs : key, value mappings + Other keyword arguments are passed through to + :meth:`matplotlib.axes.Axes.scatter`. Returns ------- @@ -2850,141 +2872,80 @@ def stripplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, {swarmplot} {boxplot} {violinplot} + {catplot} Examples -------- - Draw a single horizontal strip plot: - - .. plot:: - :context: close-figs - - >>> import seaborn as sns - >>> sns.set(style="whitegrid") - >>> tips = sns.load_dataset("tips") - >>> ax = sns.stripplot(x=tips["total_bill"]) - - Group the strips by a categorical variable: - - .. plot:: - :context: close-figs - - >>> ax = sns.stripplot(x="day", y="total_bill", data=tips) - - Use a smaller amount of jitter: - - .. plot:: - :context: close-figs - - >>> ax = sns.stripplot(x="day", y="total_bill", data=tips, jitter=0.05) - - Draw horizontal strips: - - .. plot:: - :context: close-figs - - >>> ax = sns.stripplot(x="total_bill", y="day", data=tips) - - Draw outlines around the points: - - .. plot:: - :context: close-figs - - >>> ax = sns.stripplot(x="total_bill", y="day", data=tips, - ... linewidth=1) - - Nest the strips within a second categorical variable: - - .. plot:: - :context: close-figs - - >>> ax = sns.stripplot(x="sex", y="total_bill", hue="day", data=tips) + .. include:: ../docstrings/stripplot.rst - Draw each level of the ``hue`` variable at different locations on the - major categorical axis: - - .. plot:: - :context: close-figs - - >>> ax = sns.stripplot(x="day", y="total_bill", hue="smoker", - ... data=tips, palette="Set2", dodge=True) - - Control strip order by passing an explicit order: - - .. plot:: - :context: close-figs - - >>> ax = sns.stripplot(x="time", y="tip", data=tips, - ... order=["Dinner", "Lunch"]) - - Draw strips with large points and different aesthetics: - - .. plot:: - :context: close-figs - - >>> ax = sns.stripplot("day", "total_bill", "smoker", data=tips, - ... palette="Set2", size=20, marker="D", - ... edgecolor="gray", alpha=.25) - - Draw strips of observations on top of a box plot: - - .. plot:: - :context: close-figs - - >>> import numpy as np - >>> ax = sns.boxplot(x="tip", y="day", data=tips, whis=np.inf) - >>> ax = sns.stripplot(x="tip", y="day", data=tips, color=".3") - - Draw strips of observations on top of a violin plot: - - .. plot:: - :context: close-figs + """).format(**_categorical_docs) - >>> ax = sns.violinplot(x="day", y="total_bill", data=tips, - ... inner=None, color=".8") - >>> ax = sns.stripplot(x="day", y="total_bill", data=tips) - Use :func:`catplot` to combine a :func:`stripplot` and a - :class:`FacetGrid`. This allows grouping within additional categorical - variables. Using :func:`catplot` is safer than using :class:`FacetGrid` - directly, as it ensures synchronization of variable order across facets: +@_deprecate_positional_args +def swarmplot( + *, + x=None, y=None, + hue=None, data=None, + order=None, hue_order=None, + dodge=False, orient=None, color=None, palette=None, + size=5, edgecolor="gray", linewidth=0, ax=None, + hue_norm=None, fixed_scale=True, formatter=None, warn_thresh=.05, + **kwargs +): + + p = _CategoricalPlotterNew( + data=data, + variables=_CategoricalPlotterNew.get_semantics(locals()), + order=order, + orient=orient, + require_numeric=False, + fixed_scale=fixed_scale, + ) - .. plot:: - :context: close-figs + if ax is None: + ax = plt.gca() - >>> g = sns.catplot(x="sex", y="total_bill", - ... hue="smoker", col="time", - ... data=tips, kind="strip", - ... height=4, aspect=.7); + if fixed_scale or p.var_types[p.cat_axis] == "categorical": + p.scale_categorical(p.cat_axis, order=order, formatter=formatter) - """).format(**_categorical_docs) + p._attach(ax) + if not p.has_xy_data: + return ax -def swarmplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, - dodge=False, orient=None, color=None, palette=None, - size=5, edgecolor="gray", linewidth=0, ax=None, **kwargs): + palette, hue_order = p._hue_backcompat(color, palette, hue_order) - if "split" in kwargs: - dodge = kwargs.pop("split") - msg = "The `split` parameter has been renamed to `dodge`." - warnings.warn(msg, UserWarning) + color = _default_color(ax.scatter, hue, color, kwargs) - plotter = _SwarmPlotter(x, y, hue, data, order, hue_order, - dodge, orient, color, palette) - if ax is None: - ax = plt.gca() + p.map_hue(palette=palette, order=hue_order, norm=hue_norm) + # XXX Copying possibly bad default decisions from original code for now kwargs.setdefault("zorder", 3) size = kwargs.get("s", size) + if linewidth is None: linewidth = size / 10 - if edgecolor == "gray": - edgecolor = plotter.gray - kwargs.update(dict(s=size ** 2, - edgecolor=edgecolor, - linewidth=linewidth)) - plotter.plot(ax, kwargs) + kwargs.update(dict( + s=size ** 2, + linewidth=linewidth, + )) + + p.plot_swarms( + dodge=dodge, + color=color, + edgecolor=edgecolor, + warn_thresh=warn_thresh, + plot_kws=kwargs, + ) + + # XXX this happens inside a plotting method in the distribution plots + # but maybe it's better out here? Alternatively, we have an open issue + # suggesting that _attach could add default axes labels, which seems smart. + p._add_axis_labels(ax) + p._adjust_cat_axis(ax, axis=p.cat_axis) + return ax @@ -3022,9 +2983,7 @@ def swarmplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, {color} {palette} size : float, optional - Diameter of the markers, in points. (Although ``plt.scatter`` is used - to draw the points, the ``size`` argument here takes a "normal" - markersize and not size^2 like ``plt.scatter``. + Radius of the markers, in points. edgecolor : matplotlib color, "gray" is special-cased, optional Color of the lines around each point. If you pass ``"gray"``, the brightness is determined by the color palette used for the body @@ -3032,8 +2991,8 @@ def swarmplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, {linewidth} {ax_in} kwargs : key, value mappings - Other keyword arguments are passed through to ``plt.scatter`` at draw - time. + Other keyword arguments are passed through to + :meth:`matplotlib.axes.Axes.scatter`. Returns ------- @@ -3049,101 +3008,26 @@ def swarmplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, Examples -------- - Draw a single horizontal swarm plot: - - .. plot:: - :context: close-figs - - >>> import seaborn as sns - >>> sns.set(style="whitegrid") - >>> tips = sns.load_dataset("tips") - >>> ax = sns.swarmplot(x=tips["total_bill"]) - - Group the swarms by a categorical variable: - - .. plot:: - :context: close-figs - - >>> ax = sns.swarmplot(x="day", y="total_bill", data=tips) - - Draw horizontal swarms: - - .. plot:: - :context: close-figs - - >>> ax = sns.swarmplot(x="total_bill", y="day", data=tips) - - Color the points using a second categorical variable: - - .. plot:: - :context: close-figs - - >>> ax = sns.swarmplot(x="day", y="total_bill", hue="sex", data=tips) - - Split each level of the ``hue`` variable along the categorical axis: - - .. plot:: - :context: close-figs - - >>> ax = sns.swarmplot(x="day", y="total_bill", hue="smoker", - ... data=tips, palette="Set2", dodge=True) - - Control swarm order by passing an explicit order: - - .. plot:: - :context: close-figs - - >>> ax = sns.swarmplot(x="time", y="tip", data=tips, - ... order=["Dinner", "Lunch"]) - - Plot using larger points: - - .. plot:: - :context: close-figs - - >>> ax = sns.swarmplot(x="time", y="tip", data=tips, size=6) - - Draw swarms of observations on top of a box plot: - - .. plot:: - :context: close-figs - - >>> ax = sns.boxplot(x="tip", y="day", data=tips, whis=np.inf) - >>> ax = sns.swarmplot(x="tip", y="day", data=tips, color=".2") - - Draw swarms of observations on top of a violin plot: - - .. plot:: - :context: close-figs - - >>> ax = sns.violinplot(x="day", y="total_bill", data=tips, inner=None) - >>> ax = sns.swarmplot(x="day", y="total_bill", data=tips, - ... color="white", edgecolor="gray") - - Use :func:`catplot` to combine a :func:`swarmplot` and a - :class:`FacetGrid`. This allows grouping within additional categorical - variables. Using :func:`catplot` is safer than using :class:`FacetGrid` - directly, as it ensures synchronization of variable order across facets: - - .. plot:: - :context: close-figs - - >>> g = sns.catplot(x="sex", y="total_bill", - ... hue="smoker", col="time", - ... data=tips, kind="swarm", - ... height=4, aspect=.7); + .. include:: ../docstrings/swarmplot.rst """).format(**_categorical_docs) -def barplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, - estimator=np.mean, ci=95, n_boot=1000, units=None, - orient=None, color=None, palette=None, saturation=.75, - errcolor=".26", errwidth=None, capsize=None, dodge=True, - ax=None, **kwargs): +@_deprecate_positional_args +def barplot( + *, + x=None, y=None, + hue=None, data=None, + order=None, hue_order=None, + estimator=np.mean, ci=95, n_boot=1000, units=None, seed=None, + orient=None, color=None, palette=None, saturation=.75, + errcolor=".26", errwidth=None, capsize=None, dodge=True, + ax=None, + **kwargs, +): plotter = _BarPlotter(x, y, hue, data, order, hue_order, - estimator, ci, n_boot, units, + estimator, ci, n_boot, units, seed, orient, color, palette, saturation, errcolor, errwidth, capsize, dodge) @@ -3195,8 +3079,8 @@ def barplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, {dodge} {ax_in} kwargs : key, value mappings - Other keyword arguments are passed through to ``plt.bar`` at draw - time. + Other keyword arguments are passed through to + :meth:`matplotlib.axes.Axes.bar`. Returns ------- @@ -3217,7 +3101,7 @@ def barplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, :context: close-figs >>> import seaborn as sns - >>> sns.set(style="whitegrid") + >>> sns.set_theme(style="whitegrid") >>> tips = sns.load_dataset("tips") >>> ax = sns.barplot(x="day", y="total_bill", data=tips) @@ -3277,7 +3161,7 @@ def barplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, .. plot:: :context: close-figs - >>> ax = sns.barplot("size", y="total_bill", data=tips, + >>> ax = sns.barplot(x="size", y="total_bill", data=tips, ... palette="Blues_d") Use ``hue`` without changing bar position or width: @@ -3294,15 +3178,15 @@ def barplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, .. plot:: :context: close-figs - >>> ax = sns.barplot("size", y="total_bill", data=tips, + >>> ax = sns.barplot(x="size", y="total_bill", data=tips, ... color="salmon", saturation=.5) - Use ``plt.bar`` keyword arguments to further change the aesthetic: + Use :meth:`matplotlib.axes.Axes.bar` parameters to control the style. .. plot:: :context: close-figs - >>> ax = sns.barplot("day", "total_bill", data=tips, + >>> ax = sns.barplot(x="day", y="total_bill", data=tips, ... linewidth=2.5, facecolor=(1, 1, 1, 0), ... errcolor=".2", edgecolor=".2") @@ -3322,14 +3206,21 @@ def barplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, """).format(**_categorical_docs) -def pointplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, - estimator=np.mean, ci=95, n_boot=1000, units=None, - markers="o", linestyles="-", dodge=False, join=True, scale=1, - orient=None, color=None, palette=None, errwidth=None, - capsize=None, ax=None, **kwargs): +@_deprecate_positional_args +def pointplot( + *, + x=None, y=None, + hue=None, data=None, + order=None, hue_order=None, + estimator=np.mean, ci=95, n_boot=1000, units=None, seed=None, + markers="o", linestyles="-", dodge=False, join=True, scale=1, + orient=None, color=None, palette=None, errwidth=None, + capsize=None, ax=None, + **kwargs +): plotter = _PointPlotter(x, y, hue, data, order, hue_order, - estimator, ci, n_boot, units, + estimator, ci, n_boot, units, seed, markers, linestyles, dodge, join, scale, orient, color, palette, errwidth, capsize) @@ -3409,7 +3300,7 @@ def pointplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, :context: close-figs >>> import seaborn as sns - >>> sns.set(style="darkgrid") + >>> sns.set_theme(style="darkgrid") >>> tips = sns.load_dataset("tips") >>> ax = sns.pointplot(x="time", y="total_bill", data=tips) @@ -3458,7 +3349,7 @@ def pointplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, .. plot:: :context: close-figs - >>> ax = sns.pointplot("time", y="total_bill", data=tips, + >>> ax = sns.pointplot(x="time", y="total_bill", data=tips, ... color="#bb3f3f") Use a different color palette for the points: @@ -3506,7 +3397,7 @@ def pointplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, >>> ax = sns.pointplot(x="day", y="tip", data=tips, capsize=.2) - Use :func:`catplot` to combine a :func:`barplot` and a + Use :func:`catplot` to combine a :func:`pointplot` and a :class:`FacetGrid`. This allows grouping within additional categorical variables. Using :func:`catplot` is safer than using :class:`FacetGrid` directly, as it ensures synchronization of variable order across facets: @@ -3523,14 +3414,21 @@ def pointplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, """).format(**_categorical_docs) -def countplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, - orient=None, color=None, palette=None, saturation=.75, - dodge=True, ax=None, **kwargs): +@_deprecate_positional_args +def countplot( + *, + x=None, y=None, + hue=None, data=None, + order=None, hue_order=None, + orient=None, color=None, palette=None, saturation=.75, + dodge=True, ax=None, **kwargs +): estimator = len ci = None n_boot = 0 units = None + seed = None errcolor = None errwidth = None capsize = None @@ -3542,14 +3440,14 @@ def countplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, orient = "v" y = x elif x is not None and y is not None: - raise TypeError("Cannot pass values for both `x` and `y`") - else: - raise TypeError("Must pass values for either `x` or `y`") + raise ValueError("Cannot pass values for both `x` and `y`") - plotter = _BarPlotter(x, y, hue, data, order, hue_order, - estimator, ci, n_boot, units, - orient, color, palette, saturation, - errcolor, errwidth, capsize, dodge) + plotter = _CountPlotter( + x, y, hue, data, order, hue_order, + estimator, ci, n_boot, units, seed, + orient, color, palette, saturation, + errcolor, errwidth, capsize, dodge + ) plotter.value_label = "count" @@ -3583,7 +3481,8 @@ def countplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, {dodge} {ax_in} kwargs : key, value mappings - Other keyword arguments are passed to ``plt.bar``. + Other keyword arguments are passed through to + :meth:`matplotlib.axes.Axes.bar`. Returns ------- @@ -3603,7 +3502,7 @@ def countplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, :context: close-figs >>> import seaborn as sns - >>> sns.set(style="darkgrid") + >>> sns.set_theme(style="darkgrid") >>> titanic = sns.load_dataset("titanic") >>> ax = sns.countplot(x="class", data=titanic) @@ -3628,7 +3527,7 @@ def countplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, >>> ax = sns.countplot(x="who", data=titanic, palette="Set3") - Use ``plt.bar`` keyword arguments for a different look: + Use :meth:`matplotlib.axes.Axes.bar` parameters to control the style. .. plot:: :context: close-figs @@ -3675,13 +3574,21 @@ def factorplot(*args, **kwargs): return catplot(*args, **kwargs) -def catplot(x=None, y=None, hue=None, data=None, row=None, col=None, - col_wrap=None, estimator=np.mean, ci=95, n_boot=1000, - units=None, order=None, hue_order=None, row_order=None, - col_order=None, kind="strip", height=5, aspect=1, - orient=None, color=None, palette=None, - legend=True, legend_out=True, sharex=True, sharey=True, - margin_titles=False, facet_kws=None, **kwargs): +@_deprecate_positional_args +def catplot( + *, + x=None, y=None, + hue=None, data=None, + row=None, col=None, # TODO move in front of data when * is enforced + col_wrap=None, estimator=np.mean, ci=95, n_boot=1000, + units=None, seed=None, order=None, hue_order=None, row_order=None, + col_order=None, kind="strip", height=5, aspect=1, + orient=None, color=None, palette=None, + legend=True, legend_out=True, sharex=True, sharey=True, + margin_titles=False, facet_kws=None, + hue_norm=None, fixed_scale=True, formatter=None, + **kwargs +): # Handle deprecations if "size" in kwargs: @@ -3697,6 +3604,132 @@ def catplot(x=None, y=None, hue=None, data=None, row=None, col=None, err = "Plot kind '{}' is not recognized".format(kind) raise ValueError(err) + # Check for attempt to plot onto specific axes and warn + if "ax" in kwargs: + msg = ("catplot is a figure-level function and does not accept " + f"target axes. You may wish to try {kind}plot") + warnings.warn(msg, UserWarning) + kwargs.pop("ax") + + refactored_kinds = [ + "strip", "swarm", + ] + + if kind in refactored_kinds: + + p = _CategoricalFacetPlotter( + data=data, + variables=_CategoricalFacetPlotter.get_semantics(locals()), + order=order, + orient=orient, + require_numeric=False, + fixed_scale=fixed_scale, + ) + + # XXX Copying a fair amount from displot, which is not ideal + + for var in ["row", "col"]: + # Handle faceting variables that lack name information + if var in p.variables and p.variables[var] is None: + p.variables[var] = f"_{var}_" + + # Adapt the plot_data dataframe for use with FacetGrid + data = p.plot_data.rename(columns=p.variables) + data = data.loc[:, ~data.columns.duplicated()] + + col_name = p.variables.get("col", None) + row_name = p.variables.get("row", None) + + if facet_kws is None: + facet_kws = {} + + g = FacetGrid( + data=data, row=row_name, col=col_name, + col_wrap=col_wrap, row_order=row_order, + col_order=col_order, height=height, + sharex=sharex, sharey=sharey, + aspect=aspect, + **facet_kws, + ) + + if fixed_scale or p.var_types[p.cat_axis] == "categorical": + p.scale_categorical(p.cat_axis, order=order, formatter=formatter) + + p._attach(g) + + if not p.has_xy_data: + return g + + palette, hue_order = p._hue_backcompat(color, palette, hue_order) + p.map_hue(palette=palette, order=hue_order, norm=hue_norm) + + if kind == "strip": + + # TODO get these defaults programatically? + jitter = kwargs.pop("jitter", True) + dodge = kwargs.pop("dodge", False) + edgecolor = kwargs.pop("edgecolor", "gray") # XXX TODO default + + plot_kws = kwargs.copy() + + # XXX Copying possibly bad default decisions from original code for now + plot_kws.setdefault("zorder", 3) + plot_kws.setdefault("s", 25) + plot_kws.setdefault("linewidth", 0) + + p.plot_strips( + jitter=jitter, + dodge=dodge, + color=color, + edgecolor=edgecolor, + plot_kws=plot_kws, + ) + + elif kind == "swarm": + + # TODO get these defaults programatically? + dodge = kwargs.pop("dodge", False) + edgecolor = kwargs.pop("edgecolor", "gray") # XXX TODO default + warn_thresh = kwargs.pop("warn_thresh", .05) + + plot_kws = kwargs.copy() + + # XXX Copying possibly bad default decisions from original code for now + plot_kws.setdefault("zorder", 3) + plot_kws.setdefault("s", 25) + + if plot_kws.setdefault("linewidth", 0) is None: + plot_kws["linewidth"] = np.sqrt(plot_kws["s"]) / 10 + + p.plot_swarms( + dodge=dodge, + color=color, + edgecolor=edgecolor, + warn_thresh=warn_thresh, + plot_kws=plot_kws, + ) + + # XXX best way to do this housekeeping? + for ax in g.axes.flat: + p._adjust_cat_axis(ax, axis=p.cat_axis) + + g.set_axis_labels( + p.variables.get("x", None), + p.variables.get("y", None), + ) + g.set_titles() + g.tight_layout() + + # XXX Hack to get the legend data in the right place + for ax in g.axes.flat: + g._update_legend_data(ax) + ax.legend_ = None + + if legend and (hue is not None) and (hue not in [x, row, col]): + g.add_legend(title=hue, label_order=hue_order) + + return g + # Alias the input variables to determine categorical order and palette # correctly in the case of a count plot if kind == "count": @@ -3705,15 +3738,40 @@ def catplot(x=None, y=None, hue=None, data=None, row=None, col=None, elif y is None and x is not None: x_, y_, orient = x, x, "v" else: - raise ValueError("Either `x` or `y` must be None for count plots") + raise ValueError("Either `x` or `y` must be None for kind='count'") else: x_, y_ = x, y # Determine the order for the whole dataset, which will be used in all # facets to ensure representation of all data in the final plot + plotter_class = { + "box": _BoxPlotter, + "violin": _ViolinPlotter, + "boxen": _LVPlotter, + "bar": _BarPlotter, + "point": _PointPlotter, + "count": _CountPlotter, + }[kind] p = _CategoricalPlotter() + p.require_numeric = plotter_class.require_numeric p.establish_variables(x_, y_, hue, data, orient, order, hue_order) - order = p.group_names + if ( + order is not None + or (sharex and p.orient == "v") + or (sharey and p.orient == "h") + ): + # Sync categorical axis between facets to have the same categories + order = p.group_names + elif color is None and hue is None: + msg = ( + "Setting `{}=False` with `color=None` may cause different levels of the " + "`{}` variable to share colors. This will change in a future version." + ) + if not sharex and p.orient == "v": + warnings.warn(msg.format("sharex", "x"), UserWarning) + if not sharey and p.orient == "h": + warnings.warn(msg.format("sharey", "y"), UserWarning) + hue_order = p.hue_names # Determine the palette to use @@ -3721,8 +3779,17 @@ def catplot(x=None, y=None, hue=None, data=None, row=None, col=None, # so we need to define ``palette`` to get default behavior for the # categorical functions p.establish_colors(color, palette, 1) - if kind != "point" or hue is not None: - palette = p.colors + if ( + (kind != "point" or hue is not None) + # XXX changing this to temporarily support bad sharex=False behavior where + # cat variables could take different colors, which we already warned + # about "breaking" (aka fixing) in the future + and ((sharex and p.orient == "v") or (sharey and p.orient == "h")) + ): + if p.hue_names is None: + palette = dict(zip(p.group_names, p.colors)) + else: + palette = dict(zip(p.hue_names, p.colors)) # Determine keyword arguments for the facets facet_kws = {} if facet_kws is None else facet_kws @@ -3733,25 +3800,30 @@ def catplot(x=None, y=None, hue=None, data=None, row=None, col=None, sharex=sharex, sharey=sharey, legend_out=legend_out, margin_titles=margin_titles, dropna=False, - ) + ) # Determine keyword arguments for the plotting function plot_kws = dict( order=order, hue_order=hue_order, orient=orient, color=color, palette=palette, - ) + ) plot_kws.update(kwargs) if kind in ["bar", "point"]: plot_kws.update( - estimator=estimator, ci=ci, n_boot=n_boot, units=units, - ) + estimator=estimator, ci=ci, n_boot=n_boot, units=units, seed=seed, + ) # Initialize the facets g = FacetGrid(**facet_kws) # Draw the plot onto the facets - g.map_dataframe(plot_func, x, y, hue, **plot_kws) + g.map_dataframe(plot_func, x=x, y=y, hue=hue, **plot_kws) + + if p.orient == "h": + g.set_axis_labels(p.value_label, p.group_label) + else: + g.set_axis_labels(p.group_label, p.value_label) # Special case axis labels for a count type plot if kind == "count": @@ -3800,7 +3872,7 @@ def catplot(x=None, y=None, hue=None, data=None, row=None, col=None, to ``x``, ``y``, ``hue``, etc. As in the case with the underlying plot functions, if variables have a - ``categorical`` data type, the the levels of the categorical variables, and + ``categorical`` data type, the levels of the categorical variables, and their order will be inferred from the objects. Otherwise you may have to use alter the dataframe sorting or use the function parameters (``orient``, ``order``, ``hue_order``, etc.) to set up the plot correctly. @@ -3822,10 +3894,10 @@ def catplot(x=None, y=None, hue=None, data=None, row=None, col=None, row_order, col_order : lists of strings, optional Order to organize the rows and/or columns of the grid in, otherwise the orders are inferred from the data objects. - kind : string, optional - The kind of plot to draw (corresponds to the name of a categorical - plotting function. Options are: "point", "bar", "strip", "swarm", - "box", "violin", or "boxen". + kind : str, optional + The kind of plot to draw, corresponds to the name of a categorical + axes-level plotting function. Options are: "strip", "swarm", "box", "violin", + "boxen", "point", "bar", or "count". {height} {aspect} {orient} @@ -3857,7 +3929,7 @@ def catplot(x=None, y=None, hue=None, data=None, row=None, col=None, :context: close-figs >>> import seaborn as sns - >>> sns.set(style="ticks") + >>> sns.set_theme(style="ticks") >>> exercise = sns.load_dataset("exercise") >>> g = sns.catplot(x="time", y="pulse", hue="kind", data=exercise) @@ -3892,7 +3964,7 @@ def catplot(x=None, y=None, hue=None, data=None, row=None, col=None, :context: close-figs >>> titanic = sns.load_dataset("titanic") - >>> g = sns.catplot("alive", col="deck", col_wrap=4, + >>> g = sns.catplot(x="alive", col="deck", col_wrap=4, ... data=titanic[titanic.deck.notnull()], ... kind="count", height=2.5, aspect=.8) @@ -3923,3 +3995,200 @@ def catplot(x=None, y=None, hue=None, data=None, row=None, col=None, """).format(**_categorical_docs) + + +class Beeswarm: + """Modifies a scatterplot artist to show a beeswarm plot.""" + def __init__(self, orient="v", width=0.8, warn_thresh=.05): + + # XXX should we keep the orient parameterization or specify the swarm axis? + + self.orient = orient + self.width = width + self.warn_thresh = warn_thresh + + def __call__(self, points, center): + """Swarm `points`, a PathCollection, around the `center` position.""" + # Convert from point size (area) to diameter + + ax = points.axes + dpi = ax.figure.dpi + + # Get the original positions of the points + orig_xy_data = points.get_offsets() + + # Reset the categorical positions to the center line + cat_idx = 1 if self.orient == "h" else 0 + orig_xy_data[:, cat_idx] = center + + # Transform the data coordinates to point coordinates. + # We'll figure out the swarm positions in the latter + # and then convert back to data coordinates and replot + orig_x_data, orig_y_data = orig_xy_data.T + orig_xy = ax.transData.transform(orig_xy_data) + + # Order the variables so that x is the categorical axis + if self.orient == "h": + orig_xy = orig_xy[:, [1, 0]] + + # Add a column with each point's radius + sizes = points.get_sizes() + if sizes.size == 1: + sizes = np.repeat(sizes, orig_xy.shape[0]) + edge = points.get_linewidth().item() + radii = (np.sqrt(sizes) + edge) / 2 * (dpi / 72) + orig_xy = np.c_[orig_xy, radii] + + # Sort along the value axis to facilitate the beeswarm + sorter = np.argsort(orig_xy[:, 1]) + orig_xyr = orig_xy[sorter] + + # Adjust points along the categorical axis to prevent overlaps + new_xyr = np.empty_like(orig_xyr) + new_xyr[sorter] = self.beeswarm(orig_xyr) + + # Transform the point coordinates back to data coordinates + if self.orient == "h": + new_xy = new_xyr[:, [1, 0]] + else: + new_xy = new_xyr[:, :2] + new_x_data, new_y_data = ax.transData.inverted().transform(new_xy).T + + swarm_axis = {"h": "y", "v": "x"}[self.orient] + log_scale = getattr(ax, f"get_{swarm_axis}scale")() == "log" + + # Add gutters + if self.orient == "h": + self.add_gutters(new_y_data, center, log_scale=log_scale) + else: + self.add_gutters(new_x_data, center, log_scale=log_scale) + + # Reposition the points so they do not overlap + if self.orient == "h": + points.set_offsets(np.c_[orig_x_data, new_y_data]) + else: + points.set_offsets(np.c_[new_x_data, orig_y_data]) + + def beeswarm(self, orig_xyr): + """Adjust x position of points to avoid overlaps.""" + # In this method, `x` is always the categorical axis + # Center of the swarm, in point coordinates + midline = orig_xyr[0, 0] + + # Start the swarm with the first point + swarm = np.atleast_2d(orig_xyr[0]) + + # Loop over the remaining points + for xyr_i in orig_xyr[1:]: + + # Find the points in the swarm that could possibly + # overlap with the point we are currently placing + neighbors = self.could_overlap(xyr_i, swarm) + + # Find positions that would be valid individually + # with respect to each of the swarm neighbors + candidates = self.position_candidates(xyr_i, neighbors) + + # Sort candidates by their centrality + offsets = np.abs(candidates[:, 0] - midline) + candidates = candidates[np.argsort(offsets)] + + # Find the first candidate that does not overlap any neighbors + new_xyr_i = self.first_non_overlapping_candidate(candidates, neighbors) + + # Place it into the swarm + swarm = np.vstack([swarm, new_xyr_i]) + + return swarm + + def could_overlap(self, xyr_i, swarm): + """Return a list of all swarm points that could overlap with target.""" + # Because we work backwards through the swarm and can short-circuit, + # the for-loop is faster than vectorization + _, y_i, r_i = xyr_i + neighbors = [] + for xyr_j in reversed(swarm): + _, y_j, r_j = xyr_j + if (y_i - y_j) < (r_i + r_j): + neighbors.append(xyr_j) + else: + break + return np.array(neighbors)[::-1] + + def position_candidates(self, xyr_i, neighbors): + """Return a list of coordinates that might be valid by adjusting x.""" + candidates = [xyr_i] + x_i, y_i, r_i = xyr_i + left_first = True + for x_j, y_j, r_j in neighbors: + dy = y_i - y_j + dx = np.sqrt(max((r_i + r_j) ** 2 - dy ** 2, 0)) * 1.05 + cl, cr = (x_j - dx, y_i, r_i), (x_j + dx, y_i, r_i) + if left_first: + new_candidates = [cl, cr] + else: + new_candidates = [cr, cl] + candidates.extend(new_candidates) + left_first = not left_first + return np.array(candidates) + + def first_non_overlapping_candidate(self, candidates, neighbors): + """Find the first candidate that does not overlap with the swarm.""" + + # If we have no neighbors, all candidates are good. + if len(neighbors) == 0: + return candidates[0] + + neighbors_x = neighbors[:, 0] + neighbors_y = neighbors[:, 1] + neighbors_r = neighbors[:, 2] + + for xyr_i in candidates: + + x_i, y_i, r_i = xyr_i + + dx = neighbors_x - x_i + dy = neighbors_y - y_i + sq_distances = np.square(dx) + np.square(dy) + + sep_needed = np.square(neighbors_r + r_i) + + # Good candidate does not overlap any of neighbors which means that + # squared distance between candidate and any of the neighbors has + # to be at least square of the summed radii + good_candidate = np.all(sq_distances >= sep_needed) + + if good_candidate: + return xyr_i + + raise RuntimeError( + "No non-overlapping candidates found. This should not happen." + ) + + def add_gutters(self, points, center, log_scale=False): + """Stop points from extending beyond their territory.""" + half_width = self.width / 2 + if log_scale: + low_gutter = 10 ** (np.log10(center) - half_width) + else: + low_gutter = center - half_width + off_low = points < low_gutter + if off_low.any(): + points[off_low] = low_gutter + if log_scale: + high_gutter = 10 ** (np.log10(center) + half_width) + else: + high_gutter = center + half_width + off_high = points > high_gutter + if off_high.any(): + points[off_high] = high_gutter + + gutter_prop = (off_high + off_low).sum() / len(points) + if gutter_prop > self.warn_thresh: + msg = ( + "{:.1%} of the points cannot be placed; you may want " + "to decrease the size of the markers or use stripplot." + ).format(gutter_prop) + warnings.warn(msg, UserWarning) + + return points diff --git a/seaborn/cm.py b/seaborn/cm.py index c250163428..4e39fe7a67 100644 --- a/seaborn/cm.py +++ b/seaborn/cm.py @@ -1041,10 +1041,537 @@ ] -_luts = [_rocket_lut, _mako_lut, _vlag_lut, _icefire_lut] -_names = ["rocket", "mako", "vlag", "icefire"] +_flare_lut = [ + [0.92907237, 0.68878959, 0.50411509], + [0.92891402, 0.68494686, 0.50173994], + [0.92864754, 0.68116207, 0.4993754], + [0.92836112, 0.67738527, 0.49701572], + [0.9280599, 0.67361354, 0.49466044], + [0.92775569, 0.66983999, 0.49230866], + [0.9274375, 0.66607098, 0.48996097], + [0.927111, 0.66230315, 0.48761688], + [0.92677996, 0.6585342, 0.485276], + [0.92644317, 0.65476476, 0.48293832], + [0.92609759, 0.65099658, 0.48060392], + [0.925747, 0.64722729, 0.47827244], + [0.92539502, 0.64345456, 0.47594352], + [0.92503106, 0.6396848, 0.47361782], + [0.92466877, 0.6359095, 0.47129427], + [0.92429828, 0.63213463, 0.46897349], + [0.92392172, 0.62835879, 0.46665526], + [0.92354597, 0.62457749, 0.46433898], + [0.9231622, 0.6207962, 0.46202524], + [0.92277222, 0.61701365, 0.45971384], + [0.92237978, 0.61322733, 0.45740444], + [0.92198615, 0.60943622, 0.45509686], + [0.92158735, 0.60564276, 0.45279137], + [0.92118373, 0.60184659, 0.45048789], + [0.92077582, 0.59804722, 0.44818634], + [0.92036413, 0.59424414, 0.44588663], + [0.91994924, 0.5904368, 0.44358868], + [0.91952943, 0.58662619, 0.4412926], + [0.91910675, 0.58281075, 0.43899817], + [0.91868096, 0.57899046, 0.4367054], + [0.91825103, 0.57516584, 0.43441436], + [0.91781857, 0.57133556, 0.43212486], + [0.9173814, 0.56750099, 0.4298371], + [0.91694139, 0.56366058, 0.42755089], + [0.91649756, 0.55981483, 0.42526631], + [0.91604942, 0.55596387, 0.42298339], + [0.9155979, 0.55210684, 0.42070204], + [0.9151409, 0.54824485, 0.4184247], + [0.91466138, 0.54438817, 0.41617858], + [0.91416896, 0.54052962, 0.41396347], + [0.91366559, 0.53666778, 0.41177769], + [0.91315173, 0.53280208, 0.40962196], + [0.91262605, 0.52893336, 0.40749715], + [0.91208866, 0.52506133, 0.40540404], + [0.91153952, 0.52118582, 0.40334346], + [0.91097732, 0.51730767, 0.4013163], + [0.910403, 0.51342591, 0.39932342], + [0.90981494, 0.50954168, 0.39736571], + [0.90921368, 0.5056543, 0.39544411], + [0.90859797, 0.50176463, 0.39355952], + [0.90796841, 0.49787195, 0.39171297], + [0.90732341, 0.4939774, 0.38990532], + [0.90666382, 0.49008006, 0.38813773], + [0.90598815, 0.486181, 0.38641107], + [0.90529624, 0.48228017, 0.38472641], + [0.90458808, 0.47837738, 0.38308489], + [0.90386248, 0.47447348, 0.38148746], + [0.90311921, 0.4705685, 0.37993524], + [0.90235809, 0.46666239, 0.37842943], + [0.90157824, 0.46275577, 0.37697105], + [0.90077904, 0.45884905, 0.37556121], + [0.89995995, 0.45494253, 0.37420106], + [0.89912041, 0.4510366, 0.37289175], + [0.8982602, 0.44713126, 0.37163458], + [0.89737819, 0.44322747, 0.37043052], + [0.89647387, 0.43932557, 0.36928078], + [0.89554477, 0.43542759, 0.36818855], + [0.89458871, 0.4315354, 0.36715654], + [0.89360794, 0.42764714, 0.36618273], + [0.89260152, 0.42376366, 0.36526813], + [0.8915687, 0.41988565, 0.36441384], + [0.89050882, 0.41601371, 0.36362102], + [0.8894159, 0.41215334, 0.36289639], + [0.888292, 0.40830288, 0.36223756], + [0.88713784, 0.40446193, 0.36164328], + [0.88595253, 0.40063149, 0.36111438], + [0.88473115, 0.39681635, 0.3606566], + [0.88347246, 0.39301805, 0.36027074], + [0.88217931, 0.38923439, 0.35995244], + [0.880851, 0.38546632, 0.35970244], + [0.87947728, 0.38172422, 0.35953127], + [0.87806542, 0.37800172, 0.35942941], + [0.87661509, 0.37429964, 0.35939659], + [0.87511668, 0.37062819, 0.35944178], + [0.87357554, 0.36698279, 0.35955811], + [0.87199254, 0.3633634, 0.35974223], + [0.87035691, 0.35978174, 0.36000516], + [0.86867647, 0.35623087, 0.36033559], + [0.86694949, 0.35271349, 0.36073358], + [0.86516775, 0.34923921, 0.36120624], + [0.86333996, 0.34580008, 0.36174113], + [0.86145909, 0.3424046, 0.36234402], + [0.85952586, 0.33905327, 0.36301129], + [0.85754536, 0.33574168, 0.36373567], + [0.855514, 0.33247568, 0.36451271], + [0.85344392, 0.32924217, 0.36533344], + [0.8513284, 0.32604977, 0.36620106], + [0.84916723, 0.32289973, 0.36711424], + [0.84696243, 0.31979068, 0.36806976], + [0.84470627, 0.31673295, 0.36907066], + [0.84240761, 0.31371695, 0.37010969], + [0.84005337, 0.31075974, 0.37119284], + [0.83765537, 0.30784814, 0.3723105], + [0.83520234, 0.30499724, 0.37346726], + [0.83270291, 0.30219766, 0.37465552], + [0.83014895, 0.29946081, 0.37587769], + [0.82754694, 0.29677989, 0.37712733], + [0.82489111, 0.29416352, 0.37840532], + [0.82218644, 0.29160665, 0.37970606], + [0.81942908, 0.28911553, 0.38102921], + [0.81662276, 0.28668665, 0.38236999], + [0.81376555, 0.28432371, 0.383727], + [0.81085964, 0.28202508, 0.38509649], + [0.8079055, 0.27979128, 0.38647583], + [0.80490309, 0.27762348, 0.3878626], + [0.80185613, 0.2755178, 0.38925253], + [0.79876118, 0.27347974, 0.39064559], + [0.79562644, 0.27149928, 0.39203532], + [0.79244362, 0.2695883, 0.39342447], + [0.78922456, 0.26773176, 0.3948046], + [0.78596161, 0.26594053, 0.39617873], + [0.7826624, 0.26420493, 0.39754146], + [0.77932717, 0.26252522, 0.39889102], + [0.77595363, 0.2609049, 0.4002279], + [0.77254999, 0.25933319, 0.40154704], + [0.76911107, 0.25781758, 0.40284959], + [0.76564158, 0.25635173, 0.40413341], + [0.76214598, 0.25492998, 0.40539471], + [0.75861834, 0.25356035, 0.40663694], + [0.75506533, 0.25223402, 0.40785559], + [0.75148963, 0.2509473, 0.40904966], + [0.74788835, 0.24970413, 0.41022028], + [0.74426345, 0.24850191, 0.41136599], + [0.74061927, 0.24733457, 0.41248516], + [0.73695678, 0.24620072, 0.41357737], + [0.73327278, 0.24510469, 0.41464364], + [0.72957096, 0.24404127, 0.4156828], + [0.72585394, 0.24300672, 0.41669383], + [0.7221226, 0.24199971, 0.41767651], + [0.71837612, 0.24102046, 0.41863486], + [0.71463236, 0.24004289, 0.41956983], + [0.7108932, 0.23906316, 0.42048681], + [0.70715842, 0.23808142, 0.42138647], + [0.70342811, 0.2370976, 0.42226844], + [0.69970218, 0.23611179, 0.42313282], + [0.69598055, 0.2351247, 0.42397678], + [0.69226314, 0.23413578, 0.42480327], + [0.68854988, 0.23314511, 0.42561234], + [0.68484064, 0.23215279, 0.42640419], + [0.68113541, 0.23115942, 0.42717615], + [0.67743412, 0.23016472, 0.42792989], + [0.67373662, 0.22916861, 0.42866642], + [0.67004287, 0.22817117, 0.42938576], + [0.66635279, 0.22717328, 0.43008427], + [0.66266621, 0.22617435, 0.43076552], + [0.65898313, 0.22517434, 0.43142956], + [0.65530349, 0.22417381, 0.43207427], + [0.65162696, 0.22317307, 0.4327001], + [0.64795375, 0.22217149, 0.43330852], + [0.64428351, 0.22116972, 0.43389854], + [0.64061624, 0.22016818, 0.43446845], + [0.63695183, 0.21916625, 0.43502123], + [0.63329016, 0.21816454, 0.43555493], + [0.62963102, 0.2171635, 0.43606881], + [0.62597451, 0.21616235, 0.43656529], + [0.62232019, 0.21516239, 0.43704153], + [0.61866821, 0.21416307, 0.43749868], + [0.61501835, 0.21316435, 0.43793808], + [0.61137029, 0.21216761, 0.4383556], + [0.60772426, 0.2111715, 0.43875552], + [0.60407977, 0.21017746, 0.43913439], + [0.60043678, 0.20918503, 0.43949412], + [0.59679524, 0.20819447, 0.43983393], + [0.59315487, 0.20720639, 0.44015254], + [0.58951566, 0.20622027, 0.44045213], + [0.58587715, 0.20523751, 0.44072926], + [0.5822395, 0.20425693, 0.44098758], + [0.57860222, 0.20328034, 0.44122241], + [0.57496549, 0.20230637, 0.44143805], + [0.57132875, 0.20133689, 0.4416298], + [0.56769215, 0.20037071, 0.44180142], + [0.5640552, 0.19940936, 0.44194923], + [0.56041794, 0.19845221, 0.44207535], + [0.55678004, 0.1975, 0.44217824], + [0.55314129, 0.19655316, 0.44225723], + [0.54950166, 0.19561118, 0.44231412], + [0.54585987, 0.19467771, 0.44234111], + [0.54221157, 0.19375869, 0.44233698], + [0.5385549, 0.19285696, 0.44229959], + [0.5348913, 0.19197036, 0.44222958], + [0.53122177, 0.1910974, 0.44212735], + [0.52754464, 0.19024042, 0.44199159], + [0.52386353, 0.18939409, 0.44182449], + [0.52017476, 0.18856368, 0.44162345], + [0.51648277, 0.18774266, 0.44139128], + [0.51278481, 0.18693492, 0.44112605], + [0.50908361, 0.18613639, 0.4408295], + [0.50537784, 0.18534893, 0.44050064], + [0.50166912, 0.18457008, 0.44014054], + [0.49795686, 0.18380056, 0.43974881], + [0.49424218, 0.18303865, 0.43932623], + [0.49052472, 0.18228477, 0.43887255], + [0.48680565, 0.1815371, 0.43838867], + [0.48308419, 0.18079663, 0.43787408], + [0.47936222, 0.18006056, 0.43733022], + [0.47563799, 0.17933127, 0.43675585], + [0.47191466, 0.17860416, 0.43615337], + [0.46818879, 0.17788392, 0.43552047], + [0.46446454, 0.17716458, 0.43486036], + [0.46073893, 0.17645017, 0.43417097], + [0.45701462, 0.17573691, 0.43345429], + [0.45329097, 0.17502549, 0.43271025], + [0.44956744, 0.17431649, 0.4319386], + [0.44584668, 0.17360625, 0.43114133], + [0.44212538, 0.17289906, 0.43031642], + [0.43840678, 0.17219041, 0.42946642], + [0.43469046, 0.17148074, 0.42859124], + [0.4309749, 0.17077192, 0.42769008], + [0.42726297, 0.17006003, 0.42676519], + [0.42355299, 0.16934709, 0.42581586], + [0.41984535, 0.16863258, 0.42484219], + [0.41614149, 0.16791429, 0.42384614], + [0.41244029, 0.16719372, 0.42282661], + [0.40874177, 0.16647061, 0.42178429], + [0.40504765, 0.16574261, 0.42072062], + [0.401357, 0.16501079, 0.41963528], + [0.397669, 0.16427607, 0.418528], + [0.39398585, 0.16353554, 0.41740053], + [0.39030735, 0.16278924, 0.41625344], + [0.3866314, 0.16203977, 0.41508517], + [0.38295904, 0.16128519, 0.41389849], + [0.37928736, 0.16052483, 0.41270599], + [0.37562649, 0.15974704, 0.41151182], + [0.37197803, 0.15895049, 0.41031532], + [0.36833779, 0.15813871, 0.40911916], + [0.36470944, 0.15730861, 0.40792149], + [0.36109117, 0.15646169, 0.40672362], + [0.35748213, 0.15559861, 0.40552633], + [0.353885, 0.15471714, 0.40432831], + [0.35029682, 0.15381967, 0.4031316], + [0.34671861, 0.1529053, 0.40193587], + [0.34315191, 0.15197275, 0.40074049], + [0.33959331, 0.15102466, 0.3995478], + [0.33604378, 0.15006017, 0.39835754], + [0.33250529, 0.14907766, 0.39716879], + [0.32897621, 0.14807831, 0.39598285], + [0.3254559, 0.14706248, 0.39480044], + [0.32194567, 0.14602909, 0.39362106], + [0.31844477, 0.14497857, 0.39244549], + [0.31494974, 0.14391333, 0.39127626], + [0.31146605, 0.14282918, 0.39011024], + [0.30798857, 0.1417297, 0.38895105], + [0.30451661, 0.14061515, 0.38779953], + [0.30105136, 0.13948445, 0.38665531], + [0.2975886, 0.1383403, 0.38552159], + [0.29408557, 0.13721193, 0.38442775] +] + + +_crest_lut = [ + [0.6468274, 0.80289262, 0.56592265], + [0.64233318, 0.80081141, 0.56639461], + [0.63791969, 0.7987162, 0.56674976], + [0.6335316, 0.79661833, 0.56706128], + [0.62915226, 0.7945212, 0.56735066], + [0.62477862, 0.79242543, 0.56762143], + [0.62042003, 0.79032918, 0.56786129], + [0.61606327, 0.78823508, 0.56808666], + [0.61171322, 0.78614216, 0.56829092], + [0.60736933, 0.78405055, 0.56847436], + [0.60302658, 0.78196121, 0.56864272], + [0.59868708, 0.77987374, 0.56879289], + [0.59435366, 0.77778758, 0.56892099], + [0.59001953, 0.77570403, 0.56903477], + [0.58568753, 0.77362254, 0.56913028], + [0.58135593, 0.77154342, 0.56920908], + [0.57702623, 0.76946638, 0.56926895], + [0.57269165, 0.76739266, 0.5693172], + [0.56835934, 0.76532092, 0.56934507], + [0.56402533, 0.76325185, 0.56935664], + [0.55968429, 0.76118643, 0.56935732], + [0.55534159, 0.75912361, 0.56934052], + [0.55099572, 0.75706366, 0.56930743], + [0.54664626, 0.75500662, 0.56925799], + [0.54228969, 0.75295306, 0.56919546], + [0.53792417, 0.75090328, 0.56912118], + [0.53355172, 0.74885687, 0.5690324], + [0.52917169, 0.74681387, 0.56892926], + [0.52478243, 0.74477453, 0.56881287], + [0.52038338, 0.74273888, 0.56868323], + [0.5159739, 0.74070697, 0.56854039], + [0.51155269, 0.73867895, 0.56838507], + [0.50711872, 0.73665492, 0.56821764], + [0.50267118, 0.73463494, 0.56803826], + [0.49822926, 0.73261388, 0.56785146], + [0.49381422, 0.73058524, 0.56767484], + [0.48942421, 0.72854938, 0.56751036], + [0.48505993, 0.72650623, 0.56735752], + [0.48072207, 0.72445575, 0.56721583], + [0.4764113, 0.72239788, 0.56708475], + [0.47212827, 0.72033258, 0.56696376], + [0.46787361, 0.71825983, 0.56685231], + [0.46364792, 0.71617961, 0.56674986], + [0.45945271, 0.71409167, 0.56665625], + [0.45528878, 0.71199595, 0.56657103], + [0.45115557, 0.70989276, 0.5664931], + [0.44705356, 0.70778212, 0.56642189], + [0.44298321, 0.70566406, 0.56635683], + [0.43894492, 0.70353863, 0.56629734], + [0.43493911, 0.70140588, 0.56624286], + [0.43096612, 0.69926587, 0.5661928], + [0.42702625, 0.69711868, 0.56614659], + [0.42311977, 0.69496438, 0.56610368], + [0.41924689, 0.69280308, 0.56606355], + [0.41540778, 0.69063486, 0.56602564], + [0.41160259, 0.68845984, 0.56598944], + [0.40783143, 0.68627814, 0.56595436], + [0.40409434, 0.68408988, 0.56591994], + [0.40039134, 0.68189518, 0.56588564], + [0.39672238, 0.6796942, 0.56585103], + [0.39308781, 0.67748696, 0.56581581], + [0.38949137, 0.67527276, 0.56578084], + [0.38592889, 0.67305266, 0.56574422], + [0.38240013, 0.67082685, 0.56570561], + [0.37890483, 0.66859548, 0.56566462], + [0.37544276, 0.66635871, 0.56562081], + [0.37201365, 0.66411673, 0.56557372], + [0.36861709, 0.6618697, 0.5655231], + [0.36525264, 0.65961782, 0.56546873], + [0.36191986, 0.65736125, 0.56541032], + [0.35861935, 0.65509998, 0.56534768], + [0.35535621, 0.65283302, 0.56528211], + [0.35212361, 0.65056188, 0.56521171], + [0.34892097, 0.64828676, 0.56513633], + [0.34574785, 0.64600783, 0.56505539], + [0.34260357, 0.64372528, 0.5649689], + [0.33948744, 0.64143931, 0.56487679], + [0.33639887, 0.6391501, 0.56477869], + [0.33334501, 0.63685626, 0.56467661], + [0.33031952, 0.63455911, 0.564569], + [0.3273199, 0.63225924, 0.56445488], + [0.32434526, 0.62995682, 0.56433457], + [0.32139487, 0.62765201, 0.56420795], + [0.31846807, 0.62534504, 0.56407446], + [0.3155731, 0.62303426, 0.56393695], + [0.31270304, 0.62072111, 0.56379321], + [0.30985436, 0.61840624, 0.56364307], + [0.30702635, 0.61608984, 0.56348606], + [0.30421803, 0.61377205, 0.56332267], + [0.30143611, 0.61145167, 0.56315419], + [0.29867863, 0.60912907, 0.56298054], + [0.29593872, 0.60680554, 0.56280022], + [0.29321538, 0.60448121, 0.56261376], + [0.2905079, 0.60215628, 0.56242036], + [0.28782827, 0.5998285, 0.56222366], + [0.28516521, 0.59749996, 0.56202093], + [0.28251558, 0.59517119, 0.56181204], + [0.27987847, 0.59284232, 0.56159709], + [0.27726216, 0.59051189, 0.56137785], + [0.27466434, 0.58818027, 0.56115433], + [0.2720767, 0.58584893, 0.56092486], + [0.26949829, 0.58351797, 0.56068983], + [0.26693801, 0.58118582, 0.56045121], + [0.26439366, 0.57885288, 0.56020858], + [0.26185616, 0.57652063, 0.55996077], + [0.25932459, 0.57418919, 0.55970795], + [0.25681303, 0.57185614, 0.55945297], + [0.25431024, 0.56952337, 0.55919385], + [0.25180492, 0.56719255, 0.5589305], + [0.24929311, 0.56486397, 0.5586654], + [0.24678356, 0.56253666, 0.55839491], + [0.24426587, 0.56021153, 0.55812473], + [0.24174022, 0.55788852, 0.55785448], + [0.23921167, 0.55556705, 0.55758211], + [0.23668315, 0.55324675, 0.55730676], + [0.23414742, 0.55092825, 0.55703167], + [0.23160473, 0.54861143, 0.5567573], + [0.22905996, 0.54629572, 0.55648168], + [0.22651648, 0.54398082, 0.5562029], + [0.22396709, 0.54166721, 0.55592542], + [0.22141221, 0.53935481, 0.55564885], + [0.21885269, 0.53704347, 0.55537294], + [0.21629986, 0.53473208, 0.55509319], + [0.21374297, 0.53242154, 0.5548144], + [0.21118255, 0.53011166, 0.55453708], + [0.2086192, 0.52780237, 0.55426067], + [0.20605624, 0.52549322, 0.55398479], + [0.20350004, 0.5231837, 0.55370601], + [0.20094292, 0.52087429, 0.55342884], + [0.19838567, 0.51856489, 0.55315283], + [0.19582911, 0.51625531, 0.55287818], + [0.19327413, 0.51394542, 0.55260469], + [0.19072933, 0.51163448, 0.5523289], + [0.18819045, 0.50932268, 0.55205372], + [0.18565609, 0.50701014, 0.55177937], + [0.18312739, 0.50469666, 0.55150597], + [0.18060561, 0.50238204, 0.55123374], + [0.178092, 0.50006616, 0.55096224], + [0.17558808, 0.49774882, 0.55069118], + [0.17310341, 0.49542924, 0.5504176], + [0.17063111, 0.49310789, 0.55014445], + [0.1681728, 0.49078458, 0.54987159], + [0.1657302, 0.48845913, 0.54959882], + [0.16330517, 0.48613135, 0.54932605], + [0.16089963, 0.48380104, 0.54905306], + [0.15851561, 0.48146803, 0.54877953], + [0.15615526, 0.47913212, 0.54850526], + [0.15382083, 0.47679313, 0.54822991], + [0.15151471, 0.47445087, 0.54795318], + [0.14924112, 0.47210502, 0.54767411], + [0.1470032, 0.46975537, 0.54739226], + [0.14480101, 0.46740187, 0.54710832], + [0.14263736, 0.46504434, 0.54682188], + [0.14051521, 0.46268258, 0.54653253], + [0.13843761, 0.46031639, 0.54623985], + [0.13640774, 0.45794558, 0.5459434], + [0.13442887, 0.45556994, 0.54564272], + [0.1325044, 0.45318928, 0.54533736], + [0.13063777, 0.4508034, 0.54502674], + [0.12883252, 0.44841211, 0.5447104], + [0.12709242, 0.44601517, 0.54438795], + [0.1254209, 0.44361244, 0.54405855], + [0.12382162, 0.44120373, 0.54372156], + [0.12229818, 0.43878887, 0.54337634], + [0.12085453, 0.4363676, 0.54302253], + [0.11949938, 0.43393955, 0.54265715], + [0.11823166, 0.43150478, 0.54228104], + [0.11705496, 0.42906306, 0.54189388], + [0.115972, 0.42661431, 0.54149449], + [0.11498598, 0.42415835, 0.54108222], + [0.11409965, 0.42169502, 0.54065622], + [0.11331533, 0.41922424, 0.5402155], + [0.11263542, 0.41674582, 0.53975931], + [0.1120615, 0.4142597, 0.53928656], + [0.11159738, 0.41176567, 0.53879549], + [0.11125248, 0.40926325, 0.53828203], + [0.11101698, 0.40675289, 0.53774864], + [0.11089152, 0.40423445, 0.53719455], + [0.11085121, 0.4017095, 0.53662425], + [0.11087217, 0.39917938, 0.53604354], + [0.11095515, 0.39664394, 0.53545166], + [0.11110676, 0.39410282, 0.53484509], + [0.11131735, 0.39155635, 0.53422678], + [0.11158595, 0.38900446, 0.53359634], + [0.11191139, 0.38644711, 0.5329534], + [0.11229224, 0.38388426, 0.53229748], + [0.11273683, 0.38131546, 0.53162393], + [0.11323438, 0.37874109, 0.53093619], + [0.11378271, 0.37616112, 0.53023413], + [0.11437992, 0.37357557, 0.52951727], + [0.11502681, 0.37098429, 0.52878396], + [0.11572661, 0.36838709, 0.52803124], + [0.11646936, 0.36578429, 0.52726234], + [0.11725299, 0.3631759, 0.52647685], + [0.1180755, 0.36056193, 0.52567436], + [0.1189438, 0.35794203, 0.5248497], + [0.11984752, 0.35531657, 0.52400649], + [0.1207833, 0.35268564, 0.52314492], + [0.12174895, 0.35004927, 0.52226461], + [0.12274959, 0.34740723, 0.52136104], + [0.12377809, 0.34475975, 0.52043639], + [0.12482961, 0.34210702, 0.51949179], + [0.125902, 0.33944908, 0.51852688], + [0.12699998, 0.33678574, 0.51753708], + [0.12811691, 0.33411727, 0.51652464], + [0.12924811, 0.33144384, 0.51549084], + [0.13039157, 0.32876552, 0.51443538], + [0.13155228, 0.32608217, 0.51335321], + [0.13272282, 0.32339407, 0.51224759], + [0.13389954, 0.32070138, 0.51111946], + [0.13508064, 0.31800419, 0.50996862], + [0.13627149, 0.31530238, 0.50878942], + [0.13746376, 0.31259627, 0.50758645], + [0.13865499, 0.30988598, 0.50636017], + [0.13984364, 0.30717161, 0.50511042], + [0.14103515, 0.30445309, 0.50383119], + [0.14222093, 0.30173071, 0.50252813], + [0.14339946, 0.2990046, 0.50120127], + [0.14456941, 0.29627483, 0.49985054], + [0.14573579, 0.29354139, 0.49847009], + [0.14689091, 0.29080452, 0.49706566], + [0.1480336, 0.28806432, 0.49563732], + [0.1491628, 0.28532086, 0.49418508], + [0.15028228, 0.28257418, 0.49270402], + [0.15138673, 0.27982444, 0.49119848], + [0.15247457, 0.27707172, 0.48966925], + [0.15354487, 0.2743161, 0.48811641], + [0.15459955, 0.27155765, 0.4865371], + [0.15563716, 0.26879642, 0.4849321], + [0.1566572, 0.26603191, 0.48330429], + [0.15765823, 0.26326032, 0.48167456], + [0.15862147, 0.26048295, 0.48005785], + [0.15954301, 0.25770084, 0.47845341], + [0.16043267, 0.25491144, 0.4768626], + [0.16129262, 0.25211406, 0.4752857], + [0.1621119, 0.24931169, 0.47372076], + [0.16290577, 0.24649998, 0.47217025], + [0.16366819, 0.24368054, 0.47063302], + [0.1644021, 0.24085237, 0.46910949], + [0.16510882, 0.2380149, 0.46759982], + [0.16579015, 0.23516739, 0.46610429], + [0.1664433, 0.2323105, 0.46462219], + [0.16707586, 0.22944155, 0.46315508], + [0.16768475, 0.22656122, 0.46170223], + [0.16826815, 0.22366984, 0.46026308], + [0.16883174, 0.22076514, 0.45883891], + [0.16937589, 0.21784655, 0.45742976], + [0.16990129, 0.21491339, 0.45603578], + [0.1704074, 0.21196535, 0.45465677], + [0.17089473, 0.20900176, 0.4532928], + [0.17136819, 0.20602012, 0.45194524], + [0.17182683, 0.20302012, 0.45061386], + [0.17227059, 0.20000106, 0.44929865], + [0.17270583, 0.19695949, 0.44800165], + [0.17313804, 0.19389201, 0.44672488], + [0.17363177, 0.19076859, 0.44549087] +] -for _lut, _name in zip(_luts, _names): + +_lut_dict = dict( + rocket=_rocket_lut, + mako=_mako_lut, + icefire=_icefire_lut, + vlag=_vlag_lut, + flare=_flare_lut, + crest=_crest_lut, + +) + +for _name, _lut in _lut_dict.items(): _cmap = colors.ListedColormap(_lut, _name) locals()[_name] = _cmap @@ -1054,3 +1581,5 @@ mpl_cm.register_cmap(_name, _cmap) mpl_cm.register_cmap(_name + "_r", _cmap_r) + +del colors, mpl_cm \ No newline at end of file diff --git a/seaborn/colors/__init__.py b/seaborn/colors/__init__.py index 191705c6f5..3d0bf1d56b 100644 --- a/seaborn/colors/__init__.py +++ b/seaborn/colors/__init__.py @@ -1,2 +1,2 @@ -from .xkcd_rgb import xkcd_rgb -from .crayons import crayons +from .xkcd_rgb import xkcd_rgb # noqa: F401 +from .crayons import crayons # noqa: F401 diff --git a/seaborn/conftest.py b/seaborn/conftest.py index 8a0425d678..335d673b7d 100644 --- a/seaborn/conftest.py +++ b/seaborn/conftest.py @@ -1,8 +1,47 @@ import numpy as np +import pandas as pd +import datetime +import matplotlib as mpl import matplotlib.pyplot as plt + import pytest +def has_verdana(): + """Helper to verify if Verdana font is present""" + # This import is relatively lengthy, so to prevent its import for + # testing other tests in this module not requiring this knowledge, + # import font_manager here + import matplotlib.font_manager as mplfm + try: + verdana_font = mplfm.findfont('Verdana', fallback_to_default=False) + except: # noqa + # if https://github.com/matplotlib/matplotlib/pull/3435 + # gets accepted + return False + # otherwise check if not matching the logic for a 'default' one + try: + unlikely_font = mplfm.findfont("very_unlikely_to_exist1234", + fallback_to_default=False) + except: # noqa + # if matched verdana but not unlikely, Verdana must exist + return True + # otherwise -- if they match, must be the same default + return verdana_font != unlikely_font + + +@pytest.fixture(scope="session", autouse=True) +def remove_pandas_unit_conversion(): + # Prior to pandas 1.0, it registered its own datetime converters, + # but they are less powerful than what matplotlib added in 2.2, + # and we rely on that functionality in seaborn. + # https://github.com/matplotlib/matplotlib/pull/9779 + # https://github.com/pandas-dev/pandas/issues/27036 + mpl.units.registry[np.datetime64] = mpl.dates.DateConverter() + mpl.units.registry[datetime.date] = mpl.dates.DateConverter() + mpl.units.registry[datetime.datetime] = mpl.dates.DateConverter() + + @pytest.fixture(autouse=True) def close_figs(): yield @@ -11,4 +50,167 @@ def close_figs(): @pytest.fixture(autouse=True) def random_seed(): - np.random.seed(47) + seed = sum(map(ord, "seaborn random global")) + np.random.seed(seed) + + +@pytest.fixture() +def rng(): + seed = sum(map(ord, "seaborn random object")) + return np.random.RandomState(seed) + + +@pytest.fixture +def wide_df(rng): + + columns = list("abc") + index = pd.Int64Index(np.arange(10, 50, 2), name="wide_index") + values = rng.normal(size=(len(index), len(columns))) + return pd.DataFrame(values, index=index, columns=columns) + + +@pytest.fixture +def wide_array(wide_df): + + return wide_df.to_numpy() + + +@pytest.fixture +def flat_series(rng): + + index = pd.Int64Index(np.arange(10, 30), name="t") + return pd.Series(rng.normal(size=20), index, name="s") + + +@pytest.fixture +def flat_array(flat_series): + + return flat_series.to_numpy() + + +@pytest.fixture +def flat_list(flat_series): + + return flat_series.to_list() + + +@pytest.fixture(params=["series", "array", "list"]) +def flat_data(rng, request): + + index = pd.Int64Index(np.arange(10, 30), name="t") + series = pd.Series(rng.normal(size=20), index, name="s") + if request.param == "series": + data = series + elif request.param == "array": + data = series.to_numpy() + elif request.param == "list": + data = series.to_list() + return data + + +@pytest.fixture +def wide_list_of_series(rng): + + return [pd.Series(rng.normal(size=20), np.arange(20), name="a"), + pd.Series(rng.normal(size=10), np.arange(5, 15), name="b")] + + +@pytest.fixture +def wide_list_of_arrays(wide_list_of_series): + + return [s.to_numpy() for s in wide_list_of_series] + + +@pytest.fixture +def wide_list_of_lists(wide_list_of_series): + + return [s.to_list() for s in wide_list_of_series] + + +@pytest.fixture +def wide_dict_of_series(wide_list_of_series): + + return {s.name: s for s in wide_list_of_series} + + +@pytest.fixture +def wide_dict_of_arrays(wide_list_of_series): + + return {s.name: s.to_numpy() for s in wide_list_of_series} + + +@pytest.fixture +def wide_dict_of_lists(wide_list_of_series): + + return {s.name: s.to_list() for s in wide_list_of_series} + + +@pytest.fixture +def long_df(rng): + + n = 100 + df = pd.DataFrame(dict( + x=rng.uniform(0, 20, n).round().astype("int"), + y=rng.normal(size=n), + z=rng.lognormal(size=n), + a=rng.choice(list("abc"), n), + b=rng.choice(list("mnop"), n), + c=rng.choice([0, 1], n, [.3, .7]), + d=rng.choice(np.arange("2004-07-30", "2007-07-30", dtype="datetime64[Y]"), n), + t=rng.choice(np.arange("2004-07-30", "2004-07-31", dtype="datetime64[m]"), n), + s=rng.choice([2, 4, 8], n), + f=rng.choice([0.2, 0.3], n), + )) + + a_cat = df["a"].astype("category") + new_categories = np.roll(a_cat.cat.categories, 1) + df["a_cat"] = a_cat.cat.reorder_categories(new_categories) + + df["s_cat"] = df["s"].astype("category") + df["s_str"] = df["s"].astype(str) + + return df + + +@pytest.fixture +def long_dict(long_df): + + return long_df.to_dict() + + +@pytest.fixture +def repeated_df(rng): + + n = 100 + return pd.DataFrame(dict( + x=np.tile(np.arange(n // 2), 2), + y=rng.normal(size=n), + a=rng.choice(list("abc"), n), + u=np.repeat(np.arange(2), n // 2), + )) + + +@pytest.fixture +def missing_df(rng, long_df): + + df = long_df.copy() + for col in df: + idx = rng.permutation(df.index)[:10] + df.loc[idx, col] = np.nan + return df + + +@pytest.fixture +def object_df(rng, long_df): + + df = long_df.copy() + # objectify numeric columns + for col in ["c", "s", "f"]: + df[col] = df[col].astype(object) + return df + + +@pytest.fixture +def null_series(flat_series): + + return pd.Series(index=flat_series.index, dtype='float64') diff --git a/seaborn/distributions.py b/seaborn/distributions.py index 9a52cc9c7f..54d4fa46fc 100644 --- a/seaborn/distributions.py +++ b/seaborn/distributions.py @@ -1,28 +1,2383 @@ """Plotting functions for visualizing distributions.""" -from __future__ import division +from numbers import Number +from functools import partial +import math +import warnings + import numpy as np -from scipy import stats import pandas as pd import matplotlib as mpl import matplotlib.pyplot as plt import matplotlib.transforms as tx +from matplotlib.colors import to_rgba from matplotlib.collections import LineCollection -import warnings -from distutils.version import LooseVersion -from six import string_types +from ._core import ( + VectorPlotter, +) +from ._statistics import ( + KDE, + Histogram, + ECDF, +) +from .axisgrid import ( + FacetGrid, + _facet_docs, +) +from .utils import ( + remove_na, + _kde_support, + _normalize_kwargs, + _check_argument, + _assign_default_kwargs, + _default_color, +) +from .palettes import color_palette +from .external import husl +from .external.kde import gaussian_kde +from ._decorators import _deprecate_positional_args +from ._docstrings import ( + DocstringComponents, + _core_docs, +) + + +__all__ = ["displot", "histplot", "kdeplot", "ecdfplot", "rugplot", "distplot"] + +# ==================================================================================== # +# Module documentation +# ==================================================================================== # + +_dist_params = dict( + + multiple=""" +multiple : {{"layer", "stack", "fill"}} + Method for drawing multiple elements when semantic mapping creates subsets. + Only relevant with univariate data. + """, + log_scale=""" +log_scale : bool or number, or pair of bools or numbers + Set axis scale(s) to log. A single value sets the data axis for univariate + distributions and both axes for bivariate distributions. A pair of values + sets each axis independently. Numeric values are interpreted as the desired + base (default 10). If `False`, defer to the existing Axes scale. + """, + legend=""" +legend : bool + If False, suppress the legend for semantic variables. + """, + cbar=""" +cbar : bool + If True, add a colorbar to annotate the color mapping in a bivariate plot. + Note: Does not currently support plots with a ``hue`` variable well. + """, + cbar_ax=""" +cbar_ax : :class:`matplotlib.axes.Axes` + Pre-existing axes for the colorbar. + """, + cbar_kws=""" +cbar_kws : dict + Additional parameters passed to :meth:`matplotlib.figure.Figure.colorbar`. + """, +) + +_param_docs = DocstringComponents.from_nested_components( + core=_core_docs["params"], + facets=DocstringComponents(_facet_docs), + dist=DocstringComponents(_dist_params), + kde=DocstringComponents.from_function_params(KDE.__init__), + hist=DocstringComponents.from_function_params(Histogram.__init__), + ecdf=DocstringComponents.from_function_params(ECDF.__init__), +) + + +# ==================================================================================== # +# Internal API +# ==================================================================================== # + + +class _DistributionPlotter(VectorPlotter): + + semantics = "x", "y", "hue", "weights" + + wide_structure = {"x": "@values", "hue": "@columns"} + flat_structure = {"x": "@values"} + + def __init__( + self, + data=None, + variables={}, + ): + + super().__init__(data=data, variables=variables) + + @property + def univariate(self): + """Return True if only x or y are used.""" + # TODO this could go down to core, but putting it here now. + # We'd want to be conceptually clear that univariate only applies + # to x/y and not to other semantics, which can exist. + # We haven't settled on a good conceptual name for x/y. + return bool({"x", "y"} - set(self.variables)) + + @property + def data_variable(self): + """Return the variable with data for univariate plots.""" + # TODO This could also be in core, but it should have a better name. + if not self.univariate: + raise AttributeError("This is not a univariate plot") + return {"x", "y"}.intersection(self.variables).pop() + + @property + def has_xy_data(self): + """Return True at least one of x or y is defined.""" + # TODO see above points about where this should go + return bool({"x", "y"} & set(self.variables)) + + def _add_legend( + self, + ax_obj, artist, fill, element, multiple, alpha, artist_kws, legend_kws, + ): + """Add artists that reflect semantic mappings and put then in a legend.""" + # TODO note that this doesn't handle numeric mappings like the relational plots + handles = [] + labels = [] + for level in self._hue_map.levels: + color = self._hue_map(level) + + kws = self._artist_kws( + artist_kws, fill, element, multiple, color, alpha + ) + + # color gets added to the kws to workaround an issue with barplot's color + # cycle integration but it causes problems in this context where we are + # setting artist properties directly, so pop it off here + if "facecolor" in kws: + kws.pop("color", None) + + handles.append(artist(**kws)) + labels.append(level) + + if isinstance(ax_obj, mpl.axes.Axes): + ax_obj.legend(handles, labels, title=self.variables["hue"], **legend_kws) + else: # i.e. a FacetGrid. TODO make this better + legend_data = dict(zip(labels, handles)) + ax_obj.add_legend( + legend_data, + title=self.variables["hue"], + label_order=self.var_levels["hue"], + **legend_kws + ) + + def _artist_kws(self, kws, fill, element, multiple, color, alpha): + """Handle differences between artists in filled/unfilled plots.""" + kws = kws.copy() + if fill: + kws = _normalize_kwargs(kws, mpl.collections.PolyCollection) + kws.setdefault("facecolor", to_rgba(color, alpha)) + + if element == "bars": + # Make bar() interface with property cycle correctly + # https://github.com/matplotlib/matplotlib/issues/19385 + kws["color"] = "none" + + if multiple in ["stack", "fill"] or element == "bars": + kws.setdefault("edgecolor", mpl.rcParams["patch.edgecolor"]) + else: + kws.setdefault("edgecolor", to_rgba(color, 1)) + elif element == "bars": + kws["facecolor"] = "none" + kws["edgecolor"] = to_rgba(color, alpha) + else: + kws["color"] = to_rgba(color, alpha) + return kws + + def _quantile_to_level(self, data, quantile): + """Return data levels corresponding to quantile cuts of mass.""" + isoprop = np.asarray(quantile) + values = np.ravel(data) + sorted_values = np.sort(values)[::-1] + normalized_values = np.cumsum(sorted_values) / values.sum() + idx = np.searchsorted(normalized_values, 1 - isoprop) + levels = np.take(sorted_values, idx, mode="clip") + return levels + + def _cmap_from_color(self, color): + """Return a sequential colormap given a color seed.""" + # Like so much else here, this is broadly useful, but keeping it + # in this class to signify that I haven't thought overly hard about it... + r, g, b, _ = to_rgba(color) + h, s, _ = husl.rgb_to_husl(r, g, b) + xx = np.linspace(-1, 1, int(1.15 * 256))[:256] + ramp = np.zeros((256, 3)) + ramp[:, 0] = h + ramp[:, 1] = s * np.cos(xx) + ramp[:, 2] = np.linspace(35, 80, 256) + colors = np.clip([husl.husl_to_rgb(*hsl) for hsl in ramp], 0, 1) + return mpl.colors.ListedColormap(colors[::-1]) + + def _default_discrete(self): + """Find default values for discrete hist estimation based on variable type.""" + if self.univariate: + discrete = self.var_types[self.data_variable] == "categorical" + else: + discrete_x = self.var_types["x"] == "categorical" + discrete_y = self.var_types["y"] == "categorical" + discrete = discrete_x, discrete_y + return discrete + + def _resolve_multiple(self, curves, multiple): + """Modify the density data structure to handle multiple densities.""" + + # Default baselines have all densities starting at 0 + baselines = {k: np.zeros_like(v) for k, v in curves.items()} + + # TODO we should have some central clearinghouse for checking if any + # "grouping" (terminnology?) semantics have been assigned + if "hue" not in self.variables: + return curves, baselines + + if multiple in ("stack", "fill"): + + # Setting stack or fill means that the curves share a + # support grid / set of bin edges, so we can make a dataframe + # Reverse the column order to plot from top to bottom + curves = pd.DataFrame(curves).iloc[:, ::-1] + + # Find column groups that are nested within col/row variables + column_groups = {} + for i, keyd in enumerate(map(dict, curves.columns.tolist())): + facet_key = keyd.get("col", None), keyd.get("row", None) + column_groups.setdefault(facet_key, []) + column_groups[facet_key].append(i) + + baselines = curves.copy() + for cols in column_groups.values(): + + norm_constant = curves.iloc[:, cols].sum(axis="columns") + + # Take the cumulative sum to stack + curves.iloc[:, cols] = curves.iloc[:, cols].cumsum(axis="columns") + + # Normalize by row sum to fill + if multiple == "fill": + curves.iloc[:, cols] = (curves + .iloc[:, cols] + .div(norm_constant, axis="index")) + + # Define where each segment starts + baselines.iloc[:, cols] = (curves + .iloc[:, cols] + .shift(1, axis=1) + .fillna(0)) + + if multiple == "dodge": + + # Account for the unique semantic (non-faceting) levels + # This will require rethiniking if we add other semantics! + hue_levels = self.var_levels["hue"] + n = len(hue_levels) + for key in curves: + level = dict(key)["hue"] + hist = curves[key].reset_index(name="heights") + hist["widths"] /= n + hist["edges"] += hue_levels.index(level) * hist["widths"] + + curves[key] = hist.set_index(["edges", "widths"])["heights"] + + return curves, baselines + + # -------------------------------------------------------------------------------- # + # Computation + # -------------------------------------------------------------------------------- # + + def _compute_univariate_density( + self, + data_variable, + common_norm, + common_grid, + estimate_kws, + log_scale, + ): + + # Initialize the estimator object + estimator = KDE(**estimate_kws) + + all_data = self.plot_data.dropna() + + if set(self.variables) - {"x", "y"}: + if common_grid: + all_observations = self.comp_data.dropna() + estimator.define_support(all_observations[data_variable]) + else: + common_norm = False + + densities = {} + + for sub_vars, sub_data in self.iter_data("hue", from_comp_data=True): + + # Extract the data points from this sub set and remove nulls + observations = sub_data[data_variable] + + observation_variance = observations.var() + if math.isclose(observation_variance, 0) or np.isnan(observation_variance): + msg = "Dataset has 0 variance; skipping density estimate." + warnings.warn(msg, UserWarning) + continue + + # Extract the weights for this subset of observations + if "weights" in self.variables: + weights = sub_data["weights"] + else: + weights = None + + # Estimate the density of observations at this level + density, support = estimator(observations, weights=weights) + + if log_scale: + support = np.power(10, support) + + # Apply a scaling factor so that the integral over all subsets is 1 + if common_norm: + density *= len(sub_data) / len(all_data) + + # Store the density for this level + key = tuple(sub_vars.items()) + densities[key] = pd.Series(density, index=support) + + return densities + + # -------------------------------------------------------------------------------- # + # Plotting + # -------------------------------------------------------------------------------- # + + def plot_univariate_histogram( + self, + multiple, + element, + fill, + common_norm, + common_bins, + shrink, + kde, + kde_kws, + color, + legend, + line_kws, + estimate_kws, + **plot_kws, + ): + + # -- Default keyword dicts + kde_kws = {} if kde_kws is None else kde_kws.copy() + line_kws = {} if line_kws is None else line_kws.copy() + estimate_kws = {} if estimate_kws is None else estimate_kws.copy() + + # -- Input checking + _check_argument("multiple", ["layer", "stack", "fill", "dodge"], multiple) + _check_argument("element", ["bars", "step", "poly"], element) + + if estimate_kws["discrete"] and element != "bars": + raise ValueError("`element` must be 'bars' when `discrete` is True") + + auto_bins_with_weights = ( + "weights" in self.variables + and estimate_kws["bins"] == "auto" + and estimate_kws["binwidth"] is None + and not estimate_kws["discrete"] + ) + if auto_bins_with_weights: + msg = ( + "`bins` cannot be 'auto' when using weights. " + "Setting `bins=10`, but you will likely want to adjust." + ) + warnings.warn(msg, UserWarning) + estimate_kws["bins"] = 10 + + # Simplify downstream code if we are not normalizing + if estimate_kws["stat"] == "count": + common_norm = False + + # Now initialize the Histogram estimator + estimator = Histogram(**estimate_kws) + histograms = {} + + # Do pre-compute housekeeping related to multiple groups + # TODO best way to account for facet/semantic? + if set(self.variables) - {"x", "y"}: + + all_data = self.comp_data.dropna() + + if common_bins: + all_observations = all_data[self.data_variable] + estimator.define_bin_edges( + all_observations, + weights=all_data.get("weights", None), + ) + + else: + common_norm = False + + # Estimate the smoothed kernel densities, for use later + if kde: + # TODO alternatively, clip at min/max bins? + kde_kws.setdefault("cut", 0) + kde_kws["cumulative"] = estimate_kws["cumulative"] + log_scale = self._log_scaled(self.data_variable) + densities = self._compute_univariate_density( + self.data_variable, + common_norm, + common_bins, + kde_kws, + log_scale, + ) + + # First pass through the data to compute the histograms + for sub_vars, sub_data in self.iter_data("hue", from_comp_data=True): + + # Prepare the relevant data + key = tuple(sub_vars.items()) + observations = sub_data[self.data_variable] + + if "weights" in self.variables: + weights = sub_data["weights"] + else: + weights = None + + # Do the histogram computation + heights, edges = estimator(observations, weights=weights) + + # Rescale the smoothed curve to match the histogram + if kde and key in densities: + density = densities[key] + if estimator.cumulative: + hist_norm = heights.max() + else: + hist_norm = (heights * np.diff(edges)).sum() + densities[key] *= hist_norm + + # Convert edges back to original units for plotting + if self._log_scaled(self.data_variable): + edges = np.power(10, edges) + + # Pack the histogram data and metadata together + orig_widths = np.diff(edges) + widths = shrink * orig_widths + edges = edges[:-1] + (1 - shrink) / 2 * orig_widths + index = pd.MultiIndex.from_arrays([ + pd.Index(edges, name="edges"), + pd.Index(widths, name="widths"), + ]) + hist = pd.Series(heights, index=index, name="heights") + + # Apply scaling to normalize across groups + if common_norm: + hist *= len(sub_data) / len(all_data) + + # Store the finalized histogram data for future plotting + histograms[key] = hist + + # Modify the histogram and density data to resolve multiple groups + histograms, baselines = self._resolve_multiple(histograms, multiple) + if kde: + densities, _ = self._resolve_multiple( + densities, None if multiple == "dodge" else multiple + ) + + # Set autoscaling-related meta + sticky_stat = (0, 1) if multiple == "fill" else (0, np.inf) + if multiple == "fill": + # Filled plots should not have any margins + bin_vals = histograms.index.to_frame() + edges = bin_vals["edges"] + widths = bin_vals["widths"] + sticky_data = ( + edges.min(), + edges.max() + widths.loc[edges.idxmax()] + ) + else: + sticky_data = [] + + # --- Handle default visual attributes + + # Note: default linewidth is determined after plotting + + # Default alpha should depend on other parameters + if fill: + # Note: will need to account for other grouping semantics if added + if "hue" in self.variables and multiple == "layer": + default_alpha = .5 if element == "bars" else .25 + elif kde: + default_alpha = .5 + else: + default_alpha = .75 + else: + default_alpha = 1 + alpha = plot_kws.pop("alpha", default_alpha) # TODO make parameter? + + hist_artists = [] + + # Go back through the dataset and draw the plots + for sub_vars, _ in self.iter_data("hue", reverse=True): + + key = tuple(sub_vars.items()) + hist = histograms[key].rename("heights").reset_index() + bottom = np.asarray(baselines[key]) + + ax = self._get_axes(sub_vars) + + # Define the matplotlib attributes that depend on semantic mapping + if "hue" in self.variables: + sub_color = self._hue_map(sub_vars["hue"]) + else: + sub_color = color + + artist_kws = self._artist_kws( + plot_kws, fill, element, multiple, sub_color, alpha + ) + + if element == "bars": + + # Use matplotlib bar plotting + + plot_func = ax.bar if self.data_variable == "x" else ax.barh + artists = plot_func( + hist["edges"], + hist["heights"] - bottom, + hist["widths"], + bottom, + align="edge", + **artist_kws, + ) + + for bar in artists: + if self.data_variable == "x": + bar.sticky_edges.x[:] = sticky_data + bar.sticky_edges.y[:] = sticky_stat + else: + bar.sticky_edges.x[:] = sticky_stat + bar.sticky_edges.y[:] = sticky_data + + hist_artists.extend(artists) + + else: + + # Use either fill_between or plot to draw hull of histogram + if element == "step": + + final = hist.iloc[-1] + x = np.append(hist["edges"], final["edges"] + final["widths"]) + y = np.append(hist["heights"], final["heights"]) + b = np.append(bottom, bottom[-1]) + + if self.data_variable == "x": + step = "post" + drawstyle = "steps-post" + else: + step = "post" # fillbetweenx handles mapping internally + drawstyle = "steps-pre" + + elif element == "poly": + + x = hist["edges"] + hist["widths"] / 2 + y = hist["heights"] + b = bottom + + step = None + drawstyle = None + + if self.data_variable == "x": + if fill: + artist = ax.fill_between(x, b, y, step=step, **artist_kws) + else: + artist, = ax.plot(x, y, drawstyle=drawstyle, **artist_kws) + artist.sticky_edges.x[:] = sticky_data + artist.sticky_edges.y[:] = sticky_stat + else: + if fill: + artist = ax.fill_betweenx(x, b, y, step=step, **artist_kws) + else: + artist, = ax.plot(y, x, drawstyle=drawstyle, **artist_kws) + artist.sticky_edges.x[:] = sticky_stat + artist.sticky_edges.y[:] = sticky_data + + hist_artists.append(artist) + + if kde: + + # Add in the density curves + + try: + density = densities[key] + except KeyError: + continue + support = density.index + + if "x" in self.variables: + line_args = support, density + sticky_x, sticky_y = None, (0, np.inf) + else: + line_args = density, support + sticky_x, sticky_y = (0, np.inf), None + + line_kws["color"] = to_rgba(sub_color, 1) + line, = ax.plot( + *line_args, **line_kws, + ) + + if sticky_x is not None: + line.sticky_edges.x[:] = sticky_x + if sticky_y is not None: + line.sticky_edges.y[:] = sticky_y + + if element == "bars" and "linewidth" not in plot_kws: + + # Now we handle linewidth, which depends on the scaling of the plot + + # We will base everything on the minimum bin width + hist_metadata = pd.concat([ + # Use .items for generality over dict or df + h.index.to_frame() for _, h in histograms.items() + ]).reset_index(drop=True) + thin_bar_idx = hist_metadata["widths"].idxmin() + binwidth = hist_metadata.loc[thin_bar_idx, "widths"] + left_edge = hist_metadata.loc[thin_bar_idx, "edges"] + + # Set initial value + default_linewidth = math.inf + + # Loop through subsets based only on facet variables + for sub_vars, _ in self.iter_data(): + + ax = self._get_axes(sub_vars) + + # Needed in some cases to get valid transforms. + # Innocuous in other cases? + ax.autoscale_view() + + # Convert binwidth from data coordinates to pixels + pts_x, pts_y = 72 / ax.figure.dpi * abs( + ax.transData.transform([left_edge + binwidth] * 2) + - ax.transData.transform([left_edge] * 2) + ) + if self.data_variable == "x": + binwidth_points = pts_x + else: + binwidth_points = pts_y + + # The relative size of the lines depends on the appearance + # This is a provisional value and may need more tweaking + default_linewidth = min(.1 * binwidth_points, default_linewidth) + + # Set the attributes + for bar in hist_artists: + + # Don't let the lines get too thick + max_linewidth = bar.get_linewidth() + if not fill: + max_linewidth *= 1.5 + + linewidth = min(default_linewidth, max_linewidth) + + # If not filling, don't let lines dissapear + if not fill: + min_linewidth = .5 + linewidth = max(linewidth, min_linewidth) + + bar.set_linewidth(linewidth) + + # --- Finalize the plot ---- + + # Axis labels + ax = self.ax if self.ax is not None else self.facets.axes.flat[0] + default_x = default_y = "" + if self.data_variable == "x": + default_y = estimator.stat.capitalize() + if self.data_variable == "y": + default_x = estimator.stat.capitalize() + self._add_axis_labels(ax, default_x, default_y) + + # Legend for semantic variables + if "hue" in self.variables and legend: + + if fill or element == "bars": + artist = partial(mpl.patches.Patch) + else: + artist = partial(mpl.lines.Line2D, [], []) + + ax_obj = self.ax if self.ax is not None else self.facets + self._add_legend( + ax_obj, artist, fill, element, multiple, alpha, plot_kws, {}, + ) + + def plot_bivariate_histogram( + self, + common_bins, common_norm, + thresh, pthresh, pmax, + color, legend, + cbar, cbar_ax, cbar_kws, + estimate_kws, + **plot_kws, + ): + + # Default keyword dicts + cbar_kws = {} if cbar_kws is None else cbar_kws.copy() + + # Now initialize the Histogram estimator + estimator = Histogram(**estimate_kws) + + # Do pre-compute housekeeping related to multiple groups + if set(self.variables) - {"x", "y"}: + all_data = self.comp_data.dropna() + if common_bins: + estimator.define_bin_edges( + all_data["x"], + all_data["y"], + all_data.get("weights", None), + ) + else: + common_norm = False + + # -- Determine colormap threshold and norm based on the full data + + full_heights = [] + for _, sub_data in self.iter_data(from_comp_data=True): + sub_heights, _ = estimator( + sub_data["x"], sub_data["y"], sub_data.get("weights", None) + ) + full_heights.append(sub_heights) + + common_color_norm = not set(self.variables) - {"x", "y"} or common_norm + + if pthresh is not None and common_color_norm: + thresh = self._quantile_to_level(full_heights, pthresh) + + plot_kws.setdefault("vmin", 0) + if common_color_norm: + if pmax is not None: + vmax = self._quantile_to_level(full_heights, pmax) + else: + vmax = plot_kws.pop("vmax", np.max(full_heights)) + else: + vmax = None + + # Get a default color + # (We won't follow the color cycle here, as multiple plots are unlikely) + if color is None: + color = "C0" + + # --- Loop over data (subsets) and draw the histograms + for sub_vars, sub_data in self.iter_data("hue", from_comp_data=True): + + if sub_data.empty: + continue + + # Do the histogram computation + heights, (x_edges, y_edges) = estimator( + sub_data["x"], + sub_data["y"], + weights=sub_data.get("weights", None), + ) + + # Check for log scaling on the data axis + if self._log_scaled("x"): + x_edges = np.power(10, x_edges) + if self._log_scaled("y"): + y_edges = np.power(10, y_edges) + + # Apply scaling to normalize across groups + if estimator.stat != "count" and common_norm: + heights *= len(sub_data) / len(all_data) + + # Define the specific kwargs for this artist + artist_kws = plot_kws.copy() + if "hue" in self.variables: + color = self._hue_map(sub_vars["hue"]) + cmap = self._cmap_from_color(color) + artist_kws["cmap"] = cmap + else: + cmap = artist_kws.pop("cmap", None) + if isinstance(cmap, str): + cmap = color_palette(cmap, as_cmap=True) + elif cmap is None: + cmap = self._cmap_from_color(color) + artist_kws["cmap"] = cmap + + # Set the upper norm on the colormap + if not common_color_norm and pmax is not None: + vmax = self._quantile_to_level(heights, pmax) + if vmax is not None: + artist_kws["vmax"] = vmax + + # Make cells at or below the threshold transparent + if not common_color_norm and pthresh: + thresh = self._quantile_to_level(heights, pthresh) + if thresh is not None: + heights = np.ma.masked_less_equal(heights, thresh) + + # Get the axes for this plot + ax = self._get_axes(sub_vars) + + # pcolormesh is going to turn the grid off, but we want to keep it + # I'm not sure if there's a better way to get the grid state + x_grid = any([l.get_visible() for l in ax.xaxis.get_gridlines()]) + y_grid = any([l.get_visible() for l in ax.yaxis.get_gridlines()]) + + mesh = ax.pcolormesh( + x_edges, + y_edges, + heights.T, + **artist_kws, + ) + + # pcolormesh sets sticky edges, but we only want them if not thresholding + if thresh is not None: + mesh.sticky_edges.x[:] = [] + mesh.sticky_edges.y[:] = [] + + # Add an optional colorbar + # Note, we want to improve this. When hue is used, it will stack + # multiple colorbars with redundant ticks in an ugly way. + # But it's going to take some work to have multiple colorbars that + # share ticks nicely. + if cbar: + ax.figure.colorbar(mesh, cbar_ax, ax, **cbar_kws) + + # Reset the grid state + if x_grid: + ax.grid(True, axis="x") + if y_grid: + ax.grid(True, axis="y") + + # --- Finalize the plot + + ax = self.ax if self.ax is not None else self.facets.axes.flat[0] + self._add_axis_labels(ax) + + if "hue" in self.variables and legend: + + # TODO if possible, I would like to move the contour + # intensity information into the legend too and label the + # iso proportions rather than the raw density values + + artist_kws = {} + artist = partial(mpl.patches.Patch) + ax_obj = self.ax if self.ax is not None else self.facets + self._add_legend( + ax_obj, artist, True, False, "layer", 1, artist_kws, {}, + ) + + def plot_univariate_density( + self, + multiple, + common_norm, + common_grid, + fill, + color, + legend, + estimate_kws, + **plot_kws, + ): + + # Handle conditional defaults + if fill is None: + fill = multiple in ("stack", "fill") + + # Preprocess the matplotlib keyword dictionaries + if fill: + artist = mpl.collections.PolyCollection + else: + artist = mpl.lines.Line2D + plot_kws = _normalize_kwargs(plot_kws, artist) + + # Input checking + _check_argument("multiple", ["layer", "stack", "fill"], multiple) + + # Always share the evaluation grid when stacking + subsets = bool(set(self.variables) - {"x", "y"}) + if subsets and multiple in ("stack", "fill"): + common_grid = True + + # Check if the data axis is log scaled + log_scale = self._log_scaled(self.data_variable) + + # Do the computation + densities = self._compute_univariate_density( + self.data_variable, + common_norm, + common_grid, + estimate_kws, + log_scale, + ) + + # Adjust densities based on the `multiple` rule + densities, baselines = self._resolve_multiple(densities, multiple) + + # Control the interaction with autoscaling by defining sticky_edges + # i.e. we don't want autoscale margins below the density curve + sticky_density = (0, 1) if multiple == "fill" else (0, np.inf) + + if multiple == "fill": + # Filled plots should not have any margins + sticky_support = densities.index.min(), densities.index.max() + else: + sticky_support = [] + + if fill: + if multiple == "layer": + default_alpha = .25 + else: + default_alpha = .75 + else: + default_alpha = 1 + alpha = plot_kws.pop("alpha", default_alpha) # TODO make parameter? + + # Now iterate through the subsets and draw the densities + # We go backwards so stacked densities read from top-to-bottom + for sub_vars, _ in self.iter_data("hue", reverse=True): + + # Extract the support grid and density curve for this level + key = tuple(sub_vars.items()) + try: + density = densities[key] + except KeyError: + continue + support = density.index + fill_from = baselines[key] + + ax = self._get_axes(sub_vars) + + if "hue" in self.variables: + sub_color = self._hue_map(sub_vars["hue"]) + else: + sub_color = color + + artist_kws = self._artist_kws( + plot_kws, fill, False, multiple, sub_color, alpha + ) + + # Either plot a curve with observation values on the x axis + if "x" in self.variables: + + if fill: + artist = ax.fill_between(support, fill_from, density, **artist_kws) + + else: + artist, = ax.plot(support, density, **artist_kws) + + artist.sticky_edges.x[:] = sticky_support + artist.sticky_edges.y[:] = sticky_density + + # Or plot a curve with observation values on the y axis + else: + if fill: + artist = ax.fill_betweenx(support, fill_from, density, **artist_kws) + else: + artist, = ax.plot(density, support, **artist_kws) + + artist.sticky_edges.x[:] = sticky_density + artist.sticky_edges.y[:] = sticky_support + + # --- Finalize the plot ---- + + ax = self.ax if self.ax is not None else self.facets.axes.flat[0] + default_x = default_y = "" + if self.data_variable == "x": + default_y = "Density" + if self.data_variable == "y": + default_x = "Density" + self._add_axis_labels(ax, default_x, default_y) + + if "hue" in self.variables and legend: + + if fill: + artist = partial(mpl.patches.Patch) + else: + artist = partial(mpl.lines.Line2D, [], []) + + ax_obj = self.ax if self.ax is not None else self.facets + self._add_legend( + ax_obj, artist, fill, False, multiple, alpha, plot_kws, {}, + ) + + def plot_bivariate_density( + self, + common_norm, + fill, + levels, + thresh, + color, + legend, + cbar, + cbar_ax, + cbar_kws, + estimate_kws, + **contour_kws, + ): + + contour_kws = contour_kws.copy() + + estimator = KDE(**estimate_kws) + + if not set(self.variables) - {"x", "y"}: + common_norm = False + + all_data = self.plot_data.dropna() + + # Loop through the subsets and estimate the KDEs + densities, supports = {}, {} + + for sub_vars, sub_data in self.iter_data("hue", from_comp_data=True): + + # Extract the data points from this sub set and remove nulls + observations = sub_data[["x", "y"]] + + # Extract the weights for this subset of observations + if "weights" in self.variables: + weights = sub_data["weights"] + else: + weights = None + + # Check that KDE will not error out + variance = observations[["x", "y"]].var() + if any(math.isclose(x, 0) for x in variance) or variance.isna().any(): + msg = "Dataset has 0 variance; skipping density estimate." + warnings.warn(msg, UserWarning) + continue + + # Estimate the density of observations at this level + observations = observations["x"], observations["y"] + density, support = estimator(*observations, weights=weights) + + # Transform the support grid back to the original scale + xx, yy = support + if self._log_scaled("x"): + xx = np.power(10, xx) + if self._log_scaled("y"): + yy = np.power(10, yy) + support = xx, yy + + # Apply a scaling factor so that the integral over all subsets is 1 + if common_norm: + density *= len(sub_data) / len(all_data) + + key = tuple(sub_vars.items()) + densities[key] = density + supports[key] = support + + # Define a grid of iso-proportion levels + if thresh is None: + thresh = 0 + if isinstance(levels, Number): + levels = np.linspace(thresh, 1, levels) + else: + if min(levels) < 0 or max(levels) > 1: + raise ValueError("levels must be in [0, 1]") + + # Transform from iso-proportions to iso-densities + if common_norm: + common_levels = self._quantile_to_level( + list(densities.values()), levels, + ) + draw_levels = {k: common_levels for k in densities} + else: + draw_levels = { + k: self._quantile_to_level(d, levels) + for k, d in densities.items() + } + + # Get a default single color from the attribute cycle + if self.ax is None: + default_color = "C0" if color is None else color + else: + scout, = self.ax.plot([], color=color) + default_color = scout.get_color() + scout.remove() + + # Define the coloring of the contours + if "hue" in self.variables: + for param in ["cmap", "colors"]: + if param in contour_kws: + msg = f"{param} parameter ignored when using hue mapping." + warnings.warn(msg, UserWarning) + contour_kws.pop(param) + else: + + # Work out a default coloring of the contours + coloring_given = set(contour_kws) & {"cmap", "colors"} + if fill and not coloring_given: + cmap = self._cmap_from_color(default_color) + contour_kws["cmap"] = cmap + if not fill and not coloring_given: + contour_kws["colors"] = [default_color] + + # Use our internal colormap lookup + cmap = contour_kws.pop("cmap", None) + if isinstance(cmap, str): + cmap = color_palette(cmap, as_cmap=True) + if cmap is not None: + contour_kws["cmap"] = cmap + + # Loop through the subsets again and plot the data + for sub_vars, _ in self.iter_data("hue"): + + if "hue" in sub_vars: + color = self._hue_map(sub_vars["hue"]) + if fill: + contour_kws["cmap"] = self._cmap_from_color(color) + else: + contour_kws["colors"] = [color] + + ax = self._get_axes(sub_vars) + + # Choose the function to plot with + # TODO could add a pcolormesh based option as well + # Which would look something like element="raster" + if fill: + contour_func = ax.contourf + else: + contour_func = ax.contour + + key = tuple(sub_vars.items()) + if key not in densities: + continue + density = densities[key] + xx, yy = supports[key] + + label = contour_kws.pop("label", None) + + cset = contour_func( + xx, yy, density, + levels=draw_levels[key], + **contour_kws, + ) + + if "hue" not in self.variables: + cset.collections[0].set_label(label) + + # Add a color bar representing the contour heights + # Note: this shows iso densities, not iso proportions + # See more notes in histplot about how this could be improved + if cbar: + cbar_kws = {} if cbar_kws is None else cbar_kws + ax.figure.colorbar(cset, cbar_ax, ax, **cbar_kws) + + # --- Finalize the plot + ax = self.ax if self.ax is not None else self.facets.axes.flat[0] + self._add_axis_labels(ax) + + if "hue" in self.variables and legend: + + # TODO if possible, I would like to move the contour + # intensity information into the legend too and label the + # iso proportions rather than the raw density values + + artist_kws = {} + if fill: + artist = partial(mpl.patches.Patch) + else: + artist = partial(mpl.lines.Line2D, [], []) + + ax_obj = self.ax if self.ax is not None else self.facets + self._add_legend( + ax_obj, artist, fill, False, "layer", 1, artist_kws, {}, + ) + + def plot_univariate_ecdf(self, estimate_kws, legend, **plot_kws): + + estimator = ECDF(**estimate_kws) + + # Set the draw style to step the right way for the data variable + drawstyles = dict(x="steps-post", y="steps-pre") + plot_kws["drawstyle"] = drawstyles[self.data_variable] + + # Loop through the subsets, transform and plot the data + for sub_vars, sub_data in self.iter_data( + "hue", reverse=True, from_comp_data=True, + ): + + # Compute the ECDF + if sub_data.empty: + continue + + observations = sub_data[self.data_variable] + weights = sub_data.get("weights", None) + stat, vals = estimator(observations, weights=weights) + + # Assign attributes based on semantic mapping + artist_kws = plot_kws.copy() + if "hue" in self.variables: + artist_kws["color"] = self._hue_map(sub_vars["hue"]) + + # Return the data variable to the linear domain + # This needs an automatic solution; see GH2409 + if self._log_scaled(self.data_variable): + vals = np.power(10, vals) + vals[0] = -np.inf + + # Work out the orientation of the plot + if self.data_variable == "x": + plot_args = vals, stat + stat_variable = "y" + else: + plot_args = stat, vals + stat_variable = "x" + + if estimator.stat == "count": + top_edge = len(observations) + else: + top_edge = 1 + + # Draw the line for this subset + ax = self._get_axes(sub_vars) + artist, = ax.plot(*plot_args, **artist_kws) + sticky_edges = getattr(artist.sticky_edges, stat_variable) + sticky_edges[:] = 0, top_edge + + # --- Finalize the plot ---- + ax = self.ax if self.ax is not None else self.facets.axes.flat[0] + stat = estimator.stat.capitalize() + default_x = default_y = "" + if self.data_variable == "x": + default_y = stat + if self.data_variable == "y": + default_x = stat + self._add_axis_labels(ax, default_x, default_y) + + if "hue" in self.variables and legend: + artist = partial(mpl.lines.Line2D, [], []) + alpha = plot_kws.get("alpha", 1) + ax_obj = self.ax if self.ax is not None else self.facets + self._add_legend( + ax_obj, artist, False, False, None, alpha, plot_kws, {}, + ) + + def plot_rug(self, height, expand_margins, legend, **kws): + + for sub_vars, sub_data, in self.iter_data(from_comp_data=True): + + ax = self._get_axes(sub_vars) + + kws.setdefault("linewidth", 1) + + if expand_margins: + xmarg, ymarg = ax.margins() + if "x" in self.variables: + ymarg += height * 2 + if "y" in self.variables: + xmarg += height * 2 + ax.margins(x=xmarg, y=ymarg) + + if "hue" in self.variables: + kws.pop("c", None) + kws.pop("color", None) + + if "x" in self.variables: + self._plot_single_rug(sub_data, "x", height, ax, kws) + if "y" in self.variables: + self._plot_single_rug(sub_data, "y", height, ax, kws) + + # --- Finalize the plot + self._add_axis_labels(ax) + if "hue" in self.variables and legend: + # TODO ideally i'd like the legend artist to look like a rug + legend_artist = partial(mpl.lines.Line2D, [], []) + self._add_legend( + ax, legend_artist, False, False, None, 1, {}, {}, + ) + + def _plot_single_rug(self, sub_data, var, height, ax, kws): + """Draw a rugplot along one axis of the plot.""" + vector = sub_data[var] + n = len(vector) + + # Return data to linear domain + # This needs an automatic solution; see GH2409 + if self._log_scaled(var): + vector = np.power(10, vector) + + # We'll always add a single collection with varying colors + if "hue" in self.variables: + colors = self._hue_map(sub_data["hue"]) + else: + colors = None + + # Build the array of values for the LineCollection + if var == "x": + + trans = tx.blended_transform_factory(ax.transData, ax.transAxes) + xy_pairs = np.column_stack([ + np.repeat(vector, 2), np.tile([0, height], n) + ]) + + if var == "y": + + trans = tx.blended_transform_factory(ax.transAxes, ax.transData) + xy_pairs = np.column_stack([ + np.tile([0, height], n), np.repeat(vector, 2) + ]) + + # Draw the lines on the plot + line_segs = xy_pairs.reshape([n, 2, 2]) + ax.add_collection(LineCollection( + line_segs, transform=trans, colors=colors, **kws + )) + + ax.autoscale_view(scalex=var == "x", scaley=var == "y") + + +class _DistributionFacetPlotter(_DistributionPlotter): + + semantics = _DistributionPlotter.semantics + ("col", "row") + + +# ==================================================================================== # +# External API +# ==================================================================================== # + +def histplot( + data=None, *, + # Vector variables + x=None, y=None, hue=None, weights=None, + # Histogram computation parameters + stat="count", bins="auto", binwidth=None, binrange=None, + discrete=None, cumulative=False, common_bins=True, common_norm=True, + # Histogram appearance parameters + multiple="layer", element="bars", fill=True, shrink=1, + # Histogram smoothing with a kernel density estimate + kde=False, kde_kws=None, line_kws=None, + # Bivariate histogram parameters + thresh=0, pthresh=None, pmax=None, cbar=False, cbar_ax=None, cbar_kws=None, + # Hue mapping parameters + palette=None, hue_order=None, hue_norm=None, color=None, + # Axes information + log_scale=None, legend=True, ax=None, + # Other appearance keywords + **kwargs, +): + + p = _DistributionPlotter( + data=data, + variables=_DistributionPlotter.get_semantics(locals()) + ) + + p.map_hue(palette=palette, order=hue_order, norm=hue_norm) + + if ax is None: + ax = plt.gca() + + p._attach(ax, log_scale=log_scale) + + if p.univariate: # Note, bivariate plots won't cycle + if fill: + method = ax.bar if element == "bars" else ax.fill_between + else: + method = ax.plot + color = _default_color(method, hue, color, kwargs) + + if not p.has_xy_data: + return ax + + # Default to discrete bins for categorical variables + if discrete is None: + discrete = p._default_discrete() + + estimate_kws = dict( + stat=stat, + bins=bins, + binwidth=binwidth, + binrange=binrange, + discrete=discrete, + cumulative=cumulative, + ) + + if p.univariate: + + p.plot_univariate_histogram( + multiple=multiple, + element=element, + fill=fill, + shrink=shrink, + common_norm=common_norm, + common_bins=common_bins, + kde=kde, + kde_kws=kde_kws, + color=color, + legend=legend, + estimate_kws=estimate_kws, + line_kws=line_kws, + **kwargs, + ) + + else: + + p.plot_bivariate_histogram( + common_bins=common_bins, + common_norm=common_norm, + thresh=thresh, + pthresh=pthresh, + pmax=pmax, + color=color, + legend=legend, + cbar=cbar, + cbar_ax=cbar_ax, + cbar_kws=cbar_kws, + estimate_kws=estimate_kws, + **kwargs, + ) + + return ax + + +histplot.__doc__ = """\ +Plot univariate or bivariate histograms to show distributions of datasets. + +A histogram is a classic visualization tool that represents the distribution +of one or more variables by counting the number of observations that fall within +disrete bins. + +This function can normalize the statistic computed within each bin to estimate +frequency, density or probability mass, and it can add a smooth curve obtained +using a kernel density estimate, similar to :func:`kdeplot`. + +More information is provided in the :ref:`user guide `. + +Parameters +---------- +{params.core.data} +{params.core.xy} +{params.core.hue} +weights : vector or key in ``data`` + If provided, weight the contribution of the corresponding data points + towards the count in each bin by these factors. +{params.hist.stat} +{params.hist.bins} +{params.hist.binwidth} +{params.hist.binrange} +discrete : bool + If True, default to ``binwidth=1`` and draw the bars so that they are + centered on their corresponding data points. This avoids "gaps" that may + otherwise appear when using discrete (integer) data. +cumulative : bool + If True, plot the cumulative counts as bins increase. +common_bins : bool + If True, use the same bins when semantic variables produce multiple + plots. If using a reference rule to determine the bins, it will be computed + with the full dataset. +common_norm : bool + If True and using a normalized statistic, the normalization will apply over + the full dataset. Otherwise, normalize each histogram independently. +multiple : {{"layer", "dodge", "stack", "fill"}} + Approach to resolving multiple elements when semantic mapping creates subsets. + Only relevant with univariate data. +element : {{"bars", "step", "poly"}} + Visual representation of the histogram statistic. + Only relevant with univariate data. +fill : bool + If True, fill in the space under the histogram. + Only relevant with univariate data. +shrink : number + Scale the width of each bar relative to the binwidth by this factor. + Only relevant with univariate data. +kde : bool + If True, compute a kernel density estimate to smooth the distribution + and show on the plot as (one or more) line(s). + Only relevant with univariate data. +kde_kws : dict + Parameters that control the KDE computation, as in :func:`kdeplot`. +line_kws : dict + Parameters that control the KDE visualization, passed to + :meth:`matplotlib.axes.Axes.plot`. +thresh : number or None + Cells with a statistic less than or equal to this value will be transparent. + Only relevant with bivariate data. +pthresh : number or None + Like ``thresh``, but a value in [0, 1] such that cells with aggregate counts + (or other statistics, when used) up to this proportion of the total will be + transparent. +pmax : number or None + A value in [0, 1] that sets that saturation point for the colormap at a value + such that cells below is constistute this proportion of the total count (or + other statistic, when used). +{params.dist.cbar} +{params.dist.cbar_ax} +{params.dist.cbar_kws} +{params.core.palette} +{params.core.hue_order} +{params.core.hue_norm} +{params.core.color} +{params.dist.log_scale} +{params.dist.legend} +{params.core.ax} +kwargs + Other keyword arguments are passed to one of the following matplotlib + functions: + + - :meth:`matplotlib.axes.Axes.bar` (univariate, element="bars") + - :meth:`matplotlib.axes.Axes.fill_between` (univariate, other element, fill=True) + - :meth:`matplotlib.axes.Axes.plot` (univariate, other element, fill=False) + - :meth:`matplotlib.axes.Axes.pcolormesh` (bivariate) + +Returns +------- +{returns.ax} + +See Also +-------- +{seealso.displot} +{seealso.kdeplot} +{seealso.rugplot} +{seealso.ecdfplot} +{seealso.jointplot} + +Notes +----- + +The choice of bins for computing and plotting a histogram can exert +substantial influence on the insights that one is able to draw from the +visualization. If the bins are too large, they may erase important features. +On the other hand, bins that are too small may be dominated by random +variability, obscuring the shape of the true underlying distribution. The +default bin size is determined using a reference rule that depends on the +sample size and variance. This works well in many cases, (i.e., with +"well-behaved" data) but it fails in others. It is always a good to try +different bin sizes to be sure that you are not missing something important. +This function allows you to specify bins in several different ways, such as +by setting the total number of bins to use, the width of each bin, or the +specific locations where the bins should break. + +Examples +-------- + +.. include:: ../docstrings/histplot.rst + +""".format( + params=_param_docs, + returns=_core_docs["returns"], + seealso=_core_docs["seealso"], +) + + +@_deprecate_positional_args +def kdeplot( + x=None, # Allow positional x, because behavior will not change with reorg + *, + y=None, + shade=None, # Note "soft" deprecation, explained below + vertical=False, # Deprecated + kernel=None, # Deprecated + bw=None, # Deprecated + gridsize=200, # TODO maybe depend on uni/bivariate? + cut=3, clip=None, legend=True, cumulative=False, + shade_lowest=None, # Deprecated, controlled with levels now + cbar=False, cbar_ax=None, cbar_kws=None, + ax=None, + + # New params + weights=None, # TODO note that weights is grouped with semantics + hue=None, palette=None, hue_order=None, hue_norm=None, + multiple="layer", common_norm=True, common_grid=False, + levels=10, thresh=.05, + bw_method="scott", bw_adjust=1, log_scale=None, + color=None, fill=None, + + # Renamed params + data=None, data2=None, + + **kwargs, +): + + # Handle deprecation of `data2` as name for y variable + if data2 is not None: + + y = data2 + + # If `data2` is present, we need to check for the `data` kwarg being + # used to pass a vector for `x`. We'll reassign the vectors and warn. + # We need this check because just passing a vector to `data` is now + # technically valid. + + x_passed_as_data = ( + x is None + and data is not None + and np.ndim(data) == 1 + ) + + if x_passed_as_data: + msg = "Use `x` and `y` rather than `data` `and `data2`" + x = data + else: + msg = "The `data2` param is now named `y`; please update your code" + + warnings.warn(msg, FutureWarning) + + # Handle deprecation of `vertical` + if vertical: + msg = ( + "The `vertical` parameter is deprecated and will be removed in a " + "future version. Assign the data to the `y` variable instead." + ) + warnings.warn(msg, FutureWarning) + x, y = y, x + + # Handle deprecation of `bw` + if bw is not None: + msg = ( + "The `bw` parameter is deprecated in favor of `bw_method` and " + f"`bw_adjust`. Using {bw} for `bw_method`, but please " + "see the docs for the new parameters and update your code." + ) + warnings.warn(msg, FutureWarning) + bw_method = bw + + # Handle deprecation of `kernel` + if kernel is not None: + msg = ( + "Support for alternate kernels has been removed. " + "Using Gaussian kernel." + ) + warnings.warn(msg, UserWarning) + + # Handle deprecation of shade_lowest + if shade_lowest is not None: + if shade_lowest: + thresh = 0 + msg = ( + "`shade_lowest` is now deprecated in favor of `thresh`. " + f"Setting `thresh={thresh}`, but please update your code." + ) + warnings.warn(msg, UserWarning) + + # Handle `n_levels` + # This was never in the formal API but it was processed, and appeared in an + # example. We can treat as an alias for `levels` now and deprecate later. + levels = kwargs.pop("n_levels", levels) + + # Handle "soft" deprecation of shade `shade` is not really the right + # terminology here, but unlike some of the other deprecated parameters it + # is probably very commonly used and much hard to remove. This is therefore + # going to be a longer process where, first, `fill` will be introduced and + # be used throughout the documentation. In 0.12, when kwarg-only + # enforcement hits, we can remove the shade/shade_lowest out of the + # function signature all together and pull them out of the kwargs. Then we + # can actually fire a FutureWarning, and eventually remove. + if shade is not None: + fill = shade + + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + + p = _DistributionPlotter( + data=data, + variables=_DistributionPlotter.get_semantics(locals()), + ) + + p.map_hue(palette=palette, order=hue_order, norm=hue_norm) + + if ax is None: + ax = plt.gca() + + p._attach(ax, allowed_types=["numeric", "datetime"], log_scale=log_scale) + + method = ax.fill_between if fill else ax.plot + color = _default_color(method, hue, color, kwargs) + + if not p.has_xy_data: + return ax + + # Pack the kwargs for statistics.KDE + estimate_kws = dict( + bw_method=bw_method, + bw_adjust=bw_adjust, + gridsize=gridsize, + cut=cut, + clip=clip, + cumulative=cumulative, + ) + + if p.univariate: + + plot_kws = kwargs.copy() + + p.plot_univariate_density( + multiple=multiple, + common_norm=common_norm, + common_grid=common_grid, + fill=fill, + color=color, + legend=legend, + estimate_kws=estimate_kws, + **plot_kws, + ) + + else: + + p.plot_bivariate_density( + common_norm=common_norm, + fill=fill, + levels=levels, + thresh=thresh, + legend=legend, + color=color, + cbar=cbar, + cbar_ax=cbar_ax, + cbar_kws=cbar_kws, + estimate_kws=estimate_kws, + **kwargs, + ) + + return ax + + +kdeplot.__doc__ = """\ +Plot univariate or bivariate distributions using kernel density estimation. + +A kernel density estimate (KDE) plot is a method for visualizing the +distribution of observations in a dataset, analagous to a histogram. KDE +represents the data using a continuous probability density curve in one or +more dimensions. + +The approach is explained further in the :ref:`user guide `. + +Relative to a histogram, KDE can produce a plot that is less cluttered and +more interpretable, especially when drawing multiple distributions. But it +has the potential to introduce distortions if the underlying distribution is +bounded or not smooth. Like a histogram, the quality of the representation +also depends on the selection of good smoothing parameters. + +Parameters +---------- +{params.core.xy} +shade : bool + Alias for ``fill``. Using ``fill`` is recommended. +vertical : bool + Orientation parameter. + + .. deprecated:: 0.11.0 + specify orientation by assigning the ``x`` or ``y`` variables. + +kernel : str + Function that defines the kernel. + + .. deprecated:: 0.11.0 + support for non-Gaussian kernels has been removed. + +bw : str, number, or callable + Smoothing parameter. + + .. deprecated:: 0.11.0 + see ``bw_method`` and ``bw_adjust``. + +gridsize : int + Number of points on each dimension of the evaluation grid. +{params.kde.cut} +{params.kde.clip} +{params.dist.legend} +{params.kde.cumulative} +shade_lowest : bool + If False, the area below the lowest contour will be transparent + + .. deprecated:: 0.11.0 + see ``thresh``. + +{params.dist.cbar} +{params.dist.cbar_ax} +{params.dist.cbar_kws} +{params.core.ax} +weights : vector or key in ``data`` + If provided, weight the kernel density estimation using these values. +{params.core.hue} +{params.core.palette} +{params.core.hue_order} +{params.core.hue_norm} +{params.dist.multiple} +common_norm : bool + If True, scale each conditional density by the number of observations + such that the total area under all densities sums to 1. Otherwise, + normalize each density independently. +common_grid : bool + If True, use the same evaluation grid for each kernel density estimate. + Only relevant with univariate data. +levels : int or vector + Number of contour levels or values to draw contours at. A vector argument + must have increasing values in [0, 1]. Levels correspond to iso-proportions + of the density: e.g., 20% of the probability mass will lie below the + contour drawn for 0.2. Only relevant with bivariate data. +thresh : number in [0, 1] + Lowest iso-proportion level at which to draw a contour line. Ignored when + ``levels`` is a vector. Only relevant with bivariate data. +{params.kde.bw_method} +{params.kde.bw_adjust} +{params.dist.log_scale} +{params.core.color} +fill : bool or None + If True, fill in the area under univariate density curves or between + bivariate contours. If None, the default depends on ``multiple``. +{params.core.data} +kwargs + Other keyword arguments are passed to one of the following matplotlib + functions: + + - :meth:`matplotlib.axes.Axes.plot` (univariate, ``fill=False``), + - :meth:`matplotlib.axes.Axes.fill_between` (univariate, ``fill=True``), + - :meth:`matplotlib.axes.Axes.contour` (bivariate, ``fill=False``), + - :meth:`matplotlib.axes.contourf` (bivariate, ``fill=True``). + +Returns +------- +{returns.ax} + +See Also +-------- +{seealso.displot} +{seealso.histplot} +{seealso.ecdfplot} +{seealso.jointplot} +{seealso.violinplot} + +Notes +----- + +The *bandwidth*, or standard deviation of the smoothing kernel, is an +important parameter. Misspecification of the bandwidth can produce a +distorted representation of the data. Much like the choice of bin width in a +histogram, an over-smoothed curve can erase true features of a +distribution, while an under-smoothed curve can create false features out of +random variability. The rule-of-thumb that sets the default bandwidth works +best when the true distribution is smooth, unimodal, and roughly bell-shaped. +It is always a good idea to check the default behavior by using ``bw_adjust`` +to increase or decrease the amount of smoothing. + +Because the smoothing algorithm uses a Gaussian kernel, the estimated density +curve can extend to values that do not make sense for a particular dataset. +For example, the curve may be drawn over negative values when smoothing data +that are naturally positive. The ``cut`` and ``clip`` parameters can be used +to control the extent of the curve, but datasets that have many observations +close to a natural boundary may be better served by a different visualization +method. + +Similar considerations apply when a dataset is naturally discrete or "spiky" +(containing many repeated observations of the same value). Kernel density +estimation will always produce a smooth curve, which would be misleading +in these situations. + +The units on the density axis are a common source of confusion. While kernel +density estimation produces a probability distribution, the height of the curve +at each point gives a density, not a probability. A probability can be obtained +only by integrating the density across a range. The curve is normalized so +that the integral over all possible values is 1, meaning that the scale of +the density axis depends on the data values. + +Examples +-------- + +.. include:: ../docstrings/kdeplot.rst + +""".format( + params=_param_docs, + returns=_core_docs["returns"], + seealso=_core_docs["seealso"], +) + + +def ecdfplot( + data=None, *, + # Vector variables + x=None, y=None, hue=None, weights=None, + # Computation parameters + stat="proportion", complementary=False, + # Hue mapping parameters + palette=None, hue_order=None, hue_norm=None, + # Axes information + log_scale=None, legend=True, ax=None, + # Other appearance keywords + **kwargs, +): + + p = _DistributionPlotter( + data=data, + variables=_DistributionPlotter.get_semantics(locals()) + ) + + p.map_hue(palette=palette, order=hue_order, norm=hue_norm) + + # We could support other semantics (size, style) here fairly easily + # But it would make distplot a bit more complicated. + # It's always possible to add features like that later, so I am going to defer. + # It will be even easier to wait until after there is a more general/abstract + # way to go from semantic specs to artist attributes. + + if ax is None: + ax = plt.gca() + + p._attach(ax, log_scale=log_scale) + + color = kwargs.pop("color", kwargs.pop("c", None)) + kwargs["color"] = _default_color(ax.plot, hue, color, kwargs) + + if not p.has_xy_data: + return ax + + # We could add this one day, but it's of dubious value + if not p.univariate: + raise NotImplementedError("Bivariate ECDF plots are not implemented") + + estimate_kws = dict( + stat=stat, + complementary=complementary, + ) + + p.plot_univariate_ecdf( + estimate_kws=estimate_kws, + legend=legend, + **kwargs, + ) + + return ax + + +ecdfplot.__doc__ = """\ +Plot empirical cumulative distribution functions. + +An ECDF represents the proportion or count of observations falling below each +unique value in a dataset. Compared to a histogram or density plot, it has the +advantage that each observation is visualized directly, meaning that there are +no binning or smoothing parameters that need to be adjusted. It also aids direct +comparisons between multiple distributions. A downside is that the relationship +between the appearance of the plot and the basic properties of the distribution +(such as its central tendency, variance, and the presence of any bimodality) +may not be as intuitive. + +More information is provided in the :ref:`user guide `. + +Parameters +---------- +{params.core.data} +{params.core.xy} +{params.core.hue} +weights : vector or key in ``data`` + If provided, weight the contribution of the corresponding data points + towards the cumulative distribution using these values. +{params.ecdf.stat} +{params.ecdf.complementary} +{params.core.palette} +{params.core.hue_order} +{params.core.hue_norm} +{params.dist.log_scale} +{params.dist.legend} +{params.core.ax} +kwargs + Other keyword arguments are passed to :meth:`matplotlib.axes.Axes.plot`. + +Returns +------- +{returns.ax} + +See Also +-------- +{seealso.displot} +{seealso.histplot} +{seealso.kdeplot} +{seealso.rugplot} + +Examples +-------- + +.. include:: ../docstrings/ecdfplot.rst + +""".format( + params=_param_docs, + returns=_core_docs["returns"], + seealso=_core_docs["seealso"], +) + + +@_deprecate_positional_args +def rugplot( + x=None, # Allow positional x, because behavior won't change + *, + height=.025, axis=None, ax=None, + + # New parameters + data=None, y=None, hue=None, + palette=None, hue_order=None, hue_norm=None, + expand_margins=True, + legend=True, # TODO or maybe default to False? + + # Renamed parameter + a=None, + + **kwargs +): + + # A note: I think it would make sense to add multiple= to rugplot and allow + # rugs for different hue variables to be shifted orthogonal to the data axis + # But is this stacking, or dodging? + + # A note: if we want to add a style semantic to rugplot, + # we could make an option that draws the rug using scatterplot + + # A note, it would also be nice to offer some kind of histogram/density + # rugplot, since alpha blending doesn't work great in the large n regime + + # Handle deprecation of `a`` + if a is not None: + msg = "The `a` parameter is now called `x`. Please update your code." + warnings.warn(msg, FutureWarning) + x = a + del a + + # Handle deprecation of "axis" + if axis is not None: + msg = ( + "The `axis` variable is no longer used and will be removed. " + "Instead, assign variables directly to `x` or `y`." + ) + warnings.warn(msg, FutureWarning) + + # Handle deprecation of "vertical" + if kwargs.pop("vertical", axis == "y"): + x, y = None, x + msg = ( + "Using `vertical=True` to control the orientation of the plot " + "is deprecated. Instead, assign the data directly to `y`. " + ) + warnings.warn(msg, FutureWarning) + + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + + weights = None + p = _DistributionPlotter( + data=data, + variables=_DistributionPlotter.get_semantics(locals()), + ) + p.map_hue(palette=palette, order=hue_order, norm=hue_norm) + + if ax is None: + ax = plt.gca() + + p._attach(ax) + + color = kwargs.pop("color", kwargs.pop("c", None)) + kwargs["color"] = _default_color(ax.plot, hue, color, kwargs) + + if not p.has_xy_data: + return ax + + p.plot_rug(height, expand_margins, legend, **kwargs) + + return ax + + +rugplot.__doc__ = """\ +Plot marginal distributions by drawing ticks along the x and y axes. + +This function is intended to complement other plots by showing the location +of individual observations in an unobstrusive way. + +Parameters +---------- +{params.core.xy} +height : number + Proportion of axes extent covered by each rug element. +axis : {{"x", "y"}} + Axis to draw the rug on. + + .. deprecated:: 0.11.0 + specify axis by assigning the ``x`` or ``y`` variables. + +{params.core.ax} +{params.core.data} +{params.core.hue} +{params.core.palette} +{params.core.hue_order} +{params.core.hue_norm} +expand_margins : bool + If True, increase the axes margins by the height of the rug to avoid + overlap with other elements. +legend : bool + If False, do not add a legend for semantic variables. +kwargs + Other keyword arguments are passed to + :meth:`matplotlib.collections.LineCollection` + +Returns +------- +{returns.ax} + +Examples +-------- + +.. include:: ../docstrings/rugplot.rst + +""".format( + params=_param_docs, + returns=_core_docs["returns"], + seealso=_core_docs["seealso"], +) + + +def displot( + data=None, *, + # Vector variables + x=None, y=None, hue=None, row=None, col=None, weights=None, + # Other plot parameters + kind="hist", rug=False, rug_kws=None, log_scale=None, legend=True, + # Hue-mapping parameters + palette=None, hue_order=None, hue_norm=None, color=None, + # Faceting parameters + col_wrap=None, row_order=None, col_order=None, + height=5, aspect=1, facet_kws=None, + **kwargs, +): + + p = _DistributionFacetPlotter( + data=data, + variables=_DistributionFacetPlotter.get_semantics(locals()) + ) + + p.map_hue(palette=palette, order=hue_order, norm=hue_norm) + + _check_argument("kind", ["hist", "kde", "ecdf"], kind) + + # --- Initialize the FacetGrid object + + # Check for attempt to plot onto specific axes and warn + if "ax" in kwargs: + msg = ( + "`displot` is a figure-level function and does not accept " + "the ax= parameter. You may wish to try {}plot.".format(kind) + ) + warnings.warn(msg, UserWarning) + kwargs.pop("ax") + + for var in ["row", "col"]: + # Handle faceting variables that lack name information + if var in p.variables and p.variables[var] is None: + p.variables[var] = f"_{var}_" + + # Adapt the plot_data dataframe for use with FacetGrid + data = p.plot_data.rename(columns=p.variables) + data = data.loc[:, ~data.columns.duplicated()] + + col_name = p.variables.get("col", None) + row_name = p.variables.get("row", None) + + if facet_kws is None: + facet_kws = {} + + g = FacetGrid( + data=data, row=row_name, col=col_name, + col_wrap=col_wrap, row_order=row_order, + col_order=col_order, height=height, + aspect=aspect, + **facet_kws, + ) + + # Now attach the axes object to the plotter object + if kind == "kde": + allowed_types = ["numeric", "datetime"] + else: + allowed_types = None + p._attach(g, allowed_types=allowed_types, log_scale=log_scale) + + # Check for a specification that lacks x/y data and return early + if not p.has_xy_data: + return g + + if color is None and hue is None: + color = "C0" + # XXX else warn if hue is not None? + + kwargs["legend"] = legend + + # --- Draw the plots + + if kind == "hist": + + hist_kws = kwargs.copy() + + # Extract the parameters that will go directly to Histogram + estimate_defaults = {} + _assign_default_kwargs(estimate_defaults, Histogram.__init__, histplot) + + estimate_kws = {} + for key, default_val in estimate_defaults.items(): + estimate_kws[key] = hist_kws.pop(key, default_val) + + # Handle derivative defaults + if estimate_kws["discrete"] is None: + estimate_kws["discrete"] = p._default_discrete() + + hist_kws["estimate_kws"] = estimate_kws + + hist_kws.setdefault("color", color) + + if p.univariate: + + _assign_default_kwargs(hist_kws, p.plot_univariate_histogram, histplot) + p.plot_univariate_histogram(**hist_kws) + + else: + + _assign_default_kwargs(hist_kws, p.plot_bivariate_histogram, histplot) + p.plot_bivariate_histogram(**hist_kws) + + elif kind == "kde": + + kde_kws = kwargs.copy() + + # Extract the parameters that will go directly to KDE + estimate_defaults = {} + _assign_default_kwargs(estimate_defaults, KDE.__init__, kdeplot) + + estimate_kws = {} + for key, default_val in estimate_defaults.items(): + estimate_kws[key] = kde_kws.pop(key, default_val) -try: - import statsmodels.nonparametric.api as smnp - _has_statsmodels = True -except ImportError: - _has_statsmodels = False + kde_kws["estimate_kws"] = estimate_kws + kde_kws["color"] = color -from .utils import iqr, _kde_support -from .palettes import color_palette, light_palette, dark_palette, blend_palette + if p.univariate: + _assign_default_kwargs(kde_kws, p.plot_univariate_density, kdeplot) + p.plot_univariate_density(**kde_kws) -__all__ = ["distplot", "kdeplot", "rugplot"] + else: + + _assign_default_kwargs(kde_kws, p.plot_bivariate_density, kdeplot) + p.plot_bivariate_density(**kde_kws) + + elif kind == "ecdf": + + ecdf_kws = kwargs.copy() + + # Extract the parameters that will go directly to the estimator + estimate_kws = {} + estimate_defaults = {} + _assign_default_kwargs(estimate_defaults, ECDF.__init__, ecdfplot) + for key, default_val in estimate_defaults.items(): + estimate_kws[key] = ecdf_kws.pop(key, default_val) + + ecdf_kws["estimate_kws"] = estimate_kws + ecdf_kws["color"] = color + + if p.univariate: + + _assign_default_kwargs(ecdf_kws, p.plot_univariate_ecdf, ecdfplot) + p.plot_univariate_ecdf(**ecdf_kws) + + else: + + raise NotImplementedError("Bivariate ECDF plots are not implemented") + + # All plot kinds can include a rug + if rug: + # TODO with expand_margins=True, each facet expands margins... annoying! + if rug_kws is None: + rug_kws = {} + _assign_default_kwargs(rug_kws, p.plot_rug, rugplot) + rug_kws["legend"] = False + if color is not None: + rug_kws["color"] = color + p.plot_rug(**rug_kws) + + # Call FacetGrid annotation methods + # Note that the legend is currently set inside the plotting method + g.set_axis_labels( + x_var=p.variables.get("x", g.axes.flat[0].get_xlabel()), + y_var=p.variables.get("y", g.axes.flat[0].get_ylabel()), + ) + g.set_titles() + g.tight_layout() + + return g + + +displot.__doc__ = """\ +Figure-level interface for drawing distribution plots onto a FacetGrid. + +This function provides access to several approaches for visualizing the +univariate or bivariate distribution of data, including subsets of data +defined by semantic mapping and faceting across multiple subplots. The +``kind`` parameter selects the approach to use: + +- :func:`histplot` (with ``kind="hist"``; the default) +- :func:`kdeplot` (with ``kind="kde"``) +- :func:`ecdfplot` (with ``kind="ecdf"``; univariate-only) + +Additionally, a :func:`rugplot` can be added to any kind of plot to show +individual observations. + +Extra keyword arguments are passed to the underlying function, so you should +refer to the documentation for each to understand the complete set of options +for making plots with this interface. + +See the :doc:`distribution plots tutorial <../tutorial/distributions>` for a more +in-depth discussion of the relative strengths and weaknesses of each approach. +The distinction between figure-level and axes-level functions is explained +further in the :doc:`user guide <../tutorial/function_overview>`. + +Parameters +---------- +{params.core.data} +{params.core.xy} +{params.core.hue} +{params.facets.rowcol} +kind : {{"hist", "kde", "ecdf"}} + Approach for visualizing the data. Selects the underlying plotting function + and determines the additional set of valid parameters. +rug : bool + If True, show each observation with marginal ticks (as in :func:`rugplot`). +rug_kws : dict + Parameters to control the appearance of the rug plot. +{params.dist.log_scale} +{params.dist.legend} +{params.core.palette} +{params.core.hue_order} +{params.core.hue_norm} +{params.core.color} +{params.facets.col_wrap} +{params.facets.rowcol_order} +{params.facets.height} +{params.facets.aspect} +{params.facets.facet_kws} +kwargs + Other keyword arguments are documented with the relevant axes-level function: + + - :func:`histplot` (with ``kind="hist"``) + - :func:`kdeplot` (with ``kind="kde"``) + - :func:`ecdfplot` (with ``kind="ecdf"``) + +Returns +------- +{returns.facetgrid} + +See Also +-------- +{seealso.histplot} +{seealso.kdeplot} +{seealso.rugplot} +{seealso.ecdfplot} +{seealso.jointplot} + +Examples +-------- + +See the API documentation for the axes-level functions for more details +about the breadth of options available for each plot kind. + +.. include:: ../docstrings/displot.rst + +""".format( + params=_param_docs, + returns=_core_docs["returns"], + seealso=_core_docs["seealso"], +) + + +# =========================================================================== # +# DEPRECATED FUNCTIONS LIVE BELOW HERE +# =========================================================================== # def _freedman_diaconis_bins(a): @@ -31,7 +2386,8 @@ def _freedman_diaconis_bins(a): a = np.asarray(a) if len(a) < 2: return 1 - h = 2 * iqr(a) / (len(a) ** (1 / 3)) + iqr = np.subtract.reduce(np.nanpercentile(a, [75, 25])) + h = 2 * iqr / (len(a) ** (1 / 3)) # fall back to sqrt(a) bins if iqr is 0 if h == 0: return int(np.sqrt(a.size)) @@ -39,11 +2395,20 @@ def _freedman_diaconis_bins(a): return int(np.ceil((a.max() - a.min()) / h)) -def distplot(a, bins=None, hist=True, kde=True, rug=False, fit=None, +def distplot(a=None, bins=None, hist=True, kde=True, rug=False, fit=None, hist_kws=None, kde_kws=None, rug_kws=None, fit_kws=None, color=None, vertical=False, norm_hist=False, axlabel=None, - label=None, ax=None): - """Flexibly plot a univariate distribution of observations. + label=None, ax=None, x=None): + """DEPRECATED: Flexibly plot a univariate distribution of observations. + + .. warning:: + This function is deprecated and will be removed in a future version. + Please adapt your code to use one of two new functions: + + - :func:`displot`, a figure-level function with a similar flexibility + over the kind of plot to draw + - :func:`histplot`, an axes-level function for plotting histograms, + including with kernel density smoothing This function combines the matplotlib ``hist`` function (with automatic calculation of a good default bin size) with the seaborn :func:`kdeplot` @@ -52,7 +2417,6 @@ def distplot(a, bins=None, hist=True, kde=True, rug=False, fit=None, Parameters ---------- - a : Series, 1d-array, or list. Observed data. If this is a Series object with a ``name`` attribute, the name will be used to label the data axis. @@ -69,8 +2433,12 @@ def distplot(a, bins=None, hist=True, kde=True, rug=False, fit=None, An object with `fit` method, returning a tuple that can be passed to a `pdf` method a positional arguments following a grid of values to evaluate the pdf on. - {hist, kde, rug, fit}_kws : dictionaries, optional - Keyword arguments for underlying plotting functions. + hist_kws : dict, optional + Keyword arguments for :meth:`matplotlib.axes.Axes.hist`. + kde_kws : dict, optional + Keyword arguments for :func:`kdeplot`. + rug_kws : dict, optional + Keyword arguments for :func:`rugplot`. color : matplotlib color, optional Color to plot everything but the fitted curve in. vertical : bool, optional @@ -80,7 +2448,7 @@ def distplot(a, bins=None, hist=True, kde=True, rug=False, fit=None, This is implied if a KDE or fitted density is plotted. axlabel : string, False, or None, optional Name for the support axis label. If None, will try to get it - from a.namel if False, do not set a label. + from a.name if False, do not set a label. label : string, optional Legend label for the relevant component of the plot. ax : matplotlib axis, optional @@ -108,7 +2476,7 @@ def distplot(a, bins=None, hist=True, kde=True, rug=False, fit=None, :context: close-figs >>> import seaborn as sns, numpy as np - >>> sns.set(); np.random.seed(0) + >>> sns.set_theme(); np.random.seed(0) >>> x = np.random.randn(100) >>> ax = sns.distplot(x) @@ -163,6 +2531,23 @@ def distplot(a, bins=None, hist=True, kde=True, rug=False, fit=None, ... "alpha": 1, "color": "g"}) """ + + if kde and not hist: + axes_level_suggestion = ( + "`kdeplot` (an axes-level function for kernel density plots)." + ) + else: + axes_level_suggestion = ( + "`histplot` (an axes-level function for histograms)." + ) + + msg = ( + "`distplot` is a deprecated function and will be removed in a future version. " + "Please adapt your code to use either `displot` (a figure-level function with " + "similar flexibility) or " + axes_level_suggestion + ) + warnings.warn(msg, FutureWarning) + if ax is None: ax = plt.gca() @@ -173,23 +2558,26 @@ def distplot(a, bins=None, hist=True, kde=True, rug=False, fit=None, if axlabel is not None: label_ax = True - # Make a a 1-d array - a = np.asarray(a) + # Support new-style API + if x is not None: + a = x + + # Make a a 1-d float array + a = np.asarray(a, float) if a.ndim > 1: a = a.squeeze() + # Drop null values from array + a = remove_na(a) + # Decide if the hist is normed norm_hist = norm_hist or kde or (fit is not None) # Handle dictionary defaults - if hist_kws is None: - hist_kws = dict() - if kde_kws is None: - kde_kws = dict() - if rug_kws is None: - rug_kws = dict() - if fit_kws is None: - fit_kws = dict() + hist_kws = {} if hist_kws is None else hist_kws.copy() + kde_kws = {} if kde_kws is None else kde_kws.copy() + rug_kws = {} if rug_kws is None else rug_kws.copy() + fit_kws = {} if fit_kws is None else fit_kws.copy() # Get the color from the current color cycle if color is None: @@ -215,10 +2603,7 @@ def distplot(a, bins=None, hist=True, kde=True, rug=False, fit=None, if bins is None: bins = min(_freedman_diaconis_bins(a), 50) hist_kws.setdefault("alpha", 0.4) - if LooseVersion(mpl.__version__) < LooseVersion("2.2"): - hist_kws.setdefault("normed", norm_hist) - else: - hist_kws.setdefault("density", norm_hist) + hist_kws.setdefault("density", norm_hist) orientation = "horizontal" if vertical else "vertical" hist_color = hist_kws.pop("color", color) @@ -249,7 +2634,7 @@ def pdf(x): gridsize = fit_kws.pop("gridsize", 200) cut = fit_kws.pop("cut", 3) clip = fit_kws.pop("clip", (-np.inf, np.inf)) - bw = stats.gaussian_kde(a).scotts_factor() * a.std(ddof=1) + bw = gaussian_kde(a).scotts_factor() * a.std(ddof=1) x = _kde_support(a, bw, gridsize, cut, clip) params = fit.fit(a) y = pdf(x) @@ -266,485 +2651,3 @@ def pdf(x): ax.set_xlabel(axlabel) return ax - - -def _univariate_kdeplot(data, shade, vertical, kernel, bw, gridsize, cut, - clip, legend, ax, cumulative=False, **kwargs): - """Plot a univariate kernel density estimate on one of the axes.""" - - # Sort out the clipping - if clip is None: - clip = (-np.inf, np.inf) - - # Calculate the KDE - - if np.nan_to_num(data.var()) == 0: - # Don't try to compute KDE on singular data - msg = "Data must have variance to compute a kernel density estimate." - warnings.warn(msg, UserWarning) - x, y = np.array([]), np.array([]) - - elif _has_statsmodels: - # Prefer using statsmodels for kernel flexibility - x, y = _statsmodels_univariate_kde(data, kernel, bw, - gridsize, cut, clip, - cumulative=cumulative) - else: - # Fall back to scipy if missing statsmodels - if kernel != "gau": - kernel = "gau" - msg = "Kernel other than `gau` requires statsmodels." - warnings.warn(msg, UserWarning) - if cumulative: - raise ImportError("Cumulative distributions are currently " - "only implemented in statsmodels. " - "Please install statsmodels.") - x, y = _scipy_univariate_kde(data, bw, gridsize, cut, clip) - - # Make sure the density is nonnegative - y = np.amax(np.c_[np.zeros_like(y), y], axis=1) - - # Flip the data if the plot should be on the y axis - if vertical: - x, y = y, x - - # Check if a label was specified in the call - label = kwargs.pop("label", None) - - # Otherwise check if the data object has a name - if label is None and hasattr(data, "name"): - label = data.name - - # Decide if we're going to add a legend - legend = label is not None and legend - label = "_nolegend_" if label is None else label - - # Use the active color cycle to find the plot color - facecolor = kwargs.pop("facecolor", None) - line, = ax.plot(x, y, **kwargs) - color = line.get_color() - line.remove() - kwargs.pop("color", None) - facecolor = color if facecolor is None else facecolor - - # Draw the KDE plot and, optionally, shade - ax.plot(x, y, color=color, label=label, **kwargs) - shade_kws = dict( - facecolor=facecolor, - alpha=kwargs.get("alpha", 0.25), - clip_on=kwargs.get("clip_on", True), - zorder=kwargs.get("zorder", 1), - ) - if shade: - if vertical: - ax.fill_betweenx(y, 0, x, **shade_kws) - else: - ax.fill_between(x, 0, y, **shade_kws) - - # Set the density axis minimum to 0 - if vertical: - ax.set_xlim(0, auto=None) - else: - ax.set_ylim(0, auto=None) - - # Draw the legend here - handles, labels = ax.get_legend_handles_labels() - if legend and handles: - ax.legend(loc="best") - - return ax - - -def _statsmodels_univariate_kde(data, kernel, bw, gridsize, cut, clip, - cumulative=False): - """Compute a univariate kernel density estimate using statsmodels.""" - fft = kernel == "gau" - kde = smnp.KDEUnivariate(data) - kde.fit(kernel, bw, fft, gridsize=gridsize, cut=cut, clip=clip) - if cumulative: - grid, y = kde.support, kde.cdf - else: - grid, y = kde.support, kde.density - return grid, y - - -def _scipy_univariate_kde(data, bw, gridsize, cut, clip): - """Compute a univariate kernel density estimate using scipy.""" - try: - kde = stats.gaussian_kde(data, bw_method=bw) - except TypeError: - kde = stats.gaussian_kde(data) - if bw != "scott": # scipy default - msg = ("Ignoring bandwidth choice, " - "please upgrade scipy to use a different bandwidth.") - warnings.warn(msg, UserWarning) - if isinstance(bw, string_types): - bw = "scotts" if bw == "scott" else bw - bw = getattr(kde, "%s_factor" % bw)() * np.std(data) - grid = _kde_support(data, bw, gridsize, cut, clip) - y = kde(grid) - return grid, y - - -def _bivariate_kdeplot(x, y, filled, fill_lowest, - kernel, bw, gridsize, cut, clip, - axlabel, cbar, cbar_ax, cbar_kws, ax, **kwargs): - """Plot a joint KDE estimate as a bivariate contour plot.""" - # Determine the clipping - if clip is None: - clip = [(-np.inf, np.inf), (-np.inf, np.inf)] - elif np.ndim(clip) == 1: - clip = [clip, clip] - - # Calculate the KDE - if _has_statsmodels: - xx, yy, z = _statsmodels_bivariate_kde(x, y, bw, gridsize, cut, clip) - else: - xx, yy, z = _scipy_bivariate_kde(x, y, bw, gridsize, cut, clip) - - # Plot the contours - n_levels = kwargs.pop("n_levels", 10) - - scout, = ax.plot([], []) - default_color = scout.get_color() - scout.remove() - - color = kwargs.pop("color", default_color) - cmap = kwargs.pop("cmap", None) - if cmap is None: - if filled: - cmap = light_palette(color, as_cmap=True) - else: - cmap = dark_palette(color, as_cmap=True) - if isinstance(cmap, string_types): - if cmap.endswith("_d"): - pal = ["#333333"] - pal.extend(color_palette(cmap.replace("_d", "_r"), 2)) - cmap = blend_palette(pal, as_cmap=True) - else: - cmap = mpl.cm.get_cmap(cmap) - - label = kwargs.pop("label", None) - - kwargs["cmap"] = cmap - contour_func = ax.contourf if filled else ax.contour - cset = contour_func(xx, yy, z, n_levels, **kwargs) - if filled and not fill_lowest: - cset.collections[0].set_alpha(0) - kwargs["n_levels"] = n_levels - - if cbar: - cbar_kws = {} if cbar_kws is None else cbar_kws - ax.figure.colorbar(cset, cbar_ax, ax, **cbar_kws) - - # Label the axes - if hasattr(x, "name") and axlabel: - ax.set_xlabel(x.name) - if hasattr(y, "name") and axlabel: - ax.set_ylabel(y.name) - - if label is not None: - legend_color = cmap(.95) if color is None else color - if filled: - ax.fill_between([], [], color=legend_color, label=label) - else: - ax.plot([], [], color=legend_color, label=label) - - return ax - - -def _statsmodels_bivariate_kde(x, y, bw, gridsize, cut, clip): - """Compute a bivariate kde using statsmodels.""" - if isinstance(bw, string_types): - bw_func = getattr(smnp.bandwidths, "bw_" + bw) - x_bw = bw_func(x) - y_bw = bw_func(y) - bw = [x_bw, y_bw] - elif np.isscalar(bw): - bw = [bw, bw] - - if isinstance(x, pd.Series): - x = x.values - if isinstance(y, pd.Series): - y = y.values - - kde = smnp.KDEMultivariate([x, y], "cc", bw) - x_support = _kde_support(x, kde.bw[0], gridsize, cut, clip[0]) - y_support = _kde_support(y, kde.bw[1], gridsize, cut, clip[1]) - xx, yy = np.meshgrid(x_support, y_support) - z = kde.pdf([xx.ravel(), yy.ravel()]).reshape(xx.shape) - return xx, yy, z - - -def _scipy_bivariate_kde(x, y, bw, gridsize, cut, clip): - """Compute a bivariate kde using scipy.""" - data = np.c_[x, y] - kde = stats.gaussian_kde(data.T, bw_method=bw) - data_std = data.std(axis=0, ddof=1) - if isinstance(bw, string_types): - bw = "scotts" if bw == "scott" else bw - bw_x = getattr(kde, "%s_factor" % bw)() * data_std[0] - bw_y = getattr(kde, "%s_factor" % bw)() * data_std[1] - elif np.isscalar(bw): - bw_x, bw_y = bw, bw - else: - msg = ("Cannot specify a different bandwidth for each dimension " - "with the scipy backend. You should install statsmodels.") - raise ValueError(msg) - x_support = _kde_support(data[:, 0], bw_x, gridsize, cut, clip[0]) - y_support = _kde_support(data[:, 1], bw_y, gridsize, cut, clip[1]) - xx, yy = np.meshgrid(x_support, y_support) - z = kde([xx.ravel(), yy.ravel()]).reshape(xx.shape) - return xx, yy, z - - -def kdeplot(data, data2=None, shade=False, vertical=False, kernel="gau", - bw="scott", gridsize=100, cut=3, clip=None, legend=True, - cumulative=False, shade_lowest=True, cbar=False, cbar_ax=None, - cbar_kws=None, ax=None, **kwargs): - """Fit and plot a univariate or bivariate kernel density estimate. - - Parameters - ---------- - data : 1d array-like - Input data. - data2: 1d array-like, optional - Second input data. If present, a bivariate KDE will be estimated. - shade : bool, optional - If True, shade in the area under the KDE curve (or draw with filled - contours when data is bivariate). - vertical : bool, optional - If True, density is on x-axis. - kernel : {'gau' | 'cos' | 'biw' | 'epa' | 'tri' | 'triw' }, optional - Code for shape of kernel to fit with. Bivariate KDE can only use - gaussian kernel. - bw : {'scott' | 'silverman' | scalar | pair of scalars }, optional - Name of reference method to determine kernel size, scalar factor, - or scalar for each dimension of the bivariate plot. Note that the - underlying computational libraries have different interperetations - for this parameter: ``statsmodels`` uses it directly, but ``scipy`` - treats it as a scaling factor for the standard deviation of the - data. - gridsize : int, optional - Number of discrete points in the evaluation grid. - cut : scalar, optional - Draw the estimate to cut * bw from the extreme data points. - clip : pair of scalars, or pair of pair of scalars, optional - Lower and upper bounds for datapoints used to fit KDE. Can provide - a pair of (low, high) bounds for bivariate plots. - legend : bool, optional - If True, add a legend or label the axes when possible. - cumulative : bool, optional - If True, draw the cumulative distribution estimated by the kde. - shade_lowest : bool, optional - If True, shade the lowest contour of a bivariate KDE plot. Not - relevant when drawing a univariate plot or when ``shade=False``. - Setting this to ``False`` can be useful when you want multiple - densities on the same Axes. - cbar : bool, optional - If True and drawing a bivariate KDE plot, add a colorbar. - cbar_ax : matplotlib axes, optional - Existing axes to draw the colorbar onto, otherwise space is taken - from the main axes. - cbar_kws : dict, optional - Keyword arguments for ``fig.colorbar()``. - ax : matplotlib axes, optional - Axes to plot on, otherwise uses current axes. - kwargs : key, value pairings - Other keyword arguments are passed to ``plt.plot()`` or - ``plt.contour{f}`` depending on whether a univariate or bivariate - plot is being drawn. - - Returns - ------- - ax : matplotlib Axes - Axes with plot. - - See Also - -------- - distplot: Flexibly plot a univariate distribution of observations. - jointplot: Plot a joint dataset with bivariate and marginal distributions. - - Examples - -------- - - Plot a basic univariate density: - - .. plot:: - :context: close-figs - - >>> import numpy as np; np.random.seed(10) - >>> import seaborn as sns; sns.set(color_codes=True) - >>> mean, cov = [0, 2], [(1, .5), (.5, 1)] - >>> x, y = np.random.multivariate_normal(mean, cov, size=50).T - >>> ax = sns.kdeplot(x) - - Shade under the density curve and use a different color: - - .. plot:: - :context: close-figs - - >>> ax = sns.kdeplot(x, shade=True, color="r") - - Plot a bivariate density: - - .. plot:: - :context: close-figs - - >>> ax = sns.kdeplot(x, y) - - Use filled contours: - - .. plot:: - :context: close-figs - - >>> ax = sns.kdeplot(x, y, shade=True) - - Use more contour levels and a different color palette: - - .. plot:: - :context: close-figs - - >>> ax = sns.kdeplot(x, y, n_levels=30, cmap="Purples_d") - - Use a narrower bandwith: - - .. plot:: - :context: close-figs - - >>> ax = sns.kdeplot(x, bw=.15) - - Plot the density on the vertical axis: - - .. plot:: - :context: close-figs - - >>> ax = sns.kdeplot(y, vertical=True) - - Limit the density curve within the range of the data: - - .. plot:: - :context: close-figs - - >>> ax = sns.kdeplot(x, cut=0) - - Add a colorbar for the contours: - - .. plot:: - :context: close-figs - - >>> ax = sns.kdeplot(x, y, cbar=True) - - Plot two shaded bivariate densities: - - .. plot:: - :context: close-figs - - >>> iris = sns.load_dataset("iris") - >>> setosa = iris.loc[iris.species == "setosa"] - >>> virginica = iris.loc[iris.species == "virginica"] - >>> ax = sns.kdeplot(setosa.sepal_width, setosa.sepal_length, - ... cmap="Reds", shade=True, shade_lowest=False) - >>> ax = sns.kdeplot(virginica.sepal_width, virginica.sepal_length, - ... cmap="Blues", shade=True, shade_lowest=False) - - """ - if ax is None: - ax = plt.gca() - - if isinstance(data, list): - data = np.asarray(data) - - if len(data) == 0: - return ax - - data = data.astype(np.float64) - if data2 is not None: - if isinstance(data2, list): - data2 = np.asarray(data2) - data2 = data2.astype(np.float64) - - warn = False - bivariate = False - if isinstance(data, np.ndarray) and np.ndim(data) > 1: - warn = True - bivariate = True - x, y = data.T - elif isinstance(data, pd.DataFrame) and np.ndim(data) > 1: - warn = True - bivariate = True - x = data.iloc[:, 0].values - y = data.iloc[:, 1].values - elif data2 is not None: - bivariate = True - x = data - y = data2 - - if warn: - warn_msg = ("Passing a 2D dataset for a bivariate plot is deprecated " - "in favor of kdeplot(x, y), and it will cause an error in " - "future versions. Please update your code.") - warnings.warn(warn_msg, UserWarning) - - if bivariate and cumulative: - raise TypeError("Cumulative distribution plots are not" - "supported for bivariate distributions.") - if bivariate: - ax = _bivariate_kdeplot(x, y, shade, shade_lowest, - kernel, bw, gridsize, cut, clip, legend, - cbar, cbar_ax, cbar_kws, ax, **kwargs) - else: - ax = _univariate_kdeplot(data, shade, vertical, kernel, bw, - gridsize, cut, clip, legend, ax, - cumulative=cumulative, **kwargs) - - return ax - - -def rugplot(a, height=.05, axis="x", ax=None, **kwargs): - """Plot datapoints in an array as sticks on an axis. - - Parameters - ---------- - a : vector - 1D array of observations. - height : scalar, optional - Height of ticks as proportion of the axis. - axis : {'x' | 'y'}, optional - Axis to draw rugplot on. - ax : matplotlib axes, optional - Axes to draw plot into; otherwise grabs current axes. - kwargs : key, value pairings - Other keyword arguments are passed to ``LineCollection``. - - Returns - ------- - ax : matplotlib axes - The Axes object with the plot on it. - - """ - if ax is None: - ax = plt.gca() - a = np.asarray(a) - vertical = kwargs.pop("vertical", axis == "y") - - alias_map = dict(linewidth="lw", linestyle="ls", color="c") - for attr, alias in alias_map.items(): - if alias in kwargs: - kwargs[attr] = kwargs.pop(alias) - kwargs.setdefault("linewidth", 1) - - if vertical: - trans = tx.blended_transform_factory(ax.transAxes, ax.transData) - xy_pairs = np.column_stack([np.tile([0, height], len(a)), - np.repeat(a, 2)]) - else: - trans = tx.blended_transform_factory(ax.transData, ax.transAxes) - xy_pairs = np.column_stack([np.repeat(a, 2), - np.tile([0, height], len(a))]) - line_segs = xy_pairs.reshape([len(a), 2, 2]) - ax.add_collection(LineCollection(line_segs, transform=trans, **kwargs)) - - ax.autoscale_view(scalex=not vertical, scaley=vertical) - - return ax diff --git a/seaborn/external/docscrape.py b/seaborn/external/docscrape.py new file mode 100644 index 0000000000..d6552850a9 --- /dev/null +++ b/seaborn/external/docscrape.py @@ -0,0 +1,718 @@ +"""Extract reference documentation from the NumPy source tree. + +Copyright (C) 2008 Stefan van der Walt , Pauli Virtanen + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + +THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR +IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, +INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING +IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. + +""" +import inspect +import textwrap +import re +import pydoc +from warnings import warn +from collections import namedtuple +from collections.abc import Callable, Mapping +import copy +import sys + + +def strip_blank_lines(l): + "Remove leading and trailing blank lines from a list of lines" + while l and not l[0].strip(): + del l[0] + while l and not l[-1].strip(): + del l[-1] + return l + + +class Reader(object): + """A line-based string reader. + + """ + def __init__(self, data): + """ + Parameters + ---------- + data : str + String with lines separated by '\n'. + + """ + if isinstance(data, list): + self._str = data + else: + self._str = data.split('\n') # store string as list of lines + + self.reset() + + def __getitem__(self, n): + return self._str[n] + + def reset(self): + self._l = 0 # current line nr + + def read(self): + if not self.eof(): + out = self[self._l] + self._l += 1 + return out + else: + return '' + + def seek_next_non_empty_line(self): + for l in self[self._l:]: + if l.strip(): + break + else: + self._l += 1 + + def eof(self): + return self._l >= len(self._str) + + def read_to_condition(self, condition_func): + start = self._l + for line in self[start:]: + if condition_func(line): + return self[start:self._l] + self._l += 1 + if self.eof(): + return self[start:self._l+1] + return [] + + def read_to_next_empty_line(self): + self.seek_next_non_empty_line() + + def is_empty(line): + return not line.strip() + + return self.read_to_condition(is_empty) + + def read_to_next_unindented_line(self): + def is_unindented(line): + return (line.strip() and (len(line.lstrip()) == len(line))) + return self.read_to_condition(is_unindented) + + def peek(self, n=0): + if self._l + n < len(self._str): + return self[self._l + n] + else: + return '' + + def is_empty(self): + return not ''.join(self._str).strip() + + +class ParseError(Exception): + def __str__(self): + message = self.args[0] + if hasattr(self, 'docstring'): + message = "%s in %r" % (message, self.docstring) + return message + + +Parameter = namedtuple('Parameter', ['name', 'type', 'desc']) + + +class NumpyDocString(Mapping): + """Parses a numpydoc string to an abstract representation + + Instances define a mapping from section title to structured data. + + """ + + sections = { + 'Signature': '', + 'Summary': [''], + 'Extended Summary': [], + 'Parameters': [], + 'Returns': [], + 'Yields': [], + 'Receives': [], + 'Raises': [], + 'Warns': [], + 'Other Parameters': [], + 'Attributes': [], + 'Methods': [], + 'See Also': [], + 'Notes': [], + 'Warnings': [], + 'References': '', + 'Examples': '', + 'index': {} + } + + def __init__(self, docstring, config={}): + orig_docstring = docstring + docstring = textwrap.dedent(docstring).split('\n') + + self._doc = Reader(docstring) + self._parsed_data = copy.deepcopy(self.sections) + + try: + self._parse() + except ParseError as e: + e.docstring = orig_docstring + raise + + def __getitem__(self, key): + return self._parsed_data[key] + + def __setitem__(self, key, val): + if key not in self._parsed_data: + self._error_location("Unknown section %s" % key, error=False) + else: + self._parsed_data[key] = val + + def __iter__(self): + return iter(self._parsed_data) + + def __len__(self): + return len(self._parsed_data) + + def _is_at_section(self): + self._doc.seek_next_non_empty_line() + + if self._doc.eof(): + return False + + l1 = self._doc.peek().strip() # e.g. Parameters + + if l1.startswith('.. index::'): + return True + + l2 = self._doc.peek(1).strip() # ---------- or ========== + return l2.startswith('-'*len(l1)) or l2.startswith('='*len(l1)) + + def _strip(self, doc): + i = 0 + j = 0 + for i, line in enumerate(doc): + if line.strip(): + break + + for j, line in enumerate(doc[::-1]): + if line.strip(): + break + + return doc[i:len(doc)-j] + + def _read_to_next_section(self): + section = self._doc.read_to_next_empty_line() + + while not self._is_at_section() and not self._doc.eof(): + if not self._doc.peek(-1).strip(): # previous line was empty + section += [''] + + section += self._doc.read_to_next_empty_line() + + return section + + def _read_sections(self): + while not self._doc.eof(): + data = self._read_to_next_section() + name = data[0].strip() + + if name.startswith('..'): # index section + yield name, data[1:] + elif len(data) < 2: + yield StopIteration + else: + yield name, self._strip(data[2:]) + + def _parse_param_list(self, content, single_element_is_type=False): + r = Reader(content) + params = [] + while not r.eof(): + header = r.read().strip() + if ' : ' in header: + arg_name, arg_type = header.split(' : ')[:2] + else: + if single_element_is_type: + arg_name, arg_type = '', header + else: + arg_name, arg_type = header, '' + + desc = r.read_to_next_unindented_line() + desc = dedent_lines(desc) + desc = strip_blank_lines(desc) + + params.append(Parameter(arg_name, arg_type, desc)) + + return params + + # See also supports the following formats. + # + # + # SPACE* COLON SPACE+ SPACE* + # ( COMMA SPACE+ )+ (COMMA | PERIOD)? SPACE* + # ( COMMA SPACE+ )* SPACE* COLON SPACE+ SPACE* + + # is one of + # + # COLON COLON BACKTICK BACKTICK + # where + # is a legal function name, and + # is any nonempty sequence of word characters. + # Examples: func_f1 :meth:`func_h1` :obj:`~baz.obj_r` :class:`class_j` + # is a string describing the function. + + _role = r":(?P\w+):" + _funcbacktick = r"`(?P(?:~\w+\.)?[a-zA-Z0-9_\.-]+)`" + _funcplain = r"(?P[a-zA-Z0-9_\.-]+)" + _funcname = r"(" + _role + _funcbacktick + r"|" + _funcplain + r")" + _funcnamenext = _funcname.replace('role', 'rolenext') + _funcnamenext = _funcnamenext.replace('name', 'namenext') + _description = r"(?P\s*:(\s+(?P\S+.*))?)?\s*$" + _func_rgx = re.compile(r"^\s*" + _funcname + r"\s*") + _line_rgx = re.compile( + r"^\s*" + + r"(?P" + # group for all function names + _funcname + + r"(?P([,]\s+" + _funcnamenext + r")*)" + + r")" + # end of "allfuncs" + r"(?P[,\.])?" + # Some function lists have a trailing comma (or period) '\s*' + _description) + + # Empty elements are replaced with '..' + empty_description = '..' + + def _parse_see_also(self, content): + """ + func_name : Descriptive text + continued text + another_func_name : Descriptive text + func_name1, func_name2, :meth:`func_name`, func_name3 + + """ + + items = [] + + def parse_item_name(text): + """Match ':role:`name`' or 'name'.""" + m = self._func_rgx.match(text) + if not m: + raise ParseError("%s is not a item name" % text) + role = m.group('role') + name = m.group('name') if role else m.group('name2') + return name, role, m.end() + + rest = [] + for line in content: + if not line.strip(): + continue + + line_match = self._line_rgx.match(line) + description = None + if line_match: + description = line_match.group('desc') + if line_match.group('trailing') and description: + self._error_location( + 'Unexpected comma or period after function list at index %d of ' + 'line "%s"' % (line_match.end('trailing'), line), + error=False) + if not description and line.startswith(' '): + rest.append(line.strip()) + elif line_match: + funcs = [] + text = line_match.group('allfuncs') + while True: + if not text.strip(): + break + name, role, match_end = parse_item_name(text) + funcs.append((name, role)) + text = text[match_end:].strip() + if text and text[0] == ',': + text = text[1:].strip() + rest = list(filter(None, [description])) + items.append((funcs, rest)) + else: + raise ParseError("%s is not a item name" % line) + return items + + def _parse_index(self, section, content): + """ + .. index: default + :refguide: something, else, and more + + """ + def strip_each_in(lst): + return [s.strip() for s in lst] + + out = {} + section = section.split('::') + if len(section) > 1: + out['default'] = strip_each_in(section[1].split(','))[0] + for line in content: + line = line.split(':') + if len(line) > 2: + out[line[1]] = strip_each_in(line[2].split(',')) + return out + + def _parse_summary(self): + """Grab signature (if given) and summary""" + if self._is_at_section(): + return + + # If several signatures present, take the last one + while True: + summary = self._doc.read_to_next_empty_line() + summary_str = " ".join([s.strip() for s in summary]).strip() + compiled = re.compile(r'^([\w., ]+=)?\s*[\w\.]+\(.*\)$') + if compiled.match(summary_str): + self['Signature'] = summary_str + if not self._is_at_section(): + continue + break + + if summary is not None: + self['Summary'] = summary + + if not self._is_at_section(): + self['Extended Summary'] = self._read_to_next_section() + + def _parse(self): + self._doc.reset() + self._parse_summary() + + sections = list(self._read_sections()) + section_names = set([section for section, content in sections]) + + has_returns = 'Returns' in section_names + has_yields = 'Yields' in section_names + # We could do more tests, but we are not. Arbitrarily. + if has_returns and has_yields: + msg = 'Docstring contains both a Returns and Yields section.' + raise ValueError(msg) + if not has_yields and 'Receives' in section_names: + msg = 'Docstring contains a Receives section but not Yields.' + raise ValueError(msg) + + for (section, content) in sections: + if not section.startswith('..'): + section = (s.capitalize() for s in section.split(' ')) + section = ' '.join(section) + if self.get(section): + self._error_location("The section %s appears twice" + % section) + + if section in ('Parameters', 'Other Parameters', 'Attributes', + 'Methods'): + self[section] = self._parse_param_list(content) + elif section in ('Returns', 'Yields', 'Raises', 'Warns', 'Receives'): + self[section] = self._parse_param_list( + content, single_element_is_type=True) + elif section.startswith('.. index::'): + self['index'] = self._parse_index(section, content) + elif section == 'See Also': + self['See Also'] = self._parse_see_also(content) + else: + self[section] = content + + def _error_location(self, msg, error=True): + if hasattr(self, '_obj'): + # we know where the docs came from: + try: + filename = inspect.getsourcefile(self._obj) + except TypeError: + filename = None + msg = msg + (" in the docstring of %s in %s." + % (self._obj, filename)) + if error: + raise ValueError(msg) + else: + warn(msg) + + # string conversion routines + + def _str_header(self, name, symbol='-'): + return [name, len(name)*symbol] + + def _str_indent(self, doc, indent=4): + out = [] + for line in doc: + out += [' '*indent + line] + return out + + def _str_signature(self): + if self['Signature']: + return [self['Signature'].replace('*', r'\*')] + [''] + else: + return [''] + + def _str_summary(self): + if self['Summary']: + return self['Summary'] + [''] + else: + return [] + + def _str_extended_summary(self): + if self['Extended Summary']: + return self['Extended Summary'] + [''] + else: + return [] + + def _str_param_list(self, name): + out = [] + if self[name]: + out += self._str_header(name) + for param in self[name]: + parts = [] + if param.name: + parts.append(param.name) + if param.type: + parts.append(param.type) + out += [' : '.join(parts)] + if param.desc and ''.join(param.desc).strip(): + out += self._str_indent(param.desc) + out += [''] + return out + + def _str_section(self, name): + out = [] + if self[name]: + out += self._str_header(name) + out += self[name] + out += [''] + return out + + def _str_see_also(self, func_role): + if not self['See Also']: + return [] + out = [] + out += self._str_header("See Also") + out += [''] + last_had_desc = True + for funcs, desc in self['See Also']: + assert isinstance(funcs, list) + links = [] + for func, role in funcs: + if role: + link = ':%s:`%s`' % (role, func) + elif func_role: + link = ':%s:`%s`' % (func_role, func) + else: + link = "`%s`_" % func + links.append(link) + link = ', '.join(links) + out += [link] + if desc: + out += self._str_indent([' '.join(desc)]) + last_had_desc = True + else: + last_had_desc = False + out += self._str_indent([self.empty_description]) + + if last_had_desc: + out += [''] + out += [''] + return out + + def _str_index(self): + idx = self['index'] + out = [] + output_index = False + default_index = idx.get('default', '') + if default_index: + output_index = True + out += ['.. index:: %s' % default_index] + for section, references in idx.items(): + if section == 'default': + continue + output_index = True + out += [' :%s: %s' % (section, ', '.join(references))] + if output_index: + return out + else: + return '' + + def __str__(self, func_role=''): + out = [] + out += self._str_signature() + out += self._str_summary() + out += self._str_extended_summary() + for param_list in ('Parameters', 'Returns', 'Yields', 'Receives', + 'Other Parameters', 'Raises', 'Warns'): + out += self._str_param_list(param_list) + out += self._str_section('Warnings') + out += self._str_see_also(func_role) + for s in ('Notes', 'References', 'Examples'): + out += self._str_section(s) + for param_list in ('Attributes', 'Methods'): + out += self._str_param_list(param_list) + out += self._str_index() + return '\n'.join(out) + + +def indent(str, indent=4): + indent_str = ' '*indent + if str is None: + return indent_str + lines = str.split('\n') + return '\n'.join(indent_str + l for l in lines) + + +def dedent_lines(lines): + """Deindent a list of lines maximally""" + return textwrap.dedent("\n".join(lines)).split("\n") + + +def header(text, style='-'): + return text + '\n' + style*len(text) + '\n' + + +class FunctionDoc(NumpyDocString): + def __init__(self, func, role='func', doc=None, config={}): + self._f = func + self._role = role # e.g. "func" or "meth" + + if doc is None: + if func is None: + raise ValueError("No function or docstring given") + doc = inspect.getdoc(func) or '' + NumpyDocString.__init__(self, doc, config) + + if not self['Signature'] and func is not None: + func, func_name = self.get_func() + try: + try: + signature = str(inspect.signature(func)) + except (AttributeError, ValueError): + # try to read signature, backward compat for older Python + if sys.version_info[0] >= 3: + argspec = inspect.getfullargspec(func) + else: + argspec = inspect.getargspec(func) + signature = inspect.formatargspec(*argspec) + signature = '%s%s' % (func_name, signature) + except TypeError: + signature = '%s()' % func_name + self['Signature'] = signature + + def get_func(self): + func_name = getattr(self._f, '__name__', self.__class__.__name__) + if inspect.isclass(self._f): + func = getattr(self._f, '__call__', self._f.__init__) + else: + func = self._f + return func, func_name + + def __str__(self): + out = '' + + func, func_name = self.get_func() + + roles = {'func': 'function', + 'meth': 'method'} + + if self._role: + if self._role not in roles: + print("Warning: invalid role %s" % self._role) + out += '.. %s:: %s\n \n\n' % (roles.get(self._role, ''), + func_name) + + out += super(FunctionDoc, self).__str__(func_role=self._role) + return out + + +class ClassDoc(NumpyDocString): + + extra_public_methods = ['__call__'] + + def __init__(self, cls, doc=None, modulename='', func_doc=FunctionDoc, + config={}): + if not inspect.isclass(cls) and cls is not None: + raise ValueError("Expected a class or None, but got %r" % cls) + self._cls = cls + + if 'sphinx' in sys.modules: + from sphinx.ext.autodoc import ALL + else: + ALL = object() + + self.show_inherited_members = config.get( + 'show_inherited_class_members', True) + + if modulename and not modulename.endswith('.'): + modulename += '.' + self._mod = modulename + + if doc is None: + if cls is None: + raise ValueError("No class or documentation string given") + doc = pydoc.getdoc(cls) + + NumpyDocString.__init__(self, doc) + + _members = config.get('members', []) + if _members is ALL: + _members = None + _exclude = config.get('exclude-members', []) + + if config.get('show_class_members', True) and _exclude is not ALL: + def splitlines_x(s): + if not s: + return [] + else: + return s.splitlines() + for field, items in [('Methods', self.methods), + ('Attributes', self.properties)]: + if not self[field]: + doc_list = [] + for name in sorted(items): + if (name in _exclude or + (_members and name not in _members)): + continue + try: + doc_item = pydoc.getdoc(getattr(self._cls, name)) + doc_list.append( + Parameter(name, '', splitlines_x(doc_item))) + except AttributeError: + pass # method doesn't exist + self[field] = doc_list + + @property + def methods(self): + if self._cls is None: + return [] + return [name for name, func in inspect.getmembers(self._cls) + if ((not name.startswith('_') + or name in self.extra_public_methods) + and isinstance(func, Callable) + and self._is_show_member(name))] + + @property + def properties(self): + if self._cls is None: + return [] + return [name for name, func in inspect.getmembers(self._cls) + if (not name.startswith('_') and + (func is None or isinstance(func, property) or + inspect.isdatadescriptor(func)) + and self._is_show_member(name))] + + def _is_show_member(self, name): + if self.show_inherited_members: + return True # show all class members + if name not in self._cls.__dict__: + return False # class member is inherited, we do not show it + return True \ No newline at end of file diff --git a/seaborn/external/kde.py b/seaborn/external/kde.py new file mode 100644 index 0000000000..1b1d04db31 --- /dev/null +++ b/seaborn/external/kde.py @@ -0,0 +1,382 @@ +""" +This module was copied from the scipy project. + +In the process of copying, some methods were removed because they depended on +other parts of scipy (especially on compiled components), allowing seaborn to +have a simple and pure Python implementation. These include: + +- integrate_gaussian +- integrate_box +- integrate_box_1d +- integrate_kde +- logpdf +- resample + +Additionally, the numpy.linalg module was subsituted for scipy.linalg, +and the examples section (with doctests) was removed from the docstring + +The original scipy license is copied below: + +Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" + +# ------------------------------------------------------------------------------- +# +# Define classes for (uni/multi)-variate kernel density estimation. +# +# Currently, only Gaussian kernels are implemented. +# +# Written by: Robert Kern +# +# Date: 2004-08-09 +# +# Modified: 2005-02-10 by Robert Kern. +# Contributed to SciPy +# 2005-10-07 by Robert Kern. +# Some fixes to match the new scipy_core +# +# Copyright 2004-2005 by Enthought, Inc. +# +# ------------------------------------------------------------------------------- + +import numpy as np +from numpy import (asarray, atleast_2d, reshape, zeros, newaxis, dot, exp, pi, + sqrt, ravel, power, atleast_1d, squeeze, sum, transpose, + ones, cov) +from numpy import linalg + + +__all__ = ['gaussian_kde'] + + +class gaussian_kde(object): + """Representation of a kernel-density estimate using Gaussian kernels. + + Kernel density estimation is a way to estimate the probability density + function (PDF) of a random variable in a non-parametric way. + `gaussian_kde` works for both uni-variate and multi-variate data. It + includes automatic bandwidth determination. The estimation works best for + a unimodal distribution; bimodal or multi-modal distributions tend to be + oversmoothed. + + Parameters + ---------- + dataset : array_like + Datapoints to estimate from. In case of univariate data this is a 1-D + array, otherwise a 2-D array with shape (# of dims, # of data). + bw_method : str, scalar or callable, optional + The method used to calculate the estimator bandwidth. This can be + 'scott', 'silverman', a scalar constant or a callable. If a scalar, + this will be used directly as `kde.factor`. If a callable, it should + take a `gaussian_kde` instance as only parameter and return a scalar. + If None (default), 'scott' is used. See Notes for more details. + weights : array_like, optional + weights of datapoints. This must be the same shape as dataset. + If None (default), the samples are assumed to be equally weighted + + Attributes + ---------- + dataset : ndarray + The dataset with which `gaussian_kde` was initialized. + d : int + Number of dimensions. + n : int + Number of datapoints. + neff : int + Effective number of datapoints. + + .. versionadded:: 1.2.0 + factor : float + The bandwidth factor, obtained from `kde.covariance_factor`, with which + the covariance matrix is multiplied. + covariance : ndarray + The covariance matrix of `dataset`, scaled by the calculated bandwidth + (`kde.factor`). + inv_cov : ndarray + The inverse of `covariance`. + + Methods + ------- + evaluate + __call__ + integrate_gaussian + integrate_box_1d + integrate_box + integrate_kde + pdf + logpdf + resample + set_bandwidth + covariance_factor + + Notes + ----- + Bandwidth selection strongly influences the estimate obtained from the KDE + (much more so than the actual shape of the kernel). Bandwidth selection + can be done by a "rule of thumb", by cross-validation, by "plug-in + methods" or by other means; see [3]_, [4]_ for reviews. `gaussian_kde` + uses a rule of thumb, the default is Scott's Rule. + + Scott's Rule [1]_, implemented as `scotts_factor`, is:: + + n**(-1./(d+4)), + + with ``n`` the number of data points and ``d`` the number of dimensions. + In the case of unequally weighted points, `scotts_factor` becomes:: + + neff**(-1./(d+4)), + + with ``neff`` the effective number of datapoints. + Silverman's Rule [2]_, implemented as `silverman_factor`, is:: + + (n * (d + 2) / 4.)**(-1. / (d + 4)). + + or in the case of unequally weighted points:: + + (neff * (d + 2) / 4.)**(-1. / (d + 4)). + + Good general descriptions of kernel density estimation can be found in [1]_ + and [2]_, the mathematics for this multi-dimensional implementation can be + found in [1]_. + + With a set of weighted samples, the effective number of datapoints ``neff`` + is defined by:: + + neff = sum(weights)^2 / sum(weights^2) + + as detailed in [5]_. + + References + ---------- + .. [1] D.W. Scott, "Multivariate Density Estimation: Theory, Practice, and + Visualization", John Wiley & Sons, New York, Chicester, 1992. + .. [2] B.W. Silverman, "Density Estimation for Statistics and Data + Analysis", Vol. 26, Monographs on Statistics and Applied Probability, + Chapman and Hall, London, 1986. + .. [3] B.A. Turlach, "Bandwidth Selection in Kernel Density Estimation: A + Review", CORE and Institut de Statistique, Vol. 19, pp. 1-33, 1993. + .. [4] D.M. Bashtannyk and R.J. Hyndman, "Bandwidth selection for kernel + conditional density estimation", Computational Statistics & Data + Analysis, Vol. 36, pp. 279-298, 2001. + .. [5] Gray P. G., 1969, Journal of the Royal Statistical Society. + Series A (General), 132, 272 + + """ + def __init__(self, dataset, bw_method=None, weights=None): + self.dataset = atleast_2d(asarray(dataset)) + if not self.dataset.size > 1: + raise ValueError("`dataset` input should have multiple elements.") + + self.d, self.n = self.dataset.shape + + if weights is not None: + self._weights = atleast_1d(weights).astype(float) + self._weights /= sum(self._weights) + if self.weights.ndim != 1: + raise ValueError("`weights` input should be one-dimensional.") + if len(self._weights) != self.n: + raise ValueError("`weights` input should be of length n") + self._neff = 1/sum(self._weights**2) + + self.set_bandwidth(bw_method=bw_method) + + def evaluate(self, points): + """Evaluate the estimated pdf on a set of points. + + Parameters + ---------- + points : (# of dimensions, # of points)-array + Alternatively, a (# of dimensions,) vector can be passed in and + treated as a single point. + + Returns + ------- + values : (# of points,)-array + The values at each point. + + Raises + ------ + ValueError : if the dimensionality of the input points is different than + the dimensionality of the KDE. + + """ + points = atleast_2d(asarray(points)) + + d, m = points.shape + if d != self.d: + if d == 1 and m == self.d: + # points was passed in as a row vector + points = reshape(points, (self.d, 1)) + m = 1 + else: + msg = "points have dimension %s, dataset has dimension %s" % (d, + self.d) + raise ValueError(msg) + + output_dtype = np.common_type(self.covariance, points) + result = zeros((m,), dtype=output_dtype) + + whitening = linalg.cholesky(self.inv_cov) + scaled_dataset = dot(whitening, self.dataset) + scaled_points = dot(whitening, points) + + if m >= self.n: + # there are more points than data, so loop over data + for i in range(self.n): + diff = scaled_dataset[:, i, newaxis] - scaled_points + energy = sum(diff * diff, axis=0) / 2.0 + result += self.weights[i]*exp(-energy) + else: + # loop over points + for i in range(m): + diff = scaled_dataset - scaled_points[:, i, newaxis] + energy = sum(diff * diff, axis=0) / 2.0 + result[i] = sum(exp(-energy)*self.weights, axis=0) + + result = result / self._norm_factor + + return result + + __call__ = evaluate + + def scotts_factor(self): + """Compute Scott's factor. + + Returns + ------- + s : float + Scott's factor. + """ + return power(self.neff, -1./(self.d+4)) + + def silverman_factor(self): + """Compute the Silverman factor. + + Returns + ------- + s : float + The silverman factor. + """ + return power(self.neff*(self.d+2.0)/4.0, -1./(self.d+4)) + + # Default method to calculate bandwidth, can be overwritten by subclass + covariance_factor = scotts_factor + covariance_factor.__doc__ = """Computes the coefficient (`kde.factor`) that + multiplies the data covariance matrix to obtain the kernel covariance + matrix. The default is `scotts_factor`. A subclass can overwrite this + method to provide a different method, or set it through a call to + `kde.set_bandwidth`.""" + + def set_bandwidth(self, bw_method=None): + """Compute the estimator bandwidth with given method. + + The new bandwidth calculated after a call to `set_bandwidth` is used + for subsequent evaluations of the estimated density. + + Parameters + ---------- + bw_method : str, scalar or callable, optional + The method used to calculate the estimator bandwidth. This can be + 'scott', 'silverman', a scalar constant or a callable. If a + scalar, this will be used directly as `kde.factor`. If a callable, + it should take a `gaussian_kde` instance as only parameter and + return a scalar. If None (default), nothing happens; the current + `kde.covariance_factor` method is kept. + + Notes + ----- + .. versionadded:: 0.11 + + """ + if bw_method is None: + pass + elif bw_method == 'scott': + self.covariance_factor = self.scotts_factor + elif bw_method == 'silverman': + self.covariance_factor = self.silverman_factor + elif np.isscalar(bw_method) and not isinstance(bw_method, str): + self._bw_method = 'use constant' + self.covariance_factor = lambda: bw_method + elif callable(bw_method): + self._bw_method = bw_method + self.covariance_factor = lambda: self._bw_method(self) + else: + msg = "`bw_method` should be 'scott', 'silverman', a scalar " \ + "or a callable." + raise ValueError(msg) + + self._compute_covariance() + + def _compute_covariance(self): + """Computes the covariance matrix for each Gaussian kernel using + covariance_factor(). + """ + self.factor = self.covariance_factor() + # Cache covariance and inverse covariance of the data + if not hasattr(self, '_data_inv_cov'): + self._data_covariance = atleast_2d(cov(self.dataset, rowvar=1, + bias=False, + aweights=self.weights)) + self._data_inv_cov = linalg.inv(self._data_covariance) + + self.covariance = self._data_covariance * self.factor**2 + self.inv_cov = self._data_inv_cov / self.factor**2 + self._norm_factor = sqrt(linalg.det(2*pi*self.covariance)) + + def pdf(self, x): + """ + Evaluate the estimated pdf on a provided set of points. + + Notes + ----- + This is an alias for `gaussian_kde.evaluate`. See the ``evaluate`` + docstring for more details. + + """ + return self.evaluate(x) + + @property + def weights(self): + try: + return self._weights + except AttributeError: + self._weights = ones(self.n)/self.n + return self._weights + + @property + def neff(self): + try: + return self._neff + except AttributeError: + self._neff = 1/sum(self.weights**2) + return self._neff diff --git a/seaborn/external/six.py b/seaborn/external/six.py deleted file mode 100644 index c374474da1..0000000000 --- a/seaborn/external/six.py +++ /dev/null @@ -1,869 +0,0 @@ -"""Utilities for writing code that runs on Python 2 and 3""" - -# Copyright (c) 2010-2015 Benjamin Peterson -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -from __future__ import absolute_import - -import functools -import itertools -import operator -import sys -import types - -__author__ = "Benjamin Peterson " -__version__ = "1.10.0" - - -# Useful for very coarse version differentiation. -PY2 = sys.version_info[0] == 2 -PY3 = sys.version_info[0] == 3 -PY34 = sys.version_info[0:2] >= (3, 4) - -if PY3: - string_types = str, - integer_types = int, - class_types = type, - text_type = str - binary_type = bytes - - MAXSIZE = sys.maxsize -else: - string_types = basestring, - integer_types = (int, long) - class_types = (type, types.ClassType) - text_type = unicode - binary_type = str - - if sys.platform.startswith("java"): - # Jython always uses 32 bits. - MAXSIZE = int((1 << 31) - 1) - else: - # It's possible to have sizeof(long) != sizeof(Py_ssize_t). - class X(object): - - def __len__(self): - return 1 << 31 - try: - len(X()) - except OverflowError: - # 32-bit - MAXSIZE = int((1 << 31) - 1) - else: - # 64-bit - MAXSIZE = int((1 << 63) - 1) - del X - - -def _add_doc(func, doc): - """Add documentation to a function.""" - func.__doc__ = doc - - -def _import_module(name): - """Import module, returning the module after the last dot.""" - __import__(name) - return sys.modules[name] - - -class _LazyDescr(object): - - def __init__(self, name): - self.name = name - - def __get__(self, obj, tp): - result = self._resolve() - setattr(obj, self.name, result) # Invokes __set__. - try: - # This is a bit ugly, but it avoids running this again by - # removing this descriptor. - delattr(obj.__class__, self.name) - except AttributeError: - pass - return result - - -class MovedModule(_LazyDescr): - - def __init__(self, name, old, new=None): - super(MovedModule, self).__init__(name) - if PY3: - if new is None: - new = name - self.mod = new - else: - self.mod = old - - def _resolve(self): - return _import_module(self.mod) - - def __getattr__(self, attr): - _module = self._resolve() - value = getattr(_module, attr) - setattr(self, attr, value) - return value - - -class _LazyModule(types.ModuleType): - - def __init__(self, name): - super(_LazyModule, self).__init__(name) - self.__doc__ = self.__class__.__doc__ - - def __dir__(self): - attrs = ["__doc__", "__name__"] - attrs += [attr.name for attr in self._moved_attributes] - return attrs - - # Subclasses should override this - _moved_attributes = [] - - -class MovedAttribute(_LazyDescr): - - def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None): - super(MovedAttribute, self).__init__(name) - if PY3: - if new_mod is None: - new_mod = name - self.mod = new_mod - if new_attr is None: - if old_attr is None: - new_attr = name - else: - new_attr = old_attr - self.attr = new_attr - else: - self.mod = old_mod - if old_attr is None: - old_attr = name - self.attr = old_attr - - def _resolve(self): - module = _import_module(self.mod) - return getattr(module, self.attr) - - -class _SixMetaPathImporter(object): - - """ - A meta path importer to import six.moves and its submodules. - - This class implements a PEP302 finder and loader. It should be compatible - with Python 2.5 and all existing versions of Python3 - """ - - def __init__(self, six_module_name): - self.name = six_module_name - self.known_modules = {} - - def _add_module(self, mod, *fullnames): - for fullname in fullnames: - self.known_modules[self.name + "." + fullname] = mod - - def _get_module(self, fullname): - return self.known_modules[self.name + "." + fullname] - - def find_module(self, fullname, path=None): - if fullname in self.known_modules: - return self - return None - - def __get_module(self, fullname): - try: - return self.known_modules[fullname] - except KeyError: - raise ImportError("This loader does not know module " + fullname) - - def load_module(self, fullname): - try: - # in case of a reload - return sys.modules[fullname] - except KeyError: - pass - mod = self.__get_module(fullname) - if isinstance(mod, MovedModule): - mod = mod._resolve() - else: - mod.__loader__ = self - sys.modules[fullname] = mod - return mod - - def is_package(self, fullname): - """ - Return true, if the named module is a package. - - We need this method to get correct spec objects with - Python 3.4 (see PEP451) - """ - return hasattr(self.__get_module(fullname), "__path__") - - def get_code(self, fullname): - """Return None - - Required, if is_package is implemented""" - self.__get_module(fullname) # eventually raises ImportError - return None - get_source = get_code # same as get_code - -_importer = _SixMetaPathImporter(__name__) - - -class _MovedItems(_LazyModule): - - """Lazy loading of moved objects""" - __path__ = [] # mark as package - - -_moved_attributes = [ - MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"), - MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"), - MovedAttribute("filterfalse", "itertools", "itertools", "ifilterfalse", "filterfalse"), - MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"), - MovedAttribute("intern", "__builtin__", "sys"), - MovedAttribute("map", "itertools", "builtins", "imap", "map"), - MovedAttribute("getcwd", "os", "os", "getcwdu", "getcwd"), - MovedAttribute("getcwdb", "os", "os", "getcwd", "getcwdb"), - MovedAttribute("range", "__builtin__", "builtins", "xrange", "range"), - MovedAttribute("reload_module", "__builtin__", "importlib" if PY34 else "imp", "reload"), - MovedAttribute("reduce", "__builtin__", "functools"), - MovedAttribute("shlex_quote", "pipes", "shlex", "quote"), - MovedAttribute("StringIO", "StringIO", "io"), - MovedAttribute("UserDict", "UserDict", "collections"), - MovedAttribute("UserList", "UserList", "collections"), - MovedAttribute("UserString", "UserString", "collections"), - MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"), - MovedAttribute("zip", "itertools", "builtins", "izip", "zip"), - MovedAttribute("zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"), - MovedModule("builtins", "__builtin__"), - MovedModule("configparser", "ConfigParser"), - MovedModule("copyreg", "copy_reg"), - MovedModule("dbm_gnu", "gdbm", "dbm.gnu"), - MovedModule("_dummy_thread", "dummy_thread", "_dummy_thread"), - MovedModule("http_cookiejar", "cookielib", "http.cookiejar"), - MovedModule("http_cookies", "Cookie", "http.cookies"), - MovedModule("html_entities", "htmlentitydefs", "html.entities"), - MovedModule("html_parser", "HTMLParser", "html.parser"), - MovedModule("http_client", "httplib", "http.client"), - MovedModule("email_mime_multipart", "email.MIMEMultipart", "email.mime.multipart"), - MovedModule("email_mime_nonmultipart", "email.MIMENonMultipart", "email.mime.nonmultipart"), - MovedModule("email_mime_text", "email.MIMEText", "email.mime.text"), - MovedModule("email_mime_base", "email.MIMEBase", "email.mime.base"), - MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"), - MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"), - MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"), - MovedModule("cPickle", "cPickle", "pickle"), - MovedModule("queue", "Queue"), - MovedModule("reprlib", "repr"), - MovedModule("socketserver", "SocketServer"), - MovedModule("_thread", "thread", "_thread"), - MovedModule("tkinter", "Tkinter"), - MovedModule("tkinter_dialog", "Dialog", "tkinter.dialog"), - MovedModule("tkinter_filedialog", "FileDialog", "tkinter.filedialog"), - MovedModule("tkinter_scrolledtext", "ScrolledText", "tkinter.scrolledtext"), - MovedModule("tkinter_simpledialog", "SimpleDialog", "tkinter.simpledialog"), - MovedModule("tkinter_tix", "Tix", "tkinter.tix"), - MovedModule("tkinter_ttk", "ttk", "tkinter.ttk"), - MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"), - MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"), - MovedModule("tkinter_colorchooser", "tkColorChooser", - "tkinter.colorchooser"), - MovedModule("tkinter_commondialog", "tkCommonDialog", - "tkinter.commondialog"), - MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"), - MovedModule("tkinter_font", "tkFont", "tkinter.font"), - MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"), - MovedModule("tkinter_tksimpledialog", "tkSimpleDialog", - "tkinter.simpledialog"), - MovedModule("urllib_parse", __name__ + ".moves.urllib_parse", "urllib.parse"), - MovedModule("urllib_error", __name__ + ".moves.urllib_error", "urllib.error"), - MovedModule("urllib", __name__ + ".moves.urllib", __name__ + ".moves.urllib"), - MovedModule("urllib_robotparser", "robotparser", "urllib.robotparser"), - MovedModule("xmlrpc_client", "xmlrpclib", "xmlrpc.client"), - MovedModule("xmlrpc_server", "SimpleXMLRPCServer", "xmlrpc.server"), -] -# Add windows specific modules. -if sys.platform == "win32": - _moved_attributes += [ - MovedModule("winreg", "_winreg"), - ] - -for attr in _moved_attributes: - setattr(_MovedItems, attr.name, attr) - if isinstance(attr, MovedModule): - _importer._add_module(attr, "moves." + attr.name) -del attr - -_MovedItems._moved_attributes = _moved_attributes - -moves = _MovedItems(__name__ + ".moves") -_importer._add_module(moves, "moves") - - -class Module_six_moves_urllib_parse(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_parse""" - - -_urllib_parse_moved_attributes = [ - MovedAttribute("ParseResult", "urlparse", "urllib.parse"), - MovedAttribute("SplitResult", "urlparse", "urllib.parse"), - MovedAttribute("parse_qs", "urlparse", "urllib.parse"), - MovedAttribute("parse_qsl", "urlparse", "urllib.parse"), - MovedAttribute("urldefrag", "urlparse", "urllib.parse"), - MovedAttribute("urljoin", "urlparse", "urllib.parse"), - MovedAttribute("urlparse", "urlparse", "urllib.parse"), - MovedAttribute("urlsplit", "urlparse", "urllib.parse"), - MovedAttribute("urlunparse", "urlparse", "urllib.parse"), - MovedAttribute("urlunsplit", "urlparse", "urllib.parse"), - MovedAttribute("quote", "urllib", "urllib.parse"), - MovedAttribute("quote_plus", "urllib", "urllib.parse"), - MovedAttribute("unquote", "urllib", "urllib.parse"), - MovedAttribute("unquote_plus", "urllib", "urllib.parse"), - MovedAttribute("urlencode", "urllib", "urllib.parse"), - MovedAttribute("splitquery", "urllib", "urllib.parse"), - MovedAttribute("splittag", "urllib", "urllib.parse"), - MovedAttribute("splituser", "urllib", "urllib.parse"), - MovedAttribute("uses_fragment", "urlparse", "urllib.parse"), - MovedAttribute("uses_netloc", "urlparse", "urllib.parse"), - MovedAttribute("uses_params", "urlparse", "urllib.parse"), - MovedAttribute("uses_query", "urlparse", "urllib.parse"), - MovedAttribute("uses_relative", "urlparse", "urllib.parse"), -] -for attr in _urllib_parse_moved_attributes: - setattr(Module_six_moves_urllib_parse, attr.name, attr) -del attr - -Module_six_moves_urllib_parse._moved_attributes = _urllib_parse_moved_attributes - -_importer._add_module(Module_six_moves_urllib_parse(__name__ + ".moves.urllib_parse"), - "moves.urllib_parse", "moves.urllib.parse") - - -class Module_six_moves_urllib_error(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_error""" - - -_urllib_error_moved_attributes = [ - MovedAttribute("URLError", "urllib2", "urllib.error"), - MovedAttribute("HTTPError", "urllib2", "urllib.error"), - MovedAttribute("ContentTooShortError", "urllib", "urllib.error"), -] -for attr in _urllib_error_moved_attributes: - setattr(Module_six_moves_urllib_error, attr.name, attr) -del attr - -Module_six_moves_urllib_error._moved_attributes = _urllib_error_moved_attributes - -_importer._add_module(Module_six_moves_urllib_error(__name__ + ".moves.urllib.error"), - "moves.urllib_error", "moves.urllib.error") - - -class Module_six_moves_urllib_request(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_request""" - - -_urllib_request_moved_attributes = [ - MovedAttribute("urlopen", "urllib2", "urllib.request"), - MovedAttribute("install_opener", "urllib2", "urllib.request"), - MovedAttribute("build_opener", "urllib2", "urllib.request"), - MovedAttribute("pathname2url", "urllib", "urllib.request"), - MovedAttribute("url2pathname", "urllib", "urllib.request"), - MovedAttribute("getproxies", "urllib", "urllib.request"), - MovedAttribute("Request", "urllib2", "urllib.request"), - MovedAttribute("OpenerDirector", "urllib2", "urllib.request"), - MovedAttribute("HTTPDefaultErrorHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPRedirectHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPCookieProcessor", "urllib2", "urllib.request"), - MovedAttribute("ProxyHandler", "urllib2", "urllib.request"), - MovedAttribute("BaseHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPPasswordMgr", "urllib2", "urllib.request"), - MovedAttribute("HTTPPasswordMgrWithDefaultRealm", "urllib2", "urllib.request"), - MovedAttribute("AbstractBasicAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPBasicAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("ProxyBasicAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("AbstractDigestAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPDigestAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("ProxyDigestAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPSHandler", "urllib2", "urllib.request"), - MovedAttribute("FileHandler", "urllib2", "urllib.request"), - MovedAttribute("FTPHandler", "urllib2", "urllib.request"), - MovedAttribute("CacheFTPHandler", "urllib2", "urllib.request"), - MovedAttribute("UnknownHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPErrorProcessor", "urllib2", "urllib.request"), - MovedAttribute("urlretrieve", "urllib", "urllib.request"), - MovedAttribute("urlcleanup", "urllib", "urllib.request"), - MovedAttribute("URLopener", "urllib", "urllib.request"), - MovedAttribute("FancyURLopener", "urllib", "urllib.request"), - MovedAttribute("proxy_bypass", "urllib", "urllib.request"), -] -for attr in _urllib_request_moved_attributes: - setattr(Module_six_moves_urllib_request, attr.name, attr) -del attr - -Module_six_moves_urllib_request._moved_attributes = _urllib_request_moved_attributes - -_importer._add_module(Module_six_moves_urllib_request(__name__ + ".moves.urllib.request"), - "moves.urllib_request", "moves.urllib.request") - - -class Module_six_moves_urllib_response(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_response""" - - -_urllib_response_moved_attributes = [ - MovedAttribute("addbase", "urllib", "urllib.response"), - MovedAttribute("addclosehook", "urllib", "urllib.response"), - MovedAttribute("addinfo", "urllib", "urllib.response"), - MovedAttribute("addinfourl", "urllib", "urllib.response"), -] -for attr in _urllib_response_moved_attributes: - setattr(Module_six_moves_urllib_response, attr.name, attr) -del attr - -Module_six_moves_urllib_response._moved_attributes = _urllib_response_moved_attributes - -_importer._add_module(Module_six_moves_urllib_response(__name__ + ".moves.urllib.response"), - "moves.urllib_response", "moves.urllib.response") - - -class Module_six_moves_urllib_robotparser(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_robotparser""" - - -_urllib_robotparser_moved_attributes = [ - MovedAttribute("RobotFileParser", "robotparser", "urllib.robotparser"), -] -for attr in _urllib_robotparser_moved_attributes: - setattr(Module_six_moves_urllib_robotparser, attr.name, attr) -del attr - -Module_six_moves_urllib_robotparser._moved_attributes = _urllib_robotparser_moved_attributes - -_importer._add_module(Module_six_moves_urllib_robotparser(__name__ + ".moves.urllib.robotparser"), - "moves.urllib_robotparser", "moves.urllib.robotparser") - - -class Module_six_moves_urllib(types.ModuleType): - - """Create a six.moves.urllib namespace that resembles the Python 3 namespace""" - __path__ = [] # mark as package - parse = _importer._get_module("moves.urllib_parse") - error = _importer._get_module("moves.urllib_error") - request = _importer._get_module("moves.urllib_request") - response = _importer._get_module("moves.urllib_response") - robotparser = _importer._get_module("moves.urllib_robotparser") - - def __dir__(self): - return ['parse', 'error', 'request', 'response', 'robotparser'] - -_importer._add_module(Module_six_moves_urllib(__name__ + ".moves.urllib"), - "moves.urllib") - - -def add_move(move): - """Add an item to six.moves.""" - setattr(_MovedItems, move.name, move) - - -def remove_move(name): - """Remove item from six.moves.""" - try: - delattr(_MovedItems, name) - except AttributeError: - try: - del moves.__dict__[name] - except KeyError: - raise AttributeError("no such move, %r" % (name,)) - - -if PY3: - _meth_func = "__func__" - _meth_self = "__self__" - - _func_closure = "__closure__" - _func_code = "__code__" - _func_defaults = "__defaults__" - _func_globals = "__globals__" -else: - _meth_func = "im_func" - _meth_self = "im_self" - - _func_closure = "func_closure" - _func_code = "func_code" - _func_defaults = "func_defaults" - _func_globals = "func_globals" - - -try: - advance_iterator = next -except NameError: - def advance_iterator(it): - return it.next() -next = advance_iterator - - -try: - callable = callable -except NameError: - def callable(obj): - return any("__call__" in klass.__dict__ for klass in type(obj).__mro__) - - -if PY3: - def get_unbound_function(unbound): - return unbound - - create_bound_method = types.MethodType - - def create_unbound_method(func, cls): - return func - - Iterator = object -else: - def get_unbound_function(unbound): - return unbound.im_func - - def create_bound_method(func, obj): - return types.MethodType(func, obj, obj.__class__) - - def create_unbound_method(func, cls): - return types.MethodType(func, None, cls) - - class Iterator(object): - - def next(self): - return type(self).__next__(self) - - callable = callable -_add_doc(get_unbound_function, - """Get the function out of a possibly unbound function""") - - -get_method_function = operator.attrgetter(_meth_func) -get_method_self = operator.attrgetter(_meth_self) -get_function_closure = operator.attrgetter(_func_closure) -get_function_code = operator.attrgetter(_func_code) -get_function_defaults = operator.attrgetter(_func_defaults) -get_function_globals = operator.attrgetter(_func_globals) - - -if PY3: - def iterkeys(d, **kw): - return iter(d.keys(**kw)) - - def itervalues(d, **kw): - return iter(d.values(**kw)) - - def iteritems(d, **kw): - return iter(d.items(**kw)) - - def iterlists(d, **kw): - return iter(d.lists(**kw)) - - viewkeys = operator.methodcaller("keys") - - viewvalues = operator.methodcaller("values") - - viewitems = operator.methodcaller("items") -else: - def iterkeys(d, **kw): - return d.iterkeys(**kw) - - def itervalues(d, **kw): - return d.itervalues(**kw) - - def iteritems(d, **kw): - return d.iteritems(**kw) - - def iterlists(d, **kw): - return d.iterlists(**kw) - - viewkeys = operator.methodcaller("viewkeys") - - viewvalues = operator.methodcaller("viewvalues") - - viewitems = operator.methodcaller("viewitems") - -_add_doc(iterkeys, "Return an iterator over the keys of a dictionary.") -_add_doc(itervalues, "Return an iterator over the values of a dictionary.") -_add_doc(iteritems, - "Return an iterator over the (key, value) pairs of a dictionary.") -_add_doc(iterlists, - "Return an iterator over the (key, [values]) pairs of a dictionary.") - - -if PY3: - def b(s): - return s.encode("latin-1") - - def u(s): - return s - unichr = chr - import struct - int2byte = struct.Struct(">B").pack - del struct - byte2int = operator.itemgetter(0) - indexbytes = operator.getitem - iterbytes = iter - import io - StringIO = io.StringIO - BytesIO = io.BytesIO - _assertCountEqual = "assertCountEqual" - if sys.version_info[1] <= 1: - _assertRaisesRegex = "assertRaisesRegexp" - _assertRegex = "assertRegexpMatches" - else: - _assertRaisesRegex = "assertRaisesRegex" - _assertRegex = "assertRegex" -else: - def b(s): - return s - # Workaround for standalone backslash - - def u(s): - return unicode(s.replace(r'\\', r'\\\\'), "unicode_escape") - unichr = unichr - int2byte = chr - - def byte2int(bs): - return ord(bs[0]) - - def indexbytes(buf, i): - return ord(buf[i]) - iterbytes = functools.partial(itertools.imap, ord) - import StringIO - StringIO = BytesIO = StringIO.StringIO - _assertCountEqual = "assertItemsEqual" - _assertRaisesRegex = "assertRaisesRegexp" - _assertRegex = "assertRegexpMatches" -_add_doc(b, """Byte literal""") -_add_doc(u, """Text literal""") - - -def assertCountEqual(self, *args, **kwargs): - return getattr(self, _assertCountEqual)(*args, **kwargs) - - -def assertRaisesRegex(self, *args, **kwargs): - return getattr(self, _assertRaisesRegex)(*args, **kwargs) - - -def assertRegex(self, *args, **kwargs): - return getattr(self, _assertRegex)(*args, **kwargs) - - -if PY3: - exec_ = getattr(moves.builtins, "exec") - - def reraise(tp, value, tb=None): - if value is None: - value = tp() - if value.__traceback__ is not tb: - raise value.with_traceback(tb) - raise value - -else: - def exec_(_code_, _globs_=None, _locs_=None): - """Execute code in a namespace.""" - if _globs_ is None: - frame = sys._getframe(1) - _globs_ = frame.f_globals - if _locs_ is None: - _locs_ = frame.f_locals - del frame - elif _locs_ is None: - _locs_ = _globs_ - exec("""exec _code_ in _globs_, _locs_""") - - exec_("""def reraise(tp, value, tb=None): - raise tp, value, tb -""") - - -if sys.version_info[:2] == (3, 2): - exec_("""def raise_from(value, from_value): - if from_value is None: - raise value - raise value from from_value -""") -elif sys.version_info[:2] > (3, 2): - exec_("""def raise_from(value, from_value): - raise value from from_value -""") -else: - def raise_from(value, from_value): - raise value - - -print_ = getattr(moves.builtins, "print", None) -if print_ is None: - def print_(*args, **kwargs): - """The new-style print function for Python 2.4 and 2.5.""" - fp = kwargs.pop("file", sys.stdout) - if fp is None: - return - - def write(data): - if not isinstance(data, basestring): - data = str(data) - # If the file has an encoding, encode unicode with it. - if (isinstance(fp, file) and - isinstance(data, unicode) and - fp.encoding is not None): - errors = getattr(fp, "errors", None) - if errors is None: - errors = "strict" - data = data.encode(fp.encoding, errors) - fp.write(data) - want_unicode = False - sep = kwargs.pop("sep", None) - if sep is not None: - if isinstance(sep, unicode): - want_unicode = True - elif not isinstance(sep, str): - raise TypeError("sep must be None or a string") - end = kwargs.pop("end", None) - if end is not None: - if isinstance(end, unicode): - want_unicode = True - elif not isinstance(end, str): - raise TypeError("end must be None or a string") - if kwargs: - raise TypeError("invalid keyword arguments to print()") - if not want_unicode: - for arg in args: - if isinstance(arg, unicode): - want_unicode = True - break - if want_unicode: - newline = unicode("\n") - space = unicode(" ") - else: - newline = "\n" - space = " " - if sep is None: - sep = space - if end is None: - end = newline - for i, arg in enumerate(args): - if i: - write(sep) - write(arg) - write(end) -if sys.version_info[:2] < (3, 3): - _print = print_ - - def print_(*args, **kwargs): - fp = kwargs.get("file", sys.stdout) - flush = kwargs.pop("flush", False) - _print(*args, **kwargs) - if flush and fp is not None: - fp.flush() - -_add_doc(reraise, """Reraise an exception.""") - -if sys.version_info[0:2] < (3, 4): - def wraps(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS, - updated=functools.WRAPPER_UPDATES): - def wrapper(f): - f = functools.wraps(wrapped, assigned, updated)(f) - f.__wrapped__ = wrapped - return f - return wrapper -else: - wraps = functools.wraps - - -def with_metaclass(meta, *bases): - """Create a base class with a metaclass.""" - # This requires a bit of explanation: the basic idea is to make a dummy - # metaclass for one level of class instantiation that replaces itself with - # the actual metaclass. - class metaclass(meta): - - def __new__(cls, name, this_bases, d): - return meta(name, bases, d) - return type.__new__(metaclass, 'temporary_class', (), {}) - - -def add_metaclass(metaclass): - """Class decorator for creating a class with a metaclass.""" - def wrapper(cls): - orig_vars = cls.__dict__.copy() - slots = orig_vars.get('__slots__') - if slots is not None: - if isinstance(slots, str): - slots = [slots] - for slots_var in slots: - orig_vars.pop(slots_var) - orig_vars.pop('__dict__', None) - orig_vars.pop('__weakref__', None) - return metaclass(cls.__name__, cls.__bases__, orig_vars) - return wrapper - - -def python_2_unicode_compatible(klass): - """ - A decorator that defines __unicode__ and __str__ methods under Python 2. - Under Python 3 it does nothing. - - To support Python 2 and 3 with a single code base, define a __str__ method - returning text and apply this decorator to the class. - """ - if PY2: - if '__str__' not in klass.__dict__: - raise ValueError("@python_2_unicode_compatible cannot be applied " - "to %s because it doesn't define __str__()." % - klass.__name__) - klass.__unicode__ = klass.__str__ - klass.__str__ = lambda self: self.__unicode__().encode('utf-8') - return klass - - -# Complete the moves implementation. -# This code is at the end of this module to speed up module loading. -# Turn this module into a package. -__path__ = [] # required for PEP 302 and PEP 451 -__package__ = __name__ # see PEP 366 @ReservedAssignment -if globals().get("__spec__") is not None: - __spec__.submodule_search_locations = [] # PEP 451 @UndefinedVariable -# Remove other six meta path importers, since they cause problems. This can -# happen if six is removed from sys.modules and then reloaded. (Setuptools does -# this for some reason.) -if sys.meta_path: - for i, importer in enumerate(sys.meta_path): - # Here's some real nastiness: Another "instance" of the six module might - # be floating around. Therefore, we can't use isinstance() to check for - # the six meta path importer, since the other six instance will have - # inserted an importer with different class. - if (type(importer).__name__ == "_SixMetaPathImporter" and - importer.name == __name__): - del sys.meta_path[i] - break - del i, importer -# Finally, add the importer to the meta path import hook. -sys.meta_path.append(_importer) - diff --git a/seaborn/linearmodels.py b/seaborn/linearmodels.py deleted file mode 100644 index ad5e039f57..0000000000 --- a/seaborn/linearmodels.py +++ /dev/null @@ -1,7 +0,0 @@ -import warnings -from .regression import * # noqa - -msg = ( - "The `linearmodels` module has been renamed `regression`." -) -warnings.warn(msg) diff --git a/seaborn/matrix.py b/seaborn/matrix.py index 25addbef5e..a2ac8eef1c 100644 --- a/seaborn/matrix.py +++ b/seaborn/matrix.py @@ -1,6 +1,5 @@ """Functions to visualize matrices of data.""" -from __future__ import division -import itertools +import warnings import matplotlib as mpl from matplotlib.collections import LineCollection @@ -8,14 +7,22 @@ from matplotlib import gridspec import numpy as np import pandas as pd -from scipy.cluster import hierarchy +try: + from scipy.cluster import hierarchy + _no_scipy = False +except ImportError: + _no_scipy = True from . import cm from .axisgrid import Grid -from .utils import (despine, axis_ticklabels_overlap, relative_luminance, - to_utf8) - -from .external.six import string_types +from .utils import ( + despine, + axis_ticklabels_overlap, + relative_luminance, + to_utf8, + _draw_figure, +) +from ._decorators import _deprecate_positional_args __all__ = ["heatmap", "clustermap"] @@ -39,26 +46,19 @@ def _index_to_ticklabels(index): def _convert_colors(colors): """Convert either a list of colors or nested lists of colors to RGB.""" - to_rgb = mpl.colors.colorConverter.to_rgb - - if isinstance(colors, pd.DataFrame): - # Convert dataframe - return pd.DataFrame({col: colors[col].map(to_rgb) - for col in colors}) - elif isinstance(colors, pd.Series): - return colors.map(to_rgb) - else: - try: - to_rgb(colors[0]) - # If this works, there is only one level of colors - return list(map(to_rgb, colors)) - except ValueError: - # If we get here, we have nested lists - return [list(map(to_rgb, l)) for l in colors] + to_rgb = mpl.colors.to_rgb + + try: + to_rgb(colors[0]) + # If this works, there is only one level of colors + return list(map(to_rgb, colors)) + except ValueError: + # If we get here, we have nested lists + return [list(map(to_rgb, l)) for l in colors] def _matrix_mask(data, mask): - """Ensure that data and mask are compatabile and add missing values. + """Ensure that data and mask are compatible and add missing values. Values will be plotted for cells where ``mask`` is ``False``. @@ -67,7 +67,7 @@ def _matrix_mask(data, mask): """ if mask is None: - mask = np.zeros(data.shape, np.bool) + mask = np.zeros(data.shape, bool) if isinstance(mask, np.ndarray): # For array masks, ensure that shape matches data then convert @@ -77,7 +77,7 @@ def _matrix_mask(data, mask): mask = pd.DataFrame(mask, index=data.index, columns=data.columns, - dtype=np.bool) + dtype=bool) elif isinstance(mask, pd.DataFrame): # For DataFrame masks, ensure that semantic labels match data @@ -94,7 +94,7 @@ def _matrix_mask(data, mask): return mask -class _HeatMapper(object): +class _HeatMapper: """Draw a heatmap plot of a matrix with nice labels and colormaps.""" def __init__(self, data, vmin, vmax, cmap, center, robust, annot, fmt, @@ -133,13 +133,10 @@ def __init__(self, data, vmin, vmax, cmap, center, robust, annot, fmt, elif yticklabels is False: yticklabels = [] - # Get the positions and used label for the ticks - nx, ny = data.T.shape - if not len(xticklabels): self.xticks = [] self.xticklabels = [] - elif isinstance(xticklabels, string_types) and xticklabels == "auto": + elif isinstance(xticklabels, str) and xticklabels == "auto": self.xticks = "auto" self.xticklabels = _index_to_ticklabels(data.columns) else: @@ -149,7 +146,7 @@ def __init__(self, data, vmin, vmax, cmap, center, robust, annot, fmt, if not len(yticklabels): self.yticks = [] self.yticklabels = [] - elif isinstance(yticklabels, string_types) and yticklabels == "auto": + elif isinstance(yticklabels, str) and yticklabels == "auto": self.yticks = "auto" self.yticklabels = _index_to_ticklabels(data.index) else: @@ -167,22 +164,17 @@ def __init__(self, data, vmin, vmax, cmap, center, robust, annot, fmt, cmap, center, robust) # Sort out the annotations - if annot is None: + if annot is None or annot is False: annot = False annot_data = None - elif isinstance(annot, bool): - if annot: + else: + if isinstance(annot, bool): annot_data = plot_data else: - annot_data = None - else: - try: - annot_data = annot.values - except AttributeError: - annot_data = annot - if annot.shape != plot_data.shape: - raise ValueError('Data supplied to "annot" must be the same ' - 'shape as the data to plot.') + annot_data = np.asarray(annot) + if annot_data.shape != plot_data.shape: + err = "`data` and `annot` must have same shape." + raise ValueError(err) annot = True # Save other attributes to the object @@ -193,19 +185,26 @@ def __init__(self, data, vmin, vmax, cmap, center, robust, annot, fmt, self.annot_data = annot_data self.fmt = fmt - self.annot_kws = {} if annot_kws is None else annot_kws + self.annot_kws = {} if annot_kws is None else annot_kws.copy() self.cbar = cbar - self.cbar_kws = {} if cbar_kws is None else cbar_kws - self.cbar_kws.setdefault('ticks', mpl.ticker.MaxNLocator(6)) + self.cbar_kws = {} if cbar_kws is None else cbar_kws.copy() def _determine_cmap_params(self, plot_data, vmin, vmax, cmap, center, robust): """Use some heuristics to set good defaults for colorbar and range.""" - calc_data = plot_data.data[~np.isnan(plot_data.data)] + + # plot_data is a np.ma.array instance + calc_data = plot_data.astype(float).filled(np.nan) if vmin is None: - vmin = np.percentile(calc_data, 2) if robust else calc_data.min() + if robust: + vmin = np.nanpercentile(calc_data, 2) + else: + vmin = np.nanmin(calc_data) if vmax is None: - vmax = np.percentile(calc_data, 98) if robust else calc_data.max() + if robust: + vmax = np.nanpercentile(calc_data, 98) + else: + vmax = np.nanmax(calc_data) self.vmin, self.vmax = vmin, vmax # Choose default colormaps if not provided @@ -214,7 +213,7 @@ def _determine_cmap_params(self, plot_data, vmin, vmax, self.cmap = cm.rocket else: self.cmap = cm.icefire - elif isinstance(cmap, string_types): + elif isinstance(cmap, str): self.cmap = mpl.cm.get_cmap(cmap) elif isinstance(cmap, list): self.cmap = mpl.colors.ListedColormap(cmap) @@ -223,11 +222,29 @@ def _determine_cmap_params(self, plot_data, vmin, vmax, # Recenter a divergent colormap if center is not None: + + # Copy bad values + # in mpl<3.2 only masked values are honored with "bad" color spec + # (see https://github.com/matplotlib/matplotlib/pull/14257) + bad = self.cmap(np.ma.masked_invalid([np.nan]))[0] + + # under/over values are set for sure when cmap extremes + # do not map to the same color as +-inf + under = self.cmap(-np.inf) + over = self.cmap(np.inf) + under_set = under != self.cmap(0) + over_set = over != self.cmap(self.cmap.N - 1) + vrange = max(vmax - center, center - vmin) normlize = mpl.colors.Normalize(center - vrange, center + vrange) cmin, cmax = normlize([vmin, vmax]) cc = np.linspace(cmin, cmax, 256) self.cmap = mpl.colors.ListedColormap(self.cmap(cc)) + self.cmap.set_bad(bad) + if under_set: + self.cmap.set_under(under) + if over_set: + self.cmap.set_over(over) def _annotate_heatmap(self, ax, mesh): """Add textual labels with the value in each cell.""" @@ -279,9 +296,14 @@ def plot(self, ax, cax, kws): # Remove all the Axes spines despine(ax=ax, left=True, bottom=True) + # setting vmin/vmax in addition to norm is deprecated + # so avoid setting if norm is set + if "norm" not in kws: + kws.setdefault("vmin", self.vmin) + kws.setdefault("vmax", self.vmax) + # Draw the heatmap - mesh = ax.pcolormesh(self.plot_data, vmin=self.vmin, vmax=self.vmax, - cmap=self.cmap, **kws) + mesh = ax.pcolormesh(self.plot_data, cmap=self.cmap, **kws) # Set the axis limits ax.set(xlim=(0, self.data.shape[1]), ylim=(0, self.data.shape[0])) @@ -299,12 +321,12 @@ def plot(self, ax, cax, kws): cb.solids.set_rasterized(True) # Add row and column labels - if isinstance(self.xticks, string_types) and self.xticks == "auto": + if isinstance(self.xticks, str) and self.xticks == "auto": xticks, xticklabels = self._auto_ticks(ax, self.xticklabels, 0) else: xticks, xticklabels = self.xticks, self.xticklabels - if isinstance(self.yticks, string_types) and self.yticks == "auto": + if isinstance(self.yticks, str) and self.yticks == "auto": yticks, yticklabels = self._auto_ticks(ax, self.yticklabels, 1) else: yticks, yticklabels = self.yticks, self.yticklabels @@ -314,8 +336,8 @@ def plot(self, ax, cax, kws): ytl = ax.set_yticklabels(yticklabels, rotation="vertical") # Possibly rotate them if they overlap - if hasattr(ax.figure.canvas, "get_renderer"): - ax.figure.draw(ax.figure.canvas.get_renderer()) + _draw_figure(ax.figure) + if axis_ticklabels_overlap(xtl): plt.setp(xtl, rotation="vertical") if axis_ticklabels_overlap(ytl): @@ -329,12 +351,17 @@ def plot(self, ax, cax, kws): self._annotate_heatmap(ax, mesh) -def heatmap(data, vmin=None, vmax=None, cmap=None, center=None, robust=False, - annot=None, fmt=".2g", annot_kws=None, - linewidths=0, linecolor="white", - cbar=True, cbar_kws=None, cbar_ax=None, - square=False, xticklabels="auto", yticklabels="auto", - mask=None, ax=None, **kwargs): +@_deprecate_positional_args +def heatmap( + data, *, + vmin=None, vmax=None, cmap=None, center=None, robust=False, + annot=None, fmt=".2g", annot_kws=None, + linewidths=0, linecolor="white", + cbar=True, cbar_kws=None, cbar_ax=None, + square=False, xticklabels="auto", yticklabels="auto", + mask=None, ax=None, + **kwargs +): """Plot rectangular data as a color-encoded matrix. This is an Axes-level function and will draw the heatmap into the @@ -364,23 +391,24 @@ def heatmap(data, vmin=None, vmax=None, cmap=None, center=None, robust=False, annot : bool or rectangular dataset, optional If True, write the data value in each cell. If an array-like with the same shape as ``data``, then use this to annotate the heatmap instead - of the raw data. - fmt : string, optional + of the data. Note that DataFrames will match on position, not index. + fmt : str, optional String formatting code to use when adding annotations. annot_kws : dict of key, value mappings, optional - Keyword arguments for ``ax.text`` when ``annot`` is True. + Keyword arguments for :meth:`matplotlib.axes.Axes.text` when ``annot`` + is True. linewidths : float, optional Width of the lines that will divide each cell. linecolor : color, optional Color of the lines that will divide each cell. - cbar : boolean, optional + cbar : bool, optional Whether to draw a colorbar. cbar_kws : dict of key, value mappings, optional - Keyword arguments for `fig.colorbar`. + Keyword arguments for :meth:`matplotlib.figure.Figure.colorbar`. cbar_ax : matplotlib Axes, optional Axes in which to draw the colorbar, otherwise take space from the main Axes. - square : boolean, optional + square : bool, optional If True, set the Axes aspect to "equal" so each cell will be square-shaped. xticklabels, yticklabels : "auto", bool, list-like, or int, optional @@ -388,21 +416,22 @@ def heatmap(data, vmin=None, vmax=None, cmap=None, center=None, robust=False, the column names. If list-like, plot these alternate labels as the xticklabels. If an integer, use the column names but plot only every n label. If "auto", try to densely plot non-overlapping labels. - mask : boolean array or DataFrame, optional + mask : bool array or DataFrame, optional If passed, data will not be shown in cells where ``mask`` is True. Cells with missing values are automatically masked. ax : matplotlib Axes, optional Axes in which to draw the plot, otherwise use the currently-active Axes. kwargs : other keyword arguments - All other keyword arguments are passed to ``ax.pcolormesh``. + All other keyword arguments are passed to + :meth:`matplotlib.axes.Axes.pcolormesh`. Returns ------- ax : matplotlib Axes Axes object with the heatmap. - See also + See Also -------- clustermap : Plot a matrix using hierachical clustering to arrange the rows and columns. @@ -416,7 +445,7 @@ def heatmap(data, vmin=None, vmax=None, cmap=None, center=None, robust=False, :context: close-figs >>> import numpy as np; np.random.seed(0) - >>> import seaborn as sns; sns.set() + >>> import seaborn as sns; sns.set_theme() >>> uniform_data = np.random.rand(10, 12) >>> ax = sns.heatmap(uniform_data) @@ -470,7 +499,7 @@ def heatmap(data, vmin=None, vmax=None, cmap=None, center=None, robust=False, .. plot:: :context: close-figs - >>> ax = sns.heatmap(flights, center=flights.loc["January", 1955]) + >>> ax = sns.heatmap(flights, center=flights.loc["Jan", 1955]) Plot every other column label and don't plot row labels: @@ -507,9 +536,8 @@ def heatmap(data, vmin=None, vmax=None, cmap=None, center=None, robust=False, >>> mask = np.zeros_like(corr) >>> mask[np.triu_indices_from(mask)] = True >>> with sns.axes_style("white"): + ... f, ax = plt.subplots(figsize=(7, 5)) ... ax = sns.heatmap(corr, mask=mask, vmax=.3, square=True) - - """ # Initialize the plotter object plotter = _HeatMapper(data, vmin, vmax, cmap, center, robust, annot, fmt, @@ -596,9 +624,6 @@ def __init__(self, data, linkage, metric, method, axis, label, rotate): self.independent_coord = self.dendrogram['icoord'] def _calculate_linkage_scipy(self): - if np.product(self.shape) >= 10000: - UserWarning('This will be slow... (gentle suggestion: ' - '"pip install fastcluster")') linkage = hierarchy.linkage(self.array, method=self.method, metric=self.metric) return linkage @@ -622,10 +647,16 @@ def _calculate_linkage_fastcluster(self): @property def calculated_linkage(self): + try: return self._calculate_linkage_fastcluster() except ImportError: - return self._calculate_linkage_scipy() + if np.product(self.shape) >= 10000: + msg = ("Clustering large matrix with scipy. Installing " + "`fastcluster` may give better performance.") + warnings.warn(msg) + + return self._calculate_linkage_scipy() def calculate_dendrogram(self): """Calculates a dendrogram based on the linkage matrix @@ -648,7 +679,7 @@ def reordered_ind(self): """Indices of the matrix, reordered by the dendrogram""" return self.dendrogram['leaves'] - def plot(self, ax): + def plot(self, ax, tree_kws): """Plots a dendrogram of the similarities between data on the axes Parameters @@ -657,17 +688,16 @@ def plot(self, ax): Axes object upon which the dendrogram is plotted """ - line_kwargs = dict(linewidths=.5, colors='k') + tree_kws = {} if tree_kws is None else tree_kws.copy() + tree_kws.setdefault("linewidths", .5) + tree_kws.setdefault("colors", tree_kws.pop("color", (.2, .2, .2))) + if self.rotate and self.axis == 0: - lines = LineCollection([list(zip(x, y)) - for x, y in zip(self.dependent_coord, - self.independent_coord)], - **line_kwargs) + coords = zip(self.dependent_coord, self.independent_coord) else: - lines = LineCollection([list(zip(x, y)) - for x, y in zip(self.independent_coord, - self.dependent_coord)], - **line_kwargs) + coords = zip(self.independent_coord, self.dependent_coord) + lines = LineCollection([list(zip(x, y)) for x, y in coords], + **tree_kws) ax.add_collection(lines) number_of_leaves = len(self.reordered_ind) @@ -697,8 +727,8 @@ def plot(self, ax): ytl = ax.set_yticklabels(self.yticklabels, rotation='vertical') # Force a draw of the plot to avoid matplotlib window error - if hasattr(ax.figure.canvas, "get_renderer"): - ax.figure.draw(ax.figure.canvas.get_renderer()) + _draw_figure(ax.figure) + if len(ytl) > 0 and axis_ticklabels_overlap(ytl): plt.setp(ytl, rotation="horizontal") if len(xtl) > 0 and axis_ticklabels_overlap(xtl): @@ -706,8 +736,12 @@ def plot(self, ax): return self -def dendrogram(data, linkage=None, axis=1, label=True, metric='euclidean', - method='average', rotate=False, ax=None): +@_deprecate_positional_args +def dendrogram( + data, *, + linkage=None, axis=1, label=True, metric='euclidean', + method='average', rotate=False, tree_kws=None, ax=None +): """Draw a tree diagram of relationships within a matrix Parameters @@ -728,6 +762,9 @@ def dendrogram(data, linkage=None, axis=1, label=True, metric='euclidean', rotate : bool, optional When plotting the matrix, whether to rotate it 90 degrees counter-clockwise, so the leaves face right + tree_kws : dict, optional + Keyword arguments for the ``matplotlib.collections.LineCollection`` + that is used for plotting the lines of the dendrogram tree. ax : matplotlib axis, optional Axis to plot on, otherwise uses current axis @@ -742,18 +779,26 @@ def dendrogram(data, linkage=None, axis=1, label=True, metric='euclidean', dendrogramplotter.reordered_ind """ + if _no_scipy: + raise RuntimeError("dendrogram requires scipy to be installed") + plotter = _DendrogramPlotter(data, linkage=linkage, axis=axis, metric=metric, method=method, label=label, rotate=rotate) if ax is None: ax = plt.gca() - return plotter.plot(ax=ax) + + return plotter.plot(ax=ax, tree_kws=tree_kws) class ClusterGrid(Grid): + def __init__(self, data, pivot_kws=None, z_score=None, standard_scale=None, - figsize=None, row_colors=None, col_colors=None, mask=None): + figsize=None, row_colors=None, col_colors=None, mask=None, + dendrogram_ratio=None, colors_ratio=None, cbar_pos=None): """Grid object for organizing clustered heatmap input on to axes""" + if _no_scipy: + raise RuntimeError("ClusterGrid requires scipy to be available") if isinstance(data, pd.DataFrame): self.data = data @@ -765,9 +810,6 @@ def __init__(self, data, pivot_kws=None, z_score=None, standard_scale=None, self.mask = _matrix_mask(self.data2d, mask) - if figsize is None: - width, height = 10, 10 - figsize = (width, height) self.fig = plt.figure(figsize=figsize) self.row_colors, self.row_color_labels = \ @@ -775,22 +817,32 @@ def __init__(self, data, pivot_kws=None, z_score=None, standard_scale=None, self.col_colors, self.col_color_labels = \ self._preprocess_colors(data, col_colors, axis=1) - width_ratios = self.dim_ratios(self.row_colors, - figsize=figsize, - axis=1) + try: + row_dendrogram_ratio, col_dendrogram_ratio = dendrogram_ratio + except TypeError: + row_dendrogram_ratio = col_dendrogram_ratio = dendrogram_ratio + try: + row_colors_ratio, col_colors_ratio = colors_ratio + except TypeError: + row_colors_ratio = col_colors_ratio = colors_ratio + + width_ratios = self.dim_ratios(self.row_colors, + row_dendrogram_ratio, + row_colors_ratio) height_ratios = self.dim_ratios(self.col_colors, - figsize=figsize, - axis=0) - nrows = 3 if self.col_colors is None else 4 - ncols = 3 if self.row_colors is None else 4 + col_dendrogram_ratio, + col_colors_ratio) - self.gs = gridspec.GridSpec(nrows, ncols, wspace=0.01, hspace=0.01, + nrows = 2 if self.col_colors is None else 3 + ncols = 2 if self.row_colors is None else 3 + + self.gs = gridspec.GridSpec(nrows, ncols, width_ratios=width_ratios, height_ratios=height_ratios) - self.ax_row_dendrogram = self.fig.add_subplot(self.gs[nrows - 1, 0:2]) - self.ax_col_dendrogram = self.fig.add_subplot(self.gs[0:2, ncols - 1]) + self.ax_row_dendrogram = self.fig.add_subplot(self.gs[-1, 0]) + self.ax_col_dendrogram = self.fig.add_subplot(self.gs[0, -1]) self.ax_row_dendrogram.set_axis_off() self.ax_col_dendrogram.set_axis_off() @@ -799,15 +851,20 @@ def __init__(self, data, pivot_kws=None, z_score=None, standard_scale=None, if self.row_colors is not None: self.ax_row_colors = self.fig.add_subplot( - self.gs[nrows - 1, ncols - 2]) + self.gs[-1, 1]) if self.col_colors is not None: self.ax_col_colors = self.fig.add_subplot( - self.gs[nrows - 2, ncols - 1]) - - self.ax_heatmap = self.fig.add_subplot(self.gs[nrows - 1, ncols - 1]) + self.gs[1, -1]) - # colorbar for scale to left corner - self.cax = self.fig.add_subplot(self.gs[0, 0]) + self.ax_heatmap = self.fig.add_subplot(self.gs[-1, -1]) + if cbar_pos is None: + self.ax_cbar = self.cax = None + else: + # Initialize the colorbar axes in the gridspec so that tight_layout + # works. We will move it where it belongs later. This is a hack. + self.ax_cbar = self.fig.add_subplot(self.gs[0, 0]) + self.cax = self.ax_cbar # Backwards compatability + self.cbar_pos = cbar_pos self.dendrogram_row = None self.dendrogram_col = None @@ -818,15 +875,26 @@ def _preprocess_colors(self, data, colors, axis): if colors is not None: if isinstance(colors, (pd.DataFrame, pd.Series)): + + # If data is unindexed, raise + if (not hasattr(data, "index") and axis == 0) or ( + not hasattr(data, "columns") and axis == 1 + ): + axis_name = "col" if axis else "row" + msg = (f"{axis_name}_colors indices can't be matched with data " + f"indices. Provide {axis_name}_colors as a non-indexed " + "datatype, e.g. by using `.to_numpy()``") + raise TypeError(msg) + # Ensure colors match data indices if axis == 0: colors = colors.reindex(data.index) else: colors = colors.reindex(data.columns) - # Replace na's with background color + # Replace na's with white color # TODO We should set these to transparent instead - colors = colors.fillna('white') + colors = colors.astype(object).fillna('white') # Extract color values and labels from frame/series if isinstance(colors, pd.DataFrame): @@ -904,9 +972,6 @@ def standard_scale(data2d, axis=1): axis : int Which axis to normalize across. If 0, normalize across rows, if 1, normalize across columns. - vmin : int - If 0, then subtract the minimum of the data before dividing by - the range. Returns ------- @@ -930,29 +995,21 @@ def standard_scale(data2d, axis=1): else: return standardized.T - def dim_ratios(self, side_colors, axis, figsize, side_colors_ratio=0.05): - """Get the proportions of the figure taken up by each axes - """ - figdim = figsize[axis] - # Get resizing proportion of this figure for the dendrogram and - # colorbar, so only the heatmap gets bigger but the dendrogram stays - # the same size. - dendrogram = min(2. / figdim, .2) - - # add the colorbar - colorbar_width = .8 * dendrogram - colorbar_height = .2 * dendrogram - if axis == 0: - ratios = [colorbar_width, colorbar_height] - else: - ratios = [colorbar_height, colorbar_width] + def dim_ratios(self, colors, dendrogram_ratio, colors_ratio): + """Get the proportions of the figure taken up by each axes.""" + ratios = [dendrogram_ratio] + + if colors is not None: + # Colors are encoded as rgb, so ther is an extra dimention + if np.ndim(colors) > 2: + n_colors = len(colors) + else: + n_colors = 1 - if side_colors is not None: - # Add room for the colors - ratios += [side_colors_ratio] + ratios += [n_colors * colors_ratio] # Add the ratio for the heatmap itself - ratios += [.8] + ratios.append(1 - sum(ratios)) return ratios @@ -976,35 +1033,36 @@ def color_list_to_matrix_and_cmap(colors, ind, axis=0): Returns ------- matrix : numpy.array - A numpy array of integer values, where each corresponds to a color - from the originally provided list of colors + A numpy array of integer values, where each indexes into the cmap cmap : matplotlib.colors.ListedColormap """ - # check for nested lists/color palettes. - # Will fail if matplotlib color is list not tuple - if any(issubclass(type(x), list) for x in colors): - all_colors = set(itertools.chain(*colors)) - n = len(colors) - m = len(colors[0]) + try: + mpl.colors.to_rgb(colors[0]) + except ValueError: + # We have a 2D color structure + m, n = len(colors), len(colors[0]) + if not all(len(c) == n for c in colors[1:]): + raise ValueError("Multiple side color vectors must have same size") else: - all_colors = set(colors) - n = 1 - m = len(colors) + # We have one vector of colors + m, n = 1, len(colors) colors = [colors] - color_to_value = dict((col, i) for i, col in enumerate(all_colors)) - matrix = np.array([color_to_value[c] - for color in colors for c in color]) + # Map from unique colors to colormap index value + unique_colors = {} + matrix = np.zeros((m, n), int) + for i, inner in enumerate(colors): + for j, color in enumerate(inner): + idx = unique_colors.setdefault(color, len(unique_colors)) + matrix[i, j] = idx - shape = (n, m) - matrix = matrix.reshape(shape) + # Reorder for clustering and transpose for axis matrix = matrix[:, ind] if axis == 0: - # row-side: matrix = matrix.T - cmap = mpl.colors.ListedColormap(all_colors) + cmap = mpl.colors.ListedColormap(list(unique_colors)) return matrix, cmap def savefig(self, *args, **kwargs): @@ -1013,12 +1071,14 @@ def savefig(self, *args, **kwargs): self.fig.savefig(*args, **kwargs) def plot_dendrograms(self, row_cluster, col_cluster, metric, method, - row_linkage, col_linkage): + row_linkage, col_linkage, tree_kws): # Plot the row dendrogram if row_cluster: self.dendrogram_row = dendrogram( self.data2d, metric=metric, method=method, label=False, axis=0, - ax=self.ax_row_dendrogram, rotate=True, linkage=row_linkage) + ax=self.ax_row_dendrogram, rotate=True, linkage=row_linkage, + tree_kws=tree_kws + ) else: self.ax_row_dendrogram.set_xticks([]) self.ax_row_dendrogram.set_yticks([]) @@ -1026,7 +1086,9 @@ def plot_dendrograms(self, row_cluster, col_cluster, metric, method, if col_cluster: self.dendrogram_col = dendrogram( self.data2d, metric=metric, method=method, label=False, - axis=1, ax=self.ax_col_dendrogram, linkage=col_linkage) + axis=1, ax=self.ax_col_dendrogram, linkage=col_linkage, + tree_kws=tree_kws + ) else: self.ax_col_dendrogram.set_xticks([]) self.ax_col_dendrogram.set_yticks([]) @@ -1114,9 +1176,26 @@ def plot_matrix(self, colorbar_kws, xind, yind, **kws): except (TypeError, IndexError): pass - heatmap(self.data2d, ax=self.ax_heatmap, cbar_ax=self.cax, + # Reorganize the annotations to match the heatmap + annot = kws.pop("annot", None) + if annot is None or annot is False: + pass + else: + if isinstance(annot, bool): + annot_data = self.data2d + else: + annot_data = np.asarray(annot) + if annot_data.shape != self.data2d.shape: + err = "`data` and `annot` must have same shape." + raise ValueError(err) + annot_data = annot_data[yind][:, xind] + annot = annot_data + + # Setting ax_cbar=None in clustermap call implies no colorbar + kws.setdefault("cbar", self.ax_cbar is not None) + heatmap(self.data2d, ax=self.ax_heatmap, cbar_ax=self.ax_cbar, cbar_kws=colorbar_kws, mask=self.mask, - xticklabels=xtl, yticklabels=ytl, **kws) + xticklabels=xtl, yticklabels=ytl, annot=annot, **kws) ytl = self.ax_heatmap.get_yticklabels() ytl_rot = None if not ytl else ytl[0].get_rotation() @@ -1126,11 +1205,33 @@ def plot_matrix(self, colorbar_kws, xind, yind, **kws): ytl = self.ax_heatmap.get_yticklabels() plt.setp(ytl, rotation=ytl_rot) + tight_params = dict(h_pad=.02, w_pad=.02) + if self.ax_cbar is None: + self.fig.tight_layout(**tight_params) + else: + # Turn the colorbar axes off for tight layout so that its + # ticks don't interfere with the rest of the plot layout. + # Then move it. + self.ax_cbar.set_axis_off() + self.fig.tight_layout(**tight_params) + self.ax_cbar.set_axis_on() + self.ax_cbar.set_position(self.cbar_pos) + def plot(self, metric, method, colorbar_kws, row_cluster, col_cluster, - row_linkage, col_linkage, **kws): + row_linkage, col_linkage, tree_kws, **kws): + + # heatmap square=True sets the aspect ratio on the axes, but that is + # not compatible with the multi-axes layout of clustergrid + if kws.get("square", False): + msg = "``square=True`` ignored in clustermap" + warnings.warn(msg) + kws.pop("square") + colorbar_kws = {} if colorbar_kws is None else colorbar_kws + self.plot_dendrograms(row_cluster, col_cluster, metric, method, - row_linkage=row_linkage, col_linkage=col_linkage) + row_linkage=row_linkage, col_linkage=col_linkage, + tree_kws=tree_kws) try: xind = self.dendrogram_col.reordered_ind except AttributeError: @@ -1145,31 +1246,40 @@ def plot(self, metric, method, colorbar_kws, row_cluster, col_cluster, return self -def clustermap(data, pivot_kws=None, method='average', metric='euclidean', - z_score=None, standard_scale=None, figsize=None, cbar_kws=None, - row_cluster=True, col_cluster=True, - row_linkage=None, col_linkage=None, - row_colors=None, col_colors=None, mask=None, **kwargs): - """Plot a matrix dataset as a hierarchically-clustered heatmap. +@_deprecate_positional_args +def clustermap( + data, *, + pivot_kws=None, method='average', metric='euclidean', + z_score=None, standard_scale=None, figsize=(10, 10), + cbar_kws=None, row_cluster=True, col_cluster=True, + row_linkage=None, col_linkage=None, + row_colors=None, col_colors=None, mask=None, + dendrogram_ratio=.2, colors_ratio=0.03, + cbar_pos=(.02, .8, .05, .18), tree_kws=None, + **kwargs +): + """ + Plot a matrix dataset as a hierarchically-clustered heatmap. + + This function requires scipy to be available. Parameters ---------- - data: 2D array-like + data : 2D array-like Rectangular data for clustering. Cannot contain NAs. pivot_kws : dict, optional If `data` is a tidy dataframe, can provide keyword arguments for pivot to create a rectangular dataframe. method : str, optional - Linkage method to use for calculating clusters. - See scipy.cluster.hierarchy.linkage documentation for more information: - https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.linkage.html + Linkage method to use for calculating clusters. See + :func:`scipy.cluster.hierarchy.linkage` documentation for more + information. metric : str, optional Distance metric to use for the data. See - scipy.spatial.distance.pdist documentation for more options - https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.pdist.html + :func:`scipy.spatial.distance.pdist` documentation for more options. To use different metrics (or methods) for rows and columns, you may construct each linkage matrix yourself and provide them as - {row,col}_linkage. + `{row,col}_linkage`. z_score : int or None, optional Either 0 (rows) or 1 (columns). Whether or not to calculate z-scores for the rows or the columns. Z scores are: z = (x - mean)/std, so @@ -1180,35 +1290,48 @@ def clustermap(data, pivot_kws=None, method='average', metric='euclidean', Either 0 (rows) or 1 (columns). Whether or not to standardize that dimension, meaning for each row or column, subtract the minimum and divide each by its maximum. - figsize: tuple of two ints, optional - Size of the figure to create. + figsize : tuple of (width, height), optional + Overall size of the figure. cbar_kws : dict, optional - Keyword arguments to pass to ``cbar_kws`` in ``heatmap``, e.g. to + Keyword arguments to pass to `cbar_kws` in :func:`heatmap`, e.g. to add a label to the colorbar. {row,col}_cluster : bool, optional - If True, cluster the {rows, columns}. - {row,col}_linkage : numpy.array, optional + If ``True``, cluster the {rows, columns}. + {row,col}_linkage : :class:`numpy.ndarray`, optional Precomputed linkage matrix for the rows or columns. See - scipy.cluster.hierarchy.linkage for specific formats. + :func:`scipy.cluster.hierarchy.linkage` for specific formats. {row,col}_colors : list-like or pandas DataFrame/Series, optional - List of colors to label for either the rows or columns. Useful to - evaluate whether samples within a group are clustered together. Can - use nested lists or DataFrame for multiple color levels of labeling. - If given as a DataFrame or Series, labels for the colors are extracted - from the DataFrames column names or from the name of the Series. - DataFrame/Series colors are also matched to the data by their - index, ensuring colors are drawn in the correct order. - mask : boolean array or DataFrame, optional - If passed, data will not be shown in cells where ``mask`` is True. + List of colors to label for either the rows or columns. Useful to evaluate + whether samples within a group are clustered together. Can use nested lists or + DataFrame for multiple color levels of labeling. If given as a + :class:`pandas.DataFrame` or :class:`pandas.Series`, labels for the colors are + extracted from the DataFrames column names or from the name of the Series. + DataFrame/Series colors are also matched to the data by their index, ensuring + colors are drawn in the correct order. + mask : bool array or DataFrame, optional + If passed, data will not be shown in cells where `mask` is True. Cells with missing values are automatically masked. Only used for visualizing, not for calculating. + {dendrogram,colors}_ratio : float, or pair of floats, optional + Proportion of the figure size devoted to the two marginal elements. If + a pair is given, they correspond to (row, col) ratios. + cbar_pos : tuple of (left, bottom, width, height), optional + Position of the colorbar axes in the figure. Setting to ``None`` will + disable the colorbar. + tree_kws : dict, optional + Parameters for the :class:`matplotlib.collections.LineCollection` + that is used to plot the lines of the dendrogram tree. kwargs : other keyword arguments - All other keyword arguments are passed to ``sns.heatmap`` + All other keyword arguments are passed to :func:`heatmap`. Returns ------- - clustergrid : ClusterGrid - A ClusterGrid instance. + :class:`ClusterGrid` + A :class:`ClusterGrid` instance. + + See Also + -------- + heatmap : Plot rectangular data as a color-encoded matrix. Notes ----- @@ -1229,54 +1352,51 @@ def clustermap(data, pivot_kws=None, method='average', metric='euclidean', .. plot:: :context: close-figs - >>> import seaborn as sns; sns.set(color_codes=True) + >>> import seaborn as sns; sns.set_theme(color_codes=True) >>> iris = sns.load_dataset("iris") >>> species = iris.pop("species") >>> g = sns.clustermap(iris) - Use a different similarity metric: + Change the size and layout of the figure: .. plot:: :context: close-figs - >>> g = sns.clustermap(iris, metric="correlation") + >>> g = sns.clustermap(iris, + ... figsize=(7, 5), + ... row_cluster=False, + ... dendrogram_ratio=(.1, .2), + ... cbar_pos=(0, .2, .03, .4)) - Use a different clustering method: + Add colored labels to identify observations: .. plot:: :context: close-figs - >>> g = sns.clustermap(iris, method="single") - - Use a different colormap and ignore outliers in colormap limits: - - .. plot:: - :context: close-figs - - >>> g = sns.clustermap(iris, cmap="mako", robust=True) + >>> lut = dict(zip(species.unique(), "rbg")) + >>> row_colors = species.map(lut) + >>> g = sns.clustermap(iris, row_colors=row_colors) - Change the size of the figure: + Use a different colormap and adjust the limits of the color range: .. plot:: :context: close-figs - >>> g = sns.clustermap(iris, figsize=(6, 7)) + >>> g = sns.clustermap(iris, cmap="mako", vmin=0, vmax=10) - Plot one of the axes in its original organization: + Use a different similarity metric: .. plot:: :context: close-figs - >>> g = sns.clustermap(iris, col_cluster=False) + >>> g = sns.clustermap(iris, metric="correlation") - Add colored labels: + Use a different clustering method: .. plot:: :context: close-figs - >>> lut = dict(zip(species.unique(), "rbg")) - >>> row_colors = species.map(lut) - >>> g = sns.clustermap(iris, row_colors=row_colors) + >>> g = sns.clustermap(iris, method="single") Standardize the data within the columns: @@ -1290,17 +1410,19 @@ def clustermap(data, pivot_kws=None, method='average', metric='euclidean', .. plot:: :context: close-figs - >>> g = sns.clustermap(iris, z_score=0) - - + >>> g = sns.clustermap(iris, z_score=0, cmap="vlag") """ + if _no_scipy: + raise RuntimeError("clustermap requires scipy to be available") + plotter = ClusterGrid(data, pivot_kws=pivot_kws, figsize=figsize, row_colors=row_colors, col_colors=col_colors, z_score=z_score, standard_scale=standard_scale, - mask=mask) + mask=mask, dendrogram_ratio=dendrogram_ratio, + colors_ratio=colors_ratio, cbar_pos=cbar_pos) return plotter.plot(metric=metric, method=method, colorbar_kws=cbar_kws, row_cluster=row_cluster, col_cluster=col_cluster, row_linkage=row_linkage, col_linkage=col_linkage, - **kwargs) + tree_kws=tree_kws, **kwargs) diff --git a/seaborn/miscplot.py b/seaborn/miscplot.py index 0e62296682..717c0ac40b 100644 --- a/seaborn/miscplot.py +++ b/seaborn/miscplot.py @@ -1,8 +1,7 @@ -from __future__ import division import numpy as np import matplotlib as mpl import matplotlib.pyplot as plt - +import matplotlib.ticker as ticker __all__ = ["palplot", "dogplot"] @@ -25,11 +24,13 @@ def palplot(pal, size=1): interpolation="nearest", aspect="auto") ax.set_xticks(np.arange(n) - .5) ax.set_yticks([-.5, .5]) - ax.set_xticklabels([]) - ax.set_yticklabels([]) + # Ensure nice border between colors + ax.set_xticklabels(["" for _ in range(n)]) + # The proper way to set no ticks + ax.yaxis.set_major_locator(ticker.NullLocator()) -def dogplot(): +def dogplot(*_, **__): """Who's a good boy?""" try: from urllib.request import urlopen @@ -37,8 +38,9 @@ def dogplot(): from urllib2 import urlopen from io import BytesIO - url = "https://github.com/mwaskom/seaborn-data/raw/master/png/img1.png" - data = BytesIO(urlopen(url).read()) + url = "https://github.com/mwaskom/seaborn-data/raw/master/png/img{}.png" + pic = np.random.randint(2, 7) + data = BytesIO(urlopen(url.format(pic)).read()) img = plt.imread(data) f, ax = plt.subplots(figsize=(5, 5), dpi=100) f.subplots_adjust(0, 0, 1, 1) diff --git a/seaborn/palettes.py b/seaborn/palettes.py index 4b2ce7b43a..b33280e27d 100644 --- a/seaborn/palettes.py +++ b/seaborn/palettes.py @@ -1,4 +1,3 @@ -from __future__ import division import colorsys from itertools import cycle @@ -6,10 +5,8 @@ import matplotlib as mpl from .external import husl -from .external.six import string_types -from .external.six.moves import range -from .utils import desaturate, set_hls_values, get_color_cycle +from .utils import desaturate, get_color_cycle from .colors import xkcd_rgb, crayons @@ -44,7 +41,7 @@ "#CA9161", "#FBAFE4", "#949494", "#ECE133", "#56B4E9"], colorblind6=["#0173B2", "#029E73", "#D55E00", "#CC78BC", "#ECE133", "#56B4E9"] - ) +) MPL_QUAL_PALS = { @@ -57,6 +54,7 @@ QUAL_PALETTE_SIZES = MPL_QUAL_PALS.copy() QUAL_PALETTE_SIZES.update({k: len(v) for k, v in SEABORN_PALETTES.items()}) +QUAL_PALETTES = list(QUAL_PALETTE_SIZES.keys()) class _ColorPalette(list): @@ -78,25 +76,34 @@ def as_hex(self): hex = [mpl.colors.rgb2hex(rgb) for rgb in self] return _ColorPalette(hex) - -def color_palette(palette=None, n_colors=None, desat=None): - """Return a list of colors defining a color palette. - - Available seaborn palette names: - deep, muted, bright, pastel, dark, colorblind - - Other options: - name of matplotlib cmap, 'ch:', 'hls', 'husl', - or a list of colors in any format matplotlib accepts + def _repr_html_(self): + """Rich display of the color palette in an HTML frontend.""" + s = 55 + n = len(self) + html = f'' + for i, c in enumerate(self.as_hex()): + html += ( + f'' + ) + html += '' + return html + + +def color_palette(palette=None, n_colors=None, desat=None, as_cmap=False): + """Return a list of colors or continuous colormap defining a palette. + + Possible ``palette`` values include: + - Name of a seaborn palette (deep, muted, bright, pastel, dark, colorblind) + - Name of matplotlib colormap + - 'husl' or 'hls' + - 'ch:' + - 'light:', 'dark:', 'blend:,', + - A sequence of colors in any format matplotlib accepts Calling this function with ``palette=None`` will return the current matplotlib color cycle. - Matplotlib palettes can be specified as reversed palettes by appending - "_r" to the name or as "dark" palettes by appending "_d" to the name. - (These options are mutually exclusive, but the resulting list of colors - can also be reversed). - This function can also be used in a ``with`` statement to temporarily set the color cycle for a plot or set of plots. @@ -104,7 +111,7 @@ def color_palette(palette=None, n_colors=None, desat=None): Parameters ---------- - palette: None, string, or sequence, optional + palette : None, string, or sequence, optional Name of palette or None to return current palette. If a sequence, input colors are used but possibly cycled and desaturated. n_colors : int, optional @@ -112,16 +119,16 @@ def color_palette(palette=None, n_colors=None, desat=None): on how ``palette`` is specified. Named palettes default to 6 colors, but grabbing the current palette or passing in a list of colors will not change the number of colors unless this is specified. Asking for - more colors than exist in the palette will cause it to cycle. + more colors than exist in the palette will cause it to cycle. Ignored + when ``as_cmap`` is True. desat : float, optional Proportion to desaturate each color by. + as_cmap : bool + If True, return a :class:`matplotlib.colors.Colormap`. Returns ------- - palette : list of RGB tuples. - Color palette. Behaves like a list, but can be used as a context - manager and possesses an ``as_hex`` method to convert to hex color - codes. + list of RGB tuples or :class:`matplotlib.colors.Colormap` See Also -------- @@ -132,62 +139,7 @@ def color_palette(palette=None, n_colors=None, desat=None): Examples -------- - Calling with no arguments returns all colors from the current default - color cycle: - - .. plot:: - :context: close-figs - - >>> import seaborn as sns; sns.set() - >>> sns.palplot(sns.color_palette()) - - Show one of the other "seaborn palettes", which have the same basic order - of hues as the default matplotlib color cycle but more attractive colors. - Calling with the name of a palette will return 6 colors by default: - - .. plot:: - :context: close-figs - - >>> sns.palplot(sns.color_palette("muted")) - - Use discrete values from one of the built-in matplotlib colormaps: - - .. plot:: - :context: close-figs - - >>> sns.palplot(sns.color_palette("RdBu", n_colors=7)) - - Make a customized cubehelix color palette: - - .. plot:: - :context: close-figs - - >>> sns.palplot(sns.color_palette("ch:2.5,-.2,dark=.3")) - - Use a categorical matplotlib palette and add some desaturation: - - .. plot:: - :context: close-figs - - >>> sns.palplot(sns.color_palette("Set1", n_colors=8, desat=.5)) - - Make a "dark" matplotlib sequential palette variant. (This can be good - when coloring multiple lines or points that correspond to an ordered - variable, where you don't want the lightest lines to be invisible): - - .. plot:: - :context: close-figs - - >>> sns.palplot(sns.color_palette("Blues_d")) - - Use as a context manager: - - .. plot:: - :context: close-figs - - >>> import numpy as np, matplotlib.pyplot as plt - >>> with sns.color_palette("husl", 8): - ... _ = plt.plot(np.c_[np.zeros(8), np.arange(8)].T) + .. include:: ../docstrings/color_palette.rst """ if palette is None: @@ -195,7 +147,7 @@ def color_palette(palette=None, n_colors=None, desat=None): if n_colors is None: n_colors = len(palette) - elif not isinstance(palette, string_types): + elif not isinstance(palette, str): palette = palette if n_colors is None: n_colors = len(palette) @@ -206,16 +158,16 @@ def color_palette(palette=None, n_colors=None, desat=None): n_colors = QUAL_PALETTE_SIZES.get(palette, 6) if palette in SEABORN_PALETTES: - # Named "seaborn variant" of old matplotlib default palette + # Named "seaborn variant" of matplotlib default color cycle palette = SEABORN_PALETTES[palette] elif palette == "hls": # Evenly spaced colors in cylindrical RGB space - palette = hls_palette(n_colors) + palette = hls_palette(n_colors, as_cmap=as_cmap) elif palette == "husl": # Evenly spaced colors in cylindrical Lab space - palette = husl_palette(n_colors) + palette = husl_palette(n_colors, as_cmap=as_cmap) elif palette.lower() == "jet": # Paternalism @@ -224,33 +176,57 @@ def color_palette(palette=None, n_colors=None, desat=None): elif palette.startswith("ch:"): # Cubehelix palette with params specified in string args, kwargs = _parse_cubehelix_args(palette) - palette = cubehelix_palette(n_colors, *args, **kwargs) + palette = cubehelix_palette(n_colors, *args, **kwargs, as_cmap=as_cmap) + + elif palette.startswith("light:"): + # light palette to color specified in string + _, color = palette.split(":") + reverse = color.endswith("_r") + if reverse: + color = color[:-2] + palette = light_palette(color, n_colors, reverse=reverse, as_cmap=as_cmap) + + elif palette.startswith("dark:"): + # light palette to color specified in string + _, color = palette.split(":") + reverse = color.endswith("_r") + if reverse: + color = color[:-2] + palette = dark_palette(color, n_colors, reverse=reverse, as_cmap=as_cmap) + + elif palette.startswith("blend:"): + # blend palette between colors specified in string + _, colors = palette.split(":") + colors = colors.split(",") + palette = blend_palette(colors, n_colors, as_cmap=as_cmap) else: try: # Perhaps a named matplotlib colormap? - palette = mpl_palette(palette, n_colors) + palette = mpl_palette(palette, n_colors, as_cmap=as_cmap) except ValueError: raise ValueError("%s is not a valid palette name" % palette) if desat is not None: palette = [desaturate(c, desat) for c in palette] - # Always return as many colors as we asked for - pal_cycle = cycle(palette) - palette = [next(pal_cycle) for _ in range(n_colors)] + if not as_cmap: - # Always return in r, g, b tuple format - try: - palette = map(mpl.colors.colorConverter.to_rgb, palette) - palette = _ColorPalette(palette) - except ValueError: - raise ValueError("Could not generate a palette for %s" % str(palette)) + # Always return as many colors as we asked for + pal_cycle = cycle(palette) + palette = [next(pal_cycle) for _ in range(n_colors)] + + # Always return in r, g, b tuple format + try: + palette = map(mpl.colors.colorConverter.to_rgb, palette) + palette = _ColorPalette(palette) + except ValueError: + raise ValueError(f"Could not generate a palette for {palette}") return palette -def hls_palette(n_colors=6, h=.01, l=.6, s=.65): # noqa +def hls_palette(n_colors=6, h=.01, l=.6, s=.65, as_cmap=False): # noqa """Get a set of evenly spaced colors in HLS hue space. h, l, and s should be between 0 and 1 @@ -269,13 +245,11 @@ def hls_palette(n_colors=6, h=.01, l=.6, s=.65): # noqa Returns ------- - palette : seaborn color palette - List-like object of colors as RGB tuples. + list of RGB tuples or :class:`matplotlib.colors.Colormap` See Also -------- - husl_palette : Make a palette using evently spaced circular hues in the - HUSL system. + husl_palette : Make a palette using evenly spaced hues in the HUSL system. Examples -------- @@ -285,7 +259,7 @@ def hls_palette(n_colors=6, h=.01, l=.6, s=.65): # noqa .. plot:: :context: close-figs - >>> import seaborn as sns; sns.set() + >>> import seaborn as sns; sns.set_theme() >>> sns.palplot(sns.hls_palette(10)) Create a palette of 10 colors that begins at a different hue value: @@ -310,15 +284,20 @@ def hls_palette(n_colors=6, h=.01, l=.6, s=.65): # noqa >>> sns.palplot(sns.hls_palette(10, s=.4)) """ - hues = np.linspace(0, 1, n_colors + 1)[:-1] + if as_cmap: + n_colors = 256 + hues = np.linspace(0, 1, int(n_colors) + 1)[:-1] hues += h hues %= 1 hues -= hues.astype(int) palette = [colorsys.hls_to_rgb(h_i, l, s) for h_i in hues] - return _ColorPalette(palette) + if as_cmap: + return mpl.colors.ListedColormap(palette, "hls") + else: + return _ColorPalette(palette) -def husl_palette(n_colors=6, h=.01, s=.9, l=.65): # noqa +def husl_palette(n_colors=6, h=.01, s=.9, l=.65, as_cmap=False): # noqa """Get a set of evenly spaced colors in HUSL hue space. h, s, and l should be between 0 and 1 @@ -337,8 +316,7 @@ def husl_palette(n_colors=6, h=.01, s=.9, l=.65): # noqa Returns ------- - palette : seaborn color palette - List-like object of colors as RGB tuples. + list of RGB tuples or :class:`matplotlib.colors.Colormap` See Also -------- @@ -353,7 +331,7 @@ def husl_palette(n_colors=6, h=.01, s=.9, l=.65): # noqa .. plot:: :context: close-figs - >>> import seaborn as sns; sns.set() + >>> import seaborn as sns; sns.set_theme() >>> sns.palplot(sns.husl_palette(10)) Create a palette of 10 colors that begins at a different hue value: @@ -378,17 +356,22 @@ def husl_palette(n_colors=6, h=.01, s=.9, l=.65): # noqa >>> sns.palplot(sns.husl_palette(10, s=.4)) """ - hues = np.linspace(0, 1, n_colors + 1)[:-1] + if as_cmap: + n_colors = 256 + hues = np.linspace(0, 1, int(n_colors) + 1)[:-1] hues += h hues %= 1 hues *= 359 s *= 99 l *= 99 # noqa - palette = [husl.husl_to_rgb(h_i, s, l) for h_i in hues] - return _ColorPalette(palette) + palette = [_color_to_rgb((h_i, s, l), input="husl") for h_i in hues] + if as_cmap: + return mpl.colors.ListedColormap(palette, "hsl") + else: + return _ColorPalette(palette) -def mpl_palette(name, n_colors=6): +def mpl_palette(name, n_colors=6, as_cmap=False): """Return discrete colors from a matplotlib palette. Note that this handles the qualitative colorbrewer palettes @@ -410,10 +393,7 @@ def mpl_palette(name, n_colors=6): Returns ------- - palette or cmap : seaborn color palette or matplotlib colormap - List-like object of colors as RGB tuples, or colormap object that - can map continuous values to colors, depending on the value of the - ``as_cmap`` parameter. + list of RGB tuples or :class:`matplotlib.colors.Colormap` Examples -------- @@ -423,7 +403,7 @@ def mpl_palette(name, n_colors=6): .. plot:: :context: close-figs - >>> import seaborn as sns; sns.set() + >>> import seaborn as sns; sns.set_theme() >>> sns.palplot(sns.mpl_palette("Set2", 8)) Create a sequential colorbrewer palette: @@ -449,20 +429,29 @@ def mpl_palette(name, n_colors=6): """ if name.endswith("_d"): - pal = ["#333333"] - pal.extend(color_palette(name.replace("_d", "_r"), 2)) + sub_name = name[:-2] + if sub_name.endswith("_r"): + reverse = True + sub_name = sub_name[:-2] + else: + reverse = False + pal = color_palette(sub_name, 2) + ["#333333"] + if reverse: + pal = pal[::-1] cmap = blend_palette(pal, n_colors, as_cmap=True) else: cmap = mpl.cm.get_cmap(name) - if cmap is None: - raise ValueError("{} is not a valid colormap".format(name)) + if name in MPL_QUAL_PALS: bins = np.linspace(0, 1, MPL_QUAL_PALS[name])[:n_colors] else: - bins = np.linspace(0, 1, n_colors + 2)[1:-1] + bins = np.linspace(0, 1, int(n_colors) + 2)[1:-1] palette = list(map(tuple, cmap(bins)[:, :3])) - return _ColorPalette(palette) + if as_cmap: + return cmap + else: + return _ColorPalette(palette) def _color_to_rgb(color, input): @@ -471,9 +460,11 @@ def _color_to_rgb(color, input): color = colorsys.hls_to_rgb(*color) elif input == "husl": color = husl.husl_to_rgb(*color) + color = tuple(np.clip(color, 0, 1)) elif input == "xkcd": color = xkcd_rgb[color] - return color + + return mpl.colors.to_rgb(color) def dark_palette(color, n_colors=6, reverse=False, as_cmap=False, input="rgb"): @@ -499,17 +490,14 @@ def dark_palette(color, n_colors=6, reverse=False, as_cmap=False, input="rgb"): reverse : bool, optional if True, reverse the direction of the blend as_cmap : bool, optional - if True, return as a matplotlib colormap instead of list + If True, return a :class:`matplotlib.colors.Colormap`. input : {'rgb', 'hls', 'husl', xkcd'} Color space to interpret the input color. The first three options apply to tuple inputs and the latter applies to string inputs. Returns ------- - palette or cmap : seaborn color palette or matplotlib colormap - List-like object of colors as RGB tuples, or colormap object that - can map continuous values to colors, depending on the value of the - ``as_cmap`` parameter. + list of RGB tuples or :class:`matplotlib.colors.Colormap` See Also -------- @@ -524,7 +512,7 @@ def dark_palette(color, n_colors=6, reverse=False, as_cmap=False, input="rgb"): .. plot:: :context: close-figs - >>> import seaborn as sns; sns.set() + >>> import seaborn as sns; sns.set_theme() >>> sns.palplot(sns.dark_palette("purple")) Generate a palette that decreases in lightness: @@ -552,14 +540,15 @@ def dark_palette(color, n_colors=6, reverse=False, as_cmap=False, input="rgb"): >>> ax = sns.heatmap(x, cmap=cmap) """ - color = _color_to_rgb(color, input) - gray = "#222222" - colors = [color, gray] if reverse else [gray, color] + rgb = _color_to_rgb(color, input) + h, s, l = husl.rgb_to_husl(*rgb) + gray_s, gray_l = .15 * s, 15 + gray = _color_to_rgb((h, gray_s, gray_l), input="husl") + colors = [rgb, gray] if reverse else [gray, rgb] return blend_palette(colors, n_colors, as_cmap) -def light_palette(color, n_colors=6, reverse=False, as_cmap=False, - input="rgb"): +def light_palette(color, n_colors=6, reverse=False, as_cmap=False, input="rgb"): """Make a sequential palette that blends from light to ``color``. This kind of palette is good for data that range between relatively @@ -582,17 +571,14 @@ def light_palette(color, n_colors=6, reverse=False, as_cmap=False, reverse : bool, optional if True, reverse the direction of the blend as_cmap : bool, optional - if True, return as a matplotlib colormap instead of list + If True, return a :class:`matplotlib.colors.Colormap`. input : {'rgb', 'hls', 'husl', xkcd'} Color space to interpret the input color. The first three options apply to tuple inputs and the latter applies to string inputs. Returns ------- - palette or cmap : seaborn color palette or matplotlib colormap - List-like object of colors as RGB tuples, or colormap object that - can map continuous values to colors, depending on the value of the - ``as_cmap`` parameter. + list of RGB tuples or :class:`matplotlib.colors.Colormap` See Also -------- @@ -607,7 +593,7 @@ def light_palette(color, n_colors=6, reverse=False, as_cmap=False, .. plot:: :context: close-figs - >>> import seaborn as sns; sns.set() + >>> import seaborn as sns; sns.set_theme() >>> sns.palplot(sns.light_palette("purple")) Generate a palette that increases in lightness: @@ -635,40 +621,15 @@ def light_palette(color, n_colors=6, reverse=False, as_cmap=False, >>> ax = sns.heatmap(x, cmap=cmap) """ - color = _color_to_rgb(color, input) - light = set_hls_values(color, l=.95) # noqa - colors = [color, light] if reverse else [light, color] - return blend_palette(colors, n_colors, as_cmap) - - -def _flat_palette(color, n_colors=6, reverse=False, as_cmap=False, - input="rgb"): - """Make a sequential palette that blends from gray to ``color``. - - Parameters - ---------- - color : matplotlib color - hex, rgb-tuple, or html color name - n_colors : int, optional - number of colors in the palette - reverse : bool, optional - if True, reverse the direction of the blend - as_cmap : bool, optional - if True, return as a matplotlib colormap instead of list - - Returns - ------- - palette : list or colormap - dark_palette : Create a sequential palette with dark low values. - - """ - color = _color_to_rgb(color, input) - flat = desaturate(color, 0) - colors = [color, flat] if reverse else [flat, color] + rgb = _color_to_rgb(color, input) + h, s, l = husl.rgb_to_husl(*rgb) + gray_s, gray_l = .15 * s, 95 + gray = _color_to_rgb((h, gray_s, gray_l), input="husl") + colors = [rgb, gray] if reverse else [gray, rgb] return blend_palette(colors, n_colors, as_cmap) -def diverging_palette(h_neg, h_pos, s=75, l=50, sep=10, n=6, # noqa +def diverging_palette(h_neg, h_pos, s=75, l=50, sep=1, n=6, # noqa center="light", as_cmap=False): """Make a diverging palette between two HUSL colors. @@ -683,20 +644,18 @@ def diverging_palette(h_neg, h_pos, s=75, l=50, sep=10, n=6, # noqa Anchor saturation for both extents of the map. l : float in [0, 100], optional Anchor lightness for both extents of the map. + sep : int, optional + Size of the intermediate region. n : int, optional Number of colors in the palette (if not returning a cmap) center : {"light", "dark"}, optional Whether the center of the palette is light or dark as_cmap : bool, optional - If true, return a matplotlib colormap object rather than a - list of colors. + If True, return a :class:`matplotlib.colors.Colormap`. Returns ------- - palette or cmap : seaborn color palette or matplotlib colormap - List-like object of colors as RGB tuples, or colormap object that - can map continuous values to colors, depending on the value of the - ``as_cmap`` parameter. + list of RGB tuples or :class:`matplotlib.colors.Colormap` See Also -------- @@ -711,7 +670,7 @@ def diverging_palette(h_neg, h_pos, s=75, l=50, sep=10, n=6, # noqa .. plot:: :context: close-figs - >>> import seaborn as sns; sns.set() + >>> import seaborn as sns; sns.set_theme() >>> sns.palplot(sns.diverging_palette(240, 10, n=9)) Generate a brighter green-white-purple palette: @@ -736,17 +695,17 @@ def diverging_palette(h_neg, h_pos, s=75, l=50, sep=10, n=6, # noqa >>> from numpy import arange >>> x = arange(25).reshape(5, 5) - >>> cmap = sns.diverging_palette(220, 20, sep=20, as_cmap=True) + >>> cmap = sns.diverging_palette(220, 20, as_cmap=True) >>> ax = sns.heatmap(x, cmap=cmap) """ - palfunc = dark_palette if center == "dark" else light_palette - neg = palfunc((h_neg, s, l), 128 - (sep // 2), reverse=True, input="husl") - pos = palfunc((h_pos, s, l), 128 - (sep // 2), input="husl") - midpoint = dict(light=[(.95, .95, .95, 1.)], - dark=[(.133, .133, .133, 1.)])[center] + palfunc = dict(dark=dark_palette, light=light_palette)[center] + n_half = int(128 - (sep // 2)) + neg = palfunc((h_neg, s, l), n_half, reverse=True, input="husl") + pos = palfunc((h_pos, s, l), n_half, input="husl") + midpoint = dict(light=[(.95, .95, .95)], dark=[(.133, .133, .133)])[center] mid = midpoint * sep - pal = blend_palette(np.concatenate([neg, mid, pos]), n, as_cmap=as_cmap) + pal = blend_palette(np.concatenate([neg, mid, pos]), n, as_cmap=as_cmap) return pal @@ -760,21 +719,19 @@ def blend_palette(colors, n_colors=6, as_cmap=False, input="rgb"): n_colors : int, optional Number of colors in the palette. as_cmap : bool, optional - If True, return as a matplotlib colormap instead of list. + If True, return a :class:`matplotlib.colors.Colormap`. Returns ------- - palette or cmap : seaborn color palette or matplotlib colormap - List-like object of colors as RGB tuples, or colormap object that - can map continuous values to colors, depending on the value of the - ``as_cmap`` parameter. + list of RGB tuples or :class:`matplotlib.colors.Colormap` """ colors = [_color_to_rgb(color, input) for color in colors] name = "blend" pal = mpl.colors.LinearSegmentedColormap.from_list(name, colors) if not as_cmap: - pal = _ColorPalette(pal(np.linspace(0, 1, int(n_colors)))) + rgb_array = pal(np.linspace(0, 1, int(n_colors)))[:, :3] # no alpha + pal = _ColorPalette(map(tuple, rgb_array)) return pal @@ -868,14 +825,11 @@ def cubehelix_palette(n_colors=6, start=0, rot=.4, gamma=1.0, hue=0.8, reverse : bool If True, the palette will go from dark to light. as_cmap : bool - If True, return a matplotlib colormap instead of a list of colors. + If True, return a :class:`matplotlib.colors.Colormap`. Returns ------- - palette or cmap : seaborn color palette or matplotlib colormap - List-like object of colors as RGB tuples, or colormap object that - can map continuous values to colors, depending on the value of the - ``as_cmap`` parameter. + list of RGB tuples or :class:`matplotlib.colors.Colormap` See Also -------- @@ -898,7 +852,7 @@ def cubehelix_palette(n_colors=6, start=0, rot=.4, gamma=1.0, hue=0.8, .. plot:: :context: close-figs - >>> import seaborn as sns; sns.set() + >>> import seaborn as sns; sns.set_theme() >>> sns.palplot(sns.cubehelix_palette()) Rotate backwards from the same starting location: @@ -965,14 +919,14 @@ def color(x): return color cdict = { - "red": get_color_function(-0.14861, 1.78277), - "green": get_color_function(-0.29227, -0.90649), - "blue": get_color_function(1.97294, 0.0), + "red": get_color_function(-0.14861, 1.78277), + "green": get_color_function(-0.29227, -0.90649), + "blue": get_color_function(1.97294, 0.0), } cmap = mpl.colors.LinearSegmentedColormap("cubehelix", cdict) - x = np.linspace(light, dark, n_colors) + x = np.linspace(light, dark, int(n_colors)) pal = cmap(x)[:, :3].tolist() if reverse: pal = pal[::-1] @@ -1050,7 +1004,7 @@ def set_color_codes(palette="deep"): :context: close-figs >>> import matplotlib.pyplot as plt - >>> import seaborn as sns; sns.set() + >>> import seaborn as sns; sns.set_theme() >>> sns.set_color_codes() >>> _ = plt.plot([0, 1], color="r") @@ -1065,9 +1019,9 @@ def set_color_codes(palette="deep"): """ if palette == "reset": - colors = [(0., 0., 1.), (0., .5, 0.), (1., 0., 0.), (.75, .75, 0.), + colors = [(0., 0., 1.), (0., .5, 0.), (1., 0., 0.), (.75, 0., .75), (.75, .75, 0.), (0., .75, .75), (0., 0., 0.)] - elif not isinstance(palette, string_types): + elif not isinstance(palette, str): err = "set_color_codes requires a named seaborn palette" raise TypeError(err) elif palette in SEABORN_PALETTES: diff --git a/seaborn/rcmod.py b/seaborn/rcmod.py index 45cea6ca32..08ff8e876d 100644 --- a/seaborn/rcmod.py +++ b/seaborn/rcmod.py @@ -1,16 +1,12 @@ """Control plot style and scaling using the matplotlib rcParams interface.""" -from distutils.version import LooseVersion +import warnings import functools import matplotlib as mpl -import warnings -from . import palettes, _orig_rc_params - - -mpl_ge_150 = LooseVersion(mpl.__version__) >= '1.5.0' -mpl_ge_2 = LooseVersion(mpl.__version__) >= '2.0' +from cycler import cycler +from . import palettes -__all__ = ["set", "reset_defaults", "reset_orig", +__all__ = ["set_theme", "set", "reset_defaults", "reset_orig", "axes_style", "set_style", "plotting_context", "set_context", "set_palette"] @@ -37,30 +33,23 @@ "lines.solid_capstyle", "patch.edgecolor", + "patch.force_edgecolor", "image.cmap", "font.family", "font.sans-serif", - ] - -if mpl_ge_2: - - _style_keys.extend([ - - "patch.force_edgecolor", - - "xtick.bottom", - "xtick.top", - "ytick.left", - "ytick.right", + "xtick.bottom", + "xtick.top", + "ytick.left", + "ytick.right", - "axes.spines.left", - "axes.spines.bottom", - "axes.spines.right", - "axes.spines.top", + "axes.spines.left", + "axes.spines.bottom", + "axes.spines.right", + "axes.spines.top", - ]) +] _context_keys = [ @@ -70,6 +59,7 @@ "xtick.labelsize", "ytick.labelsize", "legend.fontsize", + "legend.title_fontsize", "axes.linewidth", "grid.linewidth", @@ -87,12 +77,12 @@ "xtick.minor.size", "ytick.minor.size", - ] +] -def set(context="notebook", style="darkgrid", palette="deep", - font="sans-serif", font_scale=1, color_codes=True, rc=None): - """Set aesthetic parameters in one step. +def set_theme(context="notebook", style="darkgrid", palette="deep", + font="sans-serif", font_scale=1, color_codes=True, rc=None): + """Set multiple theme parameters in one step. Each set of parameters can be set directly or temporarily, see the referenced functions below for more information. @@ -100,11 +90,11 @@ def set(context="notebook", style="darkgrid", palette="deep", Parameters ---------- context : string or dict - Plotting context parameters, see :func:`plotting_context` + Plotting context parameters, see :func:`plotting_context`. style : string or dict - Axes style parameters, see :func:`axes_style` + Axes style parameters, see :func:`axes_style`. palette : string or sequence - Color palette, see :func:`color_palette` + Color palette, see :func:`color_palette`. font : string Font family, see matplotlib font manager. font_scale : float, optional @@ -124,6 +114,11 @@ def set(context="notebook", style="darkgrid", palette="deep", mpl.rcParams.update(rc) +def set(*args, **kwargs): + """Alias for :func:`set_theme`, which is the preferred interface.""" + set_theme(*args, **kwargs) + + def reset_defaults(): """Restore all RC params to default settings.""" mpl.rcParams.update(mpl.rcParamsDefault) @@ -131,6 +126,7 @@ def reset_defaults(): def reset_orig(): """Restore all RC params to original settings (respects custom rc).""" + from . import _orig_rc_params with warnings.catch_warnings(): warnings.simplefilter('ignore', mpl.cbook.MatplotlibDeprecationWarning) mpl.rcParams.update(_orig_rc_params) @@ -217,17 +213,17 @@ def axes_style(style=None, rc=None): "xtick.top": False, "ytick.right": False, - } + } # Set grid on or off if "grid" in style: style_dict.update({ "axes.grid": True, - }) + }) else: style_dict.update({ "axes.grid": False, - }) + }) # Set the color of the background, spines, and grids if style.startswith("dark"): @@ -242,7 +238,7 @@ def axes_style(style=None, rc=None): "axes.spines.right": True, "axes.spines.top": True, - }) + }) elif style == "whitegrid": style_dict.update({ @@ -256,7 +252,7 @@ def axes_style(style=None, rc=None): "axes.spines.right": True, "axes.spines.top": True, - }) + }) elif style in ["white", "ticks"]: style_dict.update({ @@ -270,19 +266,19 @@ def axes_style(style=None, rc=None): "axes.spines.right": True, "axes.spines.top": True, - }) + }) # Show or hide the axes ticks if style == "ticks": style_dict.update({ "xtick.bottom": True, "ytick.left": True, - }) + }) else: style_dict.update({ "xtick.bottom": False, "ytick.left": False, - }) + }) # Remove entries that are not defined in the base list of valid keys # This lets us handle matplotlib <=/> 2.0 @@ -389,7 +385,7 @@ def plotting_context(context=None, font_scale=1, rc=None): raise ValueError("context must be in %s" % ", ".join(contexts)) # Set up dictionary of default parameters - base_context = { + texts_base_context = { "font.size": 12, "axes.labelsize": 12, @@ -397,6 +393,11 @@ def plotting_context(context=None, font_scale=1, rc=None): "xtick.labelsize": 11, "ytick.labelsize": 11, "legend.fontsize": 11, + "legend.title_fontsize": 12, + + } + + base_context = { "axes.linewidth": 1.25, "grid.linewidth": 1, @@ -414,15 +415,15 @@ def plotting_context(context=None, font_scale=1, rc=None): "xtick.minor.size": 4, "ytick.minor.size": 4, - } + } + base_context.update(texts_base_context) # Scale all the parameters by the same factor depending on the context scaling = dict(paper=.8, notebook=1, talk=1.5, poster=2)[context] context_dict = {k: v * scaling for k, v in base_context.items()} # Now independently scale the fonts - font_keys = ["axes.labelsize", "axes.titlesize", "legend.fontsize", - "xtick.labelsize", "ytick.labelsize", "font.size"] + font_keys = texts_base_context.keys() font_dict = {k: context_dict[k] * font_scale for k in font_keys} context_dict.update(font_dict) @@ -540,12 +541,8 @@ def set_palette(palette, n_colors=None, desat=None, color_codes=False): """ colors = palettes.color_palette(palette, n_colors, desat) - if mpl_ge_150: - from cycler import cycler - cyl = cycler('color', colors) - mpl.rcParams['axes.prop_cycle'] = cyl - else: - mpl.rcParams["axes.color_cycle"] = list(colors) + cyl = cycler('color', colors) + mpl.rcParams['axes.prop_cycle'] = cyl mpl.rcParams["patch.facecolor"] = colors[0] if color_codes: try: diff --git a/seaborn/regression.py b/seaborn/regression.py index a1fda49666..e302149bd1 100644 --- a/seaborn/regression.py +++ b/seaborn/regression.py @@ -1,11 +1,9 @@ """Plotting functions for linear models (broadly construed).""" -from __future__ import division import copy from textwrap import dedent import warnings import numpy as np import pandas as pd -from scipy.spatial import distance import matplotlib as mpl import matplotlib.pyplot as plt @@ -16,11 +14,10 @@ except ImportError: _has_statsmodels = False -from .external.six import string_types - from . import utils from . import algorithms as algo from .axisgrid import FacetGrid, _facet_docs +from ._decorators import _deprecate_positional_args __all__ = ["lmplot", "regplot", "residplot"] @@ -38,18 +35,24 @@ def establish_variables(self, data, **kws): self.data = data # Validate the inputs - any_strings = any([isinstance(v, string_types) for v in kws.values()]) + any_strings = any([isinstance(v, str) for v in kws.values()]) if any_strings and data is None: raise ValueError("Must pass `data` if using named variables.") # Set the variables for var, val in kws.items(): - if isinstance(val, string_types): - setattr(self, var, data[val]) + if isinstance(val, str): + vector = data[val] elif isinstance(val, list): - setattr(self, var, np.asarray(val)) + vector = np.asarray(val) else: - setattr(self, var, val) + vector = val + if vector is not None and vector.shape != (1,): + vector = np.squeeze(vector) + if np.ndim(vector) > 1: + err = "regplot inputs must be 1d" + raise ValueError(err) + setattr(self, var, vector) def dropna(self, *vars): """Remove observations with missing data.""" @@ -73,7 +76,7 @@ class _RegressionPlotter(_LinearPlotter): """ def __init__(self, x, y, data=None, x_estimator=None, x_bins=None, x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000, - units=None, order=1, logistic=False, lowess=False, + units=None, seed=None, order=1, logistic=False, lowess=False, robust=False, logx=False, x_partial=None, y_partial=None, truncate=False, dropna=True, x_jitter=None, y_jitter=None, color=None, label=None): @@ -83,6 +86,7 @@ def __init__(self, x, y, data=None, x_estimator=None, x_bins=None, self.ci = ci self.x_ci = ci if x_ci == "ci" else x_ci self.n_boot = n_boot + self.seed = seed self.scatter = scatter self.fit_reg = fit_reg self.order = order @@ -122,8 +126,13 @@ def __init__(self, x, y, data=None, x_estimator=None, x_bins=None, else: self.x_discrete = self.x + # Disable regression in case of singleton inputs + if len(self.x) <= 1: + self.fit_reg = False + # Save the range of the x variable for the grid later - self.x_range = self.x.min(), self.x.max() + if self.fit_reg: + self.x_range = self.x.min(), self.x.max() @property def scatter_data(self): @@ -167,8 +176,11 @@ def estimate_data(self): else: if self.units is not None: units = self.units[x == val] - boots = algo.bootstrap(_y, func=self.x_estimator, - n_boot=self.n_boot, units=units) + boots = algo.bootstrap(_y, + func=self.x_estimator, + n_boot=self.n_boot, + units=units, + seed=self.seed) _ci = utils.ci(boots, self.x_ci) cis.append(_ci) @@ -226,8 +238,11 @@ def reg_func(_x, _y): if self.ci is None: return yhat, None - beta_boots = algo.bootstrap(X, y, func=reg_func, - n_boot=self.n_boot, units=self.units).T + beta_boots = algo.bootstrap(X, y, + func=reg_func, + n_boot=self.n_boot, + units=self.units, + seed=self.seed).T yhat_boots = grid.dot(beta_boots).T return yhat, yhat_boots @@ -241,8 +256,11 @@ def reg_func(_x, _y): if self.ci is None: return yhat, None - yhat_boots = algo.bootstrap(x, y, func=reg_func, - n_boot=self.n_boot, units=self.units) + yhat_boots = algo.bootstrap(x, y, + func=reg_func, + n_boot=self.n_boot, + units=self.units, + seed=self.seed) return yhat, yhat_boots def fit_statsmodels(self, grid, model, **kwargs): @@ -263,8 +281,11 @@ def reg_func(_x, _y): if self.ci is None: return yhat, None - yhat_boots = algo.bootstrap(X, y, func=reg_func, - n_boot=self.n_boot, units=self.units) + yhat_boots = algo.bootstrap(X, y, + func=reg_func, + n_boot=self.n_boot, + units=self.units, + seed=self.seed) return yhat, yhat_boots def fit_lowess(self): @@ -286,24 +307,27 @@ def reg_func(_x, _y): if self.ci is None: return yhat, None - beta_boots = algo.bootstrap(X, y, func=reg_func, - n_boot=self.n_boot, units=self.units).T + beta_boots = algo.bootstrap(X, y, + func=reg_func, + n_boot=self.n_boot, + units=self.units, + seed=self.seed).T yhat_boots = grid.dot(beta_boots).T return yhat, yhat_boots def bin_predictor(self, bins): """Discretize a predictor by assigning value to closest bin.""" - x = self.x + x = np.asarray(self.x) if np.isscalar(bins): percentiles = np.linspace(0, 100, bins + 2)[1:-1] - bins = np.c_[utils.percentiles(x, percentiles)] + bins = np.percentile(x, percentiles) else: - bins = np.c_[np.ravel(bins)] + bins = np.ravel(bins) - dist = distance.cdist(np.c_[x], bins) + dist = np.abs(np.subtract.outer(x, bins)) x_binned = bins[np.argmin(dist, axis=1)].ravel() - return x_binned, bins.ravel() + return x_binned, bins def regress_out(self, a, b): """Regress b from a keeping a's original mean.""" @@ -312,7 +336,7 @@ def regress_out(self, a, b): b = b - b.mean() b = np.c_[b] a_prime = a - b.dot(np.linalg.pinv(b).dot(a)) - return (a_prime + a_mean).reshape(a.shape) + return np.asarray(a_prime + a_mean).reshape(a.shape) def plot(self, ax, scatter_kws, line_kws): """Draw the full plot.""" @@ -324,13 +348,13 @@ def plot(self, ax, scatter_kws, line_kws): # Use the current color cycle state as a default if self.color is None: - lines, = ax.plot(self.x.mean(), self.y.mean()) + lines, = ax.plot([], []) color = lines.get_color() lines.remove() else: color = self.color - # Ensure that color is hex to avoid matplotlib weidness + # Ensure that color is hex to avoid matplotlib weirdness color = mpl.colors.rgb2hex(mpl.colors.colorConverter.to_rgb(color)) # Let color in keyword arguments override overall plot color @@ -340,6 +364,7 @@ def plot(self, ax, scatter_kws, line_kws): # Draw the constituent plots if self.scatter: self.scatterplot(ax, scatter_kws) + if self.fit_reg: self.lineplot(ax, line_kws) @@ -383,10 +408,9 @@ def scatterplot(self, ax, kws): def lineplot(self, ax, kws): """Draw the model.""" - xlim = ax.get_xlim() - # Fit the regression model grid, yhat, err_bands = self.fit_regression(ax) + edges = grid[0], grid[-1] # Get set default aesthetics fill_color = kws["color"] @@ -394,10 +418,10 @@ def lineplot(self, ax, kws): kws.setdefault("linewidth", lw) # Draw the regression line and confidence interval - ax.plot(grid, yhat, **kws) + line, = ax.plot(grid, yhat, **kws) + line.sticky_edges.x[:] = edges # Prevent mpl from adding margin if err_bands is not None: ax.fill_between(grid, *err_bands, facecolor=fill_color, alpha=.15) - ax.set_xlim(*xlim, auto=None) _regression_docs = dict( @@ -468,6 +492,10 @@ def lineplot(self, ax, kws): that resamples both units and observations (within unit). This does not otherwise influence how the regression is estimated or drawn.\ """), + seed=dedent("""\ + seed : int, numpy.random.Generator, or numpy.random.RandomState, optional + Seed or random number generator for reproducible bootstrapping.\ + """), order=dedent("""\ order : int, optional If ``order`` is greater than 1, use ``numpy.polyfit`` to estimate a @@ -508,9 +536,8 @@ def lineplot(self, ax, kws): """), truncate=dedent("""\ truncate : bool, optional - By default, the regression line is drawn to fill the x axis limits - after the scatterplot is drawn. If ``truncate`` is ``True``, it will - instead by bounded by the data limits.\ + If ``True``, the regression line is bounded by the data limits. If + ``False``, it extends to the ``x`` axis limits. """), xy_jitter=dedent("""\ {x,y}_jitter : floats, optional @@ -524,19 +551,25 @@ def lineplot(self, ax, kws): Additional keyword arguments to pass to ``plt.scatter`` and ``plt.plot``.\ """), - ) +) _regression_docs.update(_facet_docs) -def lmplot(x, y, data, hue=None, col=None, row=None, palette=None, - col_wrap=None, height=5, aspect=1, markers="o", sharex=True, - sharey=True, hue_order=None, col_order=None, row_order=None, - legend=True, legend_out=True, x_estimator=None, x_bins=None, - x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000, - units=None, order=1, logistic=False, lowess=False, robust=False, - logx=False, x_partial=None, y_partial=None, truncate=False, - x_jitter=None, y_jitter=None, scatter_kws=None, line_kws=None, - size=None): +@_deprecate_positional_args +def lmplot( + *, + x=None, y=None, + data=None, + hue=None, col=None, row=None, # TODO move before data once * is enforced + palette=None, col_wrap=None, height=5, aspect=1, markers="o", + sharex=True, sharey=True, hue_order=None, col_order=None, row_order=None, + legend=True, legend_out=True, x_estimator=None, x_bins=None, + x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000, + units=None, seed=None, order=1, logistic=False, lowess=False, + robust=False, logx=False, x_partial=None, y_partial=None, + truncate=True, x_jitter=None, y_jitter=None, scatter_kws=None, + line_kws=None, size=None +): # Handle deprecations if size is not None: @@ -545,17 +578,22 @@ def lmplot(x, y, data, hue=None, col=None, row=None, palette=None, "please update your code.") warnings.warn(msg, UserWarning) + if data is None: + raise TypeError("Missing required keyword argument `data`.") + # Reduce the dataframe to only needed columns need_cols = [x, y, hue, col, row, units, x_partial, y_partial] cols = np.unique([a for a in need_cols if a is not None]).tolist() data = data[cols] # Initialize the grid - facets = FacetGrid(data, row, col, hue, palette=palette, - row_order=row_order, col_order=col_order, - hue_order=hue_order, height=height, aspect=aspect, - col_wrap=col_wrap, sharex=sharex, sharey=sharey, - legend_out=legend_out) + facets = FacetGrid( + data, row=row, col=col, hue=hue, + palette=palette, + row_order=row_order, col_order=col_order, hue_order=hue_order, + height=height, aspect=aspect, col_wrap=col_wrap, + sharex=sharex, sharey=sharey, legend_out=legend_out + ) # Add the markers here as FacetGrid has figured out how many levels of the # hue variable are needed and we don't want to duplicate that process @@ -581,12 +619,15 @@ def lmplot(x, y, data, hue=None, col=None, row=None, palette=None, regplot_kws = dict( x_estimator=x_estimator, x_bins=x_bins, x_ci=x_ci, scatter=scatter, fit_reg=fit_reg, ci=ci, n_boot=n_boot, units=units, - order=order, logistic=logistic, lowess=lowess, robust=robust, - logx=logx, x_partial=x_partial, y_partial=y_partial, truncate=truncate, - x_jitter=x_jitter, y_jitter=y_jitter, + seed=seed, order=order, logistic=logistic, lowess=lowess, + robust=robust, logx=logx, x_partial=x_partial, y_partial=y_partial, + truncate=truncate, x_jitter=x_jitter, y_jitter=y_jitter, scatter_kws=scatter_kws, line_kws=line_kws, - ) - facets.map_dataframe(regplot, x, y, **regplot_kws) + ) + facets.map_dataframe(regplot, x=x, y=y, **regplot_kws) + + # TODO this will need to change when we relax string requirement + facets.set_axis_labels(x, y) # Add a legend if legend and (hue is not None) and (hue not in [col, row]): @@ -645,6 +686,7 @@ def lmplot(x, y, data, hue=None, col=None, row=None, palette=None, {ci} {n_boot} {units} + {seed} {order} {logistic} {lowess} @@ -681,7 +723,7 @@ def lmplot(x, y, data, hue=None, col=None, row=None, palette=None, .. plot:: :context: close-figs - >>> import seaborn as sns; sns.set(color_codes=True) + >>> import seaborn as sns; sns.set_theme(color_codes=True) >>> tips = sns.load_dataset("tips") >>> g = sns.lmplot(x="total_bill", y="tip", data=tips) @@ -766,16 +808,22 @@ def lmplot(x, y, data, hue=None, col=None, row=None, palette=None, """).format(**_regression_docs) -def regplot(x, y, data=None, x_estimator=None, x_bins=None, x_ci="ci", - scatter=True, fit_reg=True, ci=95, n_boot=1000, units=None, - order=1, logistic=False, lowess=False, robust=False, - logx=False, x_partial=None, y_partial=None, - truncate=False, dropna=True, x_jitter=None, y_jitter=None, - label=None, color=None, marker="o", - scatter_kws=None, line_kws=None, ax=None): +@_deprecate_positional_args +def regplot( + *, + x=None, y=None, + data=None, + x_estimator=None, x_bins=None, x_ci="ci", + scatter=True, fit_reg=True, ci=95, n_boot=1000, units=None, + seed=None, order=1, logistic=False, lowess=False, robust=False, + logx=False, x_partial=None, y_partial=None, + truncate=True, dropna=True, x_jitter=None, y_jitter=None, + label=None, color=None, marker="o", + scatter_kws=None, line_kws=None, ax=None +): plotter = _RegressionPlotter(x, y, data, x_estimator, x_bins, x_ci, - scatter, fit_reg, ci, n_boot, units, + scatter, fit_reg, ci, n_boot, units, seed, order, logistic, lowess, robust, logx, x_partial, y_partial, truncate, dropna, x_jitter, y_jitter, color, label) @@ -810,6 +858,7 @@ def regplot(x, y, data=None, x_estimator=None, x_bins=None, x_ci="ci", {ci} {n_boot} {units} + {seed} {order} {logistic} {lowess} @@ -864,7 +913,7 @@ def regplot(x, y, data=None, x_estimator=None, x_bins=None, x_ci="ci", .. plot:: :context: close-figs - >>> import seaborn as sns; sns.set(color_codes=True) + >>> import seaborn as sns; sns.set_theme(color_codes=True) >>> tips = sns.load_dataset("tips") >>> ax = sns.regplot(x="total_bill", y="tip", data=tips) @@ -888,12 +937,12 @@ def regplot(x, y, data=None, x_estimator=None, x_bins=None, x_ci="ci", >>> ax = sns.regplot(x=x, y=y, marker="+") Use a 68% confidence interval, which corresponds with the standard error - of the estimate: + of the estimate, and extend the regression line to the axis limits: .. plot:: :context: close-figs - >>> ax = sns.regplot(x=x, y=y, ci=68) + >>> ax = sns.regplot(x=x, y=y, ci=68, truncate=False) Plot with a discrete ``x`` variable and add some jitter: @@ -918,7 +967,7 @@ def regplot(x, y, data=None, x_estimator=None, x_bins=None, x_ci="ci", >>> ax = sns.regplot(x=x, y=y, x_bins=4) - Fit a higher-order polynomial regression and truncate the model prediction: + Fit a higher-order polynomial regression: .. plot:: :context: close-figs @@ -926,7 +975,7 @@ def regplot(x, y, data=None, x_estimator=None, x_bins=None, x_ci="ci", >>> ans = sns.load_dataset("anscombe") >>> ax = sns.regplot(x="x", y="y", data=ans.loc[ans.dataset == "II"], ... scatter_kws={{"s": 80}}, - ... order=2, ci=None, truncate=True) + ... order=2, ci=None) Fit a robust regression and don't plot a confidence interval: @@ -947,20 +996,26 @@ def regplot(x, y, data=None, x_estimator=None, x_bins=None, x_ci="ci", >>> ax = sns.regplot(x="total_bill", y="big_tip", data=tips, ... logistic=True, n_boot=500, y_jitter=.03) - Fit the regression model using log(x) and truncate the model prediction: + Fit the regression model using log(x): .. plot:: :context: close-figs >>> ax = sns.regplot(x="size", y="total_bill", data=tips, - ... x_estimator=np.mean, logx=True, truncate=True) + ... x_estimator=np.mean, logx=True) """).format(**_regression_docs) -def residplot(x, y, data=None, lowess=False, x_partial=None, y_partial=None, - order=1, robust=False, dropna=True, label=None, color=None, - scatter_kws=None, line_kws=None, ax=None): +@_deprecate_positional_args +def residplot( + *, + x=None, y=None, + data=None, + lowess=False, x_partial=None, y_partial=None, + order=1, robust=False, dropna=True, label=None, color=None, + scatter_kws=None, line_kws=None, ax=None +): """Plot the residuals of a linear regression. This function will regress y on x (possibly as a robust or polynomial @@ -1008,8 +1063,8 @@ def residplot(x, y, data=None, lowess=False, x_partial=None, y_partial=None, See Also -------- regplot : Plot a simple linear regression model. - jointplot (with kind="resid"): Draw a residplot with univariate - marginal distrbutions. + jointplot : Draw a :func:`residplot` with univariate marginal distributions + (when used with ``kind="resid"``). """ plotter = _RegressionPlotter(x, y, data, ci=None, @@ -1034,7 +1089,7 @@ def residplot(x, y, data=None, lowess=False, x_partial=None, y_partial=None, ax.axhline(0, ls=":", c=".2") # Draw the scatterplot - scatter_kws = {} if scatter_kws is None else scatter_kws - line_kws = {} if line_kws is None else line_kws + scatter_kws = {} if scatter_kws is None else scatter_kws.copy() + line_kws = {} if line_kws is None else line_kws.copy() plotter.plot(ax, scatter_kws, line_kws) return ax diff --git a/seaborn/relational.py b/seaborn/relational.py index 77bebe30b0..c9fc15adde 100644 --- a/seaborn/relational.py +++ b/seaborn/relational.py @@ -1,561 +1,227 @@ -from __future__ import division -from itertools import product -from textwrap import dedent -from distutils.version import LooseVersion +import warnings import numpy as np -import pandas as pd import matplotlib as mpl import matplotlib.pyplot as plt -from .external.six import string_types - -from . import utils -from .utils import (categorical_order, get_color_cycle, ci_to_errsize, sort_df, - remove_na) -from .algorithms import bootstrap -from .palettes import color_palette, cubehelix_palette, _parse_cubehelix_args +from ._core import ( + VectorPlotter, +) +from .utils import ( + locator_to_legend_entries, + adjust_legend_subtitles, + _default_color, + _deprecate_ci, +) +from ._statistics import EstimateAggregator from .axisgrid import FacetGrid, _facet_docs +from ._decorators import _deprecate_positional_args +from ._docstrings import ( + DocstringComponents, + _core_docs, +) __all__ = ["relplot", "scatterplot", "lineplot"] -class _RelationalPlotter(object): - - if LooseVersion(mpl.__version__) >= "2.0": - default_markers = ["o", "X", "s", "P", "D", "^", "v", "p"] - else: - default_markers = ["o", "s", "D", "^", "v", "p"] - default_dashes = ["", (4, 1.5), (1, 1), - (3, 1, 1.5, 1), (5, 1, 1, 1), - (5, 1, 2, 1, 2, 1)] - - def establish_variables(self, x=None, y=None, - hue=None, size=None, style=None, - units=None, data=None): - """Parse the inputs to define data for plotting.""" - # Initialize label variables - x_label = y_label = hue_label = size_label = style_label = None - - # Option 1: - # We have a wide-form datast - # -------------------------- - - if x is None and y is None: - - self.input_format = "wide" - - # Option 1a: - # The input data is a Pandas DataFrame - # ------------------------------------ - # We will assign the index to x, the values to y, - # and the columns names to both hue and style - - # TODO accept a dict and try to coerce to a dataframe? - - if isinstance(data, pd.DataFrame): - - # Enforce numeric values - try: - data.astype(np.float) - except ValueError: - err = "A wide-form input must have only numeric values." - raise ValueError(err) - - plot_data = data.copy() - plot_data.loc[:, "x"] = data.index - plot_data = pd.melt(plot_data, "x", - var_name="hue", value_name="y") - plot_data["style"] = plot_data["hue"] - - x_label = getattr(data.index, "name", None) - hue_label = style_label = getattr(plot_data.columns, - "name", None) - - # Option 1b: - # The input data is an array or list - # ---------------------------------- - - else: - - if not len(data): - - plot_data = pd.DataFrame(columns=["x", "y"]) - - elif np.isscalar(np.asarray(data)[0]): - - # The input data is a flat list(like): - # We assign a numeric index for x and use the values for y - - x = getattr(data, "index", np.arange(len(data))) - plot_data = pd.DataFrame(dict(x=x, y=data)) - - elif hasattr(data, "shape"): - - # The input data is an array(like): - # We either use the index or assign a numeric index to x, - # the values to y, and id keys to both hue and style - - plot_data = pd.DataFrame(data) - plot_data.loc[:, "x"] = plot_data.index - plot_data = pd.melt(plot_data, "x", - var_name="hue", - value_name="y") - plot_data["style"] = plot_data["hue"] - - else: - - # The input data is a nested list: We will either use the - # index or assign a numeric index for x, use the values - # for y, and use numeric hue/style identifiers. - - plot_data = [] - for i, data_i in enumerate(data): - x = getattr(data_i, "index", np.arange(len(data_i))) - n = getattr(data_i, "name", i) - data_i = dict(x=x, y=data_i, hue=n, style=n, size=None) - plot_data.append(pd.DataFrame(data_i)) - plot_data = pd.concat(plot_data) - - # Option 2: - # We have long-form data - # ---------------------- - - elif x is not None and y is not None: - - self.input_format = "long" - - # Use variables as from the dataframe if specified - if data is not None: - x = data.get(x, x) - y = data.get(y, y) - hue = data.get(hue, hue) - size = data.get(size, size) - style = data.get(style, style) - units = data.get(units, units) - - # Validate the inputs - for var in [x, y, hue, size, style, units]: - if isinstance(var, string_types): - err = "Could not interpret input '{}'".format(var) - raise ValueError(err) - - # Extract variable names - x_label = getattr(x, "name", None) - y_label = getattr(y, "name", None) - hue_label = getattr(hue, "name", None) - size_label = getattr(size, "name", None) - style_label = getattr(style, "name", None) - - # Reassemble into a DataFrame - plot_data = dict( - x=x, y=y, - hue=hue, style=style, size=size, - units=units - ) - plot_data = pd.DataFrame(plot_data) - - # Option 3: - # Only one variable argument - # -------------------------- - - else: - err = ("Either both or neither of `x` and `y` must be specified " - "(but try passing to `data`, which is more flexible).") - raise ValueError(err) - - # ---- Post-processing - - # Assign default values for missing attribute variables - for attr in ["hue", "style", "size", "units"]: - if attr not in plot_data: - plot_data[attr] = None - - # Determine which semantics have (some) data - plot_valid = plot_data.notnull().any() - semantics = ["x", "y"] + [ - name for name in ["hue", "size", "style"] - if plot_valid[name] - ] - - self.x_label = x_label - self.y_label = y_label - self.hue_label = hue_label - self.size_label = size_label - self.style_label = style_label - self.plot_data = plot_data - self.semantics = semantics - - return plot_data - - def categorical_to_palette(self, data, order, palette): - """Determine colors when the hue variable is qualitative.""" - # -- Identify the order and name of the levels - - if order is None: - levels = categorical_order(data) - else: - levels = order - n_colors = len(levels) - - # -- Identify the set of colors to use - - if isinstance(palette, dict): - - missing = set(levels) - set(palette) - if any(missing): - err = "The palette dictionary is missing keys: {}" - raise ValueError(err.format(missing)) - - else: - - if palette is None: - if n_colors <= len(get_color_cycle()): - colors = color_palette(None, n_colors) - else: - colors = color_palette("husl", n_colors) - elif isinstance(palette, list): - if len(palette) != n_colors: - err = "The palette list has the wrong number of colors." - raise ValueError(err) - colors = palette - else: - colors = color_palette(palette, n_colors) - - palette = dict(zip(levels, colors)) - - return levels, palette - - def numeric_to_palette(self, data, order, palette, norm): - """Determine colors when the hue variable is quantitative.""" - levels = list(np.sort(remove_na(data.unique()))) - - # TODO do we want to do something complicated to ensure contrast - # at the extremes of the colormap against the background? - - # Identify the colormap to use - palette = "ch:" if palette is None else palette - if isinstance(palette, mpl.colors.Colormap): - cmap = palette - elif str(palette).startswith("ch:"): - args, kwargs = _parse_cubehelix_args(palette) - cmap = cubehelix_palette(0, *args, as_cmap=True, **kwargs) - else: - try: - cmap = mpl.cm.get_cmap(palette) - except (ValueError, TypeError): - err = "Palette {} not understood" - raise ValueError(err) - - if norm is None: - norm = mpl.colors.Normalize() - elif isinstance(norm, tuple): - norm = mpl.colors.Normalize(*norm) - elif not isinstance(norm, mpl.colors.Normalize): - err = "``hue_norm`` must be None, tuple, or Normalize object." - raise ValueError(err) - - if not norm.scaled(): - norm(np.asarray(data.dropna())) - - # TODO this should also use color_lookup, but that needs the - # class attributes that get set after using this function... - palette = dict(zip(levels, cmap(norm(levels)))) - # palette = {l: cmap(norm([l, 1]))[0] for l in levels} - - return levels, palette, cmap, norm - - def color_lookup(self, key): - """Return the color corresponding to the hue level.""" - if self.hue_type == "numeric": - normed = self.hue_norm(key) - if np.ma.is_masked(normed): - normed = np.nan - return self.cmap(normed) - elif self.hue_type == "categorical": - return self.palette[key] - - def size_lookup(self, key): - """Return the size corresponding to the size level.""" - if self.size_type == "numeric": - min_size, max_size = self.size_range - val = self.size_norm(key) - if np.ma.is_masked(val): - return 0 - return min_size + val * (max_size - min_size) - elif self.size_type == "categorical": - return self.sizes[key] - - def style_to_attributes(self, levels, style, defaults, name): - """Convert a style argument to a dict of matplotlib attributes.""" - if style is True: - attrdict = dict(zip(levels, defaults)) - elif style and isinstance(style, dict): - attrdict = style - elif style: - attrdict = dict(zip(levels, style)) - else: - attrdict = {} - - if attrdict: - missing_levels = set(levels) - set(attrdict) - if any(missing_levels): - err = "These `style` levels are missing {}: {}" - raise ValueError(err.format(name, missing_levels)) - - return attrdict - - def subset_data(self): - """Return (x, y) data for each subset defined by semantics.""" - data = self.plot_data - all_true = pd.Series(True, data.index) - - iter_levels = product(self.hue_levels, - self.size_levels, - self.style_levels) - - for hue, size, style in iter_levels: - - hue_rows = all_true if hue is None else data["hue"] == hue - size_rows = all_true if size is None else data["size"] == size - style_rows = all_true if style is None else data["style"] == style - - rows = hue_rows & size_rows & style_rows - data["units"] = data.units.fillna("") - subset_data = data.loc[rows, ["units", "x", "y"]].dropna() - - if not len(subset_data): - continue - - if self.sort: - subset_data = sort_df(subset_data, ["units", "x", "y"]) - - if self.units is None: - subset_data = subset_data.drop("units", axis=1) - - yield (hue, size, style), subset_data - - def parse_hue(self, data, palette, order, norm): - """Determine what colors to use given data characteristics.""" - if self._empty_data(data): - - # Set default values when not using a hue mapping - levels = [None] - limits = None - norm = None - palette = {} - var_type = None - cmap = None - - else: - - # Determine what kind of hue mapping we want - var_type = self._semantic_type(data) - - # Override depending on the type of the palette argument - if isinstance(palette, (dict, list)): - var_type = "categorical" - - # -- Option 1: categorical color palette - - if var_type == "categorical": - - cmap = None - limits = None - levels, palette = self.categorical_to_palette( - data, order, palette - ) - - # -- Option 2: sequential color palette - - elif var_type == "numeric": - - levels, palette, cmap, norm = self.numeric_to_palette( - data, order, palette, norm - ) - limits = norm.vmin, norm.vmax - - self.hue_levels = levels - self.hue_norm = norm - self.hue_limits = limits - self.hue_type = var_type - self.palette = palette - self.cmap = cmap - - def parse_size(self, data, sizes, order, norm): - """Determine the linewidths given data characteristics.""" - - # TODO could break out two options like parse_hue does for clarity - - if self._empty_data(data): - levels = [None] - limits = None - norm = None - sizes = {} - var_type = None - width_range = None +_relational_narrative = DocstringComponents(dict( - else: - - var_type = self._semantic_type(data) - - # TODO override for list/dict like in parse_hue? - - if var_type == "categorical": - levels = categorical_order(data, order) - numbers = np.arange(1, 1 + len(levels))[::-1] - elif var_type == "numeric": - levels = numbers = np.sort(remove_na(data.unique())) - - if isinstance(sizes, (dict, list)): - - # Use literal size values - if isinstance(sizes, list): - if len(sizes) != len(levels): - err = "The `sizes` list has wrong number of levels" - raise ValueError(err) - sizes = dict(zip(levels, sizes)) - - missing = set(levels) - set(sizes) - if any(missing): - err = "Missing sizes for the following levels: {}" - raise ValueError(err.format(missing)) - - width_range = min(sizes.values()), max(sizes.values()) - try: - limits = min(sizes.keys()), max(sizes.keys()) - except TypeError: - limits = None - - else: - - # Infer the range of sizes to use - if sizes is None: - min_width, max_width = self._default_size_range - else: - try: - min_width, max_width = sizes - except (TypeError, ValueError): - err = "sizes argument {} not understood".format(sizes) - raise ValueError(err) - width_range = min_width, max_width - - if norm is None: - norm = mpl.colors.Normalize() - elif isinstance(norm, tuple): - norm = mpl.colors.Normalize(*norm) - elif not isinstance(norm, mpl.colors.Normalize): - err = ("``size_norm`` must be None, tuple, " - "or Normalize object.") - raise ValueError(err) - - norm.clip = True - if not norm.scaled(): - norm(np.asarray(numbers)) - limits = norm.vmin, norm.vmax - - scl = norm(numbers) - widths = np.asarray(min_width + scl * (max_width - min_width)) - if scl.mask.any(): - widths[scl.mask] = 0 - sizes = dict(zip(levels, widths)) - # sizes = {l: min_width + norm(n) * (max_width - min_width) - # for l, n in zip(levels, numbers)} - - self.sizes = sizes - self.size_type = var_type - self.size_levels = levels - self.size_norm = norm - self.size_limits = limits - self.size_range = width_range + # --- Introductory prose + main_api=""" +The relationship between ``x`` and ``y`` can be shown for different subsets +of the data using the ``hue``, ``size``, and ``style`` parameters. These +parameters control what visual semantics are used to identify the different +subsets. It is possible to show up to three dimensions independently by +using all three semantic types, but this style of plot can be hard to +interpret and is often ineffective. Using redundant semantics (i.e. both +``hue`` and ``style`` for the same variable) can be helpful for making +graphics more accessible. + +See the :ref:`tutorial ` for more information. + """, + + relational_semantic=""" +The default treatment of the ``hue`` (and to a lesser extent, ``size``) +semantic, if present, depends on whether the variable is inferred to +represent "numeric" or "categorical" data. In particular, numeric variables +are represented with a sequential colormap by default, and the legend +entries show regular "ticks" with values that may or may not exist in the +data. This behavior can be controlled through various parameters, as +described and illustrated below. + """, +)) - def parse_style(self, data, markers, dashes, order): - """Determine the markers and line dashes.""" +_relational_docs = dict( - if self._empty_data(data): + # --- Shared function parameters + data_vars=""" +x, y : names of variables in ``data`` or vector data + Input data variables; must be numeric. Can pass data directly or + reference columns in ``data``. + """, + data=""" +data : DataFrame, array, or list of arrays + Input data structure. If ``x`` and ``y`` are specified as names, this + should be a "long-form" DataFrame containing those columns. Otherwise + it is treated as "wide-form" data and grouping variables are ignored. + See the examples for the various ways this parameter can be specified + and the different effects of each. + """, + palette=""" +palette : string, list, dict, or matplotlib colormap + An object that determines how colors are chosen when ``hue`` is used. + It can be the name of a seaborn palette or matplotlib colormap, a list + of colors (anything matplotlib understands), a dict mapping levels + of the ``hue`` variable to colors, or a matplotlib colormap object. + """, + hue_order=""" +hue_order : list + Specified order for the appearance of the ``hue`` variable levels, + otherwise they are determined from the data. Not relevant when the + ``hue`` variable is numeric. + """, + hue_norm=""" +hue_norm : tuple or :class:`matplotlib.colors.Normalize` object + Normalization in data units for colormap applied to the ``hue`` + variable when it is numeric. Not relevant if it is categorical. + """, + sizes=""" +sizes : list, dict, or tuple + An object that determines how sizes are chosen when ``size`` is used. + It can always be a list of size values or a dict mapping levels of the + ``size`` variable to sizes. When ``size`` is numeric, it can also be + a tuple specifying the minimum and maximum size to use such that other + values are normalized within this range. + """, + size_order=""" +size_order : list + Specified order for appearance of the ``size`` variable levels, + otherwise they are determined from the data. Not relevant when the + ``size`` variable is numeric. + """, + size_norm=""" +size_norm : tuple or Normalize object + Normalization in data units for scaling plot objects when the + ``size`` variable is numeric. + """, + dashes=""" +dashes : boolean, list, or dictionary + Object determining how to draw the lines for different levels of the + ``style`` variable. Setting to ``True`` will use default dash codes, or + you can pass a list of dash codes or a dictionary mapping levels of the + ``style`` variable to dash codes. Setting to ``False`` will use solid + lines for all subsets. Dashes are specified as in matplotlib: a tuple + of ``(segment, gap)`` lengths, or an empty string to draw a solid line. + """, + markers=""" +markers : boolean, list, or dictionary + Object determining how to draw the markers for different levels of the + ``style`` variable. Setting to ``True`` will use default markers, or + you can pass a list of markers or a dictionary mapping levels of the + ``style`` variable to markers. Setting to ``False`` will draw + marker-less lines. Markers are specified as in matplotlib. + """, + style_order=""" +style_order : list + Specified order for appearance of the ``style`` variable levels + otherwise they are determined from the data. Not relevant when the + ``style`` variable is numeric. + """, + units=""" +units : vector or key in ``data`` + Grouping variable identifying sampling units. When used, a separate + line will be drawn for each unit with appropriate semantics, but no + legend entry will be added. Useful for showing distribution of + experimental replicates when exact identities are not needed. + """, + estimator=""" +estimator : name of pandas method or callable or None + Method for aggregating across multiple observations of the ``y`` + variable at the same ``x`` level. If ``None``, all observations will + be drawn. + """, + ci=""" +ci : int or "sd" or None + Size of the confidence interval to draw when aggregating. + + .. deprecated:: 0.12.0 + Use the new `errorbar` parameter for more flexibility. + + """, + n_boot=""" +n_boot : int + Number of bootstraps to use for computing the confidence interval. + """, + seed=""" +seed : int, numpy.random.Generator, or numpy.random.RandomState + Seed or random number generator for reproducible bootstrapping. + """, + legend=""" +legend : "auto", "brief", "full", or False + How to draw the legend. If "brief", numeric ``hue`` and ``size`` + variables will be represented with a sample of evenly spaced values. + If "full", every group will get an entry in the legend. If "auto", + choose between brief or full representation based on number of levels. + If ``False``, no legend data is added and no legend is drawn. + """, + ax_in=""" +ax : matplotlib Axes + Axes object to draw the plot onto, otherwise uses the current Axes. + """, + ax_out=""" +ax : matplotlib Axes + Returns the Axes object with the plot drawn onto it. + """, - levels = [None] - dashes = {} - markers = {} +) - else: - if order is None: - levels = categorical_order(data) - else: - levels = order +_param_docs = DocstringComponents.from_nested_components( + core=_core_docs["params"], + facets=DocstringComponents(_facet_docs), + rel=DocstringComponents(_relational_docs), + stat=DocstringComponents.from_function_params(EstimateAggregator.__init__), +) - markers = self.style_to_attributes( - levels, markers, self.default_markers, "markers" - ) - dashes = self.style_to_attributes( - levels, dashes, self.default_dashes, "dashes" - ) +class _RelationalPlotter(VectorPlotter): - paths = {} - filled_markers = [] - for k, m in markers.items(): - if not isinstance(m, mpl.markers.MarkerStyle): - m = mpl.markers.MarkerStyle(m) - paths[k] = m.get_path().transformed(m.get_transform()) - filled_markers.append(m.is_filled()) - - # Mixture of filled and unfilled markers will show line art markers - # in the edge color, which defaults to white. This can be handled, - # but there would be additional complexity with specifying the - # weight of the line art markers without overwhelming the filled - # ones with the edges. So for now, we will disallow mixtures. - if any(filled_markers) and not all(filled_markers): - err = "Filled and line art markers cannot be mixed" - raise ValueError(err) + wide_structure = { + "x": "@index", "y": "@values", "hue": "@columns", "style": "@columns", + } - self.style_levels = levels - self.dashes = dashes - self.markers = markers - self.paths = paths - - def _empty_data(self, data): - """Test if a series is completely missing.""" - return data.isnull().all() - - def _semantic_type(self, data): - """Determine if data should considered numeric or categorical.""" - if self.input_format == "wide": - return "categorical" - elif isinstance(data, pd.Series) and data.dtype.name == "category": - return "categorical" - else: - try: - float_data = data.astype(np.float) - values = np.unique(float_data.dropna()) - if np.array_equal(values, np.array([0., 1.])): - return "categorical" - return "numeric" - except (ValueError, TypeError): - return "categorical" - - def label_axes(self, ax): - """Set x and y labels with visibility that matches the ticklabels.""" - if self.x_label is not None: - x_visible = any(t.get_visible() for t in ax.get_xticklabels()) - ax.set_xlabel(self.x_label, visible=x_visible) - if self.y_label is not None: - y_visible = any(t.get_visible() for t in ax.get_yticklabels()) - ax.set_ylabel(self.y_label, visible=y_visible) + # TODO where best to define default parameters? + sort = True def add_legend_data(self, ax): """Add labeled artists to represent the different plot semantics.""" verbosity = self.legend - if verbosity not in ["brief", "full"]: - err = "`legend` must be 'brief', 'full', or False" + if isinstance(verbosity, str) and verbosity not in ["auto", "brief", "full"]: + err = "`legend` must be 'auto', 'brief', 'full', or a boolean." raise ValueError(err) + elif verbosity is True: + verbosity = "auto" legend_kwargs = {} keys = [] - title_kws = dict(color="w", s=0, linewidth=0, marker="", dashes="") + # Assign a legend title if there is only going to be one sub-legend, + # otherwise, subtitles will be inserted into the texts list with an + # invisible handle (which is a hack) + titles = { + title for title in + (self.variables.get(v, None) for v in ["hue", "size", "style"]) + if title is not None + } + if len(titles) == 1: + legend_title = titles.pop() + else: + legend_title = "" + + title_kws = dict( + visible=False, color="w", s=0, linewidth=0, marker="", dashes="" + ) def update(var_name, val_name, **kws): @@ -567,62 +233,94 @@ def update(var_name, val_name, **kws): legend_kwargs[key] = dict(**kws) - # -- Add a legend for hue semantics + # Define the maximum number of ticks to use for "brief" legends + brief_ticks = 6 - if verbosity == "brief" and self.hue_type == "numeric": - if isinstance(self.hue_norm, mpl.colors.LogNorm): - ticker = mpl.ticker.LogLocator(numticks=3) + # -- Add a legend for hue semantics + brief_hue = self._hue_map.map_type == "numeric" and ( + verbosity == "brief" + or (verbosity == "auto" and len(self._hue_map.levels) > brief_ticks) + ) + if brief_hue: + if isinstance(self._hue_map.norm, mpl.colors.LogNorm): + locator = mpl.ticker.LogLocator(numticks=brief_ticks) else: - ticker = mpl.ticker.MaxNLocator(nbins=3) - hue_levels = (ticker.tick_values(*self.hue_limits) - .astype(self.plot_data["hue"].dtype)) + locator = mpl.ticker.MaxNLocator(nbins=brief_ticks) + limits = min(self._hue_map.levels), max(self._hue_map.levels) + hue_levels, hue_formatted_levels = locator_to_legend_entries( + locator, limits, self.plot_data["hue"].infer_objects().dtype + ) + elif self._hue_map.levels is None: + hue_levels = hue_formatted_levels = [] else: - hue_levels = self.hue_levels + hue_levels = hue_formatted_levels = self._hue_map.levels # Add the hue semantic subtitle - if self.hue_label is not None: - update((self.hue_label, "title"), self.hue_label, **title_kws) + if not legend_title and self.variables.get("hue", None) is not None: + update((self.variables["hue"], "title"), + self.variables["hue"], **title_kws) # Add the hue semantic labels - for level in hue_levels: + for level, formatted_level in zip(hue_levels, hue_formatted_levels): if level is not None: - color = self.color_lookup(level) - update(self.hue_label, level, color=color) + color = self._hue_map(level) + update(self.variables["hue"], formatted_level, color=color) # -- Add a legend for size semantics - - if verbosity == "brief" and self.size_type == "numeric": - if isinstance(self.size_norm, mpl.colors.LogNorm): - ticker = mpl.ticker.LogLocator(numticks=3) + brief_size = self._size_map.map_type == "numeric" and ( + verbosity == "brief" + or (verbosity == "auto" and len(self._size_map.levels) > brief_ticks) + ) + if brief_size: + # Define how ticks will interpolate between the min/max data values + if isinstance(self._size_map.norm, mpl.colors.LogNorm): + locator = mpl.ticker.LogLocator(numticks=brief_ticks) else: - ticker = mpl.ticker.MaxNLocator(nbins=3) - size_levels = (ticker.tick_values(*self.size_limits) - .astype(self.plot_data["size"].dtype)) + locator = mpl.ticker.MaxNLocator(nbins=brief_ticks) + # Define the min/max data values + limits = min(self._size_map.levels), max(self._size_map.levels) + size_levels, size_formatted_levels = locator_to_legend_entries( + locator, limits, self.plot_data["size"].infer_objects().dtype + ) + elif self._size_map.levels is None: + size_levels = size_formatted_levels = [] else: - size_levels = self.size_levels + size_levels = size_formatted_levels = self._size_map.levels # Add the size semantic subtitle - if self.size_label is not None: - update((self.size_label, "title"), self.size_label, **title_kws) + if not legend_title and self.variables.get("size", None) is not None: + update((self.variables["size"], "title"), + self.variables["size"], **title_kws) # Add the size semantic labels - for level in size_levels: + for level, formatted_level in zip(size_levels, size_formatted_levels): if level is not None: - size = self.size_lookup(level) - update(self.size_label, level, linewidth=size, s=size) + size = self._size_map(level) + update( + self.variables["size"], + formatted_level, + linewidth=size, + s=size, + ) # -- Add a legend for style semantics # Add the style semantic title - if self.style_label is not None: - update((self.style_label, "title"), self.style_label, **title_kws) + if not legend_title and self.variables.get("style", None) is not None: + update((self.variables["style"], "title"), + self.variables["style"], **title_kws) # Add the style semantic labels - for level in self.style_levels: - if level is not None: - update(self.style_label, level, - marker=self.markers.get(level, ""), - dashes=self.dashes.get(level, "")) + if self._style_map.levels is not None: + for level in self._style_map.levels: + if level is not None: + attrs = self._style_map(level) + update( + self.variables["style"], + level, + marker=attrs.get("marker", ""), + dashes=attrs.get("dashes", ""), + ) func = getattr(ax, self._legend_func) @@ -641,9 +339,10 @@ def update(var_name, val_name, **kws): artist = func([], [], label=label, **use_kws) if self._legend_func == "plot": artist = artist[0] - legend_data[label] = artist - legend_order.append(label) + legend_data[key] = artist + legend_order.append(key) + self.legend_title = legend_title self.legend_data = legend_data self.legend_order = legend_order @@ -653,80 +352,34 @@ class _LinePlotter(_RelationalPlotter): _legend_attributes = ["color", "linewidth", "marker", "dashes"] _legend_func = "plot" - def __init__(self, - x=None, y=None, hue=None, size=None, style=None, data=None, - palette=None, hue_order=None, hue_norm=None, - sizes=None, size_order=None, size_norm=None, - dashes=None, markers=None, style_order=None, - units=None, estimator=None, ci=None, n_boot=None, - sort=True, err_style=None, err_kws=None, legend=None): - - plot_data = self.establish_variables( - x, y, hue, size, style, units, data - ) - + def __init__( + self, *, + data=None, variables={}, + estimator=None, ci=None, n_boot=None, seed=None, + sort=True, err_style=None, err_kws=None, legend=None, + errorbar=None, + ): + + # TODO this is messy, we want the mapping to be agnostic about + # the kind of plot to draw, but for the time being we need to set + # this information so the SizeMapping can use it self._default_size_range = ( np.r_[.5, 2] * mpl.rcParams["lines.linewidth"] ) - self.parse_hue(plot_data["hue"], palette, hue_order, hue_norm) - self.parse_size(plot_data["size"], sizes, size_order, size_norm) - self.parse_style(plot_data["style"], markers, dashes, style_order) + super().__init__(data=data, variables=variables) - self.units = units self.estimator = estimator + self.errorbar = errorbar self.ci = ci self.n_boot = n_boot + self.seed = seed self.sort = sort self.err_style = err_style self.err_kws = {} if err_kws is None else err_kws self.legend = legend - def aggregate(self, vals, grouper, units=None): - """Compute an estimate and confidence interval using grouper.""" - func = self.estimator - ci = self.ci - n_boot = self.n_boot - - # Define a "null" CI for when we only have one value - null_ci = pd.Series(index=["low", "high"], dtype=np.float) - - # Function to bootstrap in the context of a pandas group by - def bootstrapped_cis(vals): - - if len(vals) == 1: - return null_ci - - boots = bootstrap(vals, func=func, n_boot=n_boot) - cis = utils.ci(boots, ci) - return pd.Series(cis, ["low", "high"]) - - # Group and get the aggregation estimate - grouped = vals.groupby(grouper, sort=self.sort) - est = grouped.agg(func) - - # Exit early if we don't want a confidence interval - if ci is None: - return est.index, est, None - - # Compute the error bar extents - if ci == "sd": - sd = grouped.std() - cis = pd.DataFrame(np.c_[est - sd, est + sd], - index=est.index, - columns=["low", "high"]).stack() - else: - cis = grouped.apply(bootstrapped_cis) - - # Unpack the CIs into "wide" format for plotting - if cis.notnull().any(): - cis = cis.unstack().reindex(est.index) - else: - cis = None - - return est.index, est, cis - def plot(self, ax, kws): """Draw the plot onto an axes, passing matplotlib kwargs.""" @@ -739,20 +392,9 @@ def plot(self, ax, kws): # gotten from the corresponding matplotlib function, and calling the # function will advance the axes property cycle. - scout, = ax.plot([], [], **kws) - - orig_color = kws.pop("color", scout.get_color()) - orig_marker = kws.pop("marker", scout.get_marker()) - orig_linewidth = kws.pop("linewidth", - kws.pop("lw", scout.get_linewidth())) - - orig_dashes = kws.pop("dashes", "") - kws.setdefault("markeredgewidth", kws.pop("mew", .75)) kws.setdefault("markeredgecolor", kws.pop("mec", "w")) - scout.remove() - # Set default error kwargs err_kws = self.err_kws.copy() if self.err_style == "band": @@ -763,76 +405,112 @@ def plot(self, ax, kws): err = "`err_style` must be 'band' or 'bars', not {}" raise ValueError(err.format(self.err_style)) - # Loop over the semantic subsets and draw a line for each + # Initialize the aggregation object + agg = EstimateAggregator( + self.estimator, self.errorbar, n_boot=self.n_boot, seed=self.seed, + ) + + # TODO abstract variable to aggregate over here-ish. Better name? + agg_var = "y" + grouper = ["x"] - for semantics, data in self.subset_data(): + # TODO How to handle NA? We don't want NA to propagate through to the + # estimate/CI when some values are present, but we would also like + # matplotlib to show "gaps" in the line when all values are missing. + # This is straightforward absent aggregation, but complicated with it. + # If we want to use nas, we need to conditionalize dropna in iter_data. - hue, size, style = semantics - x, y, units = data["x"], data["y"], data.get("units", None) + # Loop over the semantic subsets and add to the plot + grouping_vars = "hue", "size", "style" + for sub_vars, sub_data in self.iter_data(grouping_vars, from_comp_data=True): + + if self.sort: + sort_vars = ["units", "x", "y"] + sort_cols = [var for var in sort_vars if var in self.variables] + sub_data = sub_data.sort_values(sort_cols) if self.estimator is not None: - if self.units is not None: + if "units" in self.variables: + # TODO eventually relax this constraint err = "estimator must be None when specifying units" raise ValueError(err) - x, y, y_ci = self.aggregate(y, x, units) + grouped = sub_data.groupby(grouper, sort=self.sort) + # Could pass as_index=False instead of reset_index, + # but that fails on a corner case with older pandas. + sub_data = grouped.apply(agg, agg_var).reset_index() + + # TODO this is pretty ad hoc ; see GH2409 + for var in "xy": + if self._log_scaled(var): + for col in sub_data.filter(regex=f"^{var}"): + sub_data[col] = np.power(10, sub_data[col]) + + # --- Draw the main line(s) + + if "units" in self.variables: # XXX why not add to grouping variables? + lines = [] + for _, unit_data in sub_data.groupby("units"): + lines.extend(ax.plot(unit_data["x"], unit_data["y"], **kws)) else: - y_ci = None - - kws["color"] = self.palette.get(hue, orig_color) - kws["dashes"] = self.dashes.get(style, orig_dashes) - kws["marker"] = self.markers.get(style, orig_marker) - kws["linewidth"] = self.sizes.get(size, orig_linewidth) + lines = ax.plot(sub_data["x"], sub_data["y"], **kws) - line, = ax.plot([], [], **kws) - line_color = line.get_color() - line_alpha = line.get_alpha() - line_capstyle = line.get_solid_capstyle() - line.remove() + for line in lines: - # --- Draw the main line + if "hue" in sub_vars: + line.set_color(self._hue_map(sub_vars["hue"])) - x, y = np.asarray(x), np.asarray(y) + if "size" in sub_vars: + line.set_linewidth(self._size_map(sub_vars["size"])) - if self.units is None: - line, = ax.plot(x, y, **kws) + if "style" in sub_vars: + attributes = self._style_map(sub_vars["style"]) + if "dashes" in attributes: + line.set_dashes(attributes["dashes"]) + if "marker" in attributes: + line.set_marker(attributes["marker"]) - else: - for u in units.unique(): - rows = np.asarray(units == u) - ax.plot(x[rows], y[rows], **kws) + line_color = line.get_color() + line_alpha = line.get_alpha() + line_capstyle = line.get_solid_capstyle() # --- Draw the confidence intervals - if y_ci is not None: + if self.estimator is not None and self.errorbar is not None: - low, high = np.asarray(y_ci["low"]), np.asarray(y_ci["high"]) + # TODO handling of orientation will need to happen here if self.err_style == "band": - ax.fill_between(x, low, high, color=line_color, **err_kws) + ax.fill_between( + sub_data["x"], sub_data["ymin"], sub_data["ymax"], + color=line_color, **err_kws + ) elif self.err_style == "bars": - y_err = ci_to_errsize((low, high), y) - ebars = ax.errorbar(x, y, y_err, linestyle="", - color=line_color, alpha=line_alpha, - **err_kws) + error_deltas = ( + sub_data["y"] - sub_data["ymin"], + sub_data["ymax"] - sub_data["y"], + ) + ebars = ax.errorbar( + sub_data["x"], sub_data["y"], error_deltas, + linestyle="", color=line_color, alpha=line_alpha, + **err_kws + ) # Set the capstyle properly on the error bars for obj in ebars.get_children(): - try: + if isinstance(obj, mpl.collections.LineCollection): obj.set_capstyle(line_capstyle) - except AttributeError: - # Does not exist on mpl < 2.2 - pass # Finalize the axes details - self.label_axes(ax) + self._add_axis_labels(ax) if self.legend: self.add_legend_data(ax) handles, _ = ax.get_legend_handles_labels() if handles: - ax.legend() + legend = ax.legend(title=self.legend_title) + adjust_legend_subtitles(legend) class _ScatterPlotter(_RelationalPlotter): @@ -840,718 +518,350 @@ class _ScatterPlotter(_RelationalPlotter): _legend_attributes = ["color", "s", "marker"] _legend_func = "scatter" - def __init__(self, - x=None, y=None, hue=None, size=None, style=None, data=None, - palette=None, hue_order=None, hue_norm=None, - sizes=None, size_order=None, size_norm=None, - dashes=None, markers=None, style_order=None, - x_bins=None, y_bins=None, - units=None, estimator=None, ci=None, n_boot=None, - alpha=None, x_jitter=None, y_jitter=None, - legend=None): - - plot_data = self.establish_variables( - x, y, hue, size, style, units, data - ) - + def __init__( + self, *, + data=None, variables={}, + x_bins=None, y_bins=None, + estimator=None, ci=None, n_boot=None, + alpha=None, x_jitter=None, y_jitter=None, + legend=None + ): + + # TODO this is messy, we want the mapping to be agnoistic about + # the kind of plot to draw, but for the time being we need to set + # this information so the SizeMapping can use it self._default_size_range = ( np.r_[.5, 2] * np.square(mpl.rcParams["lines.markersize"]) ) - self.parse_hue(plot_data["hue"], palette, hue_order, hue_norm) - self.parse_size(plot_data["size"], sizes, size_order, size_norm) - self.parse_style(plot_data["style"], markers, None, style_order) - self.units = units + super().__init__(data=data, variables=variables) self.alpha = alpha - self.legend = legend def plot(self, ax, kws): - # Draw a test plot, using the passed in kwargs. The goal here is to - # honor both (a) the current state of the plot cycler and (b) the - # specified kwargs on all the lines we will draw, overriding when - # relevant with the data semantics. Note that we won't cycle - # internally; in other words, if ``hue`` is not used, all elements will - # have the same color, but they will have the color that you would have - # gotten from the corresponding matplotlib function, and calling the - # function will advance the axes property cycle. + # --- Determine the visual attributes of the plot - scout = ax.scatter([], [], **kws) - s = kws.pop("s", scout.get_sizes()) - c = kws.pop("c", scout.get_facecolors()) - scout.remove() + data = self.plot_data.dropna() + if data.empty: + return - kws.pop("color", None) # TODO is this optimal? + # Define the vectors of x and y positions + empty = np.full(len(data), np.nan) + x = data.get("x", empty) + y = data.get("y", empty) - kws.setdefault("linewidth", .75) # TODO scale with marker size? + # Set defaults for other visual attributes kws.setdefault("edgecolor", "w") - if self.markers: + if "style" in self.variables: # Use a representative marker so scatter sets the edgecolor # properly for line art markers. We currently enforce either # all or none line art so this works. - example_marker = list(self.markers.values())[0] + example_level = self._style_map.levels[0] + example_marker = self._style_map(example_level, "marker") kws.setdefault("marker", example_marker) # TODO this makes it impossible to vary alpha with hue which might # otherwise be useful? Should we just pass None? kws["alpha"] = 1 if self.alpha == "auto" else self.alpha - # Assign arguments for plt.scatter and draw the plot + # Draw the scatter plot + points = ax.scatter(x=x, y=y, **kws) - data = self.plot_data[self.semantics].dropna() - if not data.size: - return + # Apply the mapping from semantic variables to artist attributes - x = data["x"] - y = data["y"] + if "hue" in self.variables: + points.set_facecolors(self._hue_map(data["hue"])) - if self.palette: - c = [self.palette.get(val) for val in data["hue"]] + if "size" in self.variables: + points.set_sizes(self._size_map(data["size"])) - if self.sizes: - s = [self.sizes.get(val) for val in data["size"]] - - args = np.asarray(x), np.asarray(y), np.asarray(s), np.asarray(c) - points = ax.scatter(*args, **kws) + if "style" in self.variables: + p = [self._style_map(val, "path") for val in data["style"]] + points.set_paths(p) - # Update the paths to get different marker shapes. This has to be - # done here because plt.scatter allows varying sizes and colors - # but only a single marker shape per call. + # Apply dependant default attributes - if self.paths: - p = [self.paths.get(val) for val in data["style"]] - points.set_paths(p) + if "linewidth" not in kws: + sizes = points.get_sizes() + points.set_linewidths(.08 * np.sqrt(np.percentile(sizes, 10))) # Finalize the axes details - self.label_axes(ax) + self._add_axis_labels(ax) if self.legend: self.add_legend_data(ax) handles, _ = ax.get_legend_handles_labels() if handles: - ax.legend() - - -_relational_docs = dict( - - # --- Introductory prose - main_api_narrative=dedent("""\ - The relationship between ``x`` and ``y`` can be shown for different subsets - of the data using the ``hue``, ``size``, and ``style`` parameters. These - parameters control what visual semantics are used to identify the different - subsets. It is possible to show up to three dimensions independently by - using all three semantic types, but this style of plot can be hard to - interpret and is often ineffective. Using redundant semantics (i.e. both - ``hue`` and ``style`` for the same variable) can be helpful for making - graphics more accessible. - - See the :ref:`tutorial ` for more information.\ - """), - - # --- Shared function parameters - data_vars=dedent("""\ - x, y : names of variables in ``data`` or vector data, optional - Input data variables; must be numeric. Can pass data directly or - reference columns in ``data``.\ - """), - data=dedent("""\ - data : DataFrame, array, or list of arrays, optional - Input data structure. If ``x`` and ``y`` are specified as names, this - should be a "long-form" DataFrame containing those columns. Otherwise - it is treated as "wide-form" data and grouping variables are ignored. - See the examples for the various ways this parameter can be specified - and the different effects of each.\ - """), - palette=dedent("""\ - palette : string, list, dict, or matplotlib colormap - An object that determines how colors are chosen when ``hue`` is used. - It can be the name of a seaborn palette or matplotlib colormap, a list - of colors (anything matplotlib understands), a dict mapping levels - of the ``hue`` variable to colors, or a matplotlib colormap object.\ - """), - hue_order=dedent("""\ - hue_order : list, optional - Specified order for the appearance of the ``hue`` variable levels, - otherwise they are determined from the data. Not relevant when the - ``hue`` variable is numeric.\ - """), - hue_norm=dedent("""\ - hue_norm : tuple or Normalize object, optional - Normalization in data units for colormap applied to the ``hue`` - variable when it is numeric. Not relevant if it is categorical.\ - """), - sizes=dedent("""\ - sizes : list, dict, or tuple, optional - An object that determines how sizes are chosen when ``size`` is used. - It can always be a list of size values or a dict mapping levels of the - ``size`` variable to sizes. When ``size`` is numeric, it can also be - a tuple specifying the minimum and maximum size to use such that other - values are normalized within this range.\ - """), - size_order=dedent("""\ - size_order : list, optional - Specified order for appearance of the ``size`` variable levels, - otherwise they are determined from the data. Not relevant when the - ``size`` variable is numeric.\ - """), - size_norm=dedent("""\ - size_norm : tuple or Normalize object, optional - Normalization in data units for scaling plot objects when the - ``size`` variable is numeric.\ - """), - markers=dedent("""\ - markers : boolean, list, or dictionary, optional - Object determining how to draw the markers for different levels of the - ``style`` variable. Setting to ``True`` will use default markers, or - you can pass a list of markers or a dictionary mapping levels of the - ``style`` variable to markers. Setting to ``False`` will draw - marker-less lines. Markers are specified as in matplotlib.\ - """), - style_order=dedent("""\ - style_order : list, optional - Specified order for appearance of the ``style`` variable levels - otherwise they are determined from the data. Not relevant when the - ``style`` variable is numeric.\ - """), - units=dedent("""\ - units : {long_form_var} - Grouping variable identifying sampling units. When used, a separate - line will be drawn for each unit with appropriate semantics, but no - legend entry will be added. Useful for showing distribution of - experimental replicates when exact identities are not needed. - """), - estimator=dedent("""\ - estimator : name of pandas method or callable or None, optional - Method for aggregating across multiple observations of the ``y`` - variable at the same ``x`` level. If ``None``, all observations will - be drawn.\ - """), - ci=dedent("""\ - ci : int or "sd" or None, optional - Size of the confidence interval to draw when aggregating with an - estimator. "sd" means to draw the standard deviation of the data. - Setting to ``None`` will skip bootstrapping.\ - """), - n_boot=dedent("""\ - n_boot : int, optional - Number of bootstraps to use for computing the confidence interval.\ - """), - legend=dedent("""\ - legend : "brief", "full", or False, optional - How to draw the legend. If "brief", numeric ``hue`` and ``size`` - variables will be represented with a sample of evenly spaced values. - If "full", every group will get an entry in the legend. If ``False``, - no legend data is added and no legend is drawn.\ - """), - ax_in=dedent("""\ - ax : matplotlib Axes, optional - Axes object to draw the plot onto, otherwise uses the current Axes.\ - """), - ax_out=dedent("""\ - ax : matplotlib Axes - Returns the Axes object with the plot drawn onto it.\ - """), - - # --- Repeated phrases - long_form_var="name of variables in ``data`` or vector data, optional", - - -) - -_relational_docs.update(_facet_docs) - - -def lineplot(x=None, y=None, hue=None, size=None, style=None, data=None, - palette=None, hue_order=None, hue_norm=None, - sizes=None, size_order=None, size_norm=None, - dashes=True, markers=None, style_order=None, - units=None, estimator="mean", ci=95, n_boot=1000, - sort=True, err_style="band", err_kws=None, - legend="brief", ax=None, **kwargs): - + legend = ax.legend(title=self.legend_title) + adjust_legend_subtitles(legend) + + +@_deprecate_positional_args +def lineplot( + *, + x=None, y=None, + hue=None, size=None, style=None, + data=None, + palette=None, hue_order=None, hue_norm=None, + sizes=None, size_order=None, size_norm=None, + dashes=True, markers=None, style_order=None, + units=None, estimator="mean", ci="deprecated", n_boot=1000, seed=None, + sort=True, err_style="band", err_kws=None, + legend="auto", + errorbar=("ci", 95), + ax=None, **kwargs +): + + # Handle deprecation of ci parameter + errorbar = _deprecate_ci(errorbar, ci) + + variables = _LinePlotter.get_semantics(locals()) p = _LinePlotter( - x=x, y=y, hue=hue, size=size, style=style, data=data, - palette=palette, hue_order=hue_order, hue_norm=hue_norm, - sizes=sizes, size_order=size_order, size_norm=size_norm, - dashes=dashes, markers=markers, style_order=style_order, - units=units, estimator=estimator, ci=ci, n_boot=n_boot, + data=data, variables=variables, + estimator=estimator, ci=ci, n_boot=n_boot, seed=seed, sort=sort, err_style=err_style, err_kws=err_kws, legend=legend, + errorbar=errorbar, ) + p.map_hue(palette=palette, order=hue_order, norm=hue_norm) + p.map_size(sizes=sizes, order=size_order, norm=size_norm) + p.map_style(markers=markers, dashes=dashes, order=style_order) + if ax is None: ax = plt.gca() - p.plot(ax, kwargs) - - return ax - - -lineplot.__doc__ = dedent("""\ - Draw a line plot with possibility of several semantic groupings. - - {main_api_narrative} - - By default, the plot aggregates over multiple ``y`` values at each value of - ``x`` and shows an estimate of the central tendency and a confidence - interval for that estimate. - - Parameters - ---------- - {data_vars} - hue : {long_form_var} - Grouping variable that will produce lines with different colors. - Can be either categorical or numeric, although color mapping will - behave differently in latter case. - size : {long_form_var} - Grouping variable that will produce lines with different widths. - Can be either categorical or numeric, although size mapping will - behave differently in latter case. - style : {long_form_var} - Grouping variable that will produce lines with different dashes - and/or markers. Can have a numeric dtype but will always be treated - as categorical. - {data} - {palette} - {hue_order} - {hue_norm} - {sizes} - {size_order} - {size_norm} - dashes : boolean, list, or dictionary, optional - Object determining how to draw the lines for different levels of the - ``style`` variable. Setting to ``True`` will use default dash codes, or - you can pass a list of dash codes or a dictionary mapping levels of the - ``style`` variable to dash codes. Setting to ``False`` will use solid - lines for all subsets. Dashes are specified as in matplotlib: a tuple - of ``(segment, gap)`` lengths, or an empty string to draw a solid line. - {markers} - {style_order} - {units} - {estimator} - {ci} - {n_boot} - sort : boolean, optional - If True, the data will be sorted by the x and y variables, otherwise - lines will connect points in the order they appear in the dataset. - err_style : "band" or "bars", optional - Whether to draw the confidence intervals with translucent error bands - or discrete error bars. - err_kws : dict of keyword arguments - Additional paramters to control the aesthetics of the error bars. The - kwargs are passed either to ``ax.fill_between`` or ``ax.errorbar``, - depending on the ``err_style``. - {legend} - {ax_in} - kwargs : key, value mappings - Other keyword arguments are passed down to ``plt.plot`` at draw time. - - Returns - ------- - {ax_out} - - See Also - -------- - scatterplot : Show the relationship between two variables without - emphasizing continuity of the ``x`` variable. - pointplot : Show the relationship between two variables when one is - categorical. - - Examples - -------- - - Draw a single line plot with error bands showing a confidence interval: - - .. plot:: - :context: close-figs - - >>> import seaborn as sns; sns.set() - >>> import matplotlib.pyplot as plt - >>> fmri = sns.load_dataset("fmri") - >>> ax = sns.lineplot(x="timepoint", y="signal", data=fmri) - - Group by another variable and show the groups with different colors: - - - .. plot:: - :context: close-figs - - >>> ax = sns.lineplot(x="timepoint", y="signal", hue="event", - ... data=fmri) - - Show the grouping variable with both color and line dashing: - - .. plot:: - :context: close-figs - - >>> ax = sns.lineplot(x="timepoint", y="signal", - ... hue="event", style="event", data=fmri) - - Use color and line dashing to represent two different grouping variables: - - .. plot:: - :context: close-figs - - >>> ax = sns.lineplot(x="timepoint", y="signal", - ... hue="region", style="event", data=fmri) - - Use markers instead of the dashes to identify groups: - - .. plot:: - :context: close-figs - - >>> ax = sns.lineplot(x="timepoint", y="signal", - ... hue="event", style="event", - ... markers=True, dashes=False, data=fmri) - - Show error bars instead of error bands and plot the standard error: - - .. plot:: - :context: close-figs - - >>> ax = sns.lineplot(x="timepoint", y="signal", hue="event", - ... err_style="bars", ci=68, data=fmri) - - Show experimental replicates instead of aggregating: - - .. plot:: - :context: close-figs - - >>> ax = sns.lineplot(x="timepoint", y="signal", hue="event", - ... units="subject", estimator=None, lw=1, - ... data=fmri.query("region == 'frontal'")) - - Use a quantitative color mapping: - - .. plot:: - :context: close-figs - - >>> dots = sns.load_dataset("dots").query("align == 'dots'") - >>> ax = sns.lineplot(x="time", y="firing_rate", - ... hue="coherence", style="choice", - ... data=dots) - - Use a different normalization for the colormap: - - .. plot:: - :context: close-figs - - >>> from matplotlib.colors import LogNorm - >>> ax = sns.lineplot(x="time", y="firing_rate", - ... hue="coherence", style="choice", - ... hue_norm=LogNorm(), data=dots) - - Use a different color palette: - - .. plot:: - :context: close-figs - - >>> ax = sns.lineplot(x="time", y="firing_rate", - ... hue="coherence", style="choice", - ... palette="ch:2.5,.25", data=dots) - - Use specific color values, treating the hue variable as categorical: + if style is None and not {"ls", "linestyle"} & set(kwargs): # XXX + kwargs["dashes"] = "" if dashes is None or isinstance(dashes, bool) else dashes - .. plot:: - :context: close-figs + if not p.has_xy_data: + return ax - >>> palette = sns.color_palette("mako_r", 6) - >>> ax = sns.lineplot(x="time", y="firing_rate", - ... hue="coherence", style="choice", - ... palette=palette, data=dots) + p._attach(ax) - Change the width of the lines with a quantitative variable: - - .. plot:: - :context: close-figs - - >>> ax = sns.lineplot(x="time", y="firing_rate", - ... size="coherence", hue="choice", - ... legend="full", data=dots) - - Change the range of line widths used to normalize the size variable: - - .. plot:: - :context: close-figs - - >>> ax = sns.lineplot(x="time", y="firing_rate", - ... size="coherence", hue="choice", - ... sizes=(.25, 2.5), data=dots) - - Plot from a wide-form DataFrame: - - .. plot:: - :context: close-figs - - >>> import numpy as np, pandas as pd; plt.close("all") - >>> index = pd.date_range("1 1 2000", periods=100, - ... freq="m", name="date") - >>> data = np.random.randn(100, 4).cumsum(axis=0) - >>> wide_df = pd.DataFrame(data, index, ["a", "b", "c", "d"]) - >>> ax = sns.lineplot(data=wide_df) - - Plot from a list of Series: - - .. plot:: - :context: close-figs - - >>> list_data = [wide_df.loc[:"2005", "a"], wide_df.loc["2003":, "b"]] - >>> ax = sns.lineplot(data=list_data) - - Plot a single Series, pass kwargs to ``plt.plot``: - - .. plot:: - :context: close-figs - - >>> ax = sns.lineplot(data=wide_df["a"], color="coral", label="line") - - Draw lines at points as they appear in the dataset: - - .. plot:: - :context: close-figs - - >>> x, y = np.random.randn(2, 5000).cumsum(axis=1) - >>> ax = sns.lineplot(x=x, y=y, sort=False, lw=1) + # Other functions have color as an explicit param, + # and we should probably do that here too + color = kwargs.pop("color", kwargs.pop("c", None)) + kwargs["color"] = _default_color(ax.plot, hue, color, kwargs) + p.plot(ax, kwargs) + return ax - """).format(**_relational_docs) +lineplot.__doc__ = """\ +Draw a line plot with possibility of several semantic groupings. + +{narrative.main_api} + +{narrative.relational_semantic} + +By default, the plot aggregates over multiple ``y`` values at each value of +``x`` and shows an estimate of the central tendency and a confidence +interval for that estimate. + +Parameters +---------- +{params.core.xy} +hue : vector or key in ``data`` + Grouping variable that will produce lines with different colors. + Can be either categorical or numeric, although color mapping will + behave differently in latter case. +size : vector or key in ``data`` + Grouping variable that will produce lines with different widths. + Can be either categorical or numeric, although size mapping will + behave differently in latter case. +style : vector or key in ``data`` + Grouping variable that will produce lines with different dashes + and/or markers. Can have a numeric dtype but will always be treated + as categorical. +{params.core.data} +{params.core.palette} +{params.core.hue_order} +{params.core.hue_norm} +{params.rel.sizes} +{params.rel.size_order} +{params.rel.size_norm} +{params.rel.dashes} +{params.rel.markers} +{params.rel.style_order} +{params.rel.units} +{params.rel.estimator} +{params.rel.ci} +{params.rel.n_boot} +{params.rel.seed} +sort : boolean + If True, the data will be sorted by the x and y variables, otherwise + lines will connect points in the order they appear in the dataset. +err_style : "band" or "bars" + Whether to draw the confidence intervals with translucent error bands + or discrete error bars. +err_kws : dict of keyword arguments + Additional paramters to control the aesthetics of the error bars. The + kwargs are passed either to :meth:`matplotlib.axes.Axes.fill_between` + or :meth:`matplotlib.axes.Axes.errorbar`, depending on ``err_style``. +{params.rel.legend} +{params.stat.errorbar} +{params.core.ax} +kwargs : key, value mappings + Other keyword arguments are passed down to + :meth:`matplotlib.axes.Axes.plot`. + +Returns +------- +{returns.ax} + +See Also +-------- +{seealso.scatterplot} +{seealso.pointplot} + +Examples +-------- + +.. include:: ../docstrings/lineplot.rst + +""".format( + narrative=_relational_narrative, + params=_param_docs, + returns=_core_docs["returns"], + seealso=_core_docs["seealso"], +) -def scatterplot(x=None, y=None, hue=None, style=None, size=None, data=None, - palette=None, hue_order=None, hue_norm=None, - sizes=None, size_order=None, size_norm=None, - markers=True, style_order=None, - x_bins=None, y_bins=None, - units=None, estimator=None, ci=95, n_boot=1000, - alpha="auto", x_jitter=None, y_jitter=None, - legend="brief", ax=None, **kwargs): +@_deprecate_positional_args +def scatterplot( + *, + x=None, y=None, + hue=None, style=None, size=None, data=None, + palette=None, hue_order=None, hue_norm=None, + sizes=None, size_order=None, size_norm=None, + markers=True, style_order=None, + x_bins=None, y_bins=None, + units=None, estimator=None, ci=95, n_boot=1000, + alpha=None, x_jitter=None, y_jitter=None, + legend="auto", ax=None, + **kwargs +): + + variables = _ScatterPlotter.get_semantics(locals()) p = _ScatterPlotter( - x=x, y=y, hue=hue, style=style, size=size, data=data, - palette=palette, hue_order=hue_order, hue_norm=hue_norm, - sizes=sizes, size_order=size_order, size_norm=size_norm, - markers=markers, style_order=style_order, + data=data, variables=variables, x_bins=x_bins, y_bins=y_bins, estimator=estimator, ci=ci, n_boot=n_boot, alpha=alpha, x_jitter=x_jitter, y_jitter=y_jitter, legend=legend, ) + p.map_hue(palette=palette, order=hue_order, norm=hue_norm) + p.map_size(sizes=sizes, order=size_order, norm=size_norm) + p.map_style(markers=markers, order=style_order) + if ax is None: ax = plt.gca() - p.plot(ax, kwargs) - - return ax - - -scatterplot.__doc__ = dedent("""\ - Draw a scatter plot with possibility of several semantic groupings. - - {main_api_narrative} - - Parameters - ---------- - {data_vars} - hue : {long_form_var} - Grouping variable that will produce points with different colors. - Can be either categorical or numeric, although color mapping will - behave differently in latter case. - size : {long_form_var} - Grouping variable that will produce points with different sizes. - Can be either categorical or numeric, although size mapping will - behave differently in latter case. - style : {long_form_var} - Grouping variable that will produce points with different markers. - Can have a numeric dtype but will always be treated as categorical. - {data} - {palette} - {hue_order} - {hue_norm} - {sizes} - {size_order} - {size_norm} - {markers} - {style_order} - {{x,y}}_bins : lists or arrays or functions - *Currently non-functional.* - {units} - *Currently non-functional.* - {estimator} - *Currently non-functional.* - {ci} - *Currently non-functional.* - {n_boot} - *Currently non-functional.* - alpha : float - Proportional opacity of the points. - {{x,y}}_jitter : booleans or floats - *Currently non-functional.* - {legend} - {ax_in} - kwargs : key, value mappings - Other keyword arguments are passed down to ``plt.scatter`` at draw - time. - - Returns - ------- - {ax_out} - - See Also - -------- - lineplot : Show the relationship between two variables connected with - lines to emphasize continuity. - swarmplot : Draw a scatter plot with one categorical variable, arranging - the points to show the distribution of values. - - Examples - -------- - - Draw a simple scatter plot between two variables: - - .. plot:: - :context: close-figs - - >>> import seaborn as sns; sns.set() - >>> import matplotlib.pyplot as plt - >>> tips = sns.load_dataset("tips") - >>> ax = sns.scatterplot(x="total_bill", y="tip", data=tips) - - Group by another variable and show the groups with different colors: - - .. plot:: - :context: close-figs - - >>> ax = sns.scatterplot(x="total_bill", y="tip", hue="time", - ... data=tips) - - Show the grouping variable by varying both color and marker: - - .. plot:: - :context: close-figs - - >>> ax = sns.scatterplot(x="total_bill", y="tip", - ... hue="time", style="time", data=tips) - - Vary colors and markers to show two different grouping variables: - - .. plot:: - :context: close-figs - - >>> ax = sns.scatterplot(x="total_bill", y="tip", - ... hue="day", style="time", data=tips) - - Show a quantitative variable by varying the size of the points: - - .. plot:: - :context: close-figs - - >>> ax = sns.scatterplot(x="total_bill", y="tip", size="size", - ... data=tips) - - Also show the quantitative variable by also using continuous colors: - - .. plot:: - :context: close-figs - - >>> ax = sns.scatterplot(x="total_bill", y="tip", - ... hue="size", size="size", - ... data=tips) + if not p.has_xy_data: + return ax - Use a different continuous color map: + p._attach(ax) - .. plot:: - :context: close-figs + # Other functions have color as an explicit param, + # and we should probably do that here too + color = kwargs.pop("color", None) + kwargs["color"] = _default_color(ax.scatter, hue, color, kwargs) - >>> cmap = sns.cubehelix_palette(dark=.3, light=.8, as_cmap=True) - >>> ax = sns.scatterplot(x="total_bill", y="tip", - ... hue="size", size="size", - ... palette=cmap, - ... data=tips) - - Change the minimum and maximum point size and show all sizes in legend: - - .. plot:: - :context: close-figs - - >>> cmap = sns.cubehelix_palette(dark=.3, light=.8, as_cmap=True) - >>> ax = sns.scatterplot(x="total_bill", y="tip", - ... hue="size", size="size", - ... sizes=(20, 200), palette=cmap, - ... legend="full", data=tips) - - Use a narrower range of color map intensities: - - .. plot:: - :context: close-figs - - >>> cmap = sns.cubehelix_palette(dark=.3, light=.8, as_cmap=True) - >>> ax = sns.scatterplot(x="total_bill", y="tip", - ... hue="size", size="size", - ... sizes=(20, 200), hue_norm=(0, 7), - ... legend="full", data=tips) - - Vary the size with a categorical variable, and use a different palette: - - .. plot:: - :context: close-figs - - >>> cmap = sns.cubehelix_palette(dark=.3, light=.8, as_cmap=True) - >>> ax = sns.scatterplot(x="total_bill", y="tip", - ... hue="day", size="smoker", - ... palette="Set2", - ... data=tips) - - Use a specific set of markers: - - .. plot:: - :context: close-figs - - >>> markers = {{"Lunch": "s", "Dinner": "X"}} - >>> ax = sns.scatterplot(x="total_bill", y="tip", style="time", - ... markers=markers, - ... data=tips) - - Control plot attributes using matplotlib parameters: - - .. plot:: - :context: close-figs - - >>> ax = sns.scatterplot(x="total_bill", y="tip", - ... s=100, color=".2", marker="+", - ... data=tips) - - Pass data vectors instead of names in a data frame: - - .. plot:: - :context: close-figs - - >>> iris = sns.load_dataset("iris") - >>> ax = sns.scatterplot(x=iris.sepal_length, y=iris.sepal_width, - ... hue=iris.species, style=iris.species) - - Pass a wide-form dataset and plot against its index: + p.plot(ax, kwargs) - .. plot:: - :context: close-figs + return ax - >>> import numpy as np, pandas as pd; plt.close("all") - >>> index = pd.date_range("1 1 2000", periods=100, - ... freq="m", name="date") - >>> data = np.random.randn(100, 4).cumsum(axis=0) - >>> wide_df = pd.DataFrame(data, index, ["a", "b", "c", "d"]) - >>> ax = sns.scatterplot(data=wide_df) - """).format(**_relational_docs) +scatterplot.__doc__ = """\ +Draw a scatter plot with possibility of several semantic groupings. + +{narrative.main_api} + +{narrative.relational_semantic} + +Parameters +---------- +{params.core.xy} +hue : vector or key in ``data`` + Grouping variable that will produce points with different colors. + Can be either categorical or numeric, although color mapping will + behave differently in latter case. +size : vector or key in ``data`` + Grouping variable that will produce points with different sizes. + Can be either categorical or numeric, although size mapping will + behave differently in latter case. +style : vector or key in ``data`` + Grouping variable that will produce points with different markers. + Can have a numeric dtype but will always be treated as categorical. +{params.core.data} +{params.core.palette} +{params.core.hue_order} +{params.core.hue_norm} +{params.rel.sizes} +{params.rel.size_order} +{params.rel.size_norm} +{params.rel.markers} +{params.rel.style_order} +{{x,y}}_bins : lists or arrays or functions + *Currently non-functional.* +{params.rel.units} + *Currently non-functional.* +{params.rel.estimator} + *Currently non-functional.* +{params.rel.ci} + *Currently non-functional.* +{params.rel.n_boot} + *Currently non-functional.* +alpha : float + Proportional opacity of the points. +{{x,y}}_jitter : booleans or floats + *Currently non-functional.* +{params.rel.legend} +{params.core.ax} +kwargs : key, value mappings + Other keyword arguments are passed down to + :meth:`matplotlib.axes.Axes.scatter`. + +Returns +------- +{returns.ax} + +See Also +-------- +{seealso.lineplot} +{seealso.stripplot} +{seealso.swarmplot} + +Examples +-------- + +.. include:: ../docstrings/scatterplot.rst + +""".format( + narrative=_relational_narrative, + params=_param_docs, + returns=_core_docs["returns"], + seealso=_core_docs["seealso"], +) -def relplot(x=None, y=None, hue=None, size=None, style=None, data=None, - row=None, col=None, col_wrap=None, row_order=None, col_order=None, - palette=None, hue_order=None, hue_norm=None, - sizes=None, size_order=None, size_norm=None, - markers=None, dashes=None, style_order=None, - legend="brief", kind="scatter", - height=5, aspect=1, facet_kws=None, **kwargs): +@_deprecate_positional_args +def relplot( + *, + x=None, y=None, + hue=None, size=None, style=None, data=None, + row=None, col=None, + col_wrap=None, row_order=None, col_order=None, + palette=None, hue_order=None, hue_norm=None, + sizes=None, size_order=None, size_norm=None, + markers=None, dashes=None, style_order=None, + legend="auto", kind="scatter", + height=5, aspect=1, facet_kws=None, + units=None, + **kwargs +): if kind == "scatter": @@ -1569,30 +879,60 @@ def relplot(x=None, y=None, hue=None, size=None, style=None, data=None, err = "Plot kind {} not recognized".format(kind) raise ValueError(err) - # Use the full dataset to establish how to draw the semantics + # Check for attempt to plot onto specific axes and warn + if "ax" in kwargs: + msg = ( + "relplot is a figure-level function and does not accept " + "the ax= paramter. You may wish to try {}".format(kind + "plot") + ) + warnings.warn(msg, UserWarning) + kwargs.pop("ax") + + # Use the full dataset to map the semantics p = plotter( - x=x, y=y, hue=hue, size=size, style=style, data=data, - palette=palette, hue_order=hue_order, hue_norm=hue_norm, - sizes=sizes, size_order=size_order, size_norm=size_norm, - markers=markers, dashes=dashes, style_order=style_order, + data=data, + variables=plotter.get_semantics(locals()), legend=legend, ) + p.map_hue(palette=palette, order=hue_order, norm=hue_norm) + p.map_size(sizes=sizes, order=size_order, norm=size_norm) + p.map_style(markers=markers, dashes=dashes, order=style_order) + + # Extract the semantic mappings + if "hue" in p.variables: + palette = p._hue_map.lookup_table + hue_order = p._hue_map.levels + hue_norm = p._hue_map.norm + else: + palette = hue_order = hue_norm = None - palette = p.palette if p.palette else None - hue_order = p.hue_levels if any(p.hue_levels) else None - hue_norm = p.hue_norm if p.hue_norm is not None else None + if "size" in p.variables: + sizes = p._size_map.lookup_table + size_order = p._size_map.levels + size_norm = p._size_map.norm - sizes = p.sizes if p.sizes else None - size_order = p.size_levels if any(p.size_levels) else None - size_norm = p.size_norm if p.size_norm is not None else None + if "style" in p.variables: + style_order = p._style_map.levels + if markers: + markers = {k: p._style_map(k, "marker") for k in style_order} + else: + markers = None + if dashes: + dashes = {k: p._style_map(k, "dashes") for k in style_order} + else: + dashes = None + else: + markers = dashes = style_order = None - markers = p.markers if p.markers else None - dashes = p.dashes if p.dashes else None - style_order = p.style_levels if any(p.style_levels) else None + # Now extract the data that would be used to draw a single plot + variables = p.variables + plot_data = p.plot_data + plot_semantics = p.semantics + # Define the common plotting parameters plot_kws = dict( - palette=palette, hue_order=hue_order, hue_norm=p.hue_norm, - sizes=sizes, size_order=size_order, size_norm=p.size_norm, + palette=palette, hue_order=hue_order, hue_norm=hue_norm, + sizes=sizes, size_order=size_order, size_norm=size_norm, markers=markers, dashes=dashes, style_order=style_order, legend=False, ) @@ -1600,162 +940,139 @@ def relplot(x=None, y=None, hue=None, size=None, style=None, data=None, if kind == "scatter": plot_kws.pop("dashes") + # Define the named variables for plotting on each facet + plot_variables = {key: key for key in p.variables} + plot_kws.update(plot_variables) + + # Add the grid semantics onto the plotter + grid_semantics = "row", "col" + p.semantics = plot_semantics + grid_semantics + p.assign_variables( + data=data, + variables=dict( + x=x, y=y, + hue=hue, size=size, style=style, units=units, + row=row, col=col, + ), + ) + + # Pass the row/col variables to FacetGrid with their original + # names so that the axes titles render correctly + grid_kws = {v: p.variables.get(v, None) for v in grid_semantics} + full_data = p.plot_data.rename(columns=grid_kws) + # Set up the FacetGrid object - facet_kws = {} if facet_kws is None else facet_kws + facet_kws = {} if facet_kws is None else facet_kws.copy() + facet_kws.update(grid_kws) g = FacetGrid( - data=data, row=row, col=col, col_wrap=col_wrap, - row_order=row_order, col_order=col_order, + data=full_data, + col_wrap=col_wrap, row_order=row_order, col_order=col_order, height=height, aspect=aspect, dropna=False, **facet_kws ) # Draw the plot - g.map_dataframe(func, x, y, - hue=hue, size=size, style=style, - **plot_kws) + g.map_dataframe(func, **plot_kws) + + # Label the axes + g.set_axis_labels( + variables.get("x", None), variables.get("y", None) + ) # Show the legend if legend: + # Replace the original plot data so the legend uses + # numeric data with the correct type + p.plot_data = plot_data p.add_legend_data(g.axes.flat[0]) if p.legend_data: g.add_legend(legend_data=p.legend_data, - label_order=p.legend_order) + label_order=p.legend_order, + title=p.legend_title, + adjust_subtitles=True) return g -relplot.__doc__ = dedent("""\ - Figure-level interface for drawing relational plots onto a FacetGrid. - - This function provides access to several different axes-level functions - that show the relationship between two variables with semantic mappings - of subsets. The ``kind`` parameter selects the underlying axes-level - function to use: - - - :func:`scatterplot` (with ``kind="scatter"``; the default) - - :func:`lineplot` (with ``kind="line"``) - - Extra keyword arguments are passed to the underlying function, so you - should refer to the documentation for each to see kind-specific options. - - {main_api_narrative} - - After plotting, the :class:`FacetGrid` with the plot is returned and can - be used directly to tweak supporting plot details or add other layers. - - Note that, unlike when using the underlying plotting functions directly, - data must be passed in a long-form DataFrame with variables specified by - passing strings to ``x``, ``y``, and other parameters. - - Parameters - ---------- - x, y : names of variables in ``data`` - Input data variables; must be numeric. - hue : name in ``data``, optional - Grouping variable that will produce elements with different colors. - Can be either categorical or numeric, although color mapping will - behave differently in latter case. - size : name in ``data``, optional - Grouping variable that will produce elements with different sizes. - Can be either categorical or numeric, although size mapping will - behave differently in latter case. - style : name in ``data``, optional - Grouping variable that will produce elements with different styles. - Can have a numeric dtype but will always be treated as categorical. - {data} - row, col : names of variables in ``data``, optional - Categorical variables that will determine the faceting of the grid. - {col_wrap} - row_order, col_order : lists of strings, optional - Order to organize the rows and/or columns of the grid in, otherwise the - orders are inferred from the data objects. - {palette} - {hue_order} - {hue_norm} - {sizes} - {size_order} - {size_norm} - {legend} - kind : string, optional - Kind of plot to draw, corresponding to a seaborn relational plot. - Options are {{``scatter`` and ``line``}}. - {height} - {aspect} - facet_kws : dict, optional - Dictionary of other keyword arguments to pass to :class:`FacetGrid`. - kwargs : key, value pairings - Other keyword arguments are passed through to the underlying plotting - function. - - Returns - ------- - g : :class:`FacetGrid` - Returns the :class:`FacetGrid` object with the plot on it for further - tweaking. - - Examples - -------- - - Draw a single facet to use the :class:`FacetGrid` legend placement: - - .. plot:: - :context: close-figs - - >>> import seaborn as sns - >>> sns.set(style="ticks") - >>> tips = sns.load_dataset("tips") - >>> g = sns.relplot(x="total_bill", y="tip", hue="day", data=tips) - - Facet on the columns with another variable: - - .. plot:: - :context: close-figs - - >>> g = sns.relplot(x="total_bill", y="tip", - ... hue="day", col="time", data=tips) - - Facet on the columns and rows: - - .. plot:: - :context: close-figs - - >>> g = sns.relplot(x="total_bill", y="tip", hue="day", - ... col="time", row="sex", data=tips) - - "Wrap" many column facets into multiple rows: - - .. plot:: - :context: close-figs - - >>> g = sns.relplot(x="total_bill", y="tip", hue="time", - ... col="day", col_wrap=2, data=tips) - - Use multiple semantic variables on each facet with specified attributes: - - .. plot:: - :context: close-figs - - >>> g = sns.relplot(x="total_bill", y="tip", hue="time", size="size", - ... palette=["b", "r"], sizes=(10, 100), - ... col="time", data=tips) - - Use a different kind of plot: - - .. plot:: - :context: close-figs - - >>> fmri = sns.load_dataset("fmri") - >>> g = sns.relplot(x="timepoint", y="signal", - ... hue="event", style="event", col="region", - ... kind="line", data=fmri) - - Change the size of each facet: - - .. plot:: - :context: close-figs - - >>> g = sns.relplot(x="timepoint", y="signal", - ... hue="event", style="event", col="region", - ... height=5, aspect=.7, kind="line", data=fmri) - - """).format(**_relational_docs) +relplot.__doc__ = """\ +Figure-level interface for drawing relational plots onto a FacetGrid. + +This function provides access to several different axes-level functions +that show the relationship between two variables with semantic mappings +of subsets. The ``kind`` parameter selects the underlying axes-level +function to use: + +- :func:`scatterplot` (with ``kind="scatter"``; the default) +- :func:`lineplot` (with ``kind="line"``) + +Extra keyword arguments are passed to the underlying function, so you +should refer to the documentation for each to see kind-specific options. + +{narrative.main_api} + +{narrative.relational_semantic} + +After plotting, the :class:`FacetGrid` with the plot is returned and can +be used directly to tweak supporting plot details or add other layers. + +Note that, unlike when using the underlying plotting functions directly, +data must be passed in a long-form DataFrame with variables specified by +passing strings to ``x``, ``y``, and other parameters. + +Parameters +---------- +{params.core.xy} +hue : vector or key in ``data`` + Grouping variable that will produce elements with different colors. + Can be either categorical or numeric, although color mapping will + behave differently in latter case. +size : vector or key in ``data`` + Grouping variable that will produce elements with different sizes. + Can be either categorical or numeric, although size mapping will + behave differently in latter case. +style : vector or key in ``data`` + Grouping variable that will produce elements with different styles. + Can have a numeric dtype but will always be treated as categorical. +{params.core.data} +{params.facets.rowcol} +{params.facets.col_wrap} +row_order, col_order : lists of strings + Order to organize the rows and/or columns of the grid in, otherwise the + orders are inferred from the data objects. +{params.core.palette} +{params.core.hue_order} +{params.core.hue_norm} +{params.rel.sizes} +{params.rel.size_order} +{params.rel.size_norm} +{params.rel.style_order} +{params.rel.dashes} +{params.rel.markers} +{params.rel.legend} +kind : string + Kind of plot to draw, corresponding to a seaborn relational plot. + Options are {{``scatter`` and ``line``}}. +{params.facets.height} +{params.facets.aspect} +facet_kws : dict + Dictionary of other keyword arguments to pass to :class:`FacetGrid`. +{params.rel.units} +kwargs : key, value pairings + Other keyword arguments are passed through to the underlying plotting + function. + +Returns +------- +{returns.facetgrid} + +Examples +-------- + +.. include:: ../docstrings/relplot.rst + +""".format( + narrative=_relational_narrative, + params=_param_docs, + returns=_core_docs["returns"], + seealso=_core_docs["seealso"], +) diff --git a/seaborn/tests/test_algorithms.py b/seaborn/tests/test_algorithms.py index 1beede992c..e1ae1ffb7b 100644 --- a/seaborn/tests/test_algorithms.py +++ b/seaborn/tests/test_algorithms.py @@ -1,8 +1,9 @@ import numpy as np -from ..external.six.moves import range +import numpy.random as npr -from numpy.testing import assert_array_equal import pytest +from numpy.testing import assert_array_equal +from distutils.version import LooseVersion from .. import algorithms as algo @@ -67,24 +68,15 @@ def test_bootstrap_axis(random): assert out_axis.shape, (n_boot, x.shape[1]) -def test_bootstrap_random_seed(random): +def test_bootstrap_seed(random): """Test that we can get reproducible resamples by seeding the RNG.""" data = np.random.randn(50) seed = 42 - boots1 = algo.bootstrap(data, random_seed=seed) - boots2 = algo.bootstrap(data, random_seed=seed) + boots1 = algo.bootstrap(data, seed=seed) + boots2 = algo.bootstrap(data, seed=seed) assert_array_equal(boots1, boots2) -def test_smooth_bootstrap(random): - """Test smooth bootstrap.""" - x = np.random.randn(15) - n_boot = 100 - out_smooth = algo.bootstrap(x, n_boot=n_boot, - smooth=True, func=np.median) - assert not np.median(out_smooth) in x - - def test_bootstrap_ols(random): """Test bootstrap of OLS model fit.""" def ols_fit(X, y): @@ -118,8 +110,8 @@ def test_bootstrap_units(random): data_rm = data + bwerr seed = 77 - boots_orig = algo.bootstrap(data_rm, random_seed=seed) - boots_rm = algo.bootstrap(data_rm, units=ids, random_seed=seed) + boots_orig = algo.bootstrap(data_rm, seed=seed) + boots_rm = algo.bootstrap(data_rm, units=ids, seed=seed) assert boots_rm.std() > boots_orig.std() @@ -133,13 +125,95 @@ def test_bootstrap_string_func(): """Test that named numpy methods are the same as the numpy function.""" x = np.random.randn(100) - res_a = algo.bootstrap(x, func="mean", random_seed=0) - res_b = algo.bootstrap(x, func=np.mean, random_seed=0) + res_a = algo.bootstrap(x, func="mean", seed=0) + res_b = algo.bootstrap(x, func=np.mean, seed=0) assert np.array_equal(res_a, res_b) - res_a = algo.bootstrap(x, func="std", random_seed=0) - res_b = algo.bootstrap(x, func=np.std, random_seed=0) + res_a = algo.bootstrap(x, func="std", seed=0) + res_b = algo.bootstrap(x, func=np.std, seed=0) assert np.array_equal(res_a, res_b) with pytest.raises(AttributeError): algo.bootstrap(x, func="not_a_method_name") + + +def test_bootstrap_reproducibility(random): + """Test that bootstrapping uses the internal random state.""" + data = np.random.randn(50) + boots1 = algo.bootstrap(data, seed=100) + boots2 = algo.bootstrap(data, seed=100) + assert_array_equal(boots1, boots2) + + with pytest.warns(UserWarning): + # Deprecatd, remove when removing random_seed + boots1 = algo.bootstrap(data, random_seed=100) + boots2 = algo.bootstrap(data, random_seed=100) + assert_array_equal(boots1, boots2) + + +@pytest.mark.skipif(LooseVersion(np.__version__) < "1.17", + reason="Tests new numpy random functionality") +def test_seed_new(): + + # Can't use pytest parametrize because tests will fail where the new + # Generator object and related function are not defined + + test_bank = [ + (None, None, npr.Generator, False), + (npr.RandomState(0), npr.RandomState(0), npr.RandomState, True), + (npr.RandomState(0), npr.RandomState(1), npr.RandomState, False), + (npr.default_rng(1), npr.default_rng(1), npr.Generator, True), + (npr.default_rng(1), npr.default_rng(2), npr.Generator, False), + (npr.SeedSequence(10), npr.SeedSequence(10), npr.Generator, True), + (npr.SeedSequence(10), npr.SeedSequence(20), npr.Generator, False), + (100, 100, npr.Generator, True), + (100, 200, npr.Generator, False), + ] + for seed1, seed2, rng_class, match in test_bank: + rng1 = algo._handle_random_seed(seed1) + rng2 = algo._handle_random_seed(seed2) + assert isinstance(rng1, rng_class) + assert isinstance(rng2, rng_class) + assert (rng1.uniform() == rng2.uniform()) == match + + +@pytest.mark.skipif(LooseVersion(np.__version__) >= "1.17", + reason="Tests old numpy random functionality") +@pytest.mark.parametrize("seed1, seed2, match", [ + (None, None, False), + (npr.RandomState(0), npr.RandomState(0), True), + (npr.RandomState(0), npr.RandomState(1), False), + (100, 100, True), + (100, 200, False), +]) +def test_seed_old(seed1, seed2, match): + rng1 = algo._handle_random_seed(seed1) + rng2 = algo._handle_random_seed(seed2) + assert isinstance(rng1, np.random.RandomState) + assert isinstance(rng2, np.random.RandomState) + assert (rng1.uniform() == rng2.uniform()) == match + + +@pytest.mark.skipif(LooseVersion(np.__version__) >= "1.17", + reason="Tests old numpy random functionality") +def test_bad_seed_old(): + + with pytest.raises(ValueError): + algo._handle_random_seed("not_a_random_seed") + + +def test_nanaware_func_auto(random): + + x = np.random.normal(size=10) + x[0] = np.nan + boots = algo.bootstrap(x, func="mean") + assert not np.isnan(boots).any() + + +def test_nanaware_func_warning(random): + + x = np.random.normal(size=10) + x[0] = np.nan + with pytest.warns(UserWarning, match="Data contain nans but"): + boots = algo.bootstrap(x, func="ptp") + assert np.isnan(boots).any() diff --git a/seaborn/tests/test_axisgrid.py b/seaborn/tests/test_axisgrid.py index 89601564eb..e17fc1581a 100644 --- a/seaborn/tests/test_axisgrid.py +++ b/seaborn/tests/test_axisgrid.py @@ -1,32 +1,32 @@ -import warnings - import numpy as np import pandas as pd -from scipy import stats import matplotlib as mpl import matplotlib.pyplot as plt import pytest -import nose.tools as nt import numpy.testing as npt +from numpy.testing import assert_array_equal try: import pandas.testing as tm except ImportError: import pandas.util.testing as tm -from distutils.version import LooseVersion - -from .. import axisgrid as ag +from .._core import categorical_order from .. import rcmod from ..palettes import color_palette -from ..distributions import kdeplot, _freedman_diaconis_bins +from ..relational import scatterplot +from ..distributions import histplot, kdeplot, distplot from ..categorical import pointplot -from ..utils import categorical_order +from .. import axisgrid as ag +from .._testing import ( + assert_plots_equal, + assert_colors_equal, +) rs = np.random.RandomState(0) -class TestFacetGrid(object): +class TestFacetGrid: df = pd.DataFrame(dict(x=rs.normal(size=60), y=rs.gamma(4, size=60), @@ -38,55 +38,54 @@ class TestFacetGrid(object): def test_self_data(self): g = ag.FacetGrid(self.df) - nt.assert_is(g.data, self.df) + assert g.data is self.df def test_self_fig(self): g = ag.FacetGrid(self.df) - nt.assert_is_instance(g.fig, plt.Figure) + assert isinstance(g.fig, plt.Figure) def test_self_axes(self): g = ag.FacetGrid(self.df, row="a", col="b", hue="c") for ax in g.axes.flat: - nt.assert_is_instance(ax, plt.Axes) + assert isinstance(ax, plt.Axes) def test_axes_array_size(self): - g1 = ag.FacetGrid(self.df) - nt.assert_equal(g1.axes.shape, (1, 1)) - - g2 = ag.FacetGrid(self.df, row="a") - nt.assert_equal(g2.axes.shape, (3, 1)) + g = ag.FacetGrid(self.df) + assert g.axes.shape == (1, 1) - g3 = ag.FacetGrid(self.df, col="b") - nt.assert_equal(g3.axes.shape, (1, 2)) + g = ag.FacetGrid(self.df, row="a") + assert g.axes.shape == (3, 1) - g4 = ag.FacetGrid(self.df, hue="c") - nt.assert_equal(g4.axes.shape, (1, 1)) + g = ag.FacetGrid(self.df, col="b") + assert g.axes.shape == (1, 2) - g5 = ag.FacetGrid(self.df, row="a", col="b", hue="c") - nt.assert_equal(g5.axes.shape, (3, 2)) + g = ag.FacetGrid(self.df, hue="c") + assert g.axes.shape == (1, 1) - for ax in g5.axes.flat: - nt.assert_is_instance(ax, plt.Axes) + g = ag.FacetGrid(self.df, row="a", col="b", hue="c") + assert g.axes.shape == (3, 2) + for ax in g.axes.flat: + assert isinstance(ax, plt.Axes) def test_single_axes(self): - g1 = ag.FacetGrid(self.df) - nt.assert_is_instance(g1.ax, plt.Axes) + g = ag.FacetGrid(self.df) + assert isinstance(g.ax, plt.Axes) - g2 = ag.FacetGrid(self.df, row="a") - with nt.assert_raises(AttributeError): - g2.ax + g = ag.FacetGrid(self.df, row="a") + with pytest.raises(AttributeError): + g.ax - g3 = ag.FacetGrid(self.df, col="a") - with nt.assert_raises(AttributeError): - g3.ax + g = ag.FacetGrid(self.df, col="a") + with pytest.raises(AttributeError): + g.ax - g4 = ag.FacetGrid(self.df, col="a", row="b") - with nt.assert_raises(AttributeError): - g4.ax + g = ag.FacetGrid(self.df, col="a", row="b") + with pytest.raises(AttributeError): + g.ax def test_col_wrap(self): @@ -160,6 +159,33 @@ def test_wrapped_axes(self): npt.assert_array_equal(g._not_left_axes, g.axes[np.array([1])].flat) npt.assert_array_equal(g._inner_axes, null) + def test_axes_dict(self): + + g = ag.FacetGrid(self.df) + assert isinstance(g.axes_dict, dict) + assert not g.axes_dict + + g = ag.FacetGrid(self.df, row="c") + assert list(g.axes_dict.keys()) == g.row_names + for (name, ax) in zip(g.row_names, g.axes.flat): + assert g.axes_dict[name] is ax + + g = ag.FacetGrid(self.df, col="c") + assert list(g.axes_dict.keys()) == g.col_names + for (name, ax) in zip(g.col_names, g.axes.flat): + assert g.axes_dict[name] is ax + + g = ag.FacetGrid(self.df, col="a", col_wrap=2) + assert list(g.axes_dict.keys()) == g.col_names + for (name, ax) in zip(g.col_names, g.axes.flat): + assert g.axes_dict[name] is ax + + g = ag.FacetGrid(self.df, row="a", col="c") + for (row_var, col_var), ax in g.axes_dict.items(): + i = g.row_names.index(row_var) + j = g.col_names.index(col_var) + assert g.axes[i, j] is ax + def test_figure_size(self): g = ag.FacetGrid(self.df, row="a", col="b") @@ -173,94 +199,112 @@ def test_figure_size(self): def test_figure_size_with_legend(self): - g1 = ag.FacetGrid(self.df, col="a", hue="c", height=4, aspect=.5) - npt.assert_array_equal(g1.fig.get_size_inches(), (6, 4)) - g1.add_legend() - nt.assert_greater(g1.fig.get_size_inches()[0], 6) + g = ag.FacetGrid(self.df, col="a", hue="c", height=4, aspect=.5) + npt.assert_array_equal(g.fig.get_size_inches(), (6, 4)) + g.add_legend() + assert g.fig.get_size_inches()[0] > 6 - g2 = ag.FacetGrid(self.df, col="a", hue="c", height=4, aspect=.5, - legend_out=False) - npt.assert_array_equal(g2.fig.get_size_inches(), (6, 4)) - g2.add_legend() - npt.assert_array_equal(g2.fig.get_size_inches(), (6, 4)) + g = ag.FacetGrid(self.df, col="a", hue="c", height=4, aspect=.5, + legend_out=False) + npt.assert_array_equal(g.fig.get_size_inches(), (6, 4)) + g.add_legend() + npt.assert_array_equal(g.fig.get_size_inches(), (6, 4)) def test_legend_data(self): - g1 = ag.FacetGrid(self.df, hue="a") - g1.map(plt.plot, "x", "y") - g1.add_legend() + g = ag.FacetGrid(self.df, hue="a") + g.map(plt.plot, "x", "y") + g.add_legend() palette = color_palette(n_colors=3) - nt.assert_equal(g1._legend.get_title().get_text(), "a") + assert g._legend.get_title().get_text() == "a" a_levels = sorted(self.df.a.unique()) - lines = g1._legend.get_lines() - nt.assert_equal(len(lines), len(a_levels)) + lines = g._legend.get_lines() + assert len(lines) == len(a_levels) for line, hue in zip(lines, palette): - nt.assert_equal(line.get_color(), hue) + assert_colors_equal(line.get_color(), hue) - labels = g1._legend.get_texts() - nt.assert_equal(len(labels), len(a_levels)) + labels = g._legend.get_texts() + assert len(labels) == len(a_levels) for label, level in zip(labels, a_levels): - nt.assert_equal(label.get_text(), level) + assert label.get_text() == level def test_legend_data_missing_level(self): - g1 = ag.FacetGrid(self.df, hue="a", hue_order=list("azbc")) - g1.map(plt.plot, "x", "y") - g1.add_legend() + g = ag.FacetGrid(self.df, hue="a", hue_order=list("azbc")) + g.map(plt.plot, "x", "y") + g.add_legend() - b, g, r, p = color_palette(n_colors=4) - palette = [b, r, p] + c1, c2, c3, c4 = color_palette(n_colors=4) + palette = [c1, c3, c4] - nt.assert_equal(g1._legend.get_title().get_text(), "a") + assert g._legend.get_title().get_text() == "a" a_levels = sorted(self.df.a.unique()) - lines = g1._legend.get_lines() - nt.assert_equal(len(lines), len(a_levels)) + lines = g._legend.get_lines() + assert len(lines) == len(a_levels) for line, hue in zip(lines, palette): - nt.assert_equal(line.get_color(), hue) + assert_colors_equal(line.get_color(), hue) - labels = g1._legend.get_texts() - nt.assert_equal(len(labels), 4) + labels = g._legend.get_texts() + assert len(labels) == 4 for label, level in zip(labels, list("azbc")): - nt.assert_equal(label.get_text(), level) + assert label.get_text() == level def test_get_boolean_legend_data(self): self.df["b_bool"] = self.df.b == "m" - g1 = ag.FacetGrid(self.df, hue="b_bool") - g1.map(plt.plot, "x", "y") - g1.add_legend() + g = ag.FacetGrid(self.df, hue="b_bool") + g.map(plt.plot, "x", "y") + g.add_legend() palette = color_palette(n_colors=2) - nt.assert_equal(g1._legend.get_title().get_text(), "b_bool") + assert g._legend.get_title().get_text() == "b_bool" b_levels = list(map(str, categorical_order(self.df.b_bool))) - lines = g1._legend.get_lines() - nt.assert_equal(len(lines), len(b_levels)) + lines = g._legend.get_lines() + assert len(lines) == len(b_levels) for line, hue in zip(lines, palette): - nt.assert_equal(line.get_color(), hue) + assert_colors_equal(line.get_color(), hue) - labels = g1._legend.get_texts() - nt.assert_equal(len(labels), len(b_levels)) + labels = g._legend.get_texts() + assert len(labels) == len(b_levels) for label, level in zip(labels, b_levels): - nt.assert_equal(label.get_text(), level) + assert label.get_text() == level + + def test_legend_tuples(self): + + g = ag.FacetGrid(self.df, hue="a") + g.map(plt.plot, "x", "y") + + handles, labels = g.ax.get_legend_handles_labels() + label_tuples = [("", l) for l in labels] + legend_data = dict(zip(label_tuples, handles)) + g.add_legend(legend_data, label_tuples) + for entry, label in zip(g._legend.get_texts(), labels): + assert entry.get_text() == label def test_legend_options(self): - g1 = ag.FacetGrid(self.df, hue="b") - g1.map(plt.plot, "x", "y") - g1.add_legend() + g = ag.FacetGrid(self.df, hue="b") + g.map(plt.plot, "x", "y") + g.add_legend() + + g1 = ag.FacetGrid(self.df, hue="b", legend_out=False) + g1.add_legend(adjust_subtitles=True) + + g1 = ag.FacetGrid(self.df, hue="b", legend_out=False) + g1.add_legend(adjust_subtitles=False) def test_legendout_with_colwrap(self): @@ -269,12 +313,24 @@ def test_legendout_with_colwrap(self): g.map(plt.plot, "x", "y", linewidth=3) g.add_legend() + def test_legend_tight_layout(self): + + g = ag.FacetGrid(self.df, hue='b') + g.map(plt.plot, "x", "y", linewidth=3) + g.add_legend() + g.tight_layout() + + axes_right_edge = g.ax.get_window_extent().xmax + legend_left_edge = g._legend.get_window_extent().xmin + + assert axes_right_edge < legend_left_edge + def test_subplot_kws(self): g = ag.FacetGrid(self.df, despine=False, subplot_kws=dict(projection="polar")) for ax in g.axes.flat: - nt.assert_true("PolarAxesSubplot" in str(type(ax))) + assert "PolarAxesSubplot" in str(type(ax)) def test_gridspec_kws(self): ratios = [3, 1, 2] @@ -296,51 +352,48 @@ def test_gridspec_kws_col_wrap(self): ratios = [3, 1, 2, 1, 1] gskws = dict(width_ratios=ratios) - with warnings.catch_warnings(): - warnings.resetwarnings() - warnings.simplefilter("always") - npt.assert_warns(UserWarning, ag.FacetGrid, self.df, col='d', - col_wrap=5, gridspec_kws=gskws) + with pytest.warns(UserWarning): + ag.FacetGrid(self.df, col='d', col_wrap=5, gridspec_kws=gskws) def test_data_generator(self): g = ag.FacetGrid(self.df, row="a") d = list(g.facet_data()) - nt.assert_equal(len(d), 3) + assert len(d) == 3 tup, data = d[0] - nt.assert_equal(tup, (0, 0, 0)) - nt.assert_true((data["a"] == "a").all()) + assert tup == (0, 0, 0) + assert (data["a"] == "a").all() tup, data = d[1] - nt.assert_equal(tup, (1, 0, 0)) - nt.assert_true((data["a"] == "b").all()) + assert tup == (1, 0, 0) + assert (data["a"] == "b").all() g = ag.FacetGrid(self.df, row="a", col="b") d = list(g.facet_data()) - nt.assert_equal(len(d), 6) + assert len(d) == 6 tup, data = d[0] - nt.assert_equal(tup, (0, 0, 0)) - nt.assert_true((data["a"] == "a").all()) - nt.assert_true((data["b"] == "m").all()) + assert tup == (0, 0, 0) + assert (data["a"] == "a").all() + assert (data["b"] == "m").all() tup, data = d[1] - nt.assert_equal(tup, (0, 1, 0)) - nt.assert_true((data["a"] == "a").all()) - nt.assert_true((data["b"] == "n").all()) + assert tup == (0, 1, 0) + assert (data["a"] == "a").all() + assert (data["b"] == "n").all() tup, data = d[2] - nt.assert_equal(tup, (1, 0, 0)) - nt.assert_true((data["a"] == "b").all()) - nt.assert_true((data["b"] == "m").all()) + assert tup == (1, 0, 0) + assert (data["a"] == "b").all() + assert (data["b"] == "m").all() g = ag.FacetGrid(self.df, hue="c") d = list(g.facet_data()) - nt.assert_equal(len(d), 3) + assert len(d) == 3 tup, data = d[1] - nt.assert_equal(tup, (0, 0, 1)) - nt.assert_true((data["c"] == "u").all()) + assert tup == (0, 0, 1) + assert (data["c"] == "u").all() def test_map(self): @@ -348,10 +401,10 @@ def test_map(self): g.map(plt.plot, "x", "y", linewidth=3) lines = g.axes[0, 0].lines - nt.assert_equal(len(lines), 3) + assert len(lines) == 3 line1, _, _ = lines - nt.assert_equal(line1.get_linewidth(), 3) + assert line1.get_linewidth() == 3 x, y = line1.get_data() mask = (self.df.a == "a") & (self.df.b == "m") & (self.df.c == "t") npt.assert_array_equal(x, self.df.x[mask]) @@ -363,14 +416,16 @@ def test_map_dataframe(self): def plot(x, y, data=None, **kws): plt.plot(data[x], data[y], **kws) + # Modify __module__ so this doesn't look like a seaborn function + plot.__module__ = "test" g.map_dataframe(plot, "x", "y", linestyle="--") lines = g.axes[0, 0].lines - nt.assert_equal(len(lines), 3) + assert len(g.axes[0, 0].lines) == 3 line1, _, _ = lines - nt.assert_equal(line1.get_linestyle(), "--") + assert line1.get_linestyle() == "--" x, y = line1.get_data() mask = (self.df.a == "a") & (self.df.b == "m") & (self.df.c == "t") npt.assert_array_equal(x, self.df.x[mask]) @@ -396,23 +451,23 @@ def test_set_titles(self): g.map(plt.plot, "x", "y") # Test the default titles - nt.assert_equal(g.axes[0, 0].get_title(), "a = a | b = m") - nt.assert_equal(g.axes[0, 1].get_title(), "a = a | b = n") - nt.assert_equal(g.axes[1, 0].get_title(), "a = b | b = m") + assert g.axes[0, 0].get_title() == "a = a | b = m" + assert g.axes[0, 1].get_title() == "a = a | b = n" + assert g.axes[1, 0].get_title() == "a = b | b = m" # Test a provided title g.set_titles("{row_var} == {row_name} \\/ {col_var} == {col_name}") - nt.assert_equal(g.axes[0, 0].get_title(), "a == a \\/ b == m") - nt.assert_equal(g.axes[0, 1].get_title(), "a == a \\/ b == n") - nt.assert_equal(g.axes[1, 0].get_title(), "a == b \\/ b == m") + assert g.axes[0, 0].get_title() == "a == a \\/ b == m" + assert g.axes[0, 1].get_title() == "a == a \\/ b == n" + assert g.axes[1, 0].get_title() == "a == b \\/ b == m" # Test a single row - g = ag.FacetGrid(self.df, col="b") + g = ag.FacetGrid(self.df, col="b") g.map(plt.plot, "x", "y") # Test the default titles - nt.assert_equal(g.axes[0, 0].get_title(), "b = m") - nt.assert_equal(g.axes[0, 1].get_title(), "b = n") + assert g.axes[0, 0].get_title() == "b = m" + assert g.axes[0, 1].get_title() == "b = n" # test with dropna=False g = ag.FacetGrid(self.df, col="b", hue="b", dropna=False) @@ -424,37 +479,43 @@ def test_set_titles_margin_titles(self): g.map(plt.plot, "x", "y") # Test the default titles - nt.assert_equal(g.axes[0, 0].get_title(), "b = m") - nt.assert_equal(g.axes[0, 1].get_title(), "b = n") - nt.assert_equal(g.axes[1, 0].get_title(), "") + assert g.axes[0, 0].get_title() == "b = m" + assert g.axes[0, 1].get_title() == "b = n" + assert g.axes[1, 0].get_title() == "" # Test the row "titles" - nt.assert_equal(g.axes[0, 1].texts[0].get_text(), "a = a") - nt.assert_equal(g.axes[1, 1].texts[0].get_text(), "a = b") + assert g.axes[0, 1].texts[0].get_text() == "a = a" + assert g.axes[1, 1].texts[0].get_text() == "a = b" + assert g.axes[0, 1].texts[0] is g._margin_titles_texts[0] - # Test a provided title - g.set_titles(col_template="{col_var} == {col_name}") - nt.assert_equal(g.axes[0, 0].get_title(), "b == m") - nt.assert_equal(g.axes[0, 1].get_title(), "b == n") - nt.assert_equal(g.axes[1, 0].get_title(), "") + # Test provided titles + g.set_titles(col_template="{col_name}", row_template="{row_name}") + assert g.axes[0, 0].get_title() == "m" + assert g.axes[0, 1].get_title() == "n" + assert g.axes[1, 0].get_title() == "" + + assert len(g.axes[1, 1].texts) == 1 + assert g.axes[1, 1].texts[0].get_text() == "b" def test_set_ticklabels(self): g = ag.FacetGrid(self.df, row="a", col="b") g.map(plt.plot, "x", "y") - xlab = [l.get_text() + "h" for l in g.axes[1, 0].get_xticklabels()] - ylab = [l.get_text() for l in g.axes[1, 0].get_yticklabels()] + + ax = g.axes[-1, 0] + xlab = [l.get_text() + "h" for l in ax.get_xticklabels()] + ylab = [l.get_text() + "i" for l in ax.get_yticklabels()] g.set_xticklabels(xlab) g.set_yticklabels(ylab) - got_x = [l.get_text() for l in g.axes[1, 1].get_xticklabels()] + got_x = [l.get_text() for l in g.axes[-1, 1].get_xticklabels()] got_y = [l.get_text() for l in g.axes[0, 0].get_yticklabels()] npt.assert_array_equal(got_x, xlab) npt.assert_array_equal(got_y, ylab) x, y = np.arange(10), np.arange(10) df = pd.DataFrame(np.c_[x, y], columns=["x", "y"]) - g = ag.FacetGrid(df).map(pointplot, "x", "y", order=x) + g = ag.FacetGrid(df).map_dataframe(pointplot, x="x", y="y", order=x) g.set_xticklabels(step=2) got_x = [int(l.get_text()) for l in g.axes[0, 0].get_xticklabels()] npt.assert_array_equal(x[::2], got_x) @@ -465,10 +526,10 @@ def test_set_ticklabels(self): g.set_yticklabels(rotation=75) for ax in g._bottom_axes: for l in ax.get_xticklabels(): - nt.assert_equal(l.get_rotation(), 45) + assert l.get_rotation() == 45 for ax in g._left_axes: for l in ax.get_yticklabels(): - nt.assert_equal(l.get_rotation(), 75) + assert l.get_rotation() == 75 def test_set_axis_labels(self): @@ -484,40 +545,49 @@ def test_set_axis_labels(self): npt.assert_array_equal(got_x, xlab) npt.assert_array_equal(got_y, ylab) + for ax in g.axes.flat: + ax.set(xlabel="x", ylabel="y") + + g.set_axis_labels(xlab, ylab) + for ax in g._not_bottom_axes: + assert not ax.get_xlabel() + for ax in g._not_left_axes: + assert not ax.get_ylabel() + def test_axis_lims(self): g = ag.FacetGrid(self.df, row="a", col="b", xlim=(0, 4), ylim=(-2, 3)) - nt.assert_equal(g.axes[0, 0].get_xlim(), (0, 4)) - nt.assert_equal(g.axes[0, 0].get_ylim(), (-2, 3)) + assert g.axes[0, 0].get_xlim() == (0, 4) + assert g.axes[0, 0].get_ylim() == (-2, 3) def test_data_orders(self): g = ag.FacetGrid(self.df, row="a", col="b", hue="c") - nt.assert_equal(g.row_names, list("abc")) - nt.assert_equal(g.col_names, list("mn")) - nt.assert_equal(g.hue_names, list("tuv")) - nt.assert_equal(g.axes.shape, (3, 2)) + assert g.row_names == list("abc") + assert g.col_names == list("mn") + assert g.hue_names == list("tuv") + assert g.axes.shape == (3, 2) g = ag.FacetGrid(self.df, row="a", col="b", hue="c", row_order=list("bca"), col_order=list("nm"), hue_order=list("vtu")) - nt.assert_equal(g.row_names, list("bca")) - nt.assert_equal(g.col_names, list("nm")) - nt.assert_equal(g.hue_names, list("vtu")) - nt.assert_equal(g.axes.shape, (3, 2)) + assert g.row_names == list("bca") + assert g.col_names == list("nm") + assert g.hue_names == list("vtu") + assert g.axes.shape == (3, 2) g = ag.FacetGrid(self.df, row="a", col="b", hue="c", row_order=list("bcda"), col_order=list("nom"), hue_order=list("qvtu")) - nt.assert_equal(g.row_names, list("bcda")) - nt.assert_equal(g.col_names, list("nom")) - nt.assert_equal(g.hue_names, list("qvtu")) - nt.assert_equal(g.axes.shape, (4, 3)) + assert g.row_names == list("bcda") + assert g.col_names == list("nom") + assert g.hue_names == list("qvtu") + assert g.axes.shape == (4, 3) def test_palette(self): @@ -549,158 +619,19 @@ def test_hue_kws(self): g.map(plt.plot, "x", "y") for line, marker in zip(g.axes[0, 0].lines, kws["marker"]): - nt.assert_equal(line.get_marker(), marker) + assert line.get_marker() == marker def test_dropna(self): df = self.df.copy() - hasna = pd.Series(np.tile(np.arange(6), 10), dtype=np.float) + hasna = pd.Series(np.tile(np.arange(6), 10), dtype=float) hasna[hasna == 5] = np.nan df["hasna"] = hasna g = ag.FacetGrid(df, dropna=False, row="hasna") - nt.assert_equal(g._not_na.sum(), 60) + assert g._not_na.sum() == 60 g = ag.FacetGrid(df, dropna=True, row="hasna") - nt.assert_equal(g._not_na.sum(), 50) - - def test_unicode_column_label_with_rows(self): - - # use a smaller copy of the default testing data frame: - df = self.df.copy() - df = df[["a", "b", "x"]] - - # rename column 'a' (which will be used for the columns in the grid) - # by using a Unicode string: - unicode_column_label = u"\u01ff\u02ff\u03ff" - df = df.rename(columns={"a": unicode_column_label}) - - # ensure that the data frame columns have the expected names: - nt.assert_equal(list(df.columns), [unicode_column_label, "b", "x"]) - - # plot the grid -- if successful, no UnicodeEncodingError should - # occur: - g = ag.FacetGrid(df, col=unicode_column_label, row="b") - g = g.map(plt.plot, "x") - - def test_unicode_column_label_no_rows(self): - - # use a smaller copy of the default testing data frame: - df = self.df.copy() - df = df[["a", "x"]] - - # rename column 'a' (which will be used for the columns in the grid) - # by using a Unicode string: - unicode_column_label = u"\u01ff\u02ff\u03ff" - df = df.rename(columns={"a": unicode_column_label}) - - # ensure that the data frame columns have the expected names: - nt.assert_equal(list(df.columns), [unicode_column_label, "x"]) - - # plot the grid -- if successful, no UnicodeEncodingError should - # occur: - g = ag.FacetGrid(df, col=unicode_column_label) - g = g.map(plt.plot, "x") - - def test_unicode_row_label_with_columns(self): - - # use a smaller copy of the default testing data frame: - df = self.df.copy() - df = df[["a", "b", "x"]] - - # rename column 'b' (which will be used for the rows in the grid) - # by using a Unicode string: - unicode_row_label = u"\u01ff\u02ff\u03ff" - df = df.rename(columns={"b": unicode_row_label}) - - # ensure that the data frame columns have the expected names: - nt.assert_equal(list(df.columns), ["a", unicode_row_label, "x"]) - - # plot the grid -- if successful, no UnicodeEncodingError should - # occur: - g = ag.FacetGrid(df, col="a", row=unicode_row_label) - g = g.map(plt.plot, "x") - - def test_unicode_row_label_no_columns(self): - - # use a smaller copy of the default testing data frame: - df = self.df.copy() - df = df[["b", "x"]] - - # rename column 'b' (which will be used for the rows in the grid) - # by using a Unicode string: - unicode_row_label = u"\u01ff\u02ff\u03ff" - df = df.rename(columns={"b": unicode_row_label}) - - # ensure that the data frame columns have the expected names: - nt.assert_equal(list(df.columns), [unicode_row_label, "x"]) - - # plot the grid -- if successful, no UnicodeEncodingError should - # occur: - g = ag.FacetGrid(df, row=unicode_row_label) - g = g.map(plt.plot, "x") - - @pytest.mark.skipif(pd.__version__.startswith("0.24"), - reason="known bug in pandas") - def test_unicode_content_with_row_and_column(self): - - df = self.df.copy() - - # replace content of column 'a' (which will form the columns in the - # grid) by Unicode characters: - unicode_column_val = np.repeat((u'\u01ff', u'\u02ff', u'\u03ff'), 20) - df["a"] = unicode_column_val - - # make sure that the replacement worked as expected: - nt.assert_equal( - list(df["a"]), - [u'\u01ff'] * 20 + [u'\u02ff'] * 20 + [u'\u03ff'] * 20) - - # plot the grid -- if successful, no UnicodeEncodingError should - # occur: - g = ag.FacetGrid(df, col="a", row="b") - g = g.map(plt.plot, "x") - - @pytest.mark.skipif(pd.__version__.startswith("0.24"), - reason="known bug in pandas") - def test_unicode_content_no_rows(self): - - df = self.df.copy() - - # replace content of column 'a' (which will form the columns in the - # grid) by Unicode characters: - unicode_column_val = np.repeat((u'\u01ff', u'\u02ff', u'\u03ff'), 20) - df["a"] = unicode_column_val - - # make sure that the replacement worked as expected: - nt.assert_equal( - list(df["a"]), - [u'\u01ff'] * 20 + [u'\u02ff'] * 20 + [u'\u03ff'] * 20) - - # plot the grid -- if successful, no UnicodeEncodingError should - # occur: - g = ag.FacetGrid(df, col="a") - g = g.map(plt.plot, "x") - - @pytest.mark.skipif(pd.__version__.startswith("0.24"), - reason="known bug in pandas") - def test_unicode_content_no_columns(self): - - df = self.df.copy() - - # replace content of column 'a' (which will form the rows in the - # grid) by Unicode characters: - unicode_column_val = np.repeat((u'\u01ff', u'\u02ff', u'\u03ff'), 20) - df["b"] = unicode_column_val - - # make sure that the replacement worked as expected: - nt.assert_equal( - list(df["b"]), - [u'\u01ff'] * 20 + [u'\u02ff'] * 20 + [u'\u03ff'] * 20) - - # plot the grid -- if successful, no UnicodeEncodingError should - # occur: - g = ag.FacetGrid(df, row="b") - g = g.map(plt.plot, "x") + assert g._not_na.sum() == 50 def test_categorical_column_missing_categories(self): @@ -709,18 +640,16 @@ def test_categorical_column_missing_categories(self): g = ag.FacetGrid(df[df['a'] == 'a'], col="a", col_wrap=1) - nt.assert_equal(g.axes.shape, (len(df['a'].cat.categories),)) + assert g.axes.shape == (len(df['a'].cat.categories),) def test_categorical_warning(self): g = ag.FacetGrid(self.df, col="b") - with warnings.catch_warnings(): - warnings.resetwarnings() - warnings.simplefilter("always") - npt.assert_warns(UserWarning, g.map, pointplot, "b", "x") + with pytest.warns(UserWarning): + g.map(pointplot, "b", "x") -class TestPairGrid(object): +class TestPairGrid: rs = np.random.RandomState(sum(map(ord, "PairGrid"))) df = pd.DataFrame(dict(x=rs.normal(size=60), @@ -732,7 +661,7 @@ class TestPairGrid(object): def test_self_data(self): g = ag.PairGrid(self.df) - nt.assert_is(g.data, self.df) + assert g.data is self.df def test_ignore_datelike_data(self): @@ -745,30 +674,30 @@ def test_ignore_datelike_data(self): def test_self_fig(self): g = ag.PairGrid(self.df) - nt.assert_is_instance(g.fig, plt.Figure) + assert isinstance(g.fig, plt.Figure) def test_self_axes(self): g = ag.PairGrid(self.df) for ax in g.axes.flat: - nt.assert_is_instance(ax, plt.Axes) + assert isinstance(ax, plt.Axes) def test_default_axes(self): g = ag.PairGrid(self.df) - nt.assert_equal(g.axes.shape, (3, 3)) - nt.assert_equal(g.x_vars, ["x", "y", "z"]) - nt.assert_equal(g.y_vars, ["x", "y", "z"]) - nt.assert_true(g.square_grid) + assert g.axes.shape == (3, 3) + assert g.x_vars == ["x", "y", "z"] + assert g.y_vars == ["x", "y", "z"] + assert g.square_grid - def test_specific_square_axes(self): + @pytest.mark.parametrize("vars", [["z", "x"], np.array(["z", "x"])]) + def test_specific_square_axes(self, vars): - vars = ["z", "x"] g = ag.PairGrid(self.df, vars=vars) - nt.assert_equal(g.axes.shape, (len(vars), len(vars))) - nt.assert_equal(g.x_vars, vars) - nt.assert_equal(g.y_vars, vars) - nt.assert_true(g.square_grid) + assert g.axes.shape == (len(vars), len(vars)) + assert g.x_vars == list(vars) + assert g.y_vars == list(vars) + assert g.square_grid def test_remove_hue_from_default(self): @@ -782,58 +711,40 @@ def test_remove_hue_from_default(self): assert hue in g.x_vars assert hue in g.y_vars - def test_specific_nonsquare_axes(self): + @pytest.mark.parametrize( + "x_vars, y_vars", + [ + (["x", "y"], ["z", "y", "x"]), + (["x", "y"], "z"), + (np.array(["x", "y"]), np.array(["z", "y", "x"])), + ], + ) + def test_specific_nonsquare_axes(self, x_vars, y_vars): - x_vars = ["x", "y"] - y_vars = ["z", "y", "x"] g = ag.PairGrid(self.df, x_vars=x_vars, y_vars=y_vars) - nt.assert_equal(g.axes.shape, (len(y_vars), len(x_vars))) - nt.assert_equal(g.x_vars, x_vars) - nt.assert_equal(g.y_vars, y_vars) - nt.assert_true(not g.square_grid) + assert g.axes.shape == (len(y_vars), len(x_vars)) + assert g.x_vars == list(x_vars) + assert g.y_vars == list(y_vars) + assert not g.square_grid - x_vars = ["x", "y"] - y_vars = "z" - g = ag.PairGrid(self.df, x_vars=x_vars, y_vars=y_vars) - nt.assert_equal(g.axes.shape, (len(y_vars), len(x_vars))) - nt.assert_equal(g.x_vars, list(x_vars)) - nt.assert_equal(g.y_vars, list(y_vars)) - nt.assert_true(not g.square_grid) - - def test_specific_square_axes_with_array(self): - - vars = np.array(["z", "x"]) - g = ag.PairGrid(self.df, vars=vars) - nt.assert_equal(g.axes.shape, (len(vars), len(vars))) - nt.assert_equal(g.x_vars, list(vars)) - nt.assert_equal(g.y_vars, list(vars)) - nt.assert_true(g.square_grid) - - def test_specific_nonsquare_axes_with_array(self): - - x_vars = np.array(["x", "y"]) - y_vars = np.array(["z", "y", "x"]) - g = ag.PairGrid(self.df, x_vars=x_vars, y_vars=y_vars) - nt.assert_equal(g.axes.shape, (len(y_vars), len(x_vars))) - nt.assert_equal(g.x_vars, list(x_vars)) - nt.assert_equal(g.y_vars, list(y_vars)) - nt.assert_true(not g.square_grid) - - @pytest.mark.xfail(LooseVersion(mpl.__version__) < "1.5", - reason="Expected failure on older matplotlib") def test_corner(self): plot_vars = ["x", "y", "z"] - g1 = ag.PairGrid(self.df, vars=plot_vars, corner=True) + g = ag.PairGrid(self.df, vars=plot_vars, corner=True) corner_size = sum([i + 1 for i in range(len(plot_vars))]) - assert len(g1.fig.axes) == corner_size + assert len(g.fig.axes) == corner_size - g1.map_diag(plt.hist) - assert len(g1.fig.axes) == (corner_size + len(plot_vars)) + g.map_diag(plt.hist) + assert len(g.fig.axes) == (corner_size + len(plot_vars)) - for ax in np.diag(g1.axes): + for ax in np.diag(g.axes): assert not ax.yaxis.get_visible() - assert not g1.axes[0, 0].get_ylabel() + assert not g.axes[0, 0].get_ylabel() + + plot_vars = ["x", "y", "z"] + g = ag.PairGrid(self.df, vars=plot_vars, corner=True) + g.map(scatterplot) + assert len(g.fig.axes) == corner_size def test_size(self): @@ -847,6 +758,11 @@ def test_size(self): height=2, aspect=2) npt.assert_array_equal(g3.fig.get_size_inches(), (8, 2)) + def test_empty_grid(self): + + with pytest.raises(ValueError, match="No variables found"): + ag.PairGrid(self.df[["a", "b"]]) + def test_map(self): vars = ["x", "y", "z"] @@ -861,7 +777,7 @@ def test_map(self): npt.assert_array_equal(x_in, x_out) npt.assert_array_equal(y_in, y_out) - g2 = ag.PairGrid(self.df, "a") + g2 = ag.PairGrid(self.df, hue="a") g2.map(plt.scatter) for i, axes_i in enumerate(g2.axes): @@ -906,7 +822,7 @@ def test_map_lower(self): for i, j in zip(*np.triu_indices_from(g.axes)): ax = g.axes[i, j] - nt.assert_equal(len(ax.collections), 0) + assert len(ax.collections) == 0 def test_map_upper(self): @@ -924,96 +840,133 @@ def test_map_upper(self): for i, j in zip(*np.tril_indices_from(g.axes)): ax = g.axes[i, j] - nt.assert_equal(len(ax.collections), 0) + assert len(ax.collections) == 0 + + def test_map_mixed_funcsig(self): + + vars = ["x", "y", "z"] + g = ag.PairGrid(self.df, vars=vars) + g.map_lower(scatterplot) + g.map_upper(plt.scatter) + + for i, j in zip(*np.triu_indices_from(g.axes, 1)): + ax = g.axes[i, j] + x_in = self.df[vars[j]] + y_in = self.df[vars[i]] + x_out, y_out = ax.collections[0].get_offsets().T + npt.assert_array_equal(x_in, x_out) + npt.assert_array_equal(y_in, y_out) def test_map_diag(self): - g1 = ag.PairGrid(self.df) - g1.map_diag(plt.hist) + g = ag.PairGrid(self.df) + g.map_diag(plt.hist) - for var, ax in zip(g1.diag_vars, g1.diag_axes): - nt.assert_equal(len(ax.patches), 10) + for var, ax in zip(g.diag_vars, g.diag_axes): + assert len(ax.patches) == 10 assert pytest.approx(ax.patches[0].get_x()) == self.df[var].min() - g2 = ag.PairGrid(self.df, hue="a") - g2.map_diag(plt.hist) + g = ag.PairGrid(self.df, hue="a") + g.map_diag(plt.hist) - for ax in g2.diag_axes: - nt.assert_equal(len(ax.patches), 30) + for ax in g.diag_axes: + assert len(ax.patches) == 30 - g3 = ag.PairGrid(self.df, hue="a") - g3.map_diag(plt.hist, histtype='step') + g = ag.PairGrid(self.df, hue="a") + g.map_diag(plt.hist, histtype='step') - for ax in g3.diag_axes: + for ax in g.diag_axes: for ptch in ax.patches: - nt.assert_equal(ptch.fill, False) + assert not ptch.fill def test_map_diag_rectangular(self): x_vars = ["x", "y"] - y_vars = ["x", "y", "z"] + y_vars = ["x", "z", "y"] g1 = ag.PairGrid(self.df, x_vars=x_vars, y_vars=y_vars) g1.map_diag(plt.hist) + g1.map_offdiag(plt.scatter) assert set(g1.diag_vars) == (set(x_vars) & set(y_vars)) for var, ax in zip(g1.diag_vars, g1.diag_axes): - nt.assert_equal(len(ax.patches), 10) + assert len(ax.patches) == 10 assert pytest.approx(ax.patches[0].get_x()) == self.df[var].min() - for i, ax in enumerate(np.diag(g1.axes)): - assert ax.bbox.bounds == g1.diag_axes[i].bbox.bounds + for j, x_var in enumerate(x_vars): + for i, y_var in enumerate(y_vars): + + ax = g1.axes[i, j] + if x_var == y_var: + diag_ax = g1.diag_axes[j] # because fewer x than y vars + assert ax.bbox.bounds == diag_ax.bbox.bounds + + else: + x, y = ax.collections[0].get_offsets().T + assert_array_equal(x, self.df[x_var]) + assert_array_equal(y, self.df[y_var]) g2 = ag.PairGrid(self.df, x_vars=x_vars, y_vars=y_vars, hue="a") g2.map_diag(plt.hist) + g2.map_offdiag(plt.scatter) assert set(g2.diag_vars) == (set(x_vars) & set(y_vars)) for ax in g2.diag_axes: - nt.assert_equal(len(ax.patches), 30) + assert len(ax.patches) == 30 x_vars = ["x", "y", "z"] - y_vars = ["x", "y"] + y_vars = ["x", "z"] g3 = ag.PairGrid(self.df, x_vars=x_vars, y_vars=y_vars) g3.map_diag(plt.hist) + g3.map_offdiag(plt.scatter) assert set(g3.diag_vars) == (set(x_vars) & set(y_vars)) for var, ax in zip(g3.diag_vars, g3.diag_axes): - nt.assert_equal(len(ax.patches), 10) + assert len(ax.patches) == 10 assert pytest.approx(ax.patches[0].get_x()) == self.df[var].min() - for i, ax in enumerate(np.diag(g3.axes)): - assert ax.bbox.bounds == g3.diag_axes[i].bbox.bounds + for j, x_var in enumerate(x_vars): + for i, y_var in enumerate(y_vars): + + ax = g3.axes[i, j] + if x_var == y_var: + diag_ax = g3.diag_axes[i] # because fewer y than x vars + assert ax.bbox.bounds == diag_ax.bbox.bounds + else: + x, y = ax.collections[0].get_offsets().T + assert_array_equal(x, self.df[x_var]) + assert_array_equal(y, self.df[y_var]) def test_map_diag_color(self): color = "red" - rgb_color = mpl.colors.colorConverter.to_rgba(color) g1 = ag.PairGrid(self.df) g1.map_diag(plt.hist, color=color) for ax in g1.diag_axes: for patch in ax.patches: - assert patch.get_facecolor() == rgb_color + assert_colors_equal(patch.get_facecolor(), color) g2 = ag.PairGrid(self.df) g2.map_diag(kdeplot, color='red') for ax in g2.diag_axes: for line in ax.lines: - assert line.get_color() == color + assert_colors_equal(line.get_color(), color) def test_map_diag_palette(self): - pal = color_palette(n_colors=len(self.df.a.unique())) - g = ag.PairGrid(self.df, hue="a") + palette = "muted" + pal = color_palette(palette, n_colors=len(self.df.a.unique())) + g = ag.PairGrid(self.df, hue="a", palette=palette) g.map_diag(kdeplot) for ax in g.diag_axes: - for line, color in zip(ax.lines, pal): - assert line.get_color() == color + for line, color in zip(ax.lines[::-1], pal): + assert_colors_equal(line.get_color(), color) def test_map_diag_and_offdiag(self): @@ -1023,7 +976,7 @@ def test_map_diag_and_offdiag(self): g.map_diag(plt.hist) for ax in g.diag_axes: - nt.assert_equal(len(ax.patches), 10) + assert len(ax.patches) == 10 for i, j in zip(*np.triu_indices_from(g.axes, 1)): ax = g.axes[i, j] @@ -1043,7 +996,7 @@ def test_map_diag_and_offdiag(self): for i, j in zip(*np.diag_indices_from(g.axes)): ax = g.axes[i, j] - nt.assert_equal(len(ax.collections), 0) + assert len(ax.collections) == 0 def test_diag_sharey(self): @@ -1052,6 +1005,20 @@ def test_diag_sharey(self): for ax in g.diag_axes[1:]: assert ax.get_ylim() == g.diag_axes[0].get_ylim() + def test_map_diag_matplotlib(self): + + bins = 10 + g = ag.PairGrid(self.df) + g.map_diag(plt.hist, bins=bins) + for ax in g.diag_axes: + assert len(ax.patches) == bins + + levels = len(self.df["a"].unique()) + g = ag.PairGrid(self.df, hue="a") + g.map_diag(plt.hist, bins=bins) + for ax in g.diag_axes: + assert len(ax.patches) == (bins * levels) + def test_palette(self): rcmod.set() @@ -1082,14 +1049,14 @@ def test_hue_kws(self): g.map(plt.plot) for line, marker in zip(g.axes[0, 0].lines, kws["marker"]): - nt.assert_equal(line.get_marker(), marker) + assert line.get_marker() == marker g = ag.PairGrid(self.df, hue="a", hue_kws=kws, hue_order=list("dcab")) g.map(plt.plot) for line, marker in zip(g.axes[0, 0].lines, kws["marker"]): - nt.assert_equal(line.get_marker(), marker) + assert line.get_marker() == marker def test_hue_order(self): @@ -1193,7 +1160,7 @@ def test_nondefault_index(self): npt.assert_array_equal(x_in, x_out) npt.assert_array_equal(y_in, y_out) - g2 = ag.PairGrid(df, "a") + g2 = ag.PairGrid(df, hue="a") g2.map(plt.scatter) for i, axes_i in enumerate(g2.axes): @@ -1204,10 +1171,11 @@ def test_nondefault_index(self): x_in_k = x_in[self.df.a == k_level] y_in_k = y_in[self.df.a == k_level] x_out, y_out = ax.collections[k].get_offsets().T - npt.assert_array_equal(x_in_k, x_out) - npt.assert_array_equal(y_in_k, y_out) + npt.assert_array_equal(x_in_k, x_out) + npt.assert_array_equal(y_in_k, y_out) - def test_dropna(self): + @pytest.mark.parametrize("func", [scatterplot, plt.scatter]) + def test_dropna(self, func): df = self.df.copy() n_null = 20 @@ -1216,7 +1184,7 @@ def test_dropna(self): plot_vars = ["x", "y", "z"] g1 = ag.PairGrid(df, vars=plot_vars, dropna=True) - g1.map(plt.scatter) + g1.map(func) for i, axes_i in enumerate(g1.axes): for j, ax in enumerate(axes_i): @@ -1229,6 +1197,21 @@ def test_dropna(self): assert n_valid == len(x_out) assert n_valid == len(y_out) + g1.map_diag(histplot) + for i, ax in enumerate(g1.diag_axes): + var = plot_vars[i] + count = sum([p.get_height() for p in ax.patches]) + assert count == df[var].notna().sum() + + def test_histplot_legend(self): + + # Tests _extract_legend_handles + g = ag.PairGrid(self.df, vars=["x", "y"], hue="a") + g.map_offdiag(histplot) + g.add_legend() + + assert len(g._legend.legendHandles) == len(self.df["a"].unique()) + def test_pairplot(self): vars = ["x", "y", "z"] @@ -1255,13 +1238,12 @@ def test_pairplot(self): for i, j in zip(*np.diag_indices_from(g.axes)): ax = g.axes[i, j] - nt.assert_equal(len(ax.collections), 0) + assert len(ax.collections) == 0 g = ag.pairplot(self.df, hue="a") n = len(self.df.a.unique()) for ax in g.diag_axes: - assert len(ax.lines) == n assert len(ax.collections) == n def test_pairplot_reg(self): @@ -1270,7 +1252,7 @@ def test_pairplot_reg(self): g = ag.pairplot(self.df, diag_kind="hist", kind="reg") for ax in g.diag_axes: - nt.assert_equal(len(ax.patches), 10) + assert len(ax.patches) for i, j in zip(*np.triu_indices_from(g.axes, 1)): ax = g.axes[i, j] @@ -1280,8 +1262,8 @@ def test_pairplot_reg(self): npt.assert_array_equal(x_in, x_out) npt.assert_array_equal(y_in, y_out) - nt.assert_equal(len(ax.lines), 1) - nt.assert_equal(len(ax.collections), 2) + assert len(ax.lines) == 1 + assert len(ax.collections) == 2 for i, j in zip(*np.tril_indices_from(g.axes, -1)): ax = g.axes[i, j] @@ -1291,20 +1273,34 @@ def test_pairplot_reg(self): npt.assert_array_equal(x_in, x_out) npt.assert_array_equal(y_in, y_out) - nt.assert_equal(len(ax.lines), 1) - nt.assert_equal(len(ax.collections), 2) + assert len(ax.lines) == 1 + assert len(ax.collections) == 2 for i, j in zip(*np.diag_indices_from(g.axes)): ax = g.axes[i, j] - nt.assert_equal(len(ax.collections), 0) + assert len(ax.collections) == 0 - def test_pairplot_kde(self): + def test_pairplot_reg_hue(self): + + markers = ["o", "s", "d"] + g = ag.pairplot(self.df, kind="reg", hue="a", markers=markers) + + ax = g.axes[-1, 0] + c1 = ax.collections[0] + c2 = ax.collections[2] + + assert not np.array_equal(c1.get_facecolor(), c2.get_facecolor()) + assert not np.array_equal( + c1.get_paths()[0].vertices, c2.get_paths()[0].vertices, + ) + + def test_pairplot_diag_kde(self): vars = ["x", "y", "z"] g = ag.pairplot(self.df, diag_kind="kde") for ax in g.diag_axes: - nt.assert_equal(len(ax.lines), 1) + assert len(ax.collections) == 1 for i, j in zip(*np.triu_indices_from(g.axes, 1)): ax = g.axes[i, j] @@ -1324,21 +1320,62 @@ def test_pairplot_kde(self): for i, j in zip(*np.diag_indices_from(g.axes)): ax = g.axes[i, j] - nt.assert_equal(len(ax.collections), 0) + assert len(ax.collections) == 0 + + def test_pairplot_kde(self): + + f, ax1 = plt.subplots() + kdeplot(data=self.df, x="x", y="y", ax=ax1) + + g = ag.pairplot(self.df, kind="kde") + ax2 = g.axes[1, 0] + + assert_plots_equal(ax1, ax2, labels=False) + + def test_pairplot_hist(self): + + f, ax1 = plt.subplots() + histplot(data=self.df, x="x", y="y", ax=ax1) + + g = ag.pairplot(self.df, kind="hist") + ax2 = g.axes[1, 0] + + assert_plots_equal(ax1, ax2, labels=False) def test_pairplot_markers(self): vars = ["x", "y", "z"] - markers = ["o", "x", "s"] + markers = ["o", "X", "s"] g = ag.pairplot(self.df, hue="a", vars=vars, markers=markers) - assert g.hue_kws["marker"] == markers - plt.close("all") + m1 = g._legend.legendHandles[0].get_paths()[0] + m2 = g._legend.legendHandles[1].get_paths()[0] + assert m1 != m2 with pytest.raises(ValueError): g = ag.pairplot(self.df, hue="a", vars=vars, markers=markers[:-2]) + def test_corner_despine(self): + + g = ag.PairGrid(self.df, corner=True, despine=False) + g.map_diag(histplot) + assert g.axes[0, 0].spines["top"].get_visible() + + def test_corner_set(self): + + g = ag.PairGrid(self.df, corner=True, despine=False) + g.set(xlim=(0, 10)) + assert g.axes[-1, 0].get_xlim() == (0, 10) -class TestJointGrid(object): + def test_legend(self): + + g1 = ag.pairplot(self.df, hue="a") + assert isinstance(g1.legend, mpl.legend.Legend) + + g2 = ag.pairplot(self.df) + assert g2.legend is None + + +class TestJointGrid: rs = np.random.RandomState(sum(map(ord, "JointGrid"))) x = rs.randn(100) @@ -1350,74 +1387,78 @@ class TestJointGrid(object): def test_margin_grid_from_lists(self): - g = ag.JointGrid(self.x.tolist(), self.y.tolist()) + g = ag.JointGrid(x=self.x.tolist(), y=self.y.tolist()) npt.assert_array_equal(g.x, self.x) npt.assert_array_equal(g.y, self.y) def test_margin_grid_from_arrays(self): - g = ag.JointGrid(self.x, self.y) + g = ag.JointGrid(x=self.x, y=self.y) npt.assert_array_equal(g.x, self.x) npt.assert_array_equal(g.y, self.y) def test_margin_grid_from_series(self): - g = ag.JointGrid(self.data.x, self.data.y) + g = ag.JointGrid(x=self.data.x, y=self.data.y) npt.assert_array_equal(g.x, self.x) npt.assert_array_equal(g.y, self.y) def test_margin_grid_from_dataframe(self): - g = ag.JointGrid("x", "y", self.data) + g = ag.JointGrid(x="x", y="y", data=self.data) npt.assert_array_equal(g.x, self.x) npt.assert_array_equal(g.y, self.y) def test_margin_grid_from_dataframe_bad_variable(self): - with nt.assert_raises(ValueError): - ag.JointGrid("x", "bad_column", self.data) + with pytest.raises(ValueError): + ag.JointGrid(x="x", y="bad_column", data=self.data) def test_margin_grid_axis_labels(self): - g = ag.JointGrid("x", "y", self.data) + g = ag.JointGrid(x="x", y="y", data=self.data) xlabel, ylabel = g.ax_joint.get_xlabel(), g.ax_joint.get_ylabel() - nt.assert_equal(xlabel, "x") - nt.assert_equal(ylabel, "y") + assert xlabel == "x" + assert ylabel == "y" g.set_axis_labels("x variable", "y variable") xlabel, ylabel = g.ax_joint.get_xlabel(), g.ax_joint.get_ylabel() - nt.assert_equal(xlabel, "x variable") - nt.assert_equal(ylabel, "y variable") + assert xlabel == "x variable" + assert ylabel == "y variable" def test_dropna(self): - g = ag.JointGrid("x_na", "y", self.data, dropna=False) - nt.assert_equal(len(g.x), len(self.x_na)) + g = ag.JointGrid(x="x_na", y="y", data=self.data, dropna=False) + assert len(g.x) == len(self.x_na) - g = ag.JointGrid("x_na", "y", self.data, dropna=True) - nt.assert_equal(len(g.x), pd.notnull(self.x_na).sum()) + g = ag.JointGrid(x="x_na", y="y", data=self.data, dropna=True) + assert len(g.x) == pd.notnull(self.x_na).sum() def test_axlims(self): lim = (-3, 3) - g = ag.JointGrid("x", "y", self.data, xlim=lim, ylim=lim) + g = ag.JointGrid(x="x", y="y", data=self.data, xlim=lim, ylim=lim) - nt.assert_equal(g.ax_joint.get_xlim(), lim) - nt.assert_equal(g.ax_joint.get_ylim(), lim) + assert g.ax_joint.get_xlim() == lim + assert g.ax_joint.get_ylim() == lim - nt.assert_equal(g.ax_marg_x.get_xlim(), lim) - nt.assert_equal(g.ax_marg_y.get_ylim(), lim) + assert g.ax_marg_x.get_xlim() == lim + assert g.ax_marg_y.get_ylim() == lim def test_marginal_ticks(self): - g = ag.JointGrid("x", "y", self.data) - nt.assert_true(~len(g.ax_marg_x.get_xticks())) - nt.assert_true(~len(g.ax_marg_y.get_yticks())) + g = ag.JointGrid(marginal_ticks=False) + assert not sum(t.get_visible() for t in g.ax_marg_x.get_yticklabels()) + assert not sum(t.get_visible() for t in g.ax_marg_y.get_xticklabels()) + + g = ag.JointGrid(marginal_ticks=True) + assert sum(t.get_visible() for t in g.ax_marg_x.get_yticklabels()) + assert sum(t.get_visible() for t in g.ax_marg_y.get_xticklabels()) def test_bivariate_plot(self): - g = ag.JointGrid("x", "y", self.data) + g = ag.JointGrid(x="x", y="y", data=self.data) g.plot_joint(plt.plot) x, y = g.ax_joint.lines[0].get_xydata().T @@ -1426,16 +1467,35 @@ def test_bivariate_plot(self): def test_univariate_plot(self): - g = ag.JointGrid("x", "x", self.data) + g = ag.JointGrid(x="x", y="x", data=self.data) g.plot_marginals(kdeplot) _, y1 = g.ax_marg_x.lines[0].get_xydata().T y2, _ = g.ax_marg_y.lines[0].get_xydata().T npt.assert_array_equal(y1, y2) + def test_univariate_plot_distplot(self): + + bins = 10 + g = ag.JointGrid(x="x", y="x", data=self.data) + with pytest.warns(FutureWarning): + g.plot_marginals(distplot, bins=bins) + assert len(g.ax_marg_x.patches) == bins + assert len(g.ax_marg_y.patches) == bins + for x, y in zip(g.ax_marg_x.patches, g.ax_marg_y.patches): + assert x.get_height() == y.get_width() + + def test_univariate_plot_matplotlib(self): + + bins = 10 + g = ag.JointGrid(x="x", y="x", data=self.data) + g.plot_marginals(plt.hist, bins=bins) + assert len(g.ax_marg_x.patches) == bins + assert len(g.ax_marg_y.patches) == bins + def test_plot(self): - g = ag.JointGrid("x", "x", self.data) + g = ag.JointGrid(x="x", y="x", data=self.data) g.plot(plt.plot, kdeplot) x, y = g.ax_joint.lines[0].get_xydata().T @@ -1446,46 +1506,44 @@ def test_plot(self): y2, _ = g.ax_marg_y.lines[0].get_xydata().T npt.assert_array_equal(y1, y2) - def test_annotate(self): - - g = ag.JointGrid("x", "y", self.data) - rp = stats.pearsonr(self.x, self.y) - - g.annotate(stats.pearsonr) - annotation = g.ax_joint.legend_.texts[0].get_text() - nt.assert_equal(annotation, "pearsonr = %.2g; p = %.2g" % rp) + def test_space(self): - g.annotate(stats.pearsonr, stat="correlation") - annotation = g.ax_joint.legend_.texts[0].get_text() - nt.assert_equal(annotation, "correlation = %.2g; p = %.2g" % rp) + g = ag.JointGrid(x="x", y="y", data=self.data, space=0) - def rsquared(x, y): - return stats.pearsonr(x, y)[0] ** 2 + joint_bounds = g.ax_joint.bbox.bounds + marg_x_bounds = g.ax_marg_x.bbox.bounds + marg_y_bounds = g.ax_marg_y.bbox.bounds - r2 = rsquared(self.x, self.y) - g.annotate(rsquared) - annotation = g.ax_joint.legend_.texts[0].get_text() - nt.assert_equal(annotation, "rsquared = %.2g" % r2) + assert joint_bounds[2] == marg_x_bounds[2] + assert joint_bounds[3] == marg_y_bounds[3] - template = "{stat} = {val:.3g} (p = {p:.3g})" - g.annotate(stats.pearsonr, template=template) - annotation = g.ax_joint.legend_.texts[0].get_text() - nt.assert_equal(annotation, template.format(stat="pearsonr", - val=rp[0], p=rp[1])) + @pytest.mark.parametrize( + "as_vector", [True, False], + ) + def test_hue(self, long_df, as_vector): - def test_space(self): + if as_vector: + data = None + x, y, hue = long_df["x"], long_df["y"], long_df["a"] + else: + data = long_df + x, y, hue = "x", "y", "a" - g = ag.JointGrid("x", "y", self.data, space=0) + g = ag.JointGrid(data=data, x=x, y=y, hue=hue) + g.plot_joint(scatterplot) + g.plot_marginals(histplot) - joint_bounds = g.ax_joint.bbox.bounds - marg_x_bounds = g.ax_marg_x.bbox.bounds - marg_y_bounds = g.ax_marg_y.bbox.bounds + g2 = ag.JointGrid() + scatterplot(data=long_df, x=x, y=y, hue=hue, ax=g2.ax_joint) + histplot(data=long_df, x=x, hue=hue, ax=g2.ax_marg_x) + histplot(data=long_df, y=y, hue=hue, ax=g2.ax_marg_y) - nt.assert_equal(joint_bounds[2], marg_x_bounds[2]) - nt.assert_equal(joint_bounds[3], marg_y_bounds[3]) + assert_plots_equal(g.ax_joint, g2.ax_joint) + assert_plots_equal(g.ax_marg_x, g2.ax_marg_x, labels=False) + assert_plots_equal(g.ax_marg_y, g2.ax_marg_y, labels=False) -class TestJointPlot(object): +class TestJointPlot: rs = np.random.RandomState(sum(map(ord, "jointplot"))) x = rs.randn(100) @@ -1494,97 +1552,163 @@ class TestJointPlot(object): def test_scatter(self): - g = ag.jointplot("x", "y", self.data) - nt.assert_equal(len(g.ax_joint.collections), 1) + g = ag.jointplot(x="x", y="y", data=self.data) + assert len(g.ax_joint.collections) == 1 x, y = g.ax_joint.collections[0].get_offsets().T - npt.assert_array_equal(self.x, x) - npt.assert_array_equal(self.y, y) + assert_array_equal(self.x, x) + assert_array_equal(self.y, y) - x_bins = _freedman_diaconis_bins(self.x) - nt.assert_equal(len(g.ax_marg_x.patches), x_bins) + assert_array_equal( + [b.get_x() for b in g.ax_marg_x.patches], + np.histogram_bin_edges(self.x, "auto")[:-1], + ) - y_bins = _freedman_diaconis_bins(self.y) - nt.assert_equal(len(g.ax_marg_y.patches), y_bins) + assert_array_equal( + [b.get_y() for b in g.ax_marg_y.patches], + np.histogram_bin_edges(self.y, "auto")[:-1], + ) + + def test_scatter_hue(self, long_df): + + g1 = ag.jointplot(data=long_df, x="x", y="y", hue="a") + + g2 = ag.JointGrid() + scatterplot(data=long_df, x="x", y="y", hue="a", ax=g2.ax_joint) + kdeplot(data=long_df, x="x", hue="a", ax=g2.ax_marg_x, fill=True) + kdeplot(data=long_df, y="y", hue="a", ax=g2.ax_marg_y, fill=True) + + assert_plots_equal(g1.ax_joint, g2.ax_joint) + assert_plots_equal(g1.ax_marg_x, g2.ax_marg_x, labels=False) + assert_plots_equal(g1.ax_marg_y, g2.ax_marg_y, labels=False) def test_reg(self): - g = ag.jointplot("x", "y", self.data, kind="reg") - nt.assert_equal(len(g.ax_joint.collections), 2) + g = ag.jointplot(x="x", y="y", data=self.data, kind="reg") + assert len(g.ax_joint.collections) == 2 x, y = g.ax_joint.collections[0].get_offsets().T - npt.assert_array_equal(self.x, x) - npt.assert_array_equal(self.y, y) + assert_array_equal(self.x, x) + assert_array_equal(self.y, y) - x_bins = _freedman_diaconis_bins(self.x) - nt.assert_equal(len(g.ax_marg_x.patches), x_bins) + assert g.ax_marg_x.patches + assert g.ax_marg_y.patches - y_bins = _freedman_diaconis_bins(self.y) - nt.assert_equal(len(g.ax_marg_y.patches), y_bins) - - nt.assert_equal(len(g.ax_joint.lines), 1) - nt.assert_equal(len(g.ax_marg_x.lines), 1) - nt.assert_equal(len(g.ax_marg_y.lines), 1) + assert g.ax_marg_x.lines + assert g.ax_marg_y.lines def test_resid(self): - g = ag.jointplot("x", "y", self.data, kind="resid") - nt.assert_equal(len(g.ax_joint.collections), 1) - nt.assert_equal(len(g.ax_joint.lines), 1) - nt.assert_equal(len(g.ax_marg_x.lines), 0) - nt.assert_equal(len(g.ax_marg_y.lines), 1) + g = ag.jointplot(x="x", y="y", data=self.data, kind="resid") + assert g.ax_joint.collections + assert g.ax_joint.lines + assert not g.ax_marg_x.lines + assert not g.ax_marg_y.lines + + def test_hist(self, long_df): + + bins = 3, 6 + g1 = ag.jointplot(data=long_df, x="x", y="y", kind="hist", bins=bins) + + g2 = ag.JointGrid() + histplot(data=long_df, x="x", y="y", ax=g2.ax_joint, bins=bins) + histplot(data=long_df, x="x", ax=g2.ax_marg_x, bins=bins[0]) + histplot(data=long_df, y="y", ax=g2.ax_marg_y, bins=bins[1]) + + assert_plots_equal(g1.ax_joint, g2.ax_joint) + assert_plots_equal(g1.ax_marg_x, g2.ax_marg_x, labels=False) + assert_plots_equal(g1.ax_marg_y, g2.ax_marg_y, labels=False) def test_hex(self): - g = ag.jointplot("x", "y", self.data, kind="hex") - nt.assert_equal(len(g.ax_joint.collections), 1) + g = ag.jointplot(x="x", y="y", data=self.data, kind="hex") + assert g.ax_joint.collections + assert g.ax_marg_x.patches + assert g.ax_marg_y.patches + + def test_kde(self, long_df): + + g1 = ag.jointplot(data=long_df, x="x", y="y", kind="kde") - x_bins = _freedman_diaconis_bins(self.x) - nt.assert_equal(len(g.ax_marg_x.patches), x_bins) + g2 = ag.JointGrid() + kdeplot(data=long_df, x="x", y="y", ax=g2.ax_joint) + kdeplot(data=long_df, x="x", ax=g2.ax_marg_x) + kdeplot(data=long_df, y="y", ax=g2.ax_marg_y) - y_bins = _freedman_diaconis_bins(self.y) - nt.assert_equal(len(g.ax_marg_y.patches), y_bins) + assert_plots_equal(g1.ax_joint, g2.ax_joint) + assert_plots_equal(g1.ax_marg_x, g2.ax_marg_x, labels=False) + assert_plots_equal(g1.ax_marg_y, g2.ax_marg_y, labels=False) - def test_kde(self): + def test_kde_hue(self, long_df): - g = ag.jointplot("x", "y", self.data, kind="kde") + g1 = ag.jointplot(data=long_df, x="x", y="y", hue="a", kind="kde") - nt.assert_true(len(g.ax_joint.collections) > 0) - nt.assert_equal(len(g.ax_marg_x.collections), 1) - nt.assert_equal(len(g.ax_marg_y.collections), 1) + g2 = ag.JointGrid() + kdeplot(data=long_df, x="x", y="y", hue="a", ax=g2.ax_joint) + kdeplot(data=long_df, x="x", hue="a", ax=g2.ax_marg_x) + kdeplot(data=long_df, y="y", hue="a", ax=g2.ax_marg_y) - nt.assert_equal(len(g.ax_marg_x.lines), 1) - nt.assert_equal(len(g.ax_marg_y.lines), 1) + assert_plots_equal(g1.ax_joint, g2.ax_joint) + assert_plots_equal(g1.ax_marg_x, g2.ax_marg_x, labels=False) + assert_plots_equal(g1.ax_marg_y, g2.ax_marg_y, labels=False) def test_color(self): - g = ag.jointplot("x", "y", self.data, color="purple") + g = ag.jointplot(x="x", y="y", data=self.data, color="purple") - purple = mpl.colors.colorConverter.to_rgb("purple") - scatter_color = g.ax_joint.collections[0].get_facecolor()[0, :3] - nt.assert_equal(tuple(scatter_color), purple) + scatter_color = g.ax_joint.collections[0].get_facecolor() + assert_colors_equal(scatter_color, "purple") hist_color = g.ax_marg_x.patches[0].get_facecolor()[:3] - nt.assert_equal(hist_color, purple) + assert_colors_equal(hist_color, "purple") - def test_annotation(self): + def test_palette(self, long_df): - g = ag.jointplot("x", "y", self.data, stat_func=stats.pearsonr) - nt.assert_equal(len(g.ax_joint.legend_.get_texts()), 1) + kws = dict(data=long_df, hue="a", palette="Set2") - g = ag.jointplot("x", "y", self.data, stat_func=None) - nt.assert_is(g.ax_joint.legend_, None) + g1 = ag.jointplot(x="x", y="y", **kws) + + g2 = ag.JointGrid() + scatterplot(x="x", y="y", ax=g2.ax_joint, **kws) + kdeplot(x="x", ax=g2.ax_marg_x, fill=True, **kws) + kdeplot(y="y", ax=g2.ax_marg_y, fill=True, **kws) + + assert_plots_equal(g1.ax_joint, g2.ax_joint) + assert_plots_equal(g1.ax_marg_x, g2.ax_marg_x, labels=False) + assert_plots_equal(g1.ax_marg_y, g2.ax_marg_y, labels=False) def test_hex_customise(self): # test that default gridsize can be overridden - g = ag.jointplot("x", "y", self.data, kind="hex", + g = ag.jointplot(x="x", y="y", data=self.data, kind="hex", joint_kws=dict(gridsize=5)) - nt.assert_equal(len(g.ax_joint.collections), 1) + assert len(g.ax_joint.collections) == 1 a = g.ax_joint.collections[0].get_array() - nt.assert_equal(28, a.shape[0]) # 28 hexagons expected for gridsize 5 + assert a.shape[0] == 28 # 28 hexagons expected for gridsize 5 def test_bad_kind(self): - with nt.assert_raises(ValueError): - ag.jointplot("x", "y", self.data, kind="not_a_kind") + with pytest.raises(ValueError): + ag.jointplot(x="x", y="y", data=self.data, kind="not_a_kind") + + def test_unsupported_hue_kind(self): + + for kind in ["reg", "resid", "hex"]: + with pytest.raises(ValueError): + ag.jointplot(x="x", y="y", hue="a", data=self.data, kind=kind) + + def test_leaky_dict(self): + # Validate input dicts are unchanged by jointplot plotting function + + for kwarg in ("joint_kws", "marginal_kws"): + for kind in ("hex", "kde", "resid", "reg", "scatter"): + empty_dict = {} + ag.jointplot(x="x", y="y", data=self.data, kind=kind, + **{kwarg: empty_dict}) + assert empty_dict == {} + + def test_distplot_kwarg_warning(self, long_df): + + with pytest.warns(UserWarning): + g = ag.jointplot(data=long_df, x="x", y="y", marginal_kws=dict(rug=True)) + assert g.ax_marg_x.patches diff --git a/seaborn/tests/test_categorical.py b/seaborn/tests/test_categorical.py index 71d4155fd9..0a46a7e669 100644 --- a/seaborn/tests/test_categorical.py +++ b/seaborn/tests/test_categorical.py @@ -1,32 +1,109 @@ +import itertools +from functools import partial + import numpy as np import pandas as pd -import scipy -from scipy import stats, spatial import matplotlib as mpl import matplotlib.pyplot as plt -from matplotlib.colors import rgb2hex - -from distutils.version import LooseVersion +from matplotlib.colors import rgb2hex, to_rgb, to_rgba import pytest -import nose.tools as nt +from pytest import approx import numpy.testing as npt +from distutils.version import LooseVersion +from numpy.testing import ( + assert_array_equal, + assert_array_less, +) from .. import categorical as cat from .. import palettes - -pandas_has_categoricals = LooseVersion(pd.__version__) >= "0.15" -mpl_barplot_change = LooseVersion("2.0.1") - - -class CategoricalFixture(object): +from .._core import categorical_order +from ..categorical import ( + _CategoricalPlotterNew, + Beeswarm, + catplot, + stripplot, + swarmplot, +) +from ..palettes import color_palette +from ..utils import _normal_quantile_func, _draw_figure +from .._testing import assert_plots_equal + + +PLOT_FUNCS = [ + catplot, + stripplot, + swarmplot, +] + + +class TestCategoricalPlotterNew: + + @pytest.mark.parametrize( + "func,kwargs", + itertools.product( + PLOT_FUNCS, + [ + {"x": "x", "y": "a"}, + {"x": "a", "y": "y"}, + {"x": "y"}, + {"y": "x"}, + ], + ), + ) + def test_axis_labels(self, long_df, func, kwargs): + + func(data=long_df, **kwargs) + + ax = plt.gca() + for axis in "xy": + val = kwargs.get(axis, "") + label_func = getattr(ax, f"get_{axis}label") + assert label_func() == val + + @pytest.mark.parametrize("func", PLOT_FUNCS) + def test_empty(self, func): + + func() + ax = plt.gca() + assert not ax.collections + assert not ax.patches + assert not ax.lines + + func(x=[], y=[]) + ax = plt.gca() + assert not ax.collections + assert not ax.patches + assert not ax.lines + + def test_redundant_hue_backcompat(self, long_df): + + p = _CategoricalPlotterNew( + data=long_df, + variables={"x": "s", "y": "y"}, + ) + + color = None + palette = dict(zip(long_df["s"].unique(), color_palette())) + hue_order = None + + palette, _ = p._hue_backcompat(color, palette, hue_order, force_hue=True) + + assert p.variables["hue"] == "s" + assert_array_equal(p.plot_data["hue"], p.plot_data["x"]) + assert all(isinstance(k, str) for k in palette) + + +class CategoricalFixture: """Test boxplot (also base class for things like violinplots).""" rs = np.random.RandomState(30) n_total = 60 x = rs.randn(int(n_total / 3), 3) x_df = pd.DataFrame(x, columns=pd.Series(list("XYZ"), name="big")) y = pd.Series(rs.randn(n_total), name="y_data") + y_perm = y.reindex(rs.choice(y.index, y.size, replace=False)) g = pd.Series(np.repeat(list("abc"), int(n_total / 3)), name="small") h = pd.Series(np.tile(list("mn"), int(n_total / 2)), name="medium") u = pd.Series(np.tile(list("jkh"), int(n_total / 3))) @@ -48,17 +125,17 @@ def test_wide_df_data(self): npt.assert_array_equal(x, y) # Check semantic attributes - nt.assert_equal(p.orient, "v") - nt.assert_is(p.plot_hues, None) - nt.assert_is(p.group_label, "big") - nt.assert_is(p.value_label, None) + assert p.orient == "v" + assert p.plot_hues is None + assert p.group_label == "big" + assert p.value_label is None # Test wide dataframe with forced horizontal orientation p.establish_variables(data=self.x_df, orient="horiz") - nt.assert_equal(p.orient, "h") + assert p.orient == "h" - # Text exception by trying to hue-group with a wide dataframe - with nt.assert_raises(ValueError): + # Test exception by trying to hue-group with a wide dataframe + with pytest.raises(ValueError): p.establish_variables(hue="d", data=self.x_df) def test_1d_input_data(self): @@ -68,28 +145,29 @@ def test_1d_input_data(self): # Test basic vector data x_1d_array = self.x.ravel() p.establish_variables(data=x_1d_array) - nt.assert_equal(len(p.plot_data), 1) - nt.assert_equal(len(p.plot_data[0]), self.n_total) - nt.assert_is(p.group_label, None) - nt.assert_is(p.value_label, None) + assert len(p.plot_data) == 1 + assert len(p.plot_data[0]) == self.n_total + assert p.group_label is None + assert p.value_label is None # Test basic vector data in list form x_1d_list = x_1d_array.tolist() p.establish_variables(data=x_1d_list) - nt.assert_equal(len(p.plot_data), 1) - nt.assert_equal(len(p.plot_data[0]), self.n_total) - nt.assert_is(p.group_label, None) - nt.assert_is(p.value_label, None) + assert len(p.plot_data) == 1 + assert len(p.plot_data[0]) == self.n_total + assert p.group_label is None + assert p.value_label is None # Test an object array that looks 1D but isn't x_notreally_1d = np.array([self.x.ravel(), - self.x.ravel()[:int(self.n_total / 2)]]) + self.x.ravel()[:int(self.n_total / 2)]], + dtype=object) p.establish_variables(data=x_notreally_1d) - nt.assert_equal(len(p.plot_data), 2) - nt.assert_equal(len(p.plot_data[0]), self.n_total) - nt.assert_equal(len(p.plot_data[1]), self.n_total / 2) - nt.assert_is(p.group_label, None) - nt.assert_is(p.value_label, None) + assert len(p.plot_data) == 2 + assert len(p.plot_data[0]) == self.n_total + assert len(p.plot_data[1]) == self.n_total / 2 + assert p.group_label is None + assert p.value_label is None def test_2d_input_data(self): @@ -99,17 +177,17 @@ def test_2d_input_data(self): # Test vector data that looks 2D but doesn't really have columns p.establish_variables(data=x[:, np.newaxis]) - nt.assert_equal(len(p.plot_data), 1) - nt.assert_equal(len(p.plot_data[0]), self.x.shape[0]) - nt.assert_is(p.group_label, None) - nt.assert_is(p.value_label, None) + assert len(p.plot_data) == 1 + assert len(p.plot_data[0]) == self.x.shape[0] + assert p.group_label is None + assert p.value_label is None # Test vector data that looks 2D but doesn't really have rows p.establish_variables(data=x[np.newaxis, :]) - nt.assert_equal(len(p.plot_data), 1) - nt.assert_equal(len(p.plot_data[0]), self.x.shape[0]) - nt.assert_is(p.group_label, None) - nt.assert_is(p.value_label, None) + assert len(p.plot_data) == 1 + assert len(p.plot_data[0]) == self.x.shape[0] + assert p.group_label is None + assert p.value_label is None def test_3d_input_data(self): @@ -117,7 +195,7 @@ def test_3d_input_data(self): # Test that passing actually 3D data raises x = np.zeros((5, 5, 5)) - with nt.assert_raises(ValueError): + with pytest.raises(ValueError): p.establish_variables(data=x) def test_list_of_array_input_data(self): @@ -127,13 +205,13 @@ def test_list_of_array_input_data(self): # Test 2D input in list form x_list = self.x.T.tolist() p.establish_variables(data=x_list) - nt.assert_equal(len(p.plot_data), 3) + assert len(p.plot_data) == 3 lengths = [len(v_i) for v_i in p.plot_data] - nt.assert_equal(lengths, [self.n_total / 3] * 3) + assert lengths == [self.n_total / 3] * 3 - nt.assert_is(p.group_label, None) - nt.assert_is(p.value_label, None) + assert p.group_label is None + assert p.value_label is None def test_wide_array_input_data(self): @@ -141,11 +219,11 @@ def test_wide_array_input_data(self): # Test 2D input in array form p.establish_variables(data=self.x) - nt.assert_equal(np.shape(p.plot_data), (3, self.n_total / 3)) + assert np.shape(p.plot_data) == (3, self.n_total / 3) npt.assert_array_equal(p.plot_data, self.x.T) - nt.assert_is(p.group_label, None) - nt.assert_is(p.value_label, None) + assert p.group_label is None + assert p.value_label is None def test_single_long_direct_inputs(self): @@ -154,23 +232,29 @@ def test_single_long_direct_inputs(self): # Test passing a series to the x variable p.establish_variables(x=self.y) npt.assert_equal(p.plot_data, [self.y]) - nt.assert_equal(p.orient, "h") - nt.assert_equal(p.value_label, "y_data") - nt.assert_is(p.group_label, None) + assert p.orient == "h" + assert p.value_label == "y_data" + assert p.group_label is None # Test passing a series to the y variable p.establish_variables(y=self.y) npt.assert_equal(p.plot_data, [self.y]) - nt.assert_equal(p.orient, "v") - nt.assert_equal(p.value_label, "y_data") - nt.assert_is(p.group_label, None) + assert p.orient == "v" + assert p.value_label == "y_data" + assert p.group_label is None # Test passing an array to the y variable p.establish_variables(y=self.y.values) npt.assert_equal(p.plot_data, [self.y]) - nt.assert_equal(p.orient, "v") - nt.assert_is(p.value_label, None) - nt.assert_is(p.group_label, None) + assert p.orient == "v" + assert p.group_label is None + assert p.value_label is None + + # Test array and series with non-default index + x = pd.Series([1, 1, 1, 1], index=[0, 2, 4, 6]) + y = np.array([1, 2, 3, 4]) + p.establish_variables(x, y) + assert len(p.plot_data[0]) == 4 def test_single_long_indirect_inputs(self): @@ -179,29 +263,29 @@ def test_single_long_indirect_inputs(self): # Test referencing a DataFrame series in the x variable p.establish_variables(x="y", data=self.df) npt.assert_equal(p.plot_data, [self.y]) - nt.assert_equal(p.orient, "h") - nt.assert_equal(p.value_label, "y") - nt.assert_is(p.group_label, None) + assert p.orient == "h" + assert p.value_label == "y" + assert p.group_label is None # Test referencing a DataFrame series in the y variable p.establish_variables(y="y", data=self.df) npt.assert_equal(p.plot_data, [self.y]) - nt.assert_equal(p.orient, "v") - nt.assert_equal(p.value_label, "y") - nt.assert_is(p.group_label, None) + assert p.orient == "v" + assert p.value_label == "y" + assert p.group_label is None def test_longform_groupby(self): p = cat._CategoricalPlotter() # Test a vertically oriented grouped and nested plot - p.establish_variables("g", "y", "h", data=self.df) - nt.assert_equal(len(p.plot_data), 3) - nt.assert_equal(len(p.plot_hues), 3) - nt.assert_equal(p.orient, "v") - nt.assert_equal(p.value_label, "y") - nt.assert_equal(p.group_label, "g") - nt.assert_equal(p.hue_title, "h") + p.establish_variables("g", "y", hue="h", data=self.df) + assert len(p.plot_data) == 3 + assert len(p.plot_hues) == 3 + assert p.orient == "v" + assert p.value_label == "y" + assert p.group_label == "g" + assert p.hue_title == "h" for group, vals in zip(["a", "b", "c"], p.plot_data): npt.assert_array_equal(vals, self.y[self.g == group]) @@ -211,8 +295,8 @@ def test_longform_groupby(self): # Test a grouped and nested plot with direct array value data p.establish_variables("g", self.y.values, "h", self.df) - nt.assert_is(p.value_label, None) - nt.assert_equal(p.group_label, "g") + assert p.value_label is None + assert p.group_label == "g" for group, vals in zip(["a", "b", "c"], p.plot_data): npt.assert_array_equal(vals, self.y[self.g == group]) @@ -224,34 +308,41 @@ def test_longform_groupby(self): npt.assert_array_equal(hues, self.h[self.g == group]) # Test categorical grouping data - if pandas_has_categoricals: - df = self.df.copy() - df.g = df.g.astype("category") + df = self.df.copy() + df.g = df.g.astype("category") - # Test that horizontal orientation is automatically detected - p.establish_variables("y", "g", "h", data=df) - nt.assert_equal(len(p.plot_data), 3) - nt.assert_equal(len(p.plot_hues), 3) - nt.assert_equal(p.orient, "h") - nt.assert_equal(p.value_label, "y") - nt.assert_equal(p.group_label, "g") - nt.assert_equal(p.hue_title, "h") + # Test that horizontal orientation is automatically detected + p.establish_variables("y", "g", hue="h", data=df) + assert len(p.plot_data) == 3 + assert len(p.plot_hues) == 3 + assert p.orient == "h" + assert p.value_label == "y" + assert p.group_label == "g" + assert p.hue_title == "h" - for group, vals in zip(["a", "b", "c"], p.plot_data): - npt.assert_array_equal(vals, self.y[self.g == group]) + for group, vals in zip(["a", "b", "c"], p.plot_data): + npt.assert_array_equal(vals, self.y[self.g == group]) - for group, hues in zip(["a", "b", "c"], p.plot_hues): - npt.assert_array_equal(hues, self.h[self.g == group]) + for group, hues in zip(["a", "b", "c"], p.plot_hues): + npt.assert_array_equal(hues, self.h[self.g == group]) + + # Test grouped data that matches on index + p1 = cat._CategoricalPlotter() + p1.establish_variables(self.g, self.y, hue=self.h) + p2 = cat._CategoricalPlotter() + p2.establish_variables(self.g, self.y[::-1], self.h) + for i, (d1, d2) in enumerate(zip(p1.plot_data, p2.plot_data)): + assert np.array_equal(d1.sort_index(), d2.sort_index()) def test_input_validation(self): p = cat._CategoricalPlotter() kws = dict(x="g", y="y", hue="h", units="u", data=self.df) - for input in ["x", "y", "hue", "units"]: + for var in ["x", "y", "hue", "units"]: input_kws = kws.copy() - input_kws[input] = "bad_input" - with nt.assert_raises(ValueError): + input_kws[var] = "bad_input" + with pytest.raises(ValueError): p.establish_variables(**input_kws) def test_order(self): @@ -260,106 +351,79 @@ def test_order(self): # Test inferred order from a wide dataframe input p.establish_variables(data=self.x_df) - nt.assert_equal(p.group_names, ["X", "Y", "Z"]) + assert p.group_names == ["X", "Y", "Z"] # Test specified order with a wide dataframe input p.establish_variables(data=self.x_df, order=["Y", "Z", "X"]) - nt.assert_equal(p.group_names, ["Y", "Z", "X"]) + assert p.group_names == ["Y", "Z", "X"] for group, vals in zip(["Y", "Z", "X"], p.plot_data): npt.assert_array_equal(vals, self.x_df[group]) - with nt.assert_raises(ValueError): + with pytest.raises(ValueError): p.establish_variables(data=self.x, order=[1, 2, 0]) # Test inferred order from a grouped longform input p.establish_variables("g", "y", data=self.df) - nt.assert_equal(p.group_names, ["a", "b", "c"]) + assert p.group_names == ["a", "b", "c"] # Test specified order from a grouped longform input p.establish_variables("g", "y", data=self.df, order=["b", "a", "c"]) - nt.assert_equal(p.group_names, ["b", "a", "c"]) + assert p.group_names == ["b", "a", "c"] for group, vals in zip(["b", "a", "c"], p.plot_data): npt.assert_array_equal(vals, self.y[self.g == group]) # Test inferred order from a grouped input with categorical groups - if pandas_has_categoricals: - df = self.df.copy() - df.g = df.g.astype("category") - df.g = df.g.cat.reorder_categories(["c", "b", "a"]) - p.establish_variables("g", "y", data=df) - nt.assert_equal(p.group_names, ["c", "b", "a"]) + df = self.df.copy() + df.g = df.g.astype("category") + df.g = df.g.cat.reorder_categories(["c", "b", "a"]) + p.establish_variables("g", "y", data=df) + assert p.group_names == ["c", "b", "a"] - for group, vals in zip(["c", "b", "a"], p.plot_data): - npt.assert_array_equal(vals, self.y[self.g == group]) + for group, vals in zip(["c", "b", "a"], p.plot_data): + npt.assert_array_equal(vals, self.y[self.g == group]) - df.g = (df.g.cat.add_categories("d") - .cat.reorder_categories(["c", "b", "d", "a"])) - p.establish_variables("g", "y", data=df) - nt.assert_equal(p.group_names, ["c", "b", "d", "a"]) + df.g = (df.g.cat.add_categories("d") + .cat.reorder_categories(["c", "b", "d", "a"])) + p.establish_variables("g", "y", data=df) + assert p.group_names == ["c", "b", "d", "a"] def test_hue_order(self): p = cat._CategoricalPlotter() # Test inferred hue order - p.establish_variables("g", "y", "h", data=self.df) - nt.assert_equal(p.hue_names, ["m", "n"]) + p.establish_variables("g", "y", hue="h", data=self.df) + assert p.hue_names == ["m", "n"] # Test specified hue order - p.establish_variables("g", "y", "h", data=self.df, + p.establish_variables("g", "y", hue="h", data=self.df, hue_order=["n", "m"]) - nt.assert_equal(p.hue_names, ["n", "m"]) + assert p.hue_names == ["n", "m"] # Test inferred hue order from a categorical hue input - if pandas_has_categoricals: - df = self.df.copy() - df.h = df.h.astype("category") - df.h = df.h.cat.reorder_categories(["n", "m"]) - p.establish_variables("g", "y", "h", data=df) - nt.assert_equal(p.hue_names, ["n", "m"]) - - df.h = (df.h.cat.add_categories("o") - .cat.reorder_categories(["o", "m", "n"])) - p.establish_variables("g", "y", "h", data=df) - nt.assert_equal(p.hue_names, ["o", "m", "n"]) + df = self.df.copy() + df.h = df.h.astype("category") + df.h = df.h.cat.reorder_categories(["n", "m"]) + p.establish_variables("g", "y", hue="h", data=df) + assert p.hue_names == ["n", "m"] + + df.h = (df.h.cat.add_categories("o") + .cat.reorder_categories(["o", "m", "n"])) + p.establish_variables("g", "y", hue="h", data=df) + assert p.hue_names == ["o", "m", "n"] def test_plot_units(self): p = cat._CategoricalPlotter() - p.establish_variables("g", "y", "h", data=self.df) - nt.assert_is(p.plot_units, None) + p.establish_variables("g", "y", hue="h", data=self.df) + assert p.plot_units is None - p.establish_variables("g", "y", "h", data=self.df, units="u") + p.establish_variables("g", "y", hue="h", data=self.df, units="u") for group, units in zip(["a", "b", "c"], p.plot_units): npt.assert_array_equal(units, self.u[self.g == group]) - def test_infer_orient(self): - - p = cat._CategoricalPlotter() - - cats = pd.Series(["a", "b", "c"] * 10) - nums = pd.Series(self.rs.randn(30)) - - nt.assert_equal(p.infer_orient(cats, nums), "v") - nt.assert_equal(p.infer_orient(nums, cats), "h") - nt.assert_equal(p.infer_orient(nums, None), "h") - nt.assert_equal(p.infer_orient(None, nums), "v") - nt.assert_equal(p.infer_orient(nums, nums, "vert"), "v") - nt.assert_equal(p.infer_orient(nums, nums, "hori"), "h") - - with nt.assert_raises(ValueError): - p.infer_orient(cats, cats) - - if pandas_has_categoricals: - cats = pd.Series([0, 1, 2] * 10, dtype="category") - nt.assert_equal(p.infer_orient(cats, nums), "v") - nt.assert_equal(p.infer_orient(nums, cats), "h") - - with nt.assert_raises(ValueError): - p.infer_orient(cats, cats) - def test_default_palettes(self): p = cat._CategoricalPlotter() @@ -367,12 +431,12 @@ def test_default_palettes(self): # Test palette mapping the x position p.establish_variables("g", "y", data=self.df) p.establish_colors(None, None, 1) - nt.assert_equal(p.colors, palettes.color_palette(n_colors=3)) + assert p.colors == palettes.color_palette(n_colors=3) # Test palette mapping the hue position - p.establish_variables("g", "y", "h", data=self.df) + p.establish_variables("g", "y", hue="h", data=self.df) p.establish_colors(None, None, 1) - nt.assert_equal(p.colors, palettes.color_palette(n_colors=2)) + assert p.colors == palettes.color_palette(n_colors=2) def test_default_palette_with_many_levels(self): @@ -391,10 +455,10 @@ def test_specific_color(self): p.establish_variables("g", "y", data=self.df) p.establish_colors("blue", None, 1) blue_rgb = mpl.colors.colorConverter.to_rgb("blue") - nt.assert_equal(p.colors, [blue_rgb] * 3) + assert p.colors == [blue_rgb] * 3 # Test a color-based blend for the hue mapping - p.establish_variables("g", "y", "h", data=self.df) + p.establish_variables("g", "y", hue="h", data=self.df) p.establish_colors("#ff0022", None, 1) rgba_array = np.array(palettes.light_palette("#ff0022", 2)) npt.assert_array_almost_equal(p.colors, @@ -407,38 +471,36 @@ def test_specific_palette(self): # Test palette mapping the x position p.establish_variables("g", "y", data=self.df) p.establish_colors(None, "dark", 1) - nt.assert_equal(p.colors, palettes.color_palette("dark", 3)) + assert p.colors == palettes.color_palette("dark", 3) # Test that non-None `color` and `hue` raises an error - p.establish_variables("g", "y", "h", data=self.df) + p.establish_variables("g", "y", hue="h", data=self.df) p.establish_colors(None, "muted", 1) - nt.assert_equal(p.colors, palettes.color_palette("muted", 2)) + assert p.colors == palettes.color_palette("muted", 2) # Test that specified palette overrides specified color p = cat._CategoricalPlotter() p.establish_variables("g", "y", data=self.df) p.establish_colors("blue", "deep", 1) - nt.assert_equal(p.colors, palettes.color_palette("deep", 3)) + assert p.colors == palettes.color_palette("deep", 3) def test_dict_as_palette(self): p = cat._CategoricalPlotter() - p.establish_variables("g", "y", "h", data=self.df) + p.establish_variables("g", "y", hue="h", data=self.df) pal = {"m": (0, 0, 1), "n": (1, 0, 0)} p.establish_colors(None, pal, 1) - nt.assert_equal(p.colors, [(0, 0, 1), (1, 0, 0)]) + assert p.colors == [(0, 0, 1), (1, 0, 0)] def test_palette_desaturation(self): p = cat._CategoricalPlotter() p.establish_variables("g", "y", data=self.df) p.establish_colors((0, 0, 1), None, .5) - nt.assert_equal(p.colors, [(.25, .25, .75)] * 3) + assert p.colors == [(.25, .25, .75)] * 3 p.establish_colors(None, [(0, 0, 1), (1, 0, 0), "w"], .5) - nt.assert_equal(p.colors, [(.25, .25, .75), - (.75, .25, .25), - (1, 1, 1)]) + assert p.colors == [(.25, .25, .75), (.75, .25, .25), (1, 1, 1)] class TestCategoricalStatPlotter(CategoricalFixture): @@ -447,11 +509,11 @@ def test_no_bootstrappig(self): p = cat._CategoricalStatPlotter() p.establish_variables("g", "y", data=self.df) - p.estimate_statistic(np.mean, None, 100) + p.estimate_statistic(np.mean, None, 100, None) npt.assert_array_equal(p.confint, np.array([])) - p.establish_variables("g", "y", "h", data=self.df) - p.estimate_statistic(np.mean, None, 100) + p.establish_variables("g", "y", hue="h", data=self.df) + p.estimate_statistic(np.mean, None, 100, None) npt.assert_array_equal(p.confint, np.array([[], [], []])) def test_single_layer_stats(self): @@ -462,19 +524,18 @@ def test_single_layer_stats(self): y = pd.Series(np.random.RandomState(0).randn(300)) p.establish_variables(g, y) - p.estimate_statistic(np.mean, 95, 10000) + p.estimate_statistic(np.mean, 95, 10000, None) - nt.assert_equal(p.statistic.shape, (3,)) - nt.assert_equal(p.confint.shape, (3, 2)) + assert p.statistic.shape == (3,) + assert p.confint.shape == (3, 2) npt.assert_array_almost_equal(p.statistic, y.groupby(g).mean()) for ci, (_, grp_y) in zip(p.confint, y.groupby(g)): - sem = stats.sem(grp_y) + sem = grp_y.std() / np.sqrt(len(grp_y)) mean = grp_y.mean() - stats.norm.ppf(.975) - half_ci = stats.norm.ppf(.975) * sem + half_ci = _normal_quantile_func(.975) * sem ci_want = mean - half_ci, mean + half_ci npt.assert_array_almost_equal(ci_want, ci, 2) @@ -489,11 +550,11 @@ def test_single_layer_stats_with_units(self): y[u == "y"] += 3 p.establish_variables(g, y) - p.estimate_statistic(np.mean, 95, 10000) + p.estimate_statistic(np.mean, 95, 10000, None) stat1, ci1 = p.statistic, p.confint p.establish_variables(g, y, units=u) - p.estimate_statistic(np.mean, 95, 10000) + p.estimate_statistic(np.mean, 95, 10000, None) stat2, ci2 = p.statistic, p.confint npt.assert_array_equal(stat1, stat2) @@ -509,14 +570,15 @@ def test_single_layer_stats_with_missing_data(self): y = pd.Series(np.random.RandomState(0).randn(300)) p.establish_variables(g, y, order=list("abdc")) - p.estimate_statistic(np.mean, 95, 10000) + p.estimate_statistic(np.mean, 95, 10000, None) - nt.assert_equal(p.statistic.shape, (4,)) - nt.assert_equal(p.confint.shape, (4, 2)) + assert p.statistic.shape == (4,) + assert p.confint.shape == (4, 2) - mean = y[g == "b"].mean() - sem = stats.sem(y[g == "b"]) - half_ci = stats.norm.ppf(.975) * sem + rows = g == "b" + mean = y[rows].mean() + sem = y[rows].std() / np.sqrt(rows.sum()) + half_ci = _normal_quantile_func(.975) * sem ci = mean - half_ci, mean + half_ci npt.assert_almost_equal(p.statistic[1], mean) npt.assert_array_almost_equal(p.confint[1], ci, 2) @@ -533,22 +595,38 @@ def test_nested_stats(self): y = pd.Series(np.random.RandomState(0).randn(300)) p.establish_variables(g, y, h) - p.estimate_statistic(np.mean, 95, 50000) + p.estimate_statistic(np.mean, 95, 50000, None) - nt.assert_equal(p.statistic.shape, (3, 2)) - nt.assert_equal(p.confint.shape, (3, 2, 2)) + assert p.statistic.shape == (3, 2) + assert p.confint.shape == (3, 2, 2) npt.assert_array_almost_equal(p.statistic, y.groupby([g, h]).mean().unstack()) for ci_g, (_, grp_y) in zip(p.confint, y.groupby(g)): for ci, hue_y in zip(ci_g, [grp_y[::2], grp_y[1::2]]): - sem = stats.sem(hue_y) + sem = hue_y.std() / np.sqrt(len(hue_y)) mean = hue_y.mean() - half_ci = stats.norm.ppf(.975) * sem + half_ci = _normal_quantile_func(.975) * sem ci_want = mean - half_ci, mean + half_ci npt.assert_array_almost_equal(ci_want, ci, 2) + def test_bootstrap_seed(self): + + p = cat._CategoricalStatPlotter() + + g = pd.Series(np.repeat(list("abc"), 100)) + h = pd.Series(np.tile(list("xy"), 150)) + y = pd.Series(np.random.RandomState(0).randn(300)) + + p.establish_variables(g, y, h) + p.estimate_statistic(np.mean, 95, 1000, 0) + confint_1 = p.confint + p.estimate_statistic(np.mean, 95, 1000, 0) + confint_2 = p.confint + + npt.assert_array_equal(confint_1, confint_2) + def test_nested_stats_with_units(self): p = cat._CategoricalStatPlotter() @@ -561,11 +639,11 @@ def test_nested_stats_with_units(self): y[u == "k"] += 3 p.establish_variables(g, y, h) - p.estimate_statistic(np.mean, 95, 10000) + p.estimate_statistic(np.mean, 95, 10000, None) stat1, ci1 = p.statistic, p.confint p.establish_variables(g, y, h, units=u) - p.estimate_statistic(np.mean, 95, 10000) + p.estimate_statistic(np.mean, 95, 10000, None) stat2, ci2 = p.statistic, p.confint npt.assert_array_equal(stat1, stat2) @@ -584,14 +662,15 @@ def test_nested_stats_with_missing_data(self): p.establish_variables(g, y, h, order=list("abdc"), hue_order=list("zyx")) - p.estimate_statistic(np.mean, 95, 50000) + p.estimate_statistic(np.mean, 95, 50000, None) - nt.assert_equal(p.statistic.shape, (4, 3)) - nt.assert_equal(p.confint.shape, (4, 3, 2)) + assert p.statistic.shape == (4, 3) + assert p.confint.shape == (4, 3, 2) - mean = y[(g == "b") & (h == "x")].mean() - sem = stats.sem(y[(g == "b") & (h == "x")]) - half_ci = stats.norm.ppf(.975) * sem + rows = (g == "b") & (h == "x") + mean = y[rows].mean() + sem = y[rows].std() / np.sqrt(rows.sum()) + half_ci = _normal_quantile_func(.975) * sem ci = mean - half_ci, mean + half_ci npt.assert_almost_equal(p.statistic[1, 2], mean) npt.assert_array_almost_equal(p.confint[1, 2], ci, 2) @@ -611,10 +690,10 @@ def test_sd_error_bars(self): y = pd.Series(np.random.RandomState(0).randn(300)) p.establish_variables(g, y) - p.estimate_statistic(np.mean, "sd", None) + p.estimate_statistic(np.mean, "sd", None, None) - nt.assert_equal(p.statistic.shape, (3,)) - nt.assert_equal(p.confint.shape, (3, 2)) + assert p.statistic.shape == (3,) + assert p.confint.shape == (3, 2) npt.assert_array_almost_equal(p.statistic, y.groupby(g).mean()) @@ -634,10 +713,10 @@ def test_nested_sd_error_bars(self): y = pd.Series(np.random.RandomState(0).randn(300)) p.establish_variables(g, y, h) - p.estimate_statistic(np.mean, "sd", None) + p.estimate_statistic(np.mean, "sd", None, None) - nt.assert_equal(p.statistic.shape, (3, 2)) - nt.assert_equal(p.confint.shape, (3, 2, 2)) + assert p.statistic.shape == (3, 2) + assert p.confint.shape == (3, 2, 2) npt.assert_array_almost_equal(p.statistic, y.groupby([g, h]).mean().unstack()) @@ -667,7 +746,7 @@ def test_draw_cis(self): x, y = line.get_xydata().T npt.assert_array_equal(x, [at, at]) npt.assert_array_equal(y, ci) - nt.assert_equal(line.get_color(), c) + assert line.get_color() == c plt.close("all") @@ -682,7 +761,7 @@ def test_draw_cis(self): x, y = line.get_xydata().T npt.assert_array_equal(x, ci) npt.assert_array_equal(y, [at, at]) - nt.assert_equal(line.get_color(), c) + assert line.get_color() == c plt.close("all") @@ -695,8 +774,8 @@ def test_draw_cis(self): caplinestart = capline.get_xdata()[0] caplineend = capline.get_xdata()[1] caplinelength = abs(caplineend - caplinestart) - nt.assert_almost_equal(caplinelength, 0.3) - nt.assert_equal(len(ax.lines), 6) + assert caplinelength == approx(0.3) + assert len(ax.lines) == 6 plt.close("all") @@ -709,23 +788,23 @@ def test_draw_cis(self): caplinestart = capline.get_ydata()[0] caplineend = capline.get_ydata()[1] caplinelength = abs(caplineend - caplinestart) - nt.assert_almost_equal(caplinelength, 0.3) - nt.assert_equal(len(ax.lines), 6) + assert caplinelength == approx(0.3) + assert len(ax.lines) == 6 # Test extra keyword arguments f, ax = plt.subplots() p.draw_confints(ax, at_group, confints, colors, lw=4) line = ax.lines[0] - nt.assert_equal(line.get_linewidth(), 4) + assert line.get_linewidth() == 4 plt.close("all") # Test errwidth is set appropriately f, ax = plt.subplots() p.draw_confints(ax, at_group, confints, colors, errwidth=2) - capline = ax.lines[len(ax.lines)-1] - nt.assert_equal(capline._linewidth, 2) - nt.assert_equal(len(ax.lines), 2) + capline = ax.lines[len(ax.lines) - 1] + assert capline._linewidth == 2 + assert len(ax.lines) == 2 plt.close("all") @@ -742,31 +821,31 @@ def test_nested_width(self): kws = self.default_kws.copy() p = cat._BoxPlotter(**kws) - p.establish_variables("g", "y", "h", data=self.df) - nt.assert_equal(p.nested_width, .4 * .98) + p.establish_variables("g", "y", hue="h", data=self.df) + assert p.nested_width == .4 * .98 kws = self.default_kws.copy() kws["width"] = .6 p = cat._BoxPlotter(**kws) - p.establish_variables("g", "y", "h", data=self.df) - nt.assert_equal(p.nested_width, .3 * .98) + p.establish_variables("g", "y", hue="h", data=self.df) + assert p.nested_width == .3 * .98 kws = self.default_kws.copy() kws["dodge"] = False p = cat._BoxPlotter(**kws) - p.establish_variables("g", "y", "h", data=self.df) - nt.assert_equal(p.nested_width, .8) + p.establish_variables("g", "y", hue="h", data=self.df) + assert p.nested_width == .8 def test_hue_offsets(self): p = cat._BoxPlotter(**self.default_kws) - p.establish_variables("g", "y", "h", data=self.df) + p.establish_variables("g", "y", hue="h", data=self.df) npt.assert_array_equal(p.hue_offsets, [-.2, .2]) kws = self.default_kws.copy() kws["width"] = .6 p = cat._BoxPlotter(**kws) - p.establish_variables("g", "y", "h", data=self.df) + p.establish_variables("g", "y", hue="h", data=self.df) npt.assert_array_equal(p.hue_offsets, [-.15, .15]) p = cat._BoxPlotter(**kws) @@ -775,37 +854,37 @@ def test_hue_offsets(self): def test_axes_data(self): - ax = cat.boxplot("g", "y", data=self.df) - nt.assert_equal(len(ax.artists), 3) + ax = cat.boxplot(x="g", y="y", data=self.df) + assert len(ax.artists) == 3 plt.close("all") - ax = cat.boxplot("g", "y", "h", data=self.df) - nt.assert_equal(len(ax.artists), 6) + ax = cat.boxplot(x="g", y="y", hue="h", data=self.df) + assert len(ax.artists) == 6 plt.close("all") def test_box_colors(self): - ax = cat.boxplot("g", "y", data=self.df, saturation=1) + ax = cat.boxplot(x="g", y="y", data=self.df, saturation=1) pal = palettes.color_palette(n_colors=3) for patch, color in zip(ax.artists, pal): - nt.assert_equal(patch.get_facecolor()[:3], color) + assert patch.get_facecolor()[:3] == color plt.close("all") - ax = cat.boxplot("g", "y", "h", data=self.df, saturation=1) + ax = cat.boxplot(x="g", y="y", hue="h", data=self.df, saturation=1) pal = palettes.color_palette(n_colors=2) for patch, color in zip(ax.artists, pal * 2): - nt.assert_equal(patch.get_facecolor()[:3], color) + assert patch.get_facecolor()[:3] == color plt.close("all") def test_draw_missing_boxes(self): - ax = cat.boxplot("g", "y", data=self.df, + ax = cat.boxplot(x="g", y="y", data=self.df, order=["a", "b", "c", "d"]) - nt.assert_equal(len(ax.artists), 3) + assert len(ax.artists) == 3 def test_missing_data(self): @@ -814,60 +893,77 @@ def test_missing_data(self): y = self.rs.randn(8) y[-2:] = np.nan - ax = cat.boxplot(x, y) - nt.assert_equal(len(ax.artists), 3) + ax = cat.boxplot(x=x, y=y) + assert len(ax.artists) == 3 plt.close("all") y[-1] = 0 - ax = cat.boxplot(x, y, h) - nt.assert_equal(len(ax.artists), 7) + ax = cat.boxplot(x=x, y=y, hue=h) + assert len(ax.artists) == 7 plt.close("all") + def test_unaligned_index(self): + + f, (ax1, ax2) = plt.subplots(2) + cat.boxplot(x=self.g, y=self.y, ax=ax1) + cat.boxplot(x=self.g, y=self.y_perm, ax=ax2) + for l1, l2 in zip(ax1.lines, ax2.lines): + assert np.array_equal(l1.get_xydata(), l2.get_xydata()) + + f, (ax1, ax2) = plt.subplots(2) + hue_order = self.h.unique() + cat.boxplot(x=self.g, y=self.y, hue=self.h, + hue_order=hue_order, ax=ax1) + cat.boxplot(x=self.g, y=self.y_perm, hue=self.h, + hue_order=hue_order, ax=ax2) + for l1, l2 in zip(ax1.lines, ax2.lines): + assert np.array_equal(l1.get_xydata(), l2.get_xydata()) + def test_boxplots(self): # Smoke test the high level boxplot options - cat.boxplot("y", data=self.df) + cat.boxplot(x="y", data=self.df) plt.close("all") cat.boxplot(y="y", data=self.df) plt.close("all") - cat.boxplot("g", "y", data=self.df) + cat.boxplot(x="g", y="y", data=self.df) plt.close("all") - cat.boxplot("y", "g", data=self.df, orient="h") + cat.boxplot(x="y", y="g", data=self.df, orient="h") plt.close("all") - cat.boxplot("g", "y", "h", data=self.df) + cat.boxplot(x="g", y="y", hue="h", data=self.df) plt.close("all") - cat.boxplot("g", "y", "h", order=list("nabc"), data=self.df) + cat.boxplot(x="g", y="y", hue="h", order=list("nabc"), data=self.df) plt.close("all") - cat.boxplot("g", "y", "h", hue_order=list("omn"), data=self.df) + cat.boxplot(x="g", y="y", hue="h", hue_order=list("omn"), data=self.df) plt.close("all") - cat.boxplot("y", "g", "h", data=self.df, orient="h") + cat.boxplot(x="y", y="g", hue="h", data=self.df, orient="h") plt.close("all") def test_axes_annotation(self): - ax = cat.boxplot("g", "y", data=self.df) - nt.assert_equal(ax.get_xlabel(), "g") - nt.assert_equal(ax.get_ylabel(), "y") - nt.assert_equal(ax.get_xlim(), (-.5, 2.5)) + ax = cat.boxplot(x="g", y="y", data=self.df) + assert ax.get_xlabel() == "g" + assert ax.get_ylabel() == "y" + assert ax.get_xlim() == (-.5, 2.5) npt.assert_array_equal(ax.get_xticks(), [0, 1, 2]) npt.assert_array_equal([l.get_text() for l in ax.get_xticklabels()], ["a", "b", "c"]) plt.close("all") - ax = cat.boxplot("g", "y", "h", data=self.df) - nt.assert_equal(ax.get_xlabel(), "g") - nt.assert_equal(ax.get_ylabel(), "y") + ax = cat.boxplot(x="g", y="y", hue="h", data=self.df) + assert ax.get_xlabel() == "g" + assert ax.get_ylabel() == "y" npt.assert_array_equal(ax.get_xticks(), [0, 1, 2]) npt.assert_array_equal([l.get_text() for l in ax.get_xticklabels()], ["a", "b", "c"]) @@ -876,10 +972,10 @@ def test_axes_annotation(self): plt.close("all") - ax = cat.boxplot("y", "g", data=self.df, orient="h") - nt.assert_equal(ax.get_xlabel(), "y") - nt.assert_equal(ax.get_ylabel(), "g") - nt.assert_equal(ax.get_ylim(), (2.5, -.5)) + ax = cat.boxplot(x="y", y="g", data=self.df, orient="h") + assert ax.get_xlabel() == "y" + assert ax.get_ylabel() == "g" + assert ax.get_ylim() == (2.5, -.5) npt.assert_array_equal(ax.get_yticks(), [0, 1, 2]) npt.assert_array_equal([l.get_text() for l in ax.get_yticklabels()], ["a", "b", "c"]) @@ -901,7 +997,7 @@ def test_split_error(self): kws = self.default_kws.copy() kws.update(dict(x="h", y="y", hue="g", data=self.df, split=True)) - with nt.assert_raises(ValueError): + with pytest.raises(ValueError): cat._ViolinPlotter(**kws) def test_no_observations(self): @@ -914,34 +1010,34 @@ def test_no_observations(self): p.establish_variables(x, y) p.estimate_densities("scott", 2, "area", True, 20) - nt.assert_equal(len(p.support[0]), 20) - nt.assert_equal(len(p.support[1]), 0) + assert len(p.support[0]) == 20 + assert len(p.support[1]) == 0 - nt.assert_equal(len(p.density[0]), 20) - nt.assert_equal(len(p.density[1]), 1) + assert len(p.density[0]) == 20 + assert len(p.density[1]) == 1 - nt.assert_equal(p.density[1].item(), 1) + assert p.density[1].item() == 1 p.estimate_densities("scott", 2, "count", True, 20) - nt.assert_equal(p.density[1].item(), 0) + assert p.density[1].item() == 0 x = ["a"] * 4 + ["b"] * 2 y = self.rs.randn(6) h = ["m", "n"] * 2 + ["m"] * 2 - p.establish_variables(x, y, h) + p.establish_variables(x, y, hue=h) p.estimate_densities("scott", 2, "area", True, 20) - nt.assert_equal(len(p.support[1][0]), 20) - nt.assert_equal(len(p.support[1][1]), 0) + assert len(p.support[1][0]) == 20 + assert len(p.support[1][1]) == 0 - nt.assert_equal(len(p.density[1][0]), 20) - nt.assert_equal(len(p.density[1][1]), 1) + assert len(p.density[1][0]) == 20 + assert len(p.density[1][1]) == 1 - nt.assert_equal(p.density[1][1].item(), 1) + assert p.density[1][1].item() == 1 p.estimate_densities("scott", 2, "count", False, 20) - nt.assert_equal(p.density[1][1].item(), 0) + assert p.density[1][1].item() == 0 def test_single_observation(self): @@ -952,34 +1048,34 @@ def test_single_observation(self): p.establish_variables(x, y) p.estimate_densities("scott", 2, "area", True, 20) - nt.assert_equal(len(p.support[0]), 20) - nt.assert_equal(len(p.support[1]), 1) + assert len(p.support[0]) == 20 + assert len(p.support[1]) == 1 - nt.assert_equal(len(p.density[0]), 20) - nt.assert_equal(len(p.density[1]), 1) + assert len(p.density[0]) == 20 + assert len(p.density[1]) == 1 - nt.assert_equal(p.density[1].item(), 1) + assert p.density[1].item() == 1 p.estimate_densities("scott", 2, "count", True, 20) - nt.assert_equal(p.density[1].item(), .5) + assert p.density[1].item() == .5 x = ["b"] * 4 + ["a"] * 3 y = self.rs.randn(7) h = (["m", "n"] * 4)[:-1] - p.establish_variables(x, y, h) + p.establish_variables(x, y, hue=h) p.estimate_densities("scott", 2, "area", True, 20) - nt.assert_equal(len(p.support[1][0]), 20) - nt.assert_equal(len(p.support[1][1]), 1) + assert len(p.support[1][0]) == 20 + assert len(p.support[1][1]) == 1 - nt.assert_equal(len(p.density[1][0]), 20) - nt.assert_equal(len(p.density[1][1]), 1) + assert len(p.density[1][0]) == 20 + assert len(p.density[1][1]) == 1 - nt.assert_equal(p.density[1][1].item(), 1) + assert p.density[1][1].item() == 1 p.estimate_densities("scott", 2, "count", False, 20) - nt.assert_equal(p.density[1][1].item(), .5) + assert p.density[1][1].item() == .5 def test_dwidth(self): @@ -987,19 +1083,19 @@ def test_dwidth(self): kws.update(dict(x="g", y="y", data=self.df)) p = cat._ViolinPlotter(**kws) - nt.assert_equal(p.dwidth, .4) + assert p.dwidth == .4 kws.update(dict(width=.4)) p = cat._ViolinPlotter(**kws) - nt.assert_equal(p.dwidth, .2) + assert p.dwidth == .2 kws.update(dict(hue="h", width=.8)) p = cat._ViolinPlotter(**kws) - nt.assert_equal(p.dwidth, .2) + assert p.dwidth == .2 kws.update(dict(split=True)) p = cat._ViolinPlotter(**kws) - nt.assert_equal(p.dwidth, .4) + assert p.dwidth == .4 def test_scale_area(self): @@ -1013,11 +1109,11 @@ def test_scale_area(self): max_before = np.array([d.max() for d in density]) p.scale_area(density, max_before, False) max_after = np.array([d.max() for d in density]) - nt.assert_equal(max_after[0], 1) + assert max_after[0] == 1 before_ratio = max_before[1] / max_before[0] after_ratio = max_after[1] / max_after[0] - nt.assert_equal(before_ratio, after_ratio) + assert before_ratio == after_ratio # Test nested grouping scaling across all densities p.hue_names = ["foo", "bar"] @@ -1027,11 +1123,11 @@ def test_scale_area(self): max_before = np.array([[r.max() for r in row] for row in density]) p.scale_area(density, max_before, False) max_after = np.array([[r.max() for r in row] for row in density]) - nt.assert_equal(max_after[0, 0], 1) + assert max_after[0, 0] == 1 before_ratio = max_before[1, 1] / max_before[0, 0] after_ratio = max_after[1, 1] / max_after[0, 0] - nt.assert_equal(before_ratio, after_ratio) + assert before_ratio == after_ratio # Test nested grouping scaling within hue p.hue_names = ["foo", "bar"] @@ -1041,12 +1137,12 @@ def test_scale_area(self): max_before = np.array([[r.max() for r in row] for row in density]) p.scale_area(density, max_before, True) max_after = np.array([[r.max() for r in row] for row in density]) - nt.assert_equal(max_after[0, 0], 1) - nt.assert_equal(max_after[1, 0], 1) + assert max_after[0, 0] == 1 + assert max_after[1, 0] == 1 before_ratio = max_before[1, 1] / max_before[1, 0] after_ratio = max_after[1, 1] / max_after[1, 0] - nt.assert_equal(before_ratio, after_ratio) + assert before_ratio == after_ratio def test_scale_width(self): @@ -1108,7 +1204,7 @@ def test_bad_scale(self): kws = self.default_kws.copy() kws["scale"] = "not_a_scale_type" - with nt.assert_raises(ValueError): + with pytest.raises(ValueError): cat._ViolinPlotter(**kws) def test_kde_fit(self): @@ -1117,25 +1213,15 @@ def test_kde_fit(self): data = self.y data_std = data.std(ddof=1) - # Bandwidth behavior depends on scipy version - if LooseVersion(scipy.__version__) < "0.11": - # Test ignoring custom bandwidth on old scipy - kde, bw = p.fit_kde(self.y, .2) - nt.assert_is_instance(kde, stats.gaussian_kde) - nt.assert_equal(kde.factor, kde.scotts_factor()) + # Test reference rule bandwidth + kde, bw = p.fit_kde(data, "scott") + assert kde.factor == kde.scotts_factor() + assert bw == kde.scotts_factor() * data_std - else: - # Test reference rule bandwidth - kde, bw = p.fit_kde(data, "scott") - nt.assert_is_instance(kde, stats.gaussian_kde) - nt.assert_equal(kde.factor, kde.scotts_factor()) - nt.assert_equal(bw, kde.scotts_factor() * data_std) - - # Test numeric scale factor - kde, bw = p.fit_kde(self.y, .2) - nt.assert_is_instance(kde, stats.gaussian_kde) - nt.assert_equal(kde.factor, .2) - nt.assert_equal(bw, .2 * data_std) + # Test numeric scale factor + kde, bw = p.fit_kde(self.y, .2) + assert kde.factor == .2 + assert bw == .2 * data_std def test_draw_to_density(self): @@ -1231,14 +1317,14 @@ def test_draw_box_lines(self): _, ax = plt.subplots() p.draw_box_lines(ax, self.y, p.support[0], p.density[0], 0) - nt.assert_equal(len(ax.lines), 2) + assert len(ax.lines) == 2 q25, q50, q75 = np.percentile(self.y, [25, 50, 75]) _, y = ax.lines[1].get_xydata().T npt.assert_array_equal(y, [q25, q75]) _, y = ax.collections[0].get_offsets().T - nt.assert_equal(y, q50) + assert y == q50 plt.close("all") @@ -1249,14 +1335,14 @@ def test_draw_box_lines(self): _, ax = plt.subplots() p.draw_box_lines(ax, self.y, p.support[0], p.density[0], 0) - nt.assert_equal(len(ax.lines), 2) + assert len(ax.lines) == 2 q25, q50, q75 = np.percentile(self.y, [25, 50, 75]) x, _ = ax.lines[1].get_xydata().T npt.assert_array_equal(x, [q25, q75]) x, _ = ax.collections[0].get_offsets().T - nt.assert_equal(x, q50) + assert x == q50 plt.close("all") @@ -1320,7 +1406,7 @@ def test_validate_inner(self): kws = self.default_kws.copy() kws.update(dict(inner="bad_inner")) - with nt.assert_raises(ValueError): + with pytest.raises(ValueError): cat._ViolinPlotter(**kws) def test_draw_violinplots(self): @@ -1334,7 +1420,7 @@ def test_draw_violinplots(self): _, ax = plt.subplots() p.draw_violins(ax) - nt.assert_equal(len(ax.collections), 1) + assert len(ax.collections) == 1 npt.assert_array_equal(ax.collections[0].get_facecolors(), [(1, 0, 0, 1)]) plt.close("all") @@ -1345,7 +1431,7 @@ def test_draw_violinplots(self): _, ax = plt.subplots() p.draw_violins(ax) - nt.assert_equal(len(ax.collections), 1) + assert len(ax.collections) == 1 npt.assert_array_equal(ax.collections[0].get_facecolors(), [(0, 1, 0, 1)]) plt.close("all") @@ -1356,7 +1442,7 @@ def test_draw_violinplots(self): _, ax = plt.subplots() p.draw_violins(ax) - nt.assert_equal(len(ax.collections), 3) + assert len(ax.collections) == 3 for violin, color in zip(ax.collections, palettes.color_palette()): npt.assert_array_equal(violin.get_facecolors()[0, :-1], color) plt.close("all") @@ -1367,7 +1453,7 @@ def test_draw_violinplots(self): _, ax = plt.subplots() p.draw_violins(ax) - nt.assert_equal(len(ax.collections), 6) + assert len(ax.collections) == 6 for violin, color in zip(ax.collections, palettes.color_palette(n_colors=2) * 3): npt.assert_array_equal(violin.get_facecolors()[0, :-1], color) @@ -1379,7 +1465,7 @@ def test_draw_violinplots(self): _, ax = plt.subplots() p.draw_violins(ax) - nt.assert_equal(len(ax.collections), 6) + assert len(ax.collections) == 6 for violin, color in zip(ax.collections, palettes.color_palette("muted", n_colors=2) * 3): @@ -1400,8 +1486,8 @@ def test_draw_violinplots_no_observations(self): _, ax = plt.subplots() p.draw_violins(ax) - nt.assert_equal(len(ax.collections), 1) - nt.assert_equal(len(ax.lines), 0) + assert len(ax.collections) == 1 + assert len(ax.lines) == 0 plt.close("all") # Test nested hue grouping @@ -1413,8 +1499,8 @@ def test_draw_violinplots_no_observations(self): _, ax = plt.subplots() p.draw_violins(ax) - nt.assert_equal(len(ax.collections), 3) - nt.assert_equal(len(ax.lines), 0) + assert len(ax.collections) == 3 + assert len(ax.lines) == 0 plt.close("all") def test_draw_violinplots_single_observations(self): @@ -1430,8 +1516,8 @@ def test_draw_violinplots_single_observations(self): _, ax = plt.subplots() p.draw_violins(ax) - nt.assert_equal(len(ax.collections), 1) - nt.assert_equal(len(ax.lines), 1) + assert len(ax.collections) == 1 + assert len(ax.lines) == 1 plt.close("all") # Test nested hue grouping @@ -1443,8 +1529,8 @@ def test_draw_violinplots_single_observations(self): _, ax = plt.subplots() p.draw_violins(ax) - nt.assert_equal(len(ax.collections), 3) - nt.assert_equal(len(ax.lines), 1) + assert len(ax.collections) == 3 + assert len(ax.lines) == 1 plt.close("all") # Test nested hue grouping with split @@ -1453,396 +1539,650 @@ def test_draw_violinplots_single_observations(self): _, ax = plt.subplots() p.draw_violins(ax) - nt.assert_equal(len(ax.collections), 3) - nt.assert_equal(len(ax.lines), 1) + assert len(ax.collections) == 3 + assert len(ax.lines) == 1 plt.close("all") def test_violinplots(self): # Smoke test the high level violinplot options - cat.violinplot("y", data=self.df) + cat.violinplot(x="y", data=self.df) plt.close("all") cat.violinplot(y="y", data=self.df) plt.close("all") - cat.violinplot("g", "y", data=self.df) + cat.violinplot(x="g", y="y", data=self.df) plt.close("all") - cat.violinplot("y", "g", data=self.df, orient="h") + cat.violinplot(x="y", y="g", data=self.df, orient="h") plt.close("all") - cat.violinplot("g", "y", "h", data=self.df) + cat.violinplot(x="g", y="y", hue="h", data=self.df) plt.close("all") - cat.violinplot("g", "y", "h", order=list("nabc"), data=self.df) + order = list("nabc") + cat.violinplot(x="g", y="y", hue="h", order=order, data=self.df) plt.close("all") - cat.violinplot("g", "y", "h", hue_order=list("omn"), data=self.df) + order = list("omn") + cat.violinplot(x="g", y="y", hue="h", hue_order=order, data=self.df) plt.close("all") - cat.violinplot("y", "g", "h", data=self.df, orient="h") + cat.violinplot(x="y", y="g", hue="h", data=self.df, orient="h") plt.close("all") for inner in ["box", "quart", "point", "stick", None]: - cat.violinplot("g", "y", data=self.df, inner=inner) + cat.violinplot(x="g", y="y", data=self.df, inner=inner) plt.close("all") - cat.violinplot("g", "y", "h", data=self.df, inner=inner) + cat.violinplot(x="g", y="y", hue="h", data=self.df, inner=inner) plt.close("all") - cat.violinplot("g", "y", "h", data=self.df, + cat.violinplot(x="g", y="y", hue="h", data=self.df, inner=inner, split=True) plt.close("all") -class TestCategoricalScatterPlotter(CategoricalFixture): - - def test_group_point_colors(self): - - p = cat._CategoricalScatterPlotter() - - p.establish_variables(x="g", y="y", data=self.df) - p.establish_colors(None, "deep", 1) - - point_colors = p.point_colors - nt.assert_equal(len(point_colors), self.g.unique().size) - deep_colors = palettes.color_palette("deep", self.g.unique().size) - - for i, group_colors in enumerate(point_colors): - nt.assert_equal(tuple(deep_colors[i]), tuple(group_colors[0])) - for channel in group_colors.T: - assert np.unique(channel).size == 1 - - def test_hue_point_colors(self): - - p = cat._CategoricalScatterPlotter() - - hue_order = self.h.unique().tolist() - p.establish_variables(x="g", y="y", hue="h", - hue_order=hue_order, data=self.df) - p.establish_colors(None, "deep", 1) - - point_colors = p.point_colors - nt.assert_equal(len(point_colors), self.g.unique().size) - deep_colors = palettes.color_palette("deep", len(hue_order)) - - for i, group_colors in enumerate(point_colors): - for j, point_color in enumerate(group_colors): - hue_level = p.plot_hues[i][j] - nt.assert_equal(tuple(point_color), - deep_colors[hue_order.index(hue_level)]) +# ==================================================================================== +# ==================================================================================== - def test_scatterplot_legend(self): - p = cat._CategoricalScatterPlotter() +class SharedAxesLevelTests: - hue_order = ["m", "n"] - p.establish_variables(x="g", y="y", hue="h", - hue_order=hue_order, data=self.df) - p.establish_colors(None, "deep", 1) - deep_colors = palettes.color_palette("deep", self.h.unique().size) + def test_color(self, long_df): - f, ax = plt.subplots() - p.add_legend_data(ax) - leg = ax.legend() - - for i, t in enumerate(leg.get_texts()): - nt.assert_equal(t.get_text(), hue_order[i]) - - for i, h in enumerate(leg.legendHandles): - rgb = h.get_facecolor()[0, :3] - nt.assert_equal(tuple(rgb), tuple(deep_colors[i])) - - -class TestStripPlotter(CategoricalFixture): - - def test_stripplot_vertical(self): - - pal = palettes.color_palette() + ax = plt.figure().subplots() + self.func(data=long_df, x="a", y="y", ax=ax) + assert self.get_last_color(ax) == to_rgba("C0") - ax = cat.stripplot("g", "y", jitter=False, data=self.df) - for i, (_, vals) in enumerate(self.y.groupby(self.g)): + ax = plt.figure().subplots() + self.func(data=long_df, x="a", y="y", ax=ax) + self.func(data=long_df, x="a", y="y", ax=ax) + assert self.get_last_color(ax) == to_rgba("C1") - x, y = ax.collections[i].get_offsets().T + ax = plt.figure().subplots() + self.func(data=long_df, x="a", y="y", color="C2", ax=ax) + assert self.get_last_color(ax) == to_rgba("C2") - npt.assert_array_equal(x, np.ones(len(x)) * i) - npt.assert_array_equal(y, vals) + ax = plt.figure().subplots() + self.func(data=long_df, x="a", y="y", color="C3", ax=ax) + assert self.get_last_color(ax) == to_rgba("C3") - npt.assert_equal(ax.collections[i].get_facecolors()[0, :3], pal[i]) - def test_stripplot_horiztonal(self): +class SharedScatterTests(SharedAxesLevelTests): + """Tests functionality common to stripplot and swarmplot.""" - df = self.df.copy() - df.g = df.g.astype("category") + def get_last_color(self, ax): - ax = cat.stripplot("y", "g", jitter=False, data=df) - for i, (_, vals) in enumerate(self.y.groupby(self.g)): + colors = ax.collections[-1].get_facecolors() + unique_colors = np.unique(colors, axis=0) + assert len(unique_colors) == 1 + return to_rgba(unique_colors.squeeze()) - x, y = ax.collections[i].get_offsets().T + # ------------------------------------------------------------------------------ - npt.assert_array_equal(x, vals) - npt.assert_array_equal(y, np.ones(len(x)) * i) + def test_color(self, long_df): - def test_stripplot_jitter(self): + super().test_color(long_df) - pal = palettes.color_palette() + ax = plt.figure().subplots() + self.func(data=long_df, x="a", y="y", facecolor="C4", ax=ax) + assert self.get_last_color(ax) == to_rgba("C4") - ax = cat.stripplot("g", "y", data=self.df, jitter=True) - for i, (_, vals) in enumerate(self.y.groupby(self.g)): + if LooseVersion(mpl.__version__) >= "3.1.0": + # https://github.com/matplotlib/matplotlib/pull/12851 - x, y = ax.collections[i].get_offsets().T + ax = plt.figure().subplots() + self.func(data=long_df, x="a", y="y", fc="C5", ax=ax) + assert self.get_last_color(ax) == to_rgba("C5") - npt.assert_array_less(np.ones(len(x)) * i - .1, x) - npt.assert_array_less(x, np.ones(len(x)) * i + .1) - npt.assert_array_equal(y, vals) + def test_supplied_color_array(self, long_df): - npt.assert_equal(ax.collections[i].get_facecolors()[0, :3], pal[i]) + cmap = mpl.cm.get_cmap("Blues") + norm = mpl.colors.Normalize() + colors = cmap(norm(long_df["y"].to_numpy())) - def test_dodge_nested_stripplot_vertical(self): + keys = ["c", "facecolor", "facecolors"] - pal = palettes.color_palette() + if LooseVersion(mpl.__version__) >= "3.1.0": + # https://github.com/matplotlib/matplotlib/pull/12851 + keys.append("fc") - ax = cat.stripplot("g", "y", "h", data=self.df, - jitter=False, dodge=True) - for i, (_, group_vals) in enumerate(self.y.groupby(self.g)): - for j, (_, vals) in enumerate(group_vals.groupby(self.h)): + for key in keys: - x, y = ax.collections[i * 2 + j].get_offsets().T + ax = plt.figure().subplots() + self.func(x=long_df["y"], **{key: colors}) + _draw_figure(ax.figure) + assert_array_equal(ax.collections[0].get_facecolors(), colors) - npt.assert_array_equal(x, np.ones(len(x)) * i + [-.2, .2][j]) - npt.assert_array_equal(y, vals) + ax = plt.figure().subplots() + self.func(x=long_df["y"], c=long_df["y"], cmap=cmap) + _draw_figure(ax.figure) + assert_array_equal(ax.collections[0].get_facecolors(), colors) - fc = ax.collections[i * 2 + j].get_facecolors()[0, :3] - npt.assert_equal(fc, pal[j]) + @pytest.mark.parametrize( + "orient,data_type", + itertools.product(["h", "v"], ["dataframe", "dict"]), + ) + def test_wide(self, wide_df, orient, data_type): - def test_dodge_nested_stripplot_horizontal(self): + if data_type == "dict": + wide_df = {k: v.to_numpy() for k, v in wide_df.items()} - df = self.df.copy() - df.g = df.g.astype("category") + ax = self.func(data=wide_df, orient=orient) + _draw_figure(ax.figure) + palette = color_palette() - ax = cat.stripplot("y", "g", "h", data=df, - jitter=False, dodge=True) - for i, (_, group_vals) in enumerate(self.y.groupby(self.g)): - for j, (_, vals) in enumerate(group_vals.groupby(self.h)): + cat_idx = 0 if orient == "v" else 1 + val_idx = int(not cat_idx) - x, y = ax.collections[i * 2 + j].get_offsets().T + axis_objs = ax.xaxis, ax.yaxis + cat_axis = axis_objs[cat_idx] - npt.assert_array_equal(x, vals) - npt.assert_array_equal(y, np.ones(len(x)) * i + [-.2, .2][j]) + for i, label in enumerate(cat_axis.get_majorticklabels()): - def test_nested_stripplot_vertical(self): + key = label.get_text() + points = ax.collections[i] + point_pos = points.get_offsets().T + val_pos = point_pos[val_idx] + cat_pos = point_pos[cat_idx] - # Test a simple vertical strip plot - ax = cat.stripplot("g", "y", "h", data=self.df, - jitter=False, dodge=False) - for i, (_, group_vals) in enumerate(self.y.groupby(self.g)): + assert_array_equal(cat_pos.round(), i) + assert_array_equal(val_pos, wide_df[key]) - x, y = ax.collections[i].get_offsets().T + for point_color in points.get_facecolors(): + assert tuple(point_color) == to_rgba(palette[i]) - npt.assert_array_equal(x, np.ones(len(x)) * i) - npt.assert_array_equal(y, group_vals) + @pytest.mark.parametrize("orient", ["h", "v"]) + def test_flat(self, flat_series, orient): - def test_nested_stripplot_horizontal(self): + ax = self.func(data=flat_series, orient=orient) + _draw_figure(ax.figure) - df = self.df.copy() - df.g = df.g.astype("category") + cat_idx = 0 if orient == "v" else 1 + val_idx = int(not cat_idx) - ax = cat.stripplot("y", "g", "h", data=df, - jitter=False, dodge=False) - for i, (_, group_vals) in enumerate(self.y.groupby(self.g)): + axis_objs = ax.xaxis, ax.yaxis + cat_axis = axis_objs[cat_idx] - x, y = ax.collections[i].get_offsets().T + for i, label in enumerate(cat_axis.get_majorticklabels()): - npt.assert_array_equal(x, group_vals) - npt.assert_array_equal(y, np.ones(len(x)) * i) + points = ax.collections[i] + point_pos = points.get_offsets().T + val_pos = point_pos[val_idx] + cat_pos = point_pos[cat_idx] + + key = int(label.get_text()) # because fixture has integer index + assert_array_equal(val_pos, flat_series[key]) + assert_array_equal(cat_pos, i) + + @pytest.mark.parametrize( + "variables,orient", + [ + # Order matters for assigning to x/y + ({"cat": "a", "val": "y", "hue": None}, None), + ({"val": "y", "cat": "a", "hue": None}, None), + ({"cat": "a", "val": "y", "hue": "a"}, None), + ({"val": "y", "cat": "a", "hue": "a"}, None), + ({"cat": "a", "val": "y", "hue": "b"}, None), + ({"val": "y", "cat": "a", "hue": "x"}, None), + ({"cat": "s", "val": "y", "hue": None}, None), + ({"val": "y", "cat": "s", "hue": None}, "h"), + ({"cat": "a", "val": "b", "hue": None}, None), + ({"val": "a", "cat": "b", "hue": None}, "h"), + ({"cat": "a", "val": "t", "hue": None}, None), + ({"val": "t", "cat": "a", "hue": None}, None), + ({"cat": "d", "val": "y", "hue": None}, None), + ({"val": "y", "cat": "d", "hue": None}, None), + ({"cat": "a_cat", "val": "y", "hue": None}, None), + ({"val": "y", "cat": "s_cat", "hue": None}, None), + ], + ) + def test_positions(self, long_df, variables, orient): + + cat_var = variables["cat"] + val_var = variables["val"] + hue_var = variables["hue"] + var_names = list(variables.values()) + x_var, y_var, *_ = var_names + + ax = self.func( + data=long_df, x=x_var, y=y_var, hue=hue_var, orient=orient, + ) + + _draw_figure(ax.figure) + + cat_idx = var_names.index(cat_var) + val_idx = var_names.index(val_var) + + axis_objs = ax.xaxis, ax.yaxis + cat_axis = axis_objs[cat_idx] + val_axis = axis_objs[val_idx] + + cat_data = long_df[cat_var] + cat_levels = categorical_order(cat_data) + + for i, label in enumerate(cat_levels): + + vals = long_df.loc[cat_data == label, val_var] + + points = ax.collections[i].get_offsets().T + cat_pos = points[var_names.index(cat_var)] + val_pos = points[var_names.index(val_var)] + + assert_array_equal(val_pos, val_axis.convert_units(vals)) + assert_array_equal(cat_pos.round(), i) + assert 0 <= np.ptp(cat_pos) <= .8 + + label = pd.Index([label]).astype(str)[0] + assert cat_axis.get_majorticklabels()[i].get_text() == label + + @pytest.mark.parametrize( + "variables", + [ + # Order matters for assigning to x/y + {"cat": "a", "val": "y", "hue": "b"}, + {"val": "y", "cat": "a", "hue": "c"}, + {"cat": "a", "val": "y", "hue": "f"}, + ], + ) + def test_positions_dodged(self, long_df, variables): + + cat_var = variables["cat"] + val_var = variables["val"] + hue_var = variables["hue"] + var_names = list(variables.values()) + x_var, y_var, *_ = var_names + + ax = self.func( + data=long_df, x=x_var, y=y_var, hue=hue_var, dodge=True, + ) + + cat_vals = categorical_order(long_df[cat_var]) + hue_vals = categorical_order(long_df[hue_var]) + + n_hue = len(hue_vals) + offsets = np.linspace(0, .8, n_hue + 1)[:-1] + offsets -= offsets.mean() + nest_width = .8 / n_hue + + for i, cat_val in enumerate(cat_vals): + for j, hue_val in enumerate(hue_vals): + rows = (long_df[cat_var] == cat_val) & (long_df[hue_var] == hue_val) + vals = long_df.loc[rows, val_var] + + points = ax.collections[n_hue * i + j].get_offsets().T + cat_pos = points[var_names.index(cat_var)] + val_pos = points[var_names.index(val_var)] + + if pd.api.types.is_datetime64_any_dtype(vals): + vals = mpl.dates.date2num(vals) + + assert_array_equal(val_pos, vals) + + assert_array_equal(cat_pos.round(), i) + assert_array_equal((cat_pos - (i + offsets[j])).round() / nest_width, 0) + assert 0 <= np.ptp(cat_pos) <= nest_width + + @pytest.mark.parametrize("cat_var", ["a", "s", "d"]) + def test_positions_unfixed(self, long_df, cat_var): + + long_df = long_df.sort_values(cat_var) + + kws = dict(size=.001) + if "stripplot" in str(self.func): # can't use __name__ with partial + kws["jitter"] = False + + ax = self.func(data=long_df, x=cat_var, y="y", fixed_scale=False, **kws) + + for i, (cat_level, cat_data) in enumerate(long_df.groupby(cat_var)): + + points = ax.collections[i].get_offsets().T + cat_pos = points[0] + val_pos = points[1] + + assert_array_equal(val_pos, cat_data["y"]) - def test_three_strip_points(self): + comp_level = np.squeeze(ax.xaxis.convert_units(cat_level)).item() + assert_array_equal(cat_pos.round(), comp_level) - x = np.arange(3) - ax = cat.stripplot(x=x) - facecolors = ax.collections[0].get_facecolor() - nt.assert_equal(facecolors.shape, (3, 4)) - npt.assert_array_equal(facecolors[0], facecolors[1]) + @pytest.mark.parametrize( + "x_type,order", + [ + (str, None), + (str, ["a", "b", "c"]), + (str, ["c", "a"]), + (str, ["a", "b", "c", "d"]), + (int, None), + (int, [3, 1, 2]), + (int, [3, 1]), + (int, [1, 2, 3, 4]), + (int, ["3", "1", "2"]), + ] + ) + def test_order(self, x_type, order): + if x_type is str: + x = ["b", "a", "c"] + else: + x = [2, 1, 3] + y = [1, 2, 3] -class TestSwarmPlotter(CategoricalFixture): + ax = self.func(x=x, y=y, order=order) + _draw_figure(ax.figure) - default_kws = dict(x=None, y=None, hue=None, data=None, - order=None, hue_order=None, dodge=False, - orient=None, color=None, palette=None) + if order is None: + order = x + if x_type is int: + order = np.sort(order) - def test_could_overlap(self): + assert len(ax.collections) == len(order) + tick_labels = ax.xaxis.get_majorticklabels() - p = cat._SwarmPlotter(**self.default_kws) - neighbors = p.could_overlap((1, 1), [(0, 0), (1, .5), (.5, .5)], 1) - npt.assert_array_equal(neighbors, [(1, .5), (.5, .5)]) + assert ax.get_xlim()[1] == (len(order) - .5) - def test_position_candidates(self): + for i, points in enumerate(ax.collections): + cat = order[i] + assert tick_labels[i].get_text() == str(cat) - p = cat._SwarmPlotter(**self.default_kws) - xy_i = (0, 1) - neighbors = [(0, 1), (0, 1.5)] - candidates = p.position_candidates(xy_i, neighbors, 1) - dx1 = 1.05 - dx2 = np.sqrt(1 - .5 ** 2) * 1.05 - npt.assert_array_equal(candidates, - [(0, 1), (-dx1, 1), (dx1, 1), - (dx2, 1), (-dx2, 1)]) - - def test_find_first_non_overlapping_candidate(self): - - p = cat._SwarmPlotter(**self.default_kws) - candidates = [(.5, 1), (1, 1), (1.5, 1)] - neighbors = np.array([(0, 1)]) + positions = points.get_offsets() + if x_type(cat) in x: + val = y[x.index(x_type(cat))] + assert positions[0, 1] == val + else: + assert not positions.size - first = p.first_non_overlapping_candidate(candidates, neighbors, 1) - npt.assert_array_equal(first, (1, 1)) + @pytest.mark.parametrize("hue_var", ["a", "b"]) + def test_hue_categorical(self, long_df, hue_var): - def test_beeswarm(self): + cat_var = "b" - p = cat._SwarmPlotter(**self.default_kws) - d = self.y.diff().mean() * 1.5 - x = np.zeros(self.y.size) - y = np.sort(self.y) - orig_xy = np.c_[x, y] - swarm = p.beeswarm(orig_xy, d) - dmat = spatial.distance.cdist(swarm, swarm) - triu = dmat[np.triu_indices_from(dmat, 1)] - npt.assert_array_less(d, triu) - npt.assert_array_equal(y, swarm[:, 1]) + hue_levels = categorical_order(long_df[hue_var]) + cat_levels = categorical_order(long_df[cat_var]) - def test_add_gutters(self): + pal_name = "muted" + palette = dict(zip(hue_levels, color_palette(pal_name))) + ax = self.func(data=long_df, x=cat_var, y="y", hue=hue_var, palette=pal_name) - p = cat._SwarmPlotter(**self.default_kws) - points = np.array([0, -1, .4, .8]) - points = p.add_gutters(points, 0, 1) - npt.assert_array_equal(points, - np.array([0, -.5, .4, .5])) + for i, level in enumerate(cat_levels): - def test_swarmplot_vertical(self): + sub_df = long_df[long_df[cat_var] == level] + point_hues = sub_df[hue_var] - pal = palettes.color_palette() + points = ax.collections[i] + point_colors = points.get_facecolors() - ax = cat.swarmplot("g", "y", data=self.df) - for i, (_, vals) in enumerate(self.y.groupby(self.g)): + assert len(point_hues) == len(point_colors) - x, y = ax.collections[i].get_offsets().T - npt.assert_array_almost_equal(y, np.sort(vals)) + for hue, color in zip(point_hues, point_colors): + assert tuple(color) == to_rgba(palette[hue]) - fc = ax.collections[i].get_facecolors()[0, :3] - npt.assert_equal(fc, pal[i]) + @pytest.mark.parametrize("hue_var", ["a", "b"]) + def test_hue_dodged(self, long_df, hue_var): - def test_swarmplot_horizontal(self): + ax = self.func(data=long_df, x="y", y="a", hue=hue_var, dodge=True) + colors = color_palette(n_colors=long_df[hue_var].nunique()) + collections = iter(ax.collections) - pal = palettes.color_palette() + # Slightly awkward logic to handle challenges of how the artists work. + # e.g. there are empty scatter collections but the because facecolors + # for the empty collections will return the default scatter color + while colors: + points = next(collections) + if points.get_offsets().any(): + face_color = tuple(points.get_facecolors()[0]) + expected_color = to_rgba(colors.pop(0)) + assert face_color == expected_color - ax = cat.swarmplot("y", "g", data=self.df, orient="h") - for i, (_, vals) in enumerate(self.y.groupby(self.g)): + @pytest.mark.parametrize( + "val_var,val_col,hue_col", + itertools.product(["x", "y"], ["b", "y", "t"], [None, "a"]), + ) + def test_single(self, long_df, val_var, val_col, hue_col): - x, y = ax.collections[i].get_offsets().T - npt.assert_array_almost_equal(x, np.sort(vals)) + var_kws = {val_var: val_col, "hue": hue_col} + ax = self.func(data=long_df, **var_kws) + _draw_figure(ax.figure) - fc = ax.collections[i].get_facecolors()[0, :3] - npt.assert_equal(fc, pal[i]) + axis_vars = ["x", "y"] + val_idx = axis_vars.index(val_var) + cat_idx = int(not val_idx) + cat_var = axis_vars[cat_idx] - def test_dodge_nested_swarmplot_vetical(self): + cat_axis = getattr(ax, f"{cat_var}axis") + val_axis = getattr(ax, f"{val_var}axis") - pal = palettes.color_palette() + points = ax.collections[0] + point_pos = points.get_offsets().T + cat_pos = point_pos[cat_idx] + val_pos = point_pos[val_idx] + + assert_array_equal(cat_pos.round(), 0) + assert cat_pos.max() <= .4 + assert cat_pos.min() >= -.4 + + num_vals = val_axis.convert_units(long_df[val_col]) + assert_array_equal(val_pos, num_vals) + + if hue_col is not None: + palette = dict(zip( + categorical_order(long_df[hue_col]), color_palette() + )) + + facecolors = points.get_facecolors() + for i, color in enumerate(facecolors): + if hue_col is None: + assert tuple(color) == to_rgba("C0") + else: + hue_level = long_df.loc[i, hue_col] + expected_color = palette[hue_level] + assert tuple(color) == to_rgba(expected_color) - ax = cat.swarmplot("g", "y", "h", data=self.df, dodge=True) - for i, (_, group_vals) in enumerate(self.y.groupby(self.g)): - for j, (_, vals) in enumerate(group_vals.groupby(self.h)): + ticklabels = cat_axis.get_majorticklabels() + assert len(ticklabels) == 1 + assert not ticklabels[0].get_text() - x, y = ax.collections[i * 2 + j].get_offsets().T - npt.assert_array_almost_equal(y, np.sort(vals)) + def test_attributes(self, long_df): - fc = ax.collections[i * 2 + j].get_facecolors()[0, :3] - npt.assert_equal(fc, pal[j]) + kwargs = dict( + size=2, + linewidth=1, + edgecolor="C2", + ) - def test_dodge_nested_swarmplot_horizontal(self): + ax = self.func(x=long_df["y"], **kwargs) + points, = ax.collections - pal = palettes.color_palette() + assert points.get_sizes().item() == kwargs["size"] ** 2 + assert points.get_linewidths().item() == kwargs["linewidth"] + assert tuple(points.get_edgecolors().squeeze()) == to_rgba(kwargs["edgecolor"]) - ax = cat.swarmplot("y", "g", "h", data=self.df, orient="h", dodge=True) - for i, (_, group_vals) in enumerate(self.y.groupby(self.g)): - for j, (_, vals) in enumerate(group_vals.groupby(self.h)): + def test_three_points(self): - x, y = ax.collections[i * 2 + j].get_offsets().T - npt.assert_array_almost_equal(x, np.sort(vals)) + x = np.arange(3) + ax = self.func(x=x) + for point_color in ax.collections[0].get_facecolor(): + assert tuple(point_color) == to_rgba("C0") - fc = ax.collections[i * 2 + j].get_facecolors()[0, :3] - npt.assert_equal(fc, pal[j]) + def test_palette_from_color_deprecation(self, long_df): - def test_nested_swarmplot_vertical(self): + color = (.9, .4, .5) + hex_color = mpl.colors.to_hex(color) - ax = cat.swarmplot("g", "y", "h", data=self.df) + hue_var = "a" + n_hue = long_df[hue_var].nunique() + palette = color_palette(f"dark:{hex_color}", n_hue) - pal = palettes.color_palette() - hue_names = self.h.unique().tolist() - grouped_hues = list(self.h.groupby(self.g)) + with pytest.warns(FutureWarning, match="Setting a gradient palette"): + ax = self.func(data=long_df, x="z", hue=hue_var, color=color) - for i, (_, vals) in enumerate(self.y.groupby(self.g)): + points = ax.collections[0] + for point_color in points.get_facecolors(): + assert to_rgb(point_color) in palette - points = ax.collections[i] - x, y = points.get_offsets().T - sorter = np.argsort(vals) - npt.assert_array_almost_equal(y, vals.iloc[sorter]) + def test_log_scale(self): + + x = [1, 10, 100, 1000] + + ax = plt.figure().subplots() + ax.set_xscale("log") + self.func(x=x) + vals = ax.collections[0].get_offsets()[:, 0] + assert_array_equal(x, vals) + + y = [1, 2, 3, 4] + + ax = plt.figure().subplots() + ax.set_xscale("log") + self.func(x=x, y=y, fixed_scale=False) + for i, point in enumerate(ax.collections): + val = point.get_offsets()[0, 0] + assert val == pytest.approx(x[i]) + + x = y = np.ones(100) + + # Following test fails on pinned (but not latest) matplotlib. + # (Even though visual output is ok -- so it's not an actual bug). + # I'm not exactly sure why, so this version check is approximate + # and should be revisited on a version bump. + if LooseVersion(mpl.__version__) < "3.1": + pytest.xfail() + + ax = plt.figure().subplots() + ax.set_yscale("log") + self.func(x=x, y=y, orient="h", fixed_scale=False) + cat_points = ax.collections[0].get_offsets().copy()[:, 1] + assert np.ptp(np.log10(cat_points)) <= .8 + + @pytest.mark.parametrize( + "kwargs", + [ + dict(data="wide"), + dict(data="wide", orient="h"), + dict(data="long", x="x", color="C3"), + dict(data="long", y="y", hue="a", jitter=False), + # TODO XXX full numeric hue legend crashes pinned mpl, disabling for now + # dict(data="long", x="a", y="y", hue="z", edgecolor="w", linewidth=.5), + # dict(data="long", x="a_cat", y="y", hue="z"), + dict(data="long", x="y", y="s", hue="c", orient="h", dodge=True), + dict(data="long", x="s", y="y", hue="c", fixed_scale=False), + ] + ) + def test_vs_catplot(self, long_df, wide_df, kwargs): + + kwargs = kwargs.copy() + if kwargs["data"] == "long": + kwargs["data"] = long_df + elif kwargs["data"] == "wide": + kwargs["data"] = wide_df + + try: + name = self.func.__name__[:-4] + except AttributeError: + name = self.func.func.__name__[:-4] + if name == "swarm": + kwargs.pop("jitter", None) + + np.random.seed(0) # for jitter + ax = self.func(**kwargs) + + np.random.seed(0) + g = catplot(**kwargs, kind=name) + + assert_plots_equal(ax, g.ax) + + +class TestStripPlot(SharedScatterTests): + + func = staticmethod(stripplot) + + def test_jitter_unfixed(self, long_df): + + ax1, ax2 = plt.figure().subplots(2) + kws = dict(data=long_df, x="y", orient="h", fixed_scale=False) + + np.random.seed(0) + stripplot(**kws, y="s", ax=ax1) + + np.random.seed(0) + stripplot(**kws, y=long_df["s"] * 2, ax=ax2) + + p1 = ax1.collections[0].get_offsets()[1] + p2 = ax2.collections[0].get_offsets()[1] + + assert p2.std() > p1.std() + + @pytest.mark.parametrize( + "orient,jitter", + itertools.product(["v", "h"], [True, .1]), + ) + def test_jitter(self, long_df, orient, jitter): + + cat_var, val_var = "a", "y" + if orient == "v": + x_var, y_var = cat_var, val_var + cat_idx, val_idx = 0, 1 + else: + x_var, y_var = val_var, cat_var + cat_idx, val_idx = 1, 0 - _, hue_vals = grouped_hues[i] - for hue, fc in zip(hue_vals.values[sorter.values], - points.get_facecolors()): + cat_vals = categorical_order(long_df[cat_var]) - npt.assert_equal(fc[:3], pal[hue_names.index(hue)]) + ax = stripplot( + data=long_df, x=x_var, y=y_var, jitter=jitter, + ) - def test_nested_swarmplot_horizontal(self): + if jitter is True: + jitter_range = .4 + else: + jitter_range = 2 * jitter - ax = cat.swarmplot("y", "g", "h", data=self.df, orient="h") + for i, level in enumerate(cat_vals): - pal = palettes.color_palette() - hue_names = self.h.unique().tolist() - grouped_hues = list(self.h.groupby(self.g)) + vals = long_df.loc[long_df[cat_var] == level, val_var] + points = ax.collections[i].get_offsets().T + cat_points = points[cat_idx] + val_points = points[val_idx] - for i, (_, vals) in enumerate(self.y.groupby(self.g)): + assert_array_equal(val_points, vals) + assert np.std(cat_points) > 0 + assert np.ptp(cat_points) <= jitter_range - points = ax.collections[i] - x, y = points.get_offsets().T - sorter = np.argsort(vals) - npt.assert_array_almost_equal(x, vals.iloc[sorter]) - _, hue_vals = grouped_hues[i] - for hue, fc in zip(hue_vals.values[sorter.values], - points.get_facecolors()): +class TestSwarmPlot(SharedScatterTests): - npt.assert_equal(fc[:3], pal[hue_names.index(hue)]) + func = staticmethod(partial(swarmplot, warn_thresh=1)) class TestBarPlotter(CategoricalFixture): - default_kws = dict(x=None, y=None, hue=None, data=None, - estimator=np.mean, ci=95, n_boot=100, units=None, - order=None, hue_order=None, - orient=None, color=None, palette=None, - saturation=.75, errcolor=".26", errwidth=None, - capsize=None, dodge=True) + default_kws = dict( + x=None, y=None, hue=None, data=None, + estimator=np.mean, ci=95, n_boot=100, units=None, seed=None, + order=None, hue_order=None, + orient=None, color=None, palette=None, + saturation=.75, errcolor=".26", errwidth=None, + capsize=None, dodge=True + ) def test_nested_width(self): kws = self.default_kws.copy() p = cat._BarPlotter(**kws) - p.establish_variables("g", "y", "h", data=self.df) - nt.assert_equal(p.nested_width, .8 / 2) + p.establish_variables("g", "y", hue="h", data=self.df) + assert p.nested_width == .8 / 2 p = cat._BarPlotter(**kws) p.establish_variables("h", "y", "g", data=self.df) - nt.assert_equal(p.nested_width, .8 / 3) + assert p.nested_width == .8 / 3 kws["dodge"] = False p = cat._BarPlotter(**kws) p.establish_variables("h", "y", "g", data=self.df) - nt.assert_equal(p.nested_width, .8) + assert p.nested_width == .8 def test_draw_vertical_bars(self): @@ -1853,22 +2193,18 @@ def test_draw_vertical_bars(self): f, ax = plt.subplots() p.draw_bars(ax, {}) - nt.assert_equal(len(ax.patches), len(p.plot_data)) - nt.assert_equal(len(ax.lines), len(p.plot_data)) + assert len(ax.patches) == len(p.plot_data) + assert len(ax.lines) == len(p.plot_data) for bar, color in zip(ax.patches, p.colors): - nt.assert_equal(bar.get_facecolor()[:-1], color) + assert bar.get_facecolor()[:-1] == color positions = np.arange(len(p.plot_data)) - p.width / 2 for bar, pos, stat in zip(ax.patches, positions, p.statistic): - nt.assert_equal(bar.get_x(), pos) - nt.assert_equal(bar.get_width(), p.width) - if mpl.__version__ >= mpl_barplot_change: - nt.assert_equal(bar.get_y(), 0) - nt.assert_equal(bar.get_height(), stat) - else: - nt.assert_equal(bar.get_y(), min(0, stat)) - nt.assert_equal(bar.get_height(), abs(stat)) + assert bar.get_x() == pos + assert bar.get_width() == p.width + assert bar.get_y() == 0 + assert bar.get_height() == stat def test_draw_horizontal_bars(self): @@ -1879,22 +2215,18 @@ def test_draw_horizontal_bars(self): f, ax = plt.subplots() p.draw_bars(ax, {}) - nt.assert_equal(len(ax.patches), len(p.plot_data)) - nt.assert_equal(len(ax.lines), len(p.plot_data)) + assert len(ax.patches) == len(p.plot_data) + assert len(ax.lines) == len(p.plot_data) for bar, color in zip(ax.patches, p.colors): - nt.assert_equal(bar.get_facecolor()[:-1], color) + assert bar.get_facecolor()[:-1] == color positions = np.arange(len(p.plot_data)) - p.width / 2 for bar, pos, stat in zip(ax.patches, positions, p.statistic): - nt.assert_equal(bar.get_y(), pos) - nt.assert_equal(bar.get_height(), p.width) - if mpl.__version__ >= mpl_barplot_change: - nt.assert_equal(bar.get_x(), 0) - nt.assert_equal(bar.get_width(), stat) - else: - nt.assert_equal(bar.get_x(), min(0, stat)) - nt.assert_equal(bar.get_width(), abs(stat)) + assert bar.get_y() == pos + assert bar.get_height() == p.width + assert bar.get_x() == 0 + assert bar.get_width() == stat def test_draw_nested_vertical_bars(self): @@ -1906,26 +2238,22 @@ def test_draw_nested_vertical_bars(self): p.draw_bars(ax, {}) n_groups, n_hues = len(p.plot_data), len(p.hue_names) - nt.assert_equal(len(ax.patches), n_groups * n_hues) - nt.assert_equal(len(ax.lines), n_groups * n_hues) + assert len(ax.patches) == n_groups * n_hues + assert len(ax.lines) == n_groups * n_hues for bar in ax.patches[:n_groups]: - nt.assert_equal(bar.get_facecolor()[:-1], p.colors[0]) + assert bar.get_facecolor()[:-1] == p.colors[0] for bar in ax.patches[n_groups:]: - nt.assert_equal(bar.get_facecolor()[:-1], p.colors[1]) + assert bar.get_facecolor()[:-1] == p.colors[1] positions = np.arange(len(p.plot_data)) for bar, pos in zip(ax.patches[:n_groups], positions): - nt.assert_almost_equal(bar.get_x(), pos - p.width / 2) - nt.assert_almost_equal(bar.get_width(), p.nested_width) + assert bar.get_x() == approx(pos - p.width / 2) + assert bar.get_width() == approx(p.nested_width) for bar, stat in zip(ax.patches, p.statistic.T.flat): - if LooseVersion(mpl.__version__) >= mpl_barplot_change: - nt.assert_almost_equal(bar.get_y(), 0) - nt.assert_almost_equal(bar.get_height(), stat) - else: - nt.assert_almost_equal(bar.get_y(), min(0, stat)) - nt.assert_almost_equal(bar.get_height(), abs(stat)) + assert bar.get_y() == approx(0) + assert bar.get_height() == approx(stat) def test_draw_nested_horizontal_bars(self): @@ -1937,26 +2265,22 @@ def test_draw_nested_horizontal_bars(self): p.draw_bars(ax, {}) n_groups, n_hues = len(p.plot_data), len(p.hue_names) - nt.assert_equal(len(ax.patches), n_groups * n_hues) - nt.assert_equal(len(ax.lines), n_groups * n_hues) + assert len(ax.patches) == n_groups * n_hues + assert len(ax.lines) == n_groups * n_hues for bar in ax.patches[:n_groups]: - nt.assert_equal(bar.get_facecolor()[:-1], p.colors[0]) + assert bar.get_facecolor()[:-1] == p.colors[0] for bar in ax.patches[n_groups:]: - nt.assert_equal(bar.get_facecolor()[:-1], p.colors[1]) + assert bar.get_facecolor()[:-1] == p.colors[1] positions = np.arange(len(p.plot_data)) for bar, pos in zip(ax.patches[:n_groups], positions): - nt.assert_almost_equal(bar.get_y(), pos - p.width / 2) - nt.assert_almost_equal(bar.get_height(), p.nested_width) + assert bar.get_y() == approx(pos - p.width / 2) + assert bar.get_height() == approx(p.nested_width) for bar, stat in zip(ax.patches, p.statistic.T.flat): - if LooseVersion(mpl.__version__) >= mpl_barplot_change: - nt.assert_almost_equal(bar.get_x(), 0) - nt.assert_almost_equal(bar.get_width(), stat) - else: - nt.assert_almost_equal(bar.get_x(), min(0, stat)) - nt.assert_almost_equal(bar.get_width(), abs(stat)) + assert bar.get_x() == approx(0) + assert bar.get_width() == approx(stat) def test_draw_missing_bars(self): @@ -1969,8 +2293,8 @@ def test_draw_missing_bars(self): f, ax = plt.subplots() p.draw_bars(ax, {}) - nt.assert_equal(len(ax.patches), len(order)) - nt.assert_equal(len(ax.lines), len(order)) + assert len(ax.patches) == len(order) + assert len(ax.lines) == len(order) plt.close("all") @@ -1981,11 +2305,36 @@ def test_draw_missing_bars(self): f, ax = plt.subplots() p.draw_bars(ax, {}) - nt.assert_equal(len(ax.patches), len(p.plot_data) * len(hue_order)) - nt.assert_equal(len(ax.lines), len(p.plot_data) * len(hue_order)) + assert len(ax.patches) == len(p.plot_data) * len(hue_order) + assert len(ax.lines) == len(p.plot_data) * len(hue_order) plt.close("all") + def test_unaligned_index(self): + + f, (ax1, ax2) = plt.subplots(2) + cat.barplot(x=self.g, y=self.y, ci="sd", ax=ax1) + cat.barplot(x=self.g, y=self.y_perm, ci="sd", ax=ax2) + for l1, l2 in zip(ax1.lines, ax2.lines): + assert approx(l1.get_xydata()) == l2.get_xydata() + for p1, p2 in zip(ax1.patches, ax2.patches): + assert approx(p1.get_xy()) == p2.get_xy() + assert approx(p1.get_height()) == p2.get_height() + assert approx(p1.get_width()) == p2.get_width() + + f, (ax1, ax2) = plt.subplots(2) + hue_order = self.h.unique() + cat.barplot(x=self.g, y=self.y, hue=self.h, + hue_order=hue_order, ci="sd", ax=ax1) + cat.barplot(x=self.g, y=self.y_perm, hue=self.h, + hue_order=hue_order, ci="sd", ax=ax2) + for l1, l2 in zip(ax1.lines, ax2.lines): + assert approx(l1.get_xydata()) == l2.get_xydata() + for p1, p2 in zip(ax1.patches, ax2.patches): + assert approx(p1.get_xy()) == p2.get_xy() + assert approx(p1.get_height()) == p2.get_height() + assert approx(p1.get_width()) == p2.get_width() + def test_barplot_colors(self): # Test unnested palette colors @@ -1999,7 +2348,7 @@ def test_barplot_colors(self): palette = palettes.color_palette("muted", len(self.g.unique())) for patch, pal_color in zip(ax.patches, palette): - nt.assert_equal(patch.get_facecolor()[:-1], pal_color) + assert patch.get_facecolor()[:-1] == pal_color plt.close("all") @@ -2014,7 +2363,7 @@ def test_barplot_colors(self): p.draw_bars(ax, {}) for patch in ax.patches: - nt.assert_equal(patch.get_facecolor(), color) + assert patch.get_facecolor() == color plt.close("all") @@ -2029,49 +2378,49 @@ def test_barplot_colors(self): palette = palettes.color_palette("Set2", len(self.h.unique())) for patch in ax.patches[:len(self.g.unique())]: - nt.assert_equal(patch.get_facecolor()[:-1], palette[0]) + assert patch.get_facecolor()[:-1] == palette[0] for patch in ax.patches[len(self.g.unique()):]: - nt.assert_equal(patch.get_facecolor()[:-1], palette[1]) + assert patch.get_facecolor()[:-1] == palette[1] plt.close("all") def test_simple_barplots(self): - ax = cat.barplot("g", "y", data=self.df) - nt.assert_equal(len(ax.patches), len(self.g.unique())) - nt.assert_equal(ax.get_xlabel(), "g") - nt.assert_equal(ax.get_ylabel(), "y") + ax = cat.barplot(x="g", y="y", data=self.df) + assert len(ax.patches) == len(self.g.unique()) + assert ax.get_xlabel() == "g" + assert ax.get_ylabel() == "y" plt.close("all") - ax = cat.barplot("y", "g", orient="h", data=self.df) - nt.assert_equal(len(ax.patches), len(self.g.unique())) - nt.assert_equal(ax.get_xlabel(), "y") - nt.assert_equal(ax.get_ylabel(), "g") + ax = cat.barplot(x="y", y="g", orient="h", data=self.df) + assert len(ax.patches) == len(self.g.unique()) + assert ax.get_xlabel() == "y" + assert ax.get_ylabel() == "g" plt.close("all") - ax = cat.barplot("g", "y", "h", data=self.df) - nt.assert_equal(len(ax.patches), - len(self.g.unique()) * len(self.h.unique())) - nt.assert_equal(ax.get_xlabel(), "g") - nt.assert_equal(ax.get_ylabel(), "y") + ax = cat.barplot(x="g", y="y", hue="h", data=self.df) + assert len(ax.patches) == len(self.g.unique()) * len(self.h.unique()) + assert ax.get_xlabel() == "g" + assert ax.get_ylabel() == "y" plt.close("all") - ax = cat.barplot("y", "g", "h", orient="h", data=self.df) - nt.assert_equal(len(ax.patches), - len(self.g.unique()) * len(self.h.unique())) - nt.assert_equal(ax.get_xlabel(), "y") - nt.assert_equal(ax.get_ylabel(), "g") + ax = cat.barplot(x="y", y="g", hue="h", orient="h", data=self.df) + assert len(ax.patches) == len(self.g.unique()) * len(self.h.unique()) + assert ax.get_xlabel() == "y" + assert ax.get_ylabel() == "g" plt.close("all") class TestPointPlotter(CategoricalFixture): - default_kws = dict(x=None, y=None, hue=None, data=None, - estimator=np.mean, ci=95, n_boot=100, units=None, - order=None, hue_order=None, - markers="o", linestyles="-", dodge=0, - join=True, scale=1, - orient=None, color=None, palette=None) + default_kws = dict( + x=None, y=None, hue=None, data=None, + estimator=np.mean, ci=95, n_boot=100, units=None, seed=None, + order=None, hue_order=None, + markers="o", linestyles="-", dodge=0, + join=True, scale=1, + orient=None, color=None, palette=None, + ) def test_different_defualt_colors(self): @@ -2113,10 +2462,10 @@ def test_draw_vertical_points(self): f, ax = plt.subplots() p.draw_points(ax) - nt.assert_equal(len(ax.collections), 1) - nt.assert_equal(len(ax.lines), len(p.plot_data) + 1) + assert len(ax.collections) == 1 + assert len(ax.lines) == len(p.plot_data) + 1 points = ax.collections[0] - nt.assert_equal(len(points.get_offsets()), len(p.plot_data)) + assert len(points.get_offsets()) == len(p.plot_data) x, y = points.get_offsets().T npt.assert_array_equal(x, np.arange(len(p.plot_data))) @@ -2135,10 +2484,10 @@ def test_draw_horizontal_points(self): f, ax = plt.subplots() p.draw_points(ax) - nt.assert_equal(len(ax.collections), 1) - nt.assert_equal(len(ax.lines), len(p.plot_data) + 1) + assert len(ax.collections) == 1 + assert len(ax.lines) == len(p.plot_data) + 1 points = ax.collections[0] - nt.assert_equal(len(points.get_offsets()), len(p.plot_data)) + assert len(points.get_offsets()) == len(p.plot_data) x, y = points.get_offsets().T npt.assert_array_equal(x, p.statistic) @@ -2157,15 +2506,14 @@ def test_draw_vertical_nested_points(self): f, ax = plt.subplots() p.draw_points(ax) - nt.assert_equal(len(ax.collections), 2) - nt.assert_equal(len(ax.lines), - len(p.plot_data) * len(p.hue_names) + len(p.hue_names)) + assert len(ax.collections) == 2 + assert len(ax.lines) == len(p.plot_data) * len(p.hue_names) + len(p.hue_names) for points, numbers, color in zip(ax.collections, p.statistic.T, p.colors): - nt.assert_equal(len(points.get_offsets()), len(p.plot_data)) + assert len(points.get_offsets()) == len(p.plot_data) x, y = points.get_offsets().T npt.assert_array_equal(x, np.arange(len(p.plot_data))) @@ -2183,15 +2531,14 @@ def test_draw_horizontal_nested_points(self): f, ax = plt.subplots() p.draw_points(ax) - nt.assert_equal(len(ax.collections), 2) - nt.assert_equal(len(ax.lines), - len(p.plot_data) * len(p.hue_names) + len(p.hue_names)) + assert len(ax.collections) == 2 + assert len(ax.lines) == len(p.plot_data) * len(p.hue_names) + len(p.hue_names) for points, numbers, color in zip(ax.collections, p.statistic.T, p.colors): - nt.assert_equal(len(points.get_offsets()), len(p.plot_data)) + assert len(points.get_offsets()) == len(p.plot_data) x, y = points.get_offsets().T npt.assert_array_equal(x, numbers) @@ -2200,6 +2547,43 @@ def test_draw_horizontal_nested_points(self): for got_color in points.get_facecolors(): npt.assert_array_equal(got_color[:-1], color) + def test_draw_missing_points(self): + + kws = self.default_kws.copy() + df = self.df.copy() + + kws.update(x="g", y="y", hue="h", hue_order=["x", "y"], data=df) + p = cat._PointPlotter(**kws) + f, ax = plt.subplots() + p.draw_points(ax) + + df.loc[df["h"] == "m", "y"] = np.nan + kws.update(x="g", y="y", hue="h", data=df) + p = cat._PointPlotter(**kws) + f, ax = plt.subplots() + p.draw_points(ax) + + def test_unaligned_index(self): + + f, (ax1, ax2) = plt.subplots(2) + cat.pointplot(x=self.g, y=self.y, ci="sd", ax=ax1) + cat.pointplot(x=self.g, y=self.y_perm, ci="sd", ax=ax2) + for l1, l2 in zip(ax1.lines, ax2.lines): + assert approx(l1.get_xydata()) == l2.get_xydata() + for p1, p2 in zip(ax1.collections, ax2.collections): + assert approx(p1.get_offsets()) == p2.get_offsets() + + f, (ax1, ax2) = plt.subplots(2) + hue_order = self.h.unique() + cat.pointplot(x=self.g, y=self.y, hue=self.h, + hue_order=hue_order, ci="sd", ax=ax1) + cat.pointplot(x=self.g, y=self.y_perm, hue=self.h, + hue_order=hue_order, ci="sd", ax=ax2) + for l1, l2 in zip(ax1.lines, ax2.lines): + assert approx(l1.get_xydata()) == l2.get_xydata() + for p1, p2 in zip(ax1.collections, ax2.collections): + assert approx(p1.get_offsets()) == p2.get_offsets() + def test_pointplot_colors(self): # Test a single-color unnested plot @@ -2212,7 +2596,7 @@ def test_pointplot_colors(self): p.draw_points(ax) for line in ax.lines: - nt.assert_equal(line.get_color(), color[:-1]) + assert line.get_color() == color[:-1] for got_color in ax.collections[0].get_facecolors(): npt.assert_array_equal(rgb2hex(got_color), rgb2hex(color)) @@ -2224,7 +2608,7 @@ def test_pointplot_colors(self): kws.update(x="g", y="y", data=self.df, palette="Set1") p = cat._PointPlotter(**kws) - nt.assert_true(not p.join) + assert not p.join f, ax = plt.subplots() p.draw_points(ax) @@ -2247,9 +2631,9 @@ def test_pointplot_colors(self): p.draw_points(ax) for line in ax.lines[:(len(p.plot_data) + 1)]: - nt.assert_equal(line.get_color(), palette[0]) + assert line.get_color() == palette[0] for line in ax.lines[(len(p.plot_data) + 1):]: - nt.assert_equal(line.get_color(), palette[1]) + assert line.get_color() == palette[1] for i, pal_color in enumerate(palette): for point_color in ax.collections[i].get_facecolors(): @@ -2259,38 +2643,36 @@ def test_pointplot_colors(self): def test_simple_pointplots(self): - ax = cat.pointplot("g", "y", data=self.df) - nt.assert_equal(len(ax.collections), 1) - nt.assert_equal(len(ax.lines), len(self.g.unique()) + 1) - nt.assert_equal(ax.get_xlabel(), "g") - nt.assert_equal(ax.get_ylabel(), "y") + ax = cat.pointplot(x="g", y="y", data=self.df) + assert len(ax.collections) == 1 + assert len(ax.lines) == len(self.g.unique()) + 1 + assert ax.get_xlabel() == "g" + assert ax.get_ylabel() == "y" plt.close("all") - ax = cat.pointplot("y", "g", orient="h", data=self.df) - nt.assert_equal(len(ax.collections), 1) - nt.assert_equal(len(ax.lines), len(self.g.unique()) + 1) - nt.assert_equal(ax.get_xlabel(), "y") - nt.assert_equal(ax.get_ylabel(), "g") + ax = cat.pointplot(x="y", y="g", orient="h", data=self.df) + assert len(ax.collections) == 1 + assert len(ax.lines) == len(self.g.unique()) + 1 + assert ax.get_xlabel() == "y" + assert ax.get_ylabel() == "g" plt.close("all") - ax = cat.pointplot("g", "y", "h", data=self.df) - nt.assert_equal(len(ax.collections), len(self.h.unique())) - nt.assert_equal(len(ax.lines), - (len(self.g.unique()) * - len(self.h.unique()) + - len(self.h.unique()))) - nt.assert_equal(ax.get_xlabel(), "g") - nt.assert_equal(ax.get_ylabel(), "y") + ax = cat.pointplot(x="g", y="y", hue="h", data=self.df) + assert len(ax.collections) == len(self.h.unique()) + assert len(ax.lines) == ( + len(self.g.unique()) * len(self.h.unique()) + len(self.h.unique()) + ) + assert ax.get_xlabel() == "g" + assert ax.get_ylabel() == "y" plt.close("all") - ax = cat.pointplot("y", "g", "h", orient="h", data=self.df) - nt.assert_equal(len(ax.collections), len(self.h.unique())) - nt.assert_equal(len(ax.lines), - (len(self.g.unique()) * - len(self.h.unique()) + - len(self.h.unique()))) - nt.assert_equal(ax.get_xlabel(), "y") - nt.assert_equal(ax.get_ylabel(), "g") + ax = cat.pointplot(x="y", y="g", hue="h", orient="h", data=self.df) + assert len(ax.collections) == len(self.h.unique()) + assert len(ax.lines) == ( + len(self.g.unique()) * len(self.h.unique()) + len(self.h.unique()) + ) + assert ax.get_xlabel() == "y" + assert ax.get_ylabel() == "g" plt.close("all") @@ -2298,38 +2680,31 @@ class TestCountPlot(CategoricalFixture): def test_plot_elements(self): - ax = cat.countplot("g", data=self.df) - nt.assert_equal(len(ax.patches), self.g.unique().size) + ax = cat.countplot(x="g", data=self.df) + assert len(ax.patches) == self.g.unique().size for p in ax.patches: - nt.assert_equal(p.get_y(), 0) - nt.assert_equal(p.get_height(), - self.g.size / self.g.unique().size) + assert p.get_y() == 0 + assert p.get_height() == self.g.size / self.g.unique().size plt.close("all") ax = cat.countplot(y="g", data=self.df) - nt.assert_equal(len(ax.patches), self.g.unique().size) + assert len(ax.patches) == self.g.unique().size for p in ax.patches: - nt.assert_equal(p.get_x(), 0) - nt.assert_equal(p.get_width(), - self.g.size / self.g.unique().size) + assert p.get_x() == 0 + assert p.get_width() == self.g.size / self.g.unique().size plt.close("all") - ax = cat.countplot("g", hue="h", data=self.df) - nt.assert_equal(len(ax.patches), - self.g.unique().size * self.h.unique().size) + ax = cat.countplot(x="g", hue="h", data=self.df) + assert len(ax.patches) == self.g.unique().size * self.h.unique().size plt.close("all") ax = cat.countplot(y="g", hue="h", data=self.df) - nt.assert_equal(len(ax.patches), - self.g.unique().size * self.h.unique().size) + assert len(ax.patches) == self.g.unique().size * self.h.unique().size plt.close("all") def test_input_error(self): - with nt.assert_raises(TypeError): - cat.countplot() - - with nt.assert_raises(TypeError): + with pytest.raises(ValueError): cat.countplot(x="g", y="h", data=self.df) @@ -2337,135 +2712,205 @@ class TestCatPlot(CategoricalFixture): def test_facet_organization(self): - g = cat.catplot("g", "y", data=self.df) - nt.assert_equal(g.axes.shape, (1, 1)) + g = cat.catplot(x="g", y="y", data=self.df) + assert g.axes.shape == (1, 1) - g = cat.catplot("g", "y", col="h", data=self.df) - nt.assert_equal(g.axes.shape, (1, 2)) + g = cat.catplot(x="g", y="y", col="h", data=self.df) + assert g.axes.shape == (1, 2) - g = cat.catplot("g", "y", row="h", data=self.df) - nt.assert_equal(g.axes.shape, (2, 1)) + g = cat.catplot(x="g", y="y", row="h", data=self.df) + assert g.axes.shape == (2, 1) - g = cat.catplot("g", "y", col="u", row="h", data=self.df) - nt.assert_equal(g.axes.shape, (2, 3)) + g = cat.catplot(x="g", y="y", col="u", row="h", data=self.df) + assert g.axes.shape == (2, 3) def test_plot_elements(self): - g = cat.catplot("g", "y", data=self.df, kind="point") - nt.assert_equal(len(g.ax.collections), 1) + g = cat.catplot(x="g", y="y", data=self.df, kind="point") + assert len(g.ax.collections) == 1 want_lines = self.g.unique().size + 1 - nt.assert_equal(len(g.ax.lines), want_lines) + assert len(g.ax.lines) == want_lines - g = cat.catplot("g", "y", "h", data=self.df, kind="point") + g = cat.catplot(x="g", y="y", hue="h", data=self.df, kind="point") want_collections = self.h.unique().size - nt.assert_equal(len(g.ax.collections), want_collections) + assert len(g.ax.collections) == want_collections want_lines = (self.g.unique().size + 1) * self.h.unique().size - nt.assert_equal(len(g.ax.lines), want_lines) + assert len(g.ax.lines) == want_lines - g = cat.catplot("g", "y", data=self.df, kind="bar") + g = cat.catplot(x="g", y="y", data=self.df, kind="bar") want_elements = self.g.unique().size - nt.assert_equal(len(g.ax.patches), want_elements) - nt.assert_equal(len(g.ax.lines), want_elements) + assert len(g.ax.patches) == want_elements + assert len(g.ax.lines) == want_elements - g = cat.catplot("g", "y", "h", data=self.df, kind="bar") + g = cat.catplot(x="g", y="y", hue="h", data=self.df, kind="bar") want_elements = self.g.unique().size * self.h.unique().size - nt.assert_equal(len(g.ax.patches), want_elements) - nt.assert_equal(len(g.ax.lines), want_elements) + assert len(g.ax.patches) == want_elements + assert len(g.ax.lines) == want_elements - g = cat.catplot("g", data=self.df, kind="count") + g = cat.catplot(x="g", data=self.df, kind="count") want_elements = self.g.unique().size - nt.assert_equal(len(g.ax.patches), want_elements) - nt.assert_equal(len(g.ax.lines), 0) + assert len(g.ax.patches) == want_elements + assert len(g.ax.lines) == 0 - g = cat.catplot("g", hue="h", data=self.df, kind="count") + g = cat.catplot(x="g", hue="h", data=self.df, kind="count") want_elements = self.g.unique().size * self.h.unique().size - nt.assert_equal(len(g.ax.patches), want_elements) - nt.assert_equal(len(g.ax.lines), 0) + assert len(g.ax.patches) == want_elements + assert len(g.ax.lines) == 0 - g = cat.catplot("g", "y", data=self.df, kind="box") + g = cat.catplot(x="g", y="y", data=self.df, kind="box") want_artists = self.g.unique().size - nt.assert_equal(len(g.ax.artists), want_artists) + assert len(g.ax.artists) == want_artists - g = cat.catplot("g", "y", "h", data=self.df, kind="box") + g = cat.catplot(x="g", y="y", hue="h", data=self.df, kind="box") want_artists = self.g.unique().size * self.h.unique().size - nt.assert_equal(len(g.ax.artists), want_artists) + assert len(g.ax.artists) == want_artists - g = cat.catplot("g", "y", data=self.df, + g = cat.catplot(x="g", y="y", data=self.df, kind="violin", inner=None) want_elements = self.g.unique().size - nt.assert_equal(len(g.ax.collections), want_elements) + assert len(g.ax.collections) == want_elements - g = cat.catplot("g", "y", "h", data=self.df, + g = cat.catplot(x="g", y="y", hue="h", data=self.df, kind="violin", inner=None) want_elements = self.g.unique().size * self.h.unique().size - nt.assert_equal(len(g.ax.collections), want_elements) + assert len(g.ax.collections) == want_elements - g = cat.catplot("g", "y", data=self.df, kind="strip") + g = cat.catplot(x="g", y="y", data=self.df, kind="strip") want_elements = self.g.unique().size - nt.assert_equal(len(g.ax.collections), want_elements) + assert len(g.ax.collections) == want_elements - g = cat.catplot("g", "y", "h", data=self.df, kind="strip") + g = cat.catplot(x="g", y="y", hue="h", data=self.df, kind="strip") want_elements = self.g.unique().size + self.h.unique().size - nt.assert_equal(len(g.ax.collections), want_elements) + assert len(g.ax.collections) == want_elements def test_bad_plot_kind_error(self): - with nt.assert_raises(ValueError): - cat.catplot("g", "y", data=self.df, kind="not_a_kind") + with pytest.raises(ValueError): + cat.catplot(x="g", y="y", data=self.df, kind="not_a_kind") def test_count_x_and_y(self): - with nt.assert_raises(ValueError): - cat.catplot("g", "y", data=self.df, kind="count") + with pytest.raises(ValueError): + cat.catplot(x="g", y="y", data=self.df, kind="count") def test_plot_colors(self): - ax = cat.barplot("g", "y", data=self.df) - g = cat.catplot("g", "y", data=self.df, kind="bar") + ax = cat.barplot(x="g", y="y", data=self.df) + g = cat.catplot(x="g", y="y", data=self.df, kind="bar") for p1, p2 in zip(ax.patches, g.ax.patches): - nt.assert_equal(p1.get_facecolor(), p2.get_facecolor()) + assert p1.get_facecolor() == p2.get_facecolor() plt.close("all") - ax = cat.barplot("g", "y", data=self.df, color="purple") - g = cat.catplot("g", "y", data=self.df, + ax = cat.barplot(x="g", y="y", data=self.df, color="purple") + g = cat.catplot(x="g", y="y", data=self.df, kind="bar", color="purple") for p1, p2 in zip(ax.patches, g.ax.patches): - nt.assert_equal(p1.get_facecolor(), p2.get_facecolor()) + assert p1.get_facecolor() == p2.get_facecolor() plt.close("all") - ax = cat.barplot("g", "y", data=self.df, palette="Set2") - g = cat.catplot("g", "y", data=self.df, + ax = cat.barplot(x="g", y="y", data=self.df, palette="Set2") + g = cat.catplot(x="g", y="y", data=self.df, kind="bar", palette="Set2") for p1, p2 in zip(ax.patches, g.ax.patches): - nt.assert_equal(p1.get_facecolor(), p2.get_facecolor()) + assert p1.get_facecolor() == p2.get_facecolor() plt.close("all") - ax = cat.pointplot("g", "y", data=self.df) - g = cat.catplot("g", "y", data=self.df) + ax = cat.pointplot(x="g", y="y", data=self.df) + g = cat.catplot(x="g", y="y", data=self.df) for l1, l2 in zip(ax.lines, g.ax.lines): - nt.assert_equal(l1.get_color(), l2.get_color()) + assert l1.get_color() == l2.get_color() plt.close("all") - ax = cat.pointplot("g", "y", data=self.df, color="purple") - g = cat.catplot("g", "y", data=self.df, color="purple") + ax = cat.pointplot(x="g", y="y", data=self.df, color="purple") + g = cat.catplot(x="g", y="y", data=self.df, color="purple") for l1, l2 in zip(ax.lines, g.ax.lines): - nt.assert_equal(l1.get_color(), l2.get_color()) + assert l1.get_color() == l2.get_color() plt.close("all") - ax = cat.pointplot("g", "y", data=self.df, palette="Set2") - g = cat.catplot("g", "y", data=self.df, palette="Set2") + ax = cat.pointplot(x="g", y="y", data=self.df, palette="Set2") + g = cat.catplot(x="g", y="y", data=self.df, palette="Set2") for l1, l2 in zip(ax.lines, g.ax.lines): - nt.assert_equal(l1.get_color(), l2.get_color()) + assert l1.get_color() == l2.get_color() plt.close("all") + def test_ax_kwarg_removal(self): + + f, ax = plt.subplots() + with pytest.warns(UserWarning, match="catplot is a figure-level"): + g = cat.catplot(x="g", y="y", data=self.df, ax=ax) + assert len(ax.collections) == 0 + assert len(g.ax.collections) > 0 + def test_factorplot(self): with pytest.warns(UserWarning): - g = cat.factorplot("g", "y", data=self.df) + g = cat.factorplot(x="g", y="y", data=self.df) - nt.assert_equal(len(g.ax.collections), 1) + assert len(g.ax.collections) == 1 want_lines = self.g.unique().size + 1 - nt.assert_equal(len(g.ax.lines), want_lines) + assert len(g.ax.lines) == want_lines + + def test_share_xy(self): + + # Test default behavior works + g = cat.catplot(x="g", y="y", col="g", data=self.df, sharex=True) + for ax in g.axes.flat: + assert len(ax.collections) == len(self.df.g.unique()) + + g = cat.catplot(x="y", y="g", col="g", data=self.df, sharey=True) + for ax in g.axes.flat: + assert len(ax.collections) == len(self.df.g.unique()) + + # Test unsharing workscol + with pytest.warns(UserWarning): + g = cat.catplot( + x="g", y="y", col="g", data=self.df, sharex=False, kind="bar", + ) + for ax in g.axes.flat: + assert len(ax.patches) == 1 + + with pytest.warns(UserWarning): + g = cat.catplot( + x="y", y="g", col="g", data=self.df, sharey=False, kind="bar", + ) + for ax in g.axes.flat: + assert len(ax.patches) == 1 + + # Make sure no warning is raised if color is provided on unshared plot + with pytest.warns(None) as record: + g = cat.catplot( + x="g", y="y", col="g", data=self.df, sharex=False, color="b" + ) + assert not len(record) + for ax in g.axes.flat: + assert ax.get_xlim() == (-.5, .5) + + with pytest.warns(None) as record: + g = cat.catplot( + x="y", y="g", col="g", data=self.df, sharey=False, color="r" + ) + assert not len(record) + for ax in g.axes.flat: + assert ax.get_ylim() == (.5, -.5) + + # Make sure order is used if given, regardless of sharex value + order = self.df.g.unique() + g = cat.catplot(x="g", y="y", col="g", data=self.df, sharex=False, order=order) + for ax in g.axes.flat: + assert len(ax.collections) == len(self.df.g.unique()) + + g = cat.catplot(x="y", y="g", col="g", data=self.df, sharey=False, order=order) + for ax in g.axes.flat: + assert len(ax.collections) == len(self.df.g.unique()) + + @pytest.mark.parametrize("var", ["col", "row"]) + def test_array_faceter(self, long_df, var): + + g1 = catplot(data=long_df, x="y", **{var: "a"}) + g2 = catplot(data=long_df, x="y", **{var: long_df["a"].to_numpy()}) + + for ax1, ax2 in zip(g1.axes.flat, g2.axes.flat): + assert_plots_equal(ax1, ax2) class TestBoxenPlotter(CategoricalFixture): @@ -2474,13 +2919,18 @@ class TestBoxenPlotter(CategoricalFixture): order=None, hue_order=None, orient=None, color=None, palette=None, saturation=.75, width=.8, dodge=True, - k_depth='proportion', linewidth=None, - scale='exponential', outlier_prop=None) + k_depth='tukey', linewidth=None, + scale='exponential', outlier_prop=0.007, + trust_alpha=0.05, showfliers=True) def ispatch(self, c): return isinstance(c, mpl.collections.PatchCollection) + def ispath(self, c): + + return isinstance(c, mpl.collections.PathCollection) + def edge_calc(self, n, data): q = np.asanyarray([0.5 ** n, 1 - 0.5 ** n]) * 100 @@ -2488,73 +2938,174 @@ def edge_calc(self, n, data): return np.percentile(data, q) def test_box_ends_finite(self): + p = cat._LVPlotter(**self.default_kws) p.establish_variables("g", "y", data=self.df) - box_k = np.asarray([[b, k] - for b, k in map(p._lv_box_ends, p.plot_data)]) - box_ends = box_k[:, 0] - k_vals = box_k[:, 1] + box_ends = [] + k_vals = [] + for s in p.plot_data: + b, k = p._lv_box_ends(s) + box_ends.append(b) + k_vals.append(k) # Check that all the box ends are finite and are within # the bounds of the data b_e = map(lambda a: np.all(np.isfinite(a)), box_ends) - npt.assert_equal(np.sum(list(b_e)), len(box_ends)) + assert np.sum(list(b_e)) == len(box_ends) def within(t): a, d = t - return ((np.ravel(a) <= d.max()) & - (np.ravel(a) >= d.min())).all() + return ((np.ravel(a) <= d.max()) + & (np.ravel(a) >= d.min())).all() b_w = map(within, zip(box_ends, p.plot_data)) - npt.assert_equal(np.sum(list(b_w)), len(box_ends)) + assert np.sum(list(b_w)) == len(box_ends) k_f = map(lambda k: (k > 0.) & np.isfinite(k), k_vals) - npt.assert_equal(np.sum(list(k_f)), len(k_vals)) + assert np.sum(list(k_f)) == len(k_vals) - def test_box_ends_correct(self): + def test_box_ends_correct_tukey(self): n = 100 linear_data = np.arange(n) - expected_k = int(np.log2(n)) - int(np.log2(n * 0.007)) + 1 + expected_k = max(int(np.log2(n)) - 3, 1) expected_edges = [self.edge_calc(i, linear_data) - for i in range(expected_k + 2, 1, -1)] + for i in range(expected_k + 1, 1, -1)] p = cat._LVPlotter(**self.default_kws) calc_edges, calc_k = p._lv_box_ends(linear_data) - npt.assert_equal(list(expected_edges), calc_edges) - npt.assert_equal(expected_k, calc_k) + npt.assert_array_equal(expected_edges, calc_edges) + assert expected_k == calc_k + + def test_box_ends_correct_proportion(self): + + n = 100 + linear_data = np.arange(n) + expected_k = int(np.log2(n)) - int(np.log2(n * 0.007)) + 1 + expected_edges = [self.edge_calc(i, linear_data) + for i in range(expected_k + 1, 1, -1)] + + kws = self.default_kws.copy() + kws["k_depth"] = "proportion" + p = cat._LVPlotter(**kws) + calc_edges, calc_k = p._lv_box_ends(linear_data) + + npt.assert_array_equal(expected_edges, calc_edges) + assert expected_k == calc_k + + @pytest.mark.parametrize( + "n,exp_k", + [(491, 6), (492, 7), (983, 7), (984, 8), (1966, 8), (1967, 9)], + ) + def test_box_ends_correct_trustworthy(self, n, exp_k): + + linear_data = np.arange(n) + kws = self.default_kws.copy() + kws["k_depth"] = "trustworthy" + p = cat._LVPlotter(**kws) + _, calc_k = p._lv_box_ends(linear_data) + + assert exp_k == calc_k def test_outliers(self): n = 100 outlier_data = np.append(np.arange(n - 1), 2 * n) - expected_k = int(np.log2(n)) - int(np.log2(n * 0.007)) + 1 + expected_k = max(int(np.log2(n)) - 3, 1) expected_edges = [self.edge_calc(i, outlier_data) - for i in range(expected_k + 2, 1, -1)] + for i in range(expected_k + 1, 1, -1)] p = cat._LVPlotter(**self.default_kws) calc_edges, calc_k = p._lv_box_ends(outlier_data) - npt.assert_equal(list(expected_edges), calc_edges) - - npt.assert_equal(expected_k, calc_k) + npt.assert_array_equal(calc_edges, expected_edges) + assert calc_k == expected_k out_calc = p._lv_outliers(outlier_data, calc_k) out_exp = p._lv_outliers(outlier_data, expected_k) - npt.assert_equal(out_exp, out_calc) + npt.assert_equal(out_calc, out_exp) + + def test_showfliers(self): + + ax = cat.boxenplot(x="g", y="y", data=self.df, k_depth="proportion", + showfliers=True) + ax_collections = list(filter(self.ispath, ax.collections)) + for c in ax_collections: + assert len(c.get_offsets()) == 2 + + # Test that all data points are in the plot + assert ax.get_ylim()[0] < self.df["y"].min() + assert ax.get_ylim()[1] > self.df["y"].max() + + plt.close("all") + + ax = cat.boxenplot(x="g", y="y", data=self.df, showfliers=False) + assert len(list(filter(self.ispath, ax.collections))) == 0 + + plt.close("all") + + def test_invalid_depths(self): + + kws = self.default_kws.copy() + + # Make sure illegal depth raises + kws["k_depth"] = "nosuchdepth" + with pytest.raises(ValueError): + cat._LVPlotter(**kws) + + # Make sure illegal outlier_prop raises + kws["k_depth"] = "proportion" + for p in (-13, 37): + kws["outlier_prop"] = p + with pytest.raises(ValueError): + cat._LVPlotter(**kws) + + kws["k_depth"] = "trustworthy" + for alpha in (-13, 37): + kws["trust_alpha"] = alpha + with pytest.raises(ValueError): + cat._LVPlotter(**kws) + + @pytest.mark.parametrize("power", [1, 3, 7, 11, 13, 17]) + def test_valid_depths(self, power): + + x = np.random.standard_t(10, 2 ** power) + + valid_depths = ["proportion", "tukey", "trustworthy", "full"] + kws = self.default_kws.copy() + + for depth in valid_depths + [4]: + kws["k_depth"] = depth + box_ends, k = cat._LVPlotter(**kws)._lv_box_ends(x) + + if depth == "full": + assert k == int(np.log2(len(x))) + 1 + + def test_valid_scales(self): + + valid_scales = ["linear", "exponential", "area"] + kws = self.default_kws.copy() + + for scale in valid_scales + ["unknown_scale"]: + kws["scale"] = scale + if scale not in valid_scales: + with pytest.raises(ValueError): + cat._LVPlotter(**kws) + else: + cat._LVPlotter(**kws) def test_hue_offsets(self): p = cat._LVPlotter(**self.default_kws) - p.establish_variables("g", "y", "h", data=self.df) + p.establish_variables("g", "y", hue="h", data=self.df) npt.assert_array_equal(p.hue_offsets, [-.2, .2]) kws = self.default_kws.copy() kws["width"] = .6 p = cat._LVPlotter(**kws) - p.establish_variables("g", "y", "h", data=self.df) + p.establish_variables("g", "y", hue="h", data=self.df) npt.assert_array_equal(p.hue_offsets, [-.15, .15]) p = cat._LVPlotter(**kws) @@ -2563,43 +3114,60 @@ def test_hue_offsets(self): def test_axes_data(self): - ax = cat.boxenplot("g", "y", data=self.df) + ax = cat.boxenplot(x="g", y="y", data=self.df) patches = filter(self.ispatch, ax.collections) - nt.assert_equal(len(list(patches)), 3) + assert len(list(patches)) == 3 plt.close("all") - ax = cat.boxenplot("g", "y", "h", data=self.df) + ax = cat.boxenplot(x="g", y="y", hue="h", data=self.df) patches = filter(self.ispatch, ax.collections) - nt.assert_equal(len(list(patches)), 6) + assert len(list(patches)) == 6 plt.close("all") def test_box_colors(self): - ax = cat.boxenplot("g", "y", data=self.df, saturation=1) + ax = cat.boxenplot(x="g", y="y", data=self.df, saturation=1) pal = palettes.color_palette(n_colors=3) for patch, color in zip(ax.artists, pal): - nt.assert_equal(patch.get_facecolor()[:3], color) + assert patch.get_facecolor()[:3] == color plt.close("all") - ax = cat.boxenplot("g", "y", "h", data=self.df, saturation=1) + ax = cat.boxenplot(x="g", y="y", hue="h", data=self.df, saturation=1) pal = palettes.color_palette(n_colors=2) for patch, color in zip(ax.artists, pal * 2): - nt.assert_equal(patch.get_facecolor()[:3], color) + assert patch.get_facecolor()[:3] == color plt.close("all") def test_draw_missing_boxes(self): - ax = cat.boxenplot("g", "y", data=self.df, + ax = cat.boxenplot(x="g", y="y", data=self.df, order=["a", "b", "c", "d"]) patches = filter(self.ispatch, ax.collections) - nt.assert_equal(len(list(patches)), 3) + assert len(list(patches)) == 3 plt.close("all") + def test_unaligned_index(self): + + f, (ax1, ax2) = plt.subplots(2) + cat.boxenplot(x=self.g, y=self.y, ax=ax1) + cat.boxenplot(x=self.g, y=self.y_perm, ax=ax2) + for l1, l2 in zip(ax1.lines, ax2.lines): + assert np.array_equal(l1.get_xydata(), l2.get_xydata()) + + f, (ax1, ax2) = plt.subplots(2) + hue_order = self.h.unique() + cat.boxenplot(x=self.g, y=self.y, hue=self.h, + hue_order=hue_order, ax=ax1) + cat.boxenplot(x=self.g, y=self.y_perm, hue=self.h, + hue_order=hue_order, ax=ax2) + for l1, l2 in zip(ax1.lines, ax2.lines): + assert np.array_equal(l1.get_xydata(), l2.get_xydata()) + def test_missing_data(self): x = ["a", "a", "b", "b", "c", "c", "d", "d"] @@ -2607,14 +3175,14 @@ def test_missing_data(self): y = self.rs.randn(8) y[-2:] = np.nan - ax = cat.boxenplot(x, y) - nt.assert_equal(len(ax.lines), 3) + ax = cat.boxenplot(x=x, y=y) + assert len(ax.lines) == 3 plt.close("all") y[-1] = 0 - ax = cat.boxenplot(x, y, h) - nt.assert_equal(len(ax.lines), 7) + ax = cat.boxenplot(x=x, y=y, hue=h) + assert len(ax.lines) == 7 plt.close("all") @@ -2622,51 +3190,63 @@ def test_boxenplots(self): # Smoke test the high level boxenplot options - cat.boxenplot("y", data=self.df) + cat.boxenplot(x="y", data=self.df) plt.close("all") cat.boxenplot(y="y", data=self.df) plt.close("all") - cat.boxenplot("g", "y", data=self.df) + cat.boxenplot(x="g", y="y", data=self.df) plt.close("all") - cat.boxenplot("y", "g", data=self.df, orient="h") + cat.boxenplot(x="y", y="g", data=self.df, orient="h") plt.close("all") - cat.boxenplot("g", "y", "h", data=self.df) + cat.boxenplot(x="g", y="y", hue="h", data=self.df) plt.close("all") - cat.boxenplot("g", "y", "h", order=list("nabc"), data=self.df) + for scale in ("linear", "area", "exponential"): + cat.boxenplot(x="g", y="y", hue="h", scale=scale, data=self.df) + plt.close("all") + + for depth in ("proportion", "tukey", "trustworthy"): + cat.boxenplot(x="g", y="y", hue="h", k_depth=depth, data=self.df) + plt.close("all") + + order = list("nabc") + cat.boxenplot(x="g", y="y", hue="h", order=order, data=self.df) plt.close("all") - cat.boxenplot("g", "y", "h", hue_order=list("omn"), data=self.df) + order = list("omn") + cat.boxenplot(x="g", y="y", hue="h", hue_order=order, data=self.df) plt.close("all") - cat.boxenplot("y", "g", "h", data=self.df, orient="h") + cat.boxenplot(x="y", y="g", hue="h", data=self.df, orient="h") plt.close("all") - cat.boxenplot("y", "g", "h", data=self.df, orient="h", palette="Set2") + cat.boxenplot(x="y", y="g", hue="h", data=self.df, orient="h", + palette="Set2") plt.close("all") - cat.boxenplot("y", "g", "h", data=self.df, orient="h", color="b") + cat.boxenplot(x="y", y="g", hue="h", data=self.df, + orient="h", color="b") plt.close("all") def test_axes_annotation(self): - ax = cat.boxenplot("g", "y", data=self.df) - nt.assert_equal(ax.get_xlabel(), "g") - nt.assert_equal(ax.get_ylabel(), "y") - nt.assert_equal(ax.get_xlim(), (-.5, 2.5)) + ax = cat.boxenplot(x="g", y="y", data=self.df) + assert ax.get_xlabel() == "g" + assert ax.get_ylabel() == "y" + assert ax.get_xlim() == (-.5, 2.5) npt.assert_array_equal(ax.get_xticks(), [0, 1, 2]) npt.assert_array_equal([l.get_text() for l in ax.get_xticklabels()], ["a", "b", "c"]) plt.close("all") - ax = cat.boxenplot("g", "y", "h", data=self.df) - nt.assert_equal(ax.get_xlabel(), "g") - nt.assert_equal(ax.get_ylabel(), "y") + ax = cat.boxenplot(x="g", y="y", hue="h", data=self.df) + assert ax.get_xlabel() == "g" + assert ax.get_ylabel() == "y" npt.assert_array_equal(ax.get_xticks(), [0, 1, 2]) npt.assert_array_equal([l.get_text() for l in ax.get_xticklabels()], ["a", "b", "c"]) @@ -2675,27 +3255,101 @@ def test_axes_annotation(self): plt.close("all") - with plt.rc_context(rc={"axes.labelsize": "large"}): - ax = cat.boxenplot("g", "y", "h", data=self.df) - - plt.close("all") - - ax = cat.boxenplot("y", "g", data=self.df, orient="h") - nt.assert_equal(ax.get_xlabel(), "y") - nt.assert_equal(ax.get_ylabel(), "g") - nt.assert_equal(ax.get_ylim(), (2.5, -.5)) + ax = cat.boxenplot(x="y", y="g", data=self.df, orient="h") + assert ax.get_xlabel() == "y" + assert ax.get_ylabel() == "g" + assert ax.get_ylim() == (2.5, -.5) npt.assert_array_equal(ax.get_yticks(), [0, 1, 2]) npt.assert_array_equal([l.get_text() for l in ax.get_yticklabels()], ["a", "b", "c"]) plt.close("all") - def test_lvplot(self): + @pytest.mark.parametrize("size", ["large", "medium", "small", 22, 12]) + def test_legend_titlesize(self, size): - with pytest.warns(UserWarning): - ax = cat.lvplot("g", "y", data=self.df) + rc_ctx = {"legend.title_fontsize": size} + exp = mpl.font_manager.FontProperties(size=size).get_size() - patches = filter(self.ispatch, ax.collections) - nt.assert_equal(len(list(patches)), 3) + with plt.rc_context(rc=rc_ctx): + ax = cat.boxenplot(x="g", y="y", hue="h", data=self.df) + obs = ax.get_legend().get_title().get_fontproperties().get_size() + assert obs == exp + + plt.close("all") + + @pytest.mark.skipif( + LooseVersion(pd.__version__) < "1.2", + reason="Test requires pandas>=1.2") + def test_Float64_input(self): + data = pd.DataFrame( + {"x": np.random.choice(["a", "b"], 20), "y": np.random.random(20)} + ) + data['y'] = data['y'].astype(pd.Float64Dtype()) + _ = cat.boxenplot(x="x", y="y", data=data) plt.close("all") + + +class TestBeeswarm: + + def test_could_overlap(self): + + p = Beeswarm() + neighbors = p.could_overlap( + (1, 1, .5), + [(0, 0, .5), + (1, .1, .2), + (.5, .5, .5)] + ) + assert_array_equal(neighbors, [(.5, .5, .5)]) + + def test_position_candidates(self): + + p = Beeswarm() + xy_i = (0, 1, .5) + neighbors = [(0, 1, .5), (0, 1.5, .5)] + candidates = p.position_candidates(xy_i, neighbors) + dx1 = 1.05 + dx2 = np.sqrt(1 - .5 ** 2) * 1.05 + assert_array_equal( + candidates, + [(0, 1, .5), (-dx1, 1, .5), (dx1, 1, .5), (dx2, 1, .5), (-dx2, 1, .5)] + ) + + def test_find_first_non_overlapping_candidate(self): + + p = Beeswarm() + candidates = [(.5, 1, .5), (1, 1, .5), (1.5, 1, .5)] + neighbors = np.array([(0, 1, .5)]) + + first = p.first_non_overlapping_candidate(candidates, neighbors) + assert_array_equal(first, (1, 1, .5)) + + def test_beeswarm(self, long_df): + + p = Beeswarm() + data = long_df["y"] + d = data.diff().mean() * 1.5 + x = np.zeros(data.size) + y = np.sort(data) + r = np.full_like(y, d) + orig_xyr = np.c_[x, y, r] + swarm = p.beeswarm(orig_xyr)[:, :2] + dmat = np.sqrt(np.sum(np.square(swarm[:, np.newaxis] - swarm), axis=-1)) + triu = dmat[np.triu_indices_from(dmat, 1)] + assert_array_less(d, triu) + assert_array_equal(y, swarm[:, 1]) + + def test_add_gutters(self): + + p = Beeswarm(width=1) + + points = np.zeros(10) + assert_array_equal(points, p.add_gutters(points, 0)) + + points = np.array([0, -1, .4, .8]) + msg = r"50.0% of the points cannot be placed.+$" + with pytest.warns(UserWarning, match=msg): + new_points = p.add_gutters(points, 0) + assert_array_equal(new_points, np.array([0, -.5, .4, .5])) diff --git a/seaborn/tests/test_core.py b/seaborn/tests/test_core.py new file mode 100644 index 0000000000..7f8173d210 --- /dev/null +++ b/seaborn/tests/test_core.py @@ -0,0 +1,1544 @@ +import itertools +import numpy as np +import pandas as pd +import matplotlib as mpl +import matplotlib.pyplot as plt + +import pytest +from numpy.testing import assert_array_equal +from pandas.testing import assert_frame_equal + +from ..axisgrid import FacetGrid +from .._core import ( + SemanticMapping, + HueMapping, + SizeMapping, + StyleMapping, + VectorPlotter, + variable_type, + infer_orient, + unique_dashes, + unique_markers, + categorical_order, +) + +from ..palettes import color_palette + + +try: + from pandas import NA as PD_NA +except ImportError: + PD_NA = None + + +@pytest.fixture(params=[ + dict(x="x", y="y"), + dict(x="t", y="y"), + dict(x="a", y="y"), + dict(x="x", y="y", hue="y"), + dict(x="x", y="y", hue="a"), + dict(x="x", y="y", size="a"), + dict(x="x", y="y", style="a"), + dict(x="x", y="y", hue="s"), + dict(x="x", y="y", size="s"), + dict(x="x", y="y", style="s"), + dict(x="x", y="y", hue="a", style="a"), + dict(x="x", y="y", hue="a", size="b", style="b"), +]) +def long_variables(request): + return request.param + + +class TestSemanticMapping: + + def test_call_lookup(self): + + m = SemanticMapping(VectorPlotter()) + lookup_table = dict(zip("abc", (1, 2, 3))) + m.lookup_table = lookup_table + for key, val in lookup_table.items(): + assert m(key) == val + + +class TestHueMapping: + + def test_init_from_map(self, long_df): + + p_orig = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a") + ) + palette = "Set2" + p = HueMapping.map(p_orig, palette=palette) + assert p is p_orig + assert isinstance(p._hue_map, HueMapping) + assert p._hue_map.palette == palette + + def test_plotter_default_init(self, long_df): + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y"), + ) + assert isinstance(p._hue_map, HueMapping) + assert p._hue_map.map_type is None + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a"), + ) + assert isinstance(p._hue_map, HueMapping) + assert p._hue_map.map_type == p.var_types["hue"] + + def test_plotter_reinit(self, long_df): + + p_orig = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a"), + ) + palette = "muted" + hue_order = ["b", "a", "c"] + p = p_orig.map_hue(palette=palette, order=hue_order) + assert p is p_orig + assert p._hue_map.palette == palette + assert p._hue_map.levels == hue_order + + def test_hue_map_null(self, flat_series, null_series): + + p = VectorPlotter(variables=dict(x=flat_series, hue=null_series)) + m = HueMapping(p) + assert m.levels is None + assert m.map_type is None + assert m.palette is None + assert m.cmap is None + assert m.norm is None + assert m.lookup_table is None + + def test_hue_map_categorical(self, wide_df, long_df): + + p = VectorPlotter(data=wide_df) + m = HueMapping(p) + assert m.levels == wide_df.columns.to_list() + assert m.map_type == "categorical" + assert m.cmap is None + + # Test named palette + palette = "Blues" + expected_colors = color_palette(palette, wide_df.shape[1]) + expected_lookup_table = dict(zip(wide_df.columns, expected_colors)) + m = HueMapping(p, palette=palette) + assert m.palette == "Blues" + assert m.lookup_table == expected_lookup_table + + # Test list palette + palette = color_palette("Reds", wide_df.shape[1]) + expected_lookup_table = dict(zip(wide_df.columns, palette)) + m = HueMapping(p, palette=palette) + assert m.palette == palette + assert m.lookup_table == expected_lookup_table + + # Test dict palette + colors = color_palette("Set1", 8) + palette = dict(zip(wide_df.columns, colors)) + m = HueMapping(p, palette=palette) + assert m.palette == palette + assert m.lookup_table == palette + + # Test dict with missing keys + palette = dict(zip(wide_df.columns[:-1], colors)) + with pytest.raises(ValueError): + HueMapping(p, palette=palette) + + # Test dict with missing keys + palette = dict(zip(wide_df.columns[:-1], colors)) + with pytest.raises(ValueError): + HueMapping(p, palette=palette) + + # Test list with wrong number of colors + palette = colors[:-1] + with pytest.raises(ValueError): + HueMapping(p, palette=palette) + + # Test hue order + hue_order = ["a", "c", "d"] + m = HueMapping(p, order=hue_order) + assert m.levels == hue_order + + # Test long data + p = VectorPlotter(data=long_df, variables=dict(x="x", y="y", hue="a")) + m = HueMapping(p) + assert m.levels == categorical_order(long_df["a"]) + assert m.map_type == "categorical" + assert m.cmap is None + + # Test default palette + m = HueMapping(p) + hue_levels = categorical_order(long_df["a"]) + expected_colors = color_palette(n_colors=len(hue_levels)) + expected_lookup_table = dict(zip(hue_levels, expected_colors)) + assert m.lookup_table == expected_lookup_table + + # Test missing data + m = HueMapping(p) + assert m(np.nan) == (0, 0, 0, 0) + + # Test default palette with many levels + x = y = np.arange(26) + hue = pd.Series(list("abcdefghijklmnopqrstuvwxyz")) + p = VectorPlotter(variables=dict(x=x, y=y, hue=hue)) + m = HueMapping(p) + expected_colors = color_palette("husl", n_colors=len(hue)) + expected_lookup_table = dict(zip(hue, expected_colors)) + assert m.lookup_table == expected_lookup_table + + # Test binary data + p = VectorPlotter(data=long_df, variables=dict(x="x", y="y", hue="c")) + m = HueMapping(p) + assert m.levels == [0, 1] + assert m.map_type == "categorical" + + for val in [0, 1]: + p = VectorPlotter( + data=long_df[long_df["c"] == val], + variables=dict(x="x", y="y", hue="c"), + ) + m = HueMapping(p) + assert m.levels == [val] + assert m.map_type == "categorical" + + # Test Timestamp data + p = VectorPlotter(data=long_df, variables=dict(x="x", y="y", hue="t")) + m = HueMapping(p) + assert m.levels == [pd.Timestamp(t) for t in long_df["t"].unique()] + assert m.map_type == "datetime" + + # Test excplicit categories + p = VectorPlotter(data=long_df, variables=dict(x="x", hue="a_cat")) + m = HueMapping(p) + assert m.levels == long_df["a_cat"].cat.categories.to_list() + assert m.map_type == "categorical" + + # Test numeric data with category type + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", hue="s_cat") + ) + m = HueMapping(p) + assert m.levels == categorical_order(long_df["s_cat"]) + assert m.map_type == "categorical" + assert m.cmap is None + + # Test categorical palette specified for numeric data + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", hue="s") + ) + palette = "deep" + levels = categorical_order(long_df["s"]) + expected_colors = color_palette(palette, n_colors=len(levels)) + expected_lookup_table = dict(zip(levels, expected_colors)) + m = HueMapping(p, palette=palette) + assert m.lookup_table == expected_lookup_table + assert m.map_type == "categorical" + + def test_hue_map_numeric(self, long_df): + + # Test default colormap + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", hue="s") + ) + hue_levels = list(np.sort(long_df["s"].unique())) + m = HueMapping(p) + assert m.levels == hue_levels + assert m.map_type == "numeric" + assert m.cmap.name == "seaborn_cubehelix" + + # Test named colormap + palette = "Purples" + m = HueMapping(p, palette=palette) + assert m.cmap is mpl.cm.get_cmap(palette) + + # Test colormap object + palette = mpl.cm.get_cmap("Greens") + m = HueMapping(p, palette=palette) + assert m.cmap is mpl.cm.get_cmap(palette) + + # Test cubehelix shorthand + palette = "ch:2,0,light=.2" + m = HueMapping(p, palette=palette) + assert isinstance(m.cmap, mpl.colors.ListedColormap) + + # Test specified hue limits + hue_norm = 1, 4 + m = HueMapping(p, norm=hue_norm) + assert isinstance(m.norm, mpl.colors.Normalize) + assert m.norm.vmin == hue_norm[0] + assert m.norm.vmax == hue_norm[1] + + # Test Normalize object + hue_norm = mpl.colors.PowerNorm(2, vmin=1, vmax=10) + m = HueMapping(p, norm=hue_norm) + assert m.norm is hue_norm + + # Test default colormap values + hmin, hmax = p.plot_data["hue"].min(), p.plot_data["hue"].max() + m = HueMapping(p) + assert m.lookup_table[hmin] == pytest.approx(m.cmap(0.0)) + assert m.lookup_table[hmax] == pytest.approx(m.cmap(1.0)) + + # Test specified colormap values + hue_norm = hmin - 1, hmax - 1 + m = HueMapping(p, norm=hue_norm) + norm_min = (hmin - hue_norm[0]) / (hue_norm[1] - hue_norm[0]) + assert m.lookup_table[hmin] == pytest.approx(m.cmap(norm_min)) + assert m.lookup_table[hmax] == pytest.approx(m.cmap(1.0)) + + # Test list of colors + hue_levels = list(np.sort(long_df["s"].unique())) + palette = color_palette("Blues", len(hue_levels)) + m = HueMapping(p, palette=palette) + assert m.lookup_table == dict(zip(hue_levels, palette)) + + palette = color_palette("Blues", len(hue_levels) + 1) + with pytest.raises(ValueError): + HueMapping(p, palette=palette) + + # Test dictionary of colors + palette = dict(zip(hue_levels, color_palette("Reds"))) + m = HueMapping(p, palette=palette) + assert m.lookup_table == palette + + palette.pop(hue_levels[0]) + with pytest.raises(ValueError): + HueMapping(p, palette=palette) + + # Test invalid palette + with pytest.raises(ValueError): + HueMapping(p, palette="not a valid palette") + + # Test bad norm argument + with pytest.raises(ValueError): + HueMapping(p, norm="not a norm") + + +class TestSizeMapping: + + def test_init_from_map(self, long_df): + + p_orig = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", size="a") + ) + sizes = 1, 6 + p = SizeMapping.map(p_orig, sizes=sizes) + assert p is p_orig + assert isinstance(p._size_map, SizeMapping) + assert min(p._size_map.lookup_table.values()) == sizes[0] + assert max(p._size_map.lookup_table.values()) == sizes[1] + + def test_plotter_default_init(self, long_df): + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y"), + ) + assert isinstance(p._size_map, SizeMapping) + assert p._size_map.map_type is None + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", size="a"), + ) + assert isinstance(p._size_map, SizeMapping) + assert p._size_map.map_type == p.var_types["size"] + + def test_plotter_reinit(self, long_df): + + p_orig = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", size="a"), + ) + sizes = [1, 4, 2] + size_order = ["b", "a", "c"] + p = p_orig.map_size(sizes=sizes, order=size_order) + assert p is p_orig + assert p._size_map.lookup_table == dict(zip(size_order, sizes)) + assert p._size_map.levels == size_order + + def test_size_map_null(self, flat_series, null_series): + + p = VectorPlotter(variables=dict(x=flat_series, size=null_series)) + m = HueMapping(p) + assert m.levels is None + assert m.map_type is None + assert m.norm is None + assert m.lookup_table is None + + def test_map_size_numeric(self, long_df): + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", size="s"), + ) + + # Test default range of keys in the lookup table values + m = SizeMapping(p) + size_values = m.lookup_table.values() + value_range = min(size_values), max(size_values) + assert value_range == p._default_size_range + + # Test specified range of size values + sizes = 1, 5 + m = SizeMapping(p, sizes=sizes) + size_values = m.lookup_table.values() + assert min(size_values), max(size_values) == sizes + + # Test size values with normalization range + norm = 1, 10 + m = SizeMapping(p, sizes=sizes, norm=norm) + normalize = mpl.colors.Normalize(*norm, clip=True) + for key, val in m.lookup_table.items(): + assert val == sizes[0] + (sizes[1] - sizes[0]) * normalize(key) + + # Test size values with normalization object + norm = mpl.colors.LogNorm(1, 10, clip=False) + m = SizeMapping(p, sizes=sizes, norm=norm) + assert m.norm.clip + for key, val in m.lookup_table.items(): + assert val == sizes[0] + (sizes[1] - sizes[0]) * norm(key) + + # Test bad sizes argument + with pytest.raises(ValueError): + SizeMapping(p, sizes="bad_sizes") + + # Test bad sizes argument + with pytest.raises(ValueError): + SizeMapping(p, sizes=(1, 2, 3)) + + # Test bad norm argument + with pytest.raises(ValueError): + SizeMapping(p, norm="bad_norm") + + def test_map_size_categorical(self, long_df): + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", size="a"), + ) + + # Test specified size order + levels = p.plot_data["size"].unique() + sizes = [1, 4, 6] + order = [levels[1], levels[2], levels[0]] + m = SizeMapping(p, sizes=sizes, order=order) + assert m.lookup_table == dict(zip(order, sizes)) + + # Test list of sizes + order = categorical_order(p.plot_data["size"]) + sizes = list(np.random.rand(len(levels))) + m = SizeMapping(p, sizes=sizes) + assert m.lookup_table == dict(zip(order, sizes)) + + # Test dict of sizes + sizes = dict(zip(levels, np.random.rand(len(levels)))) + m = SizeMapping(p, sizes=sizes) + assert m.lookup_table == sizes + + # Test specified size range + sizes = (2, 5) + m = SizeMapping(p, sizes=sizes) + values = np.linspace(*sizes, len(m.levels))[::-1] + assert m.lookup_table == dict(zip(m.levels, values)) + + # Test explicit categories + p = VectorPlotter(data=long_df, variables=dict(x="x", size="a_cat")) + m = SizeMapping(p) + assert m.levels == long_df["a_cat"].cat.categories.to_list() + assert m.map_type == "categorical" + + # Test sizes list with wrong length + sizes = list(np.random.rand(len(levels) + 1)) + with pytest.raises(ValueError): + SizeMapping(p, sizes=sizes) + + # Test sizes dict with missing levels + sizes = dict(zip(levels, np.random.rand(len(levels) - 1))) + with pytest.raises(ValueError): + SizeMapping(p, sizes=sizes) + + # Test bad sizes argument + with pytest.raises(ValueError): + SizeMapping(p, sizes="bad_size") + + +class TestStyleMapping: + + def test_init_from_map(self, long_df): + + p_orig = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", style="a") + ) + markers = ["s", "p", "h"] + p = StyleMapping.map(p_orig, markers=markers) + assert p is p_orig + assert isinstance(p._style_map, StyleMapping) + assert p._style_map(p._style_map.levels, "marker") == markers + + def test_plotter_default_init(self, long_df): + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y"), + ) + assert isinstance(p._style_map, StyleMapping) + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", style="a"), + ) + assert isinstance(p._style_map, StyleMapping) + + def test_plotter_reinit(self, long_df): + + p_orig = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", style="a"), + ) + markers = ["s", "p", "h"] + style_order = ["b", "a", "c"] + p = p_orig.map_style(markers=markers, order=style_order) + assert p is p_orig + assert p._style_map.levels == style_order + assert p._style_map(style_order, "marker") == markers + + def test_style_map_null(self, flat_series, null_series): + + p = VectorPlotter(variables=dict(x=flat_series, style=null_series)) + m = HueMapping(p) + assert m.levels is None + assert m.map_type is None + assert m.lookup_table is None + + def test_map_style(self, long_df): + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", style="a"), + ) + + # Test defaults + m = StyleMapping(p, markers=True, dashes=True) + + n = len(m.levels) + for key, dashes in zip(m.levels, unique_dashes(n)): + assert m(key, "dashes") == dashes + + actual_marker_paths = { + k: mpl.markers.MarkerStyle(m(k, "marker")).get_path() + for k in m.levels + } + expected_marker_paths = { + k: mpl.markers.MarkerStyle(m).get_path() + for k, m in zip(m.levels, unique_markers(n)) + } + assert actual_marker_paths == expected_marker_paths + + # Test lists + markers, dashes = ["o", "s", "d"], [(1, 0), (1, 1), (2, 1, 3, 1)] + m = StyleMapping(p, markers=markers, dashes=dashes) + for key, mark, dash in zip(m.levels, markers, dashes): + assert m(key, "marker") == mark + assert m(key, "dashes") == dash + + # Test dicts + markers = dict(zip(p.plot_data["style"].unique(), markers)) + dashes = dict(zip(p.plot_data["style"].unique(), dashes)) + m = StyleMapping(p, markers=markers, dashes=dashes) + for key in m.levels: + assert m(key, "marker") == markers[key] + assert m(key, "dashes") == dashes[key] + + # Test excplicit categories + p = VectorPlotter(data=long_df, variables=dict(x="x", style="a_cat")) + m = StyleMapping(p) + assert m.levels == long_df["a_cat"].cat.categories.to_list() + + # Test style order with defaults + order = p.plot_data["style"].unique()[[1, 2, 0]] + m = StyleMapping(p, markers=True, dashes=True, order=order) + n = len(order) + for key, mark, dash in zip(order, unique_markers(n), unique_dashes(n)): + assert m(key, "dashes") == dash + assert m(key, "marker") == mark + obj = mpl.markers.MarkerStyle(mark) + path = obj.get_path().transformed(obj.get_transform()) + assert_array_equal(m(key, "path").vertices, path.vertices) + + # Test too many levels with style lists + with pytest.raises(ValueError): + StyleMapping(p, markers=["o", "s"], dashes=False) + + with pytest.raises(ValueError): + StyleMapping(p, markers=False, dashes=[(2, 1)]) + + # Test too many levels with style dicts + markers, dashes = {"a": "o", "b": "s"}, False + with pytest.raises(ValueError): + StyleMapping(p, markers=markers, dashes=dashes) + + markers, dashes = False, {"a": (1, 0), "b": (2, 1)} + with pytest.raises(ValueError): + StyleMapping(p, markers=markers, dashes=dashes) + + # Test mixture of filled and unfilled markers + markers, dashes = ["o", "x", "s"], None + with pytest.raises(ValueError): + StyleMapping(p, markers=markers, dashes=dashes) + + +class TestVectorPlotter: + + def test_flat_variables(self, flat_data): + + p = VectorPlotter() + p.assign_variables(data=flat_data) + assert p.input_format == "wide" + assert list(p.variables) == ["x", "y"] + assert len(p.plot_data) == len(flat_data) + + try: + expected_x = flat_data.index + expected_x_name = flat_data.index.name + except AttributeError: + expected_x = np.arange(len(flat_data)) + expected_x_name = None + + x = p.plot_data["x"] + assert_array_equal(x, expected_x) + + expected_y = flat_data + expected_y_name = getattr(flat_data, "name", None) + + y = p.plot_data["y"] + assert_array_equal(y, expected_y) + + assert p.variables["x"] == expected_x_name + assert p.variables["y"] == expected_y_name + + def test_long_df(self, long_df, long_variables): + + p = VectorPlotter() + p.assign_variables(data=long_df, variables=long_variables) + assert p.input_format == "long" + assert p.variables == long_variables + + for key, val in long_variables.items(): + assert_array_equal(p.plot_data[key], long_df[val]) + + def test_long_df_with_index(self, long_df, long_variables): + + p = VectorPlotter() + p.assign_variables( + data=long_df.set_index("a"), + variables=long_variables, + ) + assert p.input_format == "long" + assert p.variables == long_variables + + for key, val in long_variables.items(): + assert_array_equal(p.plot_data[key], long_df[val]) + + def test_long_df_with_multiindex(self, long_df, long_variables): + + p = VectorPlotter() + p.assign_variables( + data=long_df.set_index(["a", "x"]), + variables=long_variables, + ) + assert p.input_format == "long" + assert p.variables == long_variables + + for key, val in long_variables.items(): + assert_array_equal(p.plot_data[key], long_df[val]) + + def test_long_dict(self, long_dict, long_variables): + + p = VectorPlotter() + p.assign_variables( + data=long_dict, + variables=long_variables, + ) + assert p.input_format == "long" + assert p.variables == long_variables + + for key, val in long_variables.items(): + assert_array_equal(p.plot_data[key], pd.Series(long_dict[val])) + + @pytest.mark.parametrize( + "vector_type", + ["series", "numpy", "list"], + ) + def test_long_vectors(self, long_df, long_variables, vector_type): + + variables = {key: long_df[val] for key, val in long_variables.items()} + if vector_type == "numpy": + variables = {key: val.to_numpy() for key, val in variables.items()} + elif vector_type == "list": + variables = {key: val.to_list() for key, val in variables.items()} + + p = VectorPlotter() + p.assign_variables(variables=variables) + assert p.input_format == "long" + + assert list(p.variables) == list(long_variables) + if vector_type == "series": + assert p.variables == long_variables + + for key, val in long_variables.items(): + assert_array_equal(p.plot_data[key], long_df[val]) + + def test_long_undefined_variables(self, long_df): + + p = VectorPlotter() + + with pytest.raises(ValueError): + p.assign_variables( + data=long_df, variables=dict(x="not_in_df"), + ) + + with pytest.raises(ValueError): + p.assign_variables( + data=long_df, variables=dict(x="x", y="not_in_df"), + ) + + with pytest.raises(ValueError): + p.assign_variables( + data=long_df, variables=dict(x="x", y="y", hue="not_in_df"), + ) + + @pytest.mark.parametrize( + "arg", [[], np.array([]), pd.DataFrame()], + ) + def test_empty_data_input(self, arg): + + p = VectorPlotter() + p.assign_variables(data=arg) + assert not p.variables + + if not isinstance(arg, pd.DataFrame): + p = VectorPlotter() + p.assign_variables(variables=dict(x=arg, y=arg)) + assert not p.variables + + def test_units(self, repeated_df): + + p = VectorPlotter() + p.assign_variables( + data=repeated_df, + variables=dict(x="x", y="y", units="u"), + ) + assert_array_equal(p.plot_data["units"], repeated_df["u"]) + + @pytest.mark.parametrize("name", [3, 4.5]) + def test_long_numeric_name(self, long_df, name): + + long_df[name] = long_df["x"] + p = VectorPlotter() + p.assign_variables(data=long_df, variables={"x": name}) + assert_array_equal(p.plot_data["x"], long_df[name]) + assert p.variables["x"] == name + + def test_long_hierarchical_index(self, rng): + + cols = pd.MultiIndex.from_product([["a"], ["x", "y"]]) + data = rng.uniform(size=(50, 2)) + df = pd.DataFrame(data, columns=cols) + + name = ("a", "y") + var = "y" + + p = VectorPlotter() + p.assign_variables(data=df, variables={var: name}) + assert_array_equal(p.plot_data[var], df[name]) + assert p.variables[var] == name + + def test_long_scalar_and_data(self, long_df): + + val = 22 + p = VectorPlotter(data=long_df, variables={"x": "x", "y": val}) + assert (p.plot_data["y"] == val).all() + assert p.variables["y"] is None + + def test_wide_semantic_error(self, wide_df): + + err = "The following variable cannot be assigned with wide-form data: `hue`" + with pytest.raises(ValueError, match=err): + VectorPlotter(data=wide_df, variables={"hue": "a"}) + + def test_long_unknown_error(self, long_df): + + err = "Could not interpret value `what` for parameter `hue`" + with pytest.raises(ValueError, match=err): + VectorPlotter(data=long_df, variables={"x": "x", "hue": "what"}) + + def test_long_unmatched_size_error(self, long_df, flat_array): + + err = "Length of ndarray vectors must match length of `data`" + with pytest.raises(ValueError, match=err): + VectorPlotter(data=long_df, variables={"x": "x", "hue": flat_array}) + + def test_wide_categorical_columns(self, wide_df): + + wide_df.columns = pd.CategoricalIndex(wide_df.columns) + p = VectorPlotter(data=wide_df) + assert_array_equal(p.plot_data["hue"].unique(), ["a", "b", "c"]) + + def test_iter_data_quantitites(self, long_df): + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y"), + ) + out = p.iter_data("hue") + assert len(list(out)) == 1 + + var = "a" + n_subsets = len(long_df[var].unique()) + + semantics = ["hue", "size", "style"] + for semantic in semantics: + + p = VectorPlotter( + data=long_df, + variables={"x": "x", "y": "y", semantic: var}, + ) + out = p.iter_data(semantics) + assert len(list(out)) == n_subsets + + var = "a" + n_subsets = len(long_df[var].unique()) + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", hue=var, style=var), + ) + out = p.iter_data(semantics) + assert len(list(out)) == n_subsets + + # -- + + out = p.iter_data(semantics, reverse=True) + assert len(list(out)) == n_subsets + + # -- + + var1, var2 = "a", "s" + + n_subsets = len(long_df[var1].unique()) + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", hue=var1, style=var2), + ) + out = p.iter_data(["hue"]) + assert len(list(out)) == n_subsets + + n_subsets = len(set(list(map(tuple, long_df[[var1, var2]].values)))) + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", hue=var1, style=var2), + ) + out = p.iter_data(semantics) + assert len(list(out)) == n_subsets + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", hue=var1, size=var2, style=var1), + ) + out = p.iter_data(semantics) + assert len(list(out)) == n_subsets + + # -- + + var1, var2, var3 = "a", "s", "b" + cols = [var1, var2, var3] + n_subsets = len(set(list(map(tuple, long_df[cols].values)))) + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", hue=var1, size=var2, style=var3), + ) + out = p.iter_data(semantics) + assert len(list(out)) == n_subsets + + def test_iter_data_keys(self, long_df): + + semantics = ["hue", "size", "style"] + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y"), + ) + for sub_vars, _ in p.iter_data("hue"): + assert sub_vars == {} + + # -- + + var = "a" + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", hue=var), + ) + for sub_vars, _ in p.iter_data("hue"): + assert list(sub_vars) == ["hue"] + assert sub_vars["hue"] in long_df[var].values + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", size=var), + ) + for sub_vars, _ in p.iter_data("size"): + assert list(sub_vars) == ["size"] + assert sub_vars["size"] in long_df[var].values + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", hue=var, style=var), + ) + for sub_vars, _ in p.iter_data(semantics): + assert list(sub_vars) == ["hue", "style"] + assert sub_vars["hue"] in long_df[var].values + assert sub_vars["style"] in long_df[var].values + assert sub_vars["hue"] == sub_vars["style"] + + var1, var2 = "a", "s" + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", hue=var1, size=var2), + ) + for sub_vars, _ in p.iter_data(semantics): + assert list(sub_vars) == ["hue", "size"] + assert sub_vars["hue"] in long_df[var1].values + assert sub_vars["size"] in long_df[var2].values + + semantics = ["hue", "col", "row"] + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", hue=var1, col=var2), + ) + for sub_vars, _ in p.iter_data("hue"): + assert list(sub_vars) == ["hue", "col"] + assert sub_vars["hue"] in long_df[var1].values + assert sub_vars["col"] in long_df[var2].values + + def test_iter_data_values(self, long_df): + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y"), + ) + + p.sort = True + _, sub_data = next(p.iter_data("hue")) + assert_frame_equal(sub_data, p.plot_data) + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a"), + ) + + for sub_vars, sub_data in p.iter_data("hue"): + rows = p.plot_data["hue"] == sub_vars["hue"] + assert_frame_equal(sub_data, p.plot_data[rows]) + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a", size="s"), + ) + for sub_vars, sub_data in p.iter_data(["hue", "size"]): + rows = p.plot_data["hue"] == sub_vars["hue"] + rows &= p.plot_data["size"] == sub_vars["size"] + assert_frame_equal(sub_data, p.plot_data[rows]) + + def test_iter_data_reverse(self, long_df): + + reversed_order = categorical_order(long_df["a"])[::-1] + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a") + ) + iterator = p.iter_data("hue", reverse=True) + for i, (sub_vars, _) in enumerate(iterator): + assert sub_vars["hue"] == reversed_order[i] + + def test_iter_data_dropna(self, missing_df): + + p = VectorPlotter( + data=missing_df, + variables=dict(x="x", y="y", hue="a") + ) + for _, sub_df in p.iter_data("hue"): + assert not sub_df.isna().any().any() + + some_missing = False + for _, sub_df in p.iter_data("hue", dropna=False): + some_missing |= sub_df.isna().any().any() + assert some_missing + + def test_axis_labels(self, long_df): + + f, ax = plt.subplots() + + p = VectorPlotter(data=long_df, variables=dict(x="a")) + + p._add_axis_labels(ax) + assert ax.get_xlabel() == "a" + assert ax.get_ylabel() == "" + ax.clear() + + p = VectorPlotter(data=long_df, variables=dict(y="a")) + p._add_axis_labels(ax) + assert ax.get_xlabel() == "" + assert ax.get_ylabel() == "a" + ax.clear() + + p = VectorPlotter(data=long_df, variables=dict(x="a")) + + p._add_axis_labels(ax, default_y="default") + assert ax.get_xlabel() == "a" + assert ax.get_ylabel() == "default" + ax.clear() + + p = VectorPlotter(data=long_df, variables=dict(y="a")) + p._add_axis_labels(ax, default_x="default", default_y="default") + assert ax.get_xlabel() == "default" + assert ax.get_ylabel() == "a" + ax.clear() + + p = VectorPlotter(data=long_df, variables=dict(x="x", y="a")) + ax.set(xlabel="existing", ylabel="also existing") + p._add_axis_labels(ax) + assert ax.get_xlabel() == "existing" + assert ax.get_ylabel() == "also existing" + + f, (ax1, ax2) = plt.subplots(1, 2, sharey=True) + p = VectorPlotter(data=long_df, variables=dict(x="x", y="y")) + + p._add_axis_labels(ax1) + p._add_axis_labels(ax2) + + assert ax1.get_xlabel() == "x" + assert ax1.get_ylabel() == "y" + assert ax1.yaxis.label.get_visible() + + assert ax2.get_xlabel() == "x" + assert ax2.get_ylabel() == "y" + assert not ax2.yaxis.label.get_visible() + + @pytest.mark.parametrize( + "variables", + [ + dict(x="x", y="y"), + dict(x="x"), + dict(y="y"), + dict(x="t", y="y"), + dict(x="x", y="a"), + ] + ) + def test_attach_basics(self, long_df, variables): + + _, ax = plt.subplots() + p = VectorPlotter(data=long_df, variables=variables) + p._attach(ax) + assert p.ax is ax + + def test_attach_disallowed(self, long_df): + + _, ax = plt.subplots() + p = VectorPlotter(data=long_df, variables={"x": "a"}) + + with pytest.raises(TypeError): + p._attach(ax, allowed_types="numeric") + + with pytest.raises(TypeError): + p._attach(ax, allowed_types=["datetime", "numeric"]) + + _, ax = plt.subplots() + p = VectorPlotter(data=long_df, variables={"x": "x"}) + + with pytest.raises(TypeError): + p._attach(ax, allowed_types="categorical") + + _, ax = plt.subplots() + p = VectorPlotter(data=long_df, variables={"x": "x", "y": "t"}) + + with pytest.raises(TypeError): + p._attach(ax, allowed_types=["numeric", "categorical"]) + + def test_attach_log_scale(self, long_df): + + _, ax = plt.subplots() + p = VectorPlotter(data=long_df, variables={"x": "x"}) + p._attach(ax, log_scale=True) + assert ax.xaxis.get_scale() == "log" + assert ax.yaxis.get_scale() == "linear" + assert p._log_scaled("x") + assert not p._log_scaled("y") + + _, ax = plt.subplots() + p = VectorPlotter(data=long_df, variables={"x": "x"}) + p._attach(ax, log_scale=2) + assert ax.xaxis.get_scale() == "log" + assert ax.yaxis.get_scale() == "linear" + assert p._log_scaled("x") + assert not p._log_scaled("y") + + _, ax = plt.subplots() + p = VectorPlotter(data=long_df, variables={"y": "y"}) + p._attach(ax, log_scale=True) + assert ax.xaxis.get_scale() == "linear" + assert ax.yaxis.get_scale() == "log" + assert not p._log_scaled("x") + assert p._log_scaled("y") + + _, ax = plt.subplots() + p = VectorPlotter(data=long_df, variables={"x": "x", "y": "y"}) + p._attach(ax, log_scale=True) + assert ax.xaxis.get_scale() == "log" + assert ax.yaxis.get_scale() == "log" + assert p._log_scaled("x") + assert p._log_scaled("y") + + _, ax = plt.subplots() + p = VectorPlotter(data=long_df, variables={"x": "x", "y": "y"}) + p._attach(ax, log_scale=(True, False)) + assert ax.xaxis.get_scale() == "log" + assert ax.yaxis.get_scale() == "linear" + assert p._log_scaled("x") + assert not p._log_scaled("y") + + _, ax = plt.subplots() + p = VectorPlotter(data=long_df, variables={"x": "x", "y": "y"}) + p._attach(ax, log_scale=(False, 2)) + assert ax.xaxis.get_scale() == "linear" + assert ax.yaxis.get_scale() == "log" + assert not p._log_scaled("x") + assert p._log_scaled("y") + + def test_attach_converters(self, long_df): + + _, ax = plt.subplots() + p = VectorPlotter(data=long_df, variables={"x": "x", "y": "t"}) + p._attach(ax) + assert ax.xaxis.converter is None + assert isinstance(ax.yaxis.converter, mpl.dates.DateConverter) + + _, ax = plt.subplots() + p = VectorPlotter(data=long_df, variables={"x": "a", "y": "y"}) + p._attach(ax) + assert isinstance(ax.xaxis.converter, mpl.category.StrCategoryConverter) + assert ax.yaxis.converter is None + + def test_attach_facets(self, long_df): + + g = FacetGrid(long_df, col="a") + p = VectorPlotter(data=long_df, variables={"x": "x", "col": "a"}) + p._attach(g) + assert p.ax is None + assert p.facets == g + + def test_attach_shared_axes(self, long_df): + + g = FacetGrid(long_df) + p = VectorPlotter(data=long_df, variables={"x": "x", "y": "y"}) + p._attach(g) + assert p.converters["x"].nunique() == 1 + + g = FacetGrid(long_df, col="a") + p = VectorPlotter(data=long_df, variables={"x": "x", "y": "y", "col": "a"}) + p._attach(g) + assert p.converters["x"].nunique() == 1 + assert p.converters["y"].nunique() == 1 + + g = FacetGrid(long_df, col="a", sharex=False) + p = VectorPlotter(data=long_df, variables={"x": "x", "y": "y", "col": "a"}) + p._attach(g) + assert p.converters["x"].nunique() == p.plot_data["col"].nunique() + assert p.converters["x"].groupby(p.plot_data["col"]).nunique().max() == 1 + assert p.converters["y"].nunique() == 1 + + g = FacetGrid(long_df, col="a", sharex=False, col_wrap=2) + p = VectorPlotter(data=long_df, variables={"x": "x", "y": "y", "col": "a"}) + p._attach(g) + assert p.converters["x"].nunique() == p.plot_data["col"].nunique() + assert p.converters["x"].groupby(p.plot_data["col"]).nunique().max() == 1 + assert p.converters["y"].nunique() == 1 + + g = FacetGrid(long_df, col="a", row="b") + p = VectorPlotter( + data=long_df, variables={"x": "x", "y": "y", "col": "a", "row": "b"}, + ) + p._attach(g) + assert p.converters["x"].nunique() == 1 + assert p.converters["y"].nunique() == 1 + + g = FacetGrid(long_df, col="a", row="b", sharex=False) + p = VectorPlotter( + data=long_df, variables={"x": "x", "y": "y", "col": "a", "row": "b"}, + ) + p._attach(g) + assert p.converters["x"].nunique() == len(g.axes.flat) + assert p.converters["y"].nunique() == 1 + + g = FacetGrid(long_df, col="a", row="b", sharex="col") + p = VectorPlotter( + data=long_df, variables={"x": "x", "y": "y", "col": "a", "row": "b"}, + ) + p._attach(g) + assert p.converters["x"].nunique() == p.plot_data["col"].nunique() + assert p.converters["x"].groupby(p.plot_data["col"]).nunique().max() == 1 + assert p.converters["y"].nunique() == 1 + + g = FacetGrid(long_df, col="a", row="b", sharey="row") + p = VectorPlotter( + data=long_df, variables={"x": "x", "y": "y", "col": "a", "row": "b"}, + ) + p._attach(g) + assert p.converters["x"].nunique() == 1 + assert p.converters["y"].nunique() == p.plot_data["row"].nunique() + assert p.converters["y"].groupby(p.plot_data["row"]).nunique().max() == 1 + + def test_get_axes_single(self, long_df): + + ax = plt.figure().subplots() + p = VectorPlotter(data=long_df, variables={"x": "x", "hue": "a"}) + p._attach(ax) + assert p._get_axes({"hue": "a"}) is ax + + def test_get_axes_facets(self, long_df): + + g = FacetGrid(long_df, col="a") + p = VectorPlotter(data=long_df, variables={"x": "x", "col": "a"}) + p._attach(g) + assert p._get_axes({"col": "b"}) is g.axes_dict["b"] + + g = FacetGrid(long_df, col="a", row="c") + p = VectorPlotter( + data=long_df, variables={"x": "x", "col": "a", "row": "c"} + ) + p._attach(g) + assert p._get_axes({"row": 1, "col": "b"}) is g.axes_dict[(1, "b")] + + def test_comp_data(self, long_df): + + p = VectorPlotter(data=long_df, variables={"x": "x", "y": "t"}) + + # We have disabled this check for now, while it remains part of + # the internal API, because it will require updating a number of tests + # with pytest.raises(AttributeError): + # p.comp_data + + _, ax = plt.subplots() + p._attach(ax) + + assert_array_equal(p.comp_data["x"], p.plot_data["x"]) + assert_array_equal( + p.comp_data["y"], ax.yaxis.convert_units(p.plot_data["y"]) + ) + + p = VectorPlotter(data=long_df, variables={"x": "a"}) + + _, ax = plt.subplots() + p._attach(ax) + + assert_array_equal( + p.comp_data["x"], ax.xaxis.convert_units(p.plot_data["x"]) + ) + + def test_comp_data_log(self, long_df): + + p = VectorPlotter(data=long_df, variables={"x": "z", "y": "y"}) + _, ax = plt.subplots() + p._attach(ax, log_scale=(True, False)) + + assert_array_equal( + p.comp_data["x"], np.log10(p.plot_data["x"]) + ) + assert_array_equal(p.comp_data["y"], p.plot_data["y"]) + + def test_comp_data_category_order(self): + + s = (pd.Series(["a", "b", "c", "a"], dtype="category") + .cat.set_categories(["b", "c", "a"], ordered=True)) + + p = VectorPlotter(variables={"x": s}) + _, ax = plt.subplots() + p._attach(ax) + assert_array_equal( + p.comp_data["x"], + [2, 0, 1, 2], + ) + + @pytest.fixture( + params=itertools.product( + [None, np.nan, PD_NA], + ["numeric", "category", "datetime"] + ) + ) + @pytest.mark.parametrize( + "NA,var_type", + ) + def comp_data_missing_fixture(self, request): + + # This fixture holds the logic for parameterizing + # the following test (test_comp_data_missing) + + NA, var_type = request.param + + if NA is None: + pytest.skip("No pandas.NA available") + + comp_data = [0, 1, np.nan, 2, np.nan, 1] + if var_type == "numeric": + orig_data = [0, 1, NA, 2, np.inf, 1] + elif var_type == "category": + orig_data = ["a", "b", NA, "c", NA, "b"] + elif var_type == "datetime": + # Use 1-based numbers to avoid issue on matplotlib<3.2 + # Could simplify the test a bit when we roll off that version + comp_data = [1, 2, np.nan, 3, np.nan, 2] + numbers = [1, 2, 3, 2] + + orig_data = mpl.dates.num2date(numbers) + orig_data.insert(2, NA) + orig_data.insert(4, np.inf) + + return orig_data, comp_data + + def test_comp_data_missing(self, comp_data_missing_fixture): + + orig_data, comp_data = comp_data_missing_fixture + p = VectorPlotter(variables={"x": orig_data}) + ax = plt.figure().subplots() + p._attach(ax) + assert_array_equal(p.comp_data["x"], comp_data) + + def test_var_order(self, long_df): + + order = ["c", "b", "a"] + for var in ["hue", "size", "style"]: + p = VectorPlotter(data=long_df, variables={"x": "x", var: "a"}) + + mapper = getattr(p, f"map_{var}") + mapper(order=order) + + assert p.var_levels[var] == order + + def test_scale_native(self, long_df): + + p = VectorPlotter(data=long_df, variables={"x": "x"}) + with pytest.raises(NotImplementedError): + p.scale_native("x") + + def test_scale_numeric(self, long_df): + + p = VectorPlotter(data=long_df, variables={"y": "y"}) + with pytest.raises(NotImplementedError): + p.scale_numeric("y") + + def test_scale_datetime(self, long_df): + + p = VectorPlotter(data=long_df, variables={"x": "t"}) + with pytest.raises(NotImplementedError): + p.scale_datetime("x") + + def test_scale_categorical(self, long_df): + + p = VectorPlotter(data=long_df, variables={"x": "x"}) + p.scale_categorical("y") + assert p.variables["y"] is None + assert p.var_types["y"] == "categorical" + assert (p.plot_data["y"] == "").all() + + p = VectorPlotter(data=long_df, variables={"x": "s"}) + p.scale_categorical("x") + assert p.var_types["x"] == "categorical" + assert hasattr(p.plot_data["x"], "str") + assert not p._var_ordered["x"] + assert p.plot_data["x"].is_monotonic_increasing + assert_array_equal(p.var_levels["x"], p.plot_data["x"].unique()) + + p = VectorPlotter(data=long_df, variables={"x": "a"}) + p.scale_categorical("x") + assert not p._var_ordered["x"] + assert_array_equal(p.var_levels["x"], categorical_order(long_df["a"])) + + p = VectorPlotter(data=long_df, variables={"x": "a_cat"}) + p.scale_categorical("x") + assert p._var_ordered["x"] + assert_array_equal(p.var_levels["x"], categorical_order(long_df["a_cat"])) + + p = VectorPlotter(data=long_df, variables={"x": "a"}) + order = np.roll(long_df["a"].unique(), 1) + p.scale_categorical("x", order=order) + assert p._var_ordered["x"] + assert_array_equal(p.var_levels["x"], order) + + p = VectorPlotter(data=long_df, variables={"x": "s"}) + p.scale_categorical("x", formatter=lambda x: f"{x:%}") + assert p.plot_data["x"].str.endswith("%").all() + assert all(s.endswith("%") for s in p.var_levels["x"]) + + +class TestCoreFunc: + + def test_unique_dashes(self): + + n = 24 + dashes = unique_dashes(n) + + assert len(dashes) == n + assert len(set(dashes)) == n + assert dashes[0] == "" + for spec in dashes[1:]: + assert isinstance(spec, tuple) + assert not len(spec) % 2 + + def test_unique_markers(self): + + n = 24 + markers = unique_markers(n) + + assert len(markers) == n + assert len(set(markers)) == n + for m in markers: + assert mpl.markers.MarkerStyle(m).is_filled() + + def test_variable_type(self): + + s = pd.Series([1., 2., 3.]) + assert variable_type(s) == "numeric" + assert variable_type(s.astype(int)) == "numeric" + assert variable_type(s.astype(object)) == "numeric" + assert variable_type(s.to_numpy()) == "numeric" + assert variable_type(s.to_list()) == "numeric" + + s = pd.Series([1, 2, 3, np.nan], dtype=object) + assert variable_type(s) == "numeric" + + s = pd.Series([np.nan, np.nan]) + # s = pd.Series([pd.NA, pd.NA]) + assert variable_type(s) == "numeric" + + s = pd.Series(["1", "2", "3"]) + assert variable_type(s) == "categorical" + assert variable_type(s.to_numpy()) == "categorical" + assert variable_type(s.to_list()) == "categorical" + + s = pd.Series([True, False, False]) + assert variable_type(s) == "numeric" + assert variable_type(s, boolean_type="categorical") == "categorical" + s_cat = s.astype("category") + assert variable_type(s_cat, boolean_type="categorical") == "categorical" + assert variable_type(s_cat, boolean_type="numeric") == "categorical" + + s = pd.Series([pd.Timestamp(1), pd.Timestamp(2)]) + assert variable_type(s) == "datetime" + assert variable_type(s.astype(object)) == "datetime" + assert variable_type(s.to_numpy()) == "datetime" + assert variable_type(s.to_list()) == "datetime" + + def test_infer_orient(self): + + nums = pd.Series(np.arange(6)) + cats = pd.Series(["a", "b"] * 3) + dates = pd.date_range("1999-09-22", "2006-05-14", 6) + + assert infer_orient(cats, nums) == "v" + assert infer_orient(nums, cats) == "h" + + assert infer_orient(cats, dates, require_numeric=False) == "v" + assert infer_orient(dates, cats, require_numeric=False) == "h" + + assert infer_orient(nums, None) == "h" + with pytest.warns(UserWarning, match="Vertical .+ `x`"): + assert infer_orient(nums, None, "v") == "h" + + assert infer_orient(None, nums) == "v" + with pytest.warns(UserWarning, match="Horizontal .+ `y`"): + assert infer_orient(None, nums, "h") == "v" + + infer_orient(cats, None, require_numeric=False) == "h" + with pytest.raises(TypeError, match="Horizontal .+ `x`"): + infer_orient(cats, None) + + infer_orient(cats, None, require_numeric=False) == "v" + with pytest.raises(TypeError, match="Vertical .+ `y`"): + infer_orient(None, cats) + + assert infer_orient(nums, nums, "vert") == "v" + assert infer_orient(nums, nums, "hori") == "h" + + assert infer_orient(cats, cats, "h", require_numeric=False) == "h" + assert infer_orient(cats, cats, "v", require_numeric=False) == "v" + assert infer_orient(cats, cats, require_numeric=False) == "v" + + with pytest.raises(TypeError, match="Vertical .+ `y`"): + infer_orient(cats, cats, "v") + with pytest.raises(TypeError, match="Horizontal .+ `x`"): + infer_orient(cats, cats, "h") + with pytest.raises(TypeError, match="Neither"): + infer_orient(cats, cats) + + with pytest.raises(ValueError, match="`orient` must start with"): + infer_orient(cats, nums, orient="bad value") + + def test_categorical_order(self): + + x = ["a", "c", "c", "b", "a", "d"] + y = [3, 2, 5, 1, 4] + order = ["a", "b", "c", "d"] + + out = categorical_order(x) + assert out == ["a", "c", "b", "d"] + + out = categorical_order(x, order) + assert out == order + + out = categorical_order(x, ["b", "a"]) + assert out == ["b", "a"] + + out = categorical_order(np.array(x)) + assert out == ["a", "c", "b", "d"] + + out = categorical_order(pd.Series(x)) + assert out == ["a", "c", "b", "d"] + + out = categorical_order(y) + assert out == [1, 2, 3, 4, 5] + + out = categorical_order(np.array(y)) + assert out == [1, 2, 3, 4, 5] + + out = categorical_order(pd.Series(y)) + assert out == [1, 2, 3, 4, 5] + + x = pd.Categorical(x, order) + out = categorical_order(x) + assert out == list(x.categories) + + x = pd.Series(x) + out = categorical_order(x) + assert out == list(x.cat.categories) + + out = categorical_order(x, ["b", "a"]) + assert out == ["b", "a"] + + x = ["a", np.nan, "c", "c", "b", "a", "d"] + out = categorical_order(x) + assert out == ["a", "c", "b", "d"] diff --git a/seaborn/tests/test_decorators.py b/seaborn/tests/test_decorators.py new file mode 100644 index 0000000000..ab9ebada9c --- /dev/null +++ b/seaborn/tests/test_decorators.py @@ -0,0 +1,108 @@ +import inspect +import pytest +from .._decorators import ( + _deprecate_positional_args, + share_init_params_with_map, +) + + +# This test was adapted from scikit-learn +# github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/tests/test_validation.py +def test_deprecate_positional_args_warns_for_function(): + + @_deprecate_positional_args + def f1(a, b, *, c=1, d=1): + return a, b, c, d + + with pytest.warns( + FutureWarning, + match=r"Pass the following variable as a keyword arg: c\." + ): + assert f1(1, 2, 3) == (1, 2, 3, 1) + + with pytest.warns( + FutureWarning, + match=r"Pass the following variables as keyword args: c, d\." + ): + assert f1(1, 2, 3, 4) == (1, 2, 3, 4) + + @_deprecate_positional_args + def f2(a=1, *, b=1, c=1, d=1): + return a, b, c, d + + with pytest.warns( + FutureWarning, + match=r"Pass the following variable as a keyword arg: b\.", + ): + assert f2(1, 2) == (1, 2, 1, 1) + + # The * is placed before a keyword only argument without a default value + @_deprecate_positional_args + def f3(a, *, b, c=1, d=1): + return a, b, c, d + + with pytest.warns( + FutureWarning, + match=r"Pass the following variable as a keyword arg: b\.", + ): + assert f3(1, 2) == (1, 2, 1, 1) + + +def test_deprecate_positional_args_warns_for_class(): + + class A1: + @_deprecate_positional_args + def __init__(self, a, b, *, c=1, d=1): + self.a = a, b, c, d + + with pytest.warns( + FutureWarning, + match=r"Pass the following variable as a keyword arg: c\." + ): + assert A1(1, 2, 3).a == (1, 2, 3, 1) + + with pytest.warns( + FutureWarning, + match=r"Pass the following variables as keyword args: c, d\." + ): + assert A1(1, 2, 3, 4).a == (1, 2, 3, 4) + + class A2: + @_deprecate_positional_args + def __init__(self, a=1, b=1, *, c=1, d=1): + self.a = a, b, c, d + + with pytest.warns( + FutureWarning, + match=r"Pass the following variable as a keyword arg: c\.", + ): + assert A2(1, 2, 3).a == (1, 2, 3, 1) + + with pytest.warns( + FutureWarning, + match=r"Pass the following variables as keyword args: c, d\.", + ): + assert A2(1, 2, 3, 4).a == (1, 2, 3, 4) + + +def test_share_init_params_with_map(): + + @share_init_params_with_map + class Thingie: + + def map(cls, *args, **kwargs): + return cls(*args, **kwargs) + + def __init__(self, a, b=1): + """Make a new thingie.""" + self.a = a + self.b = b + + thingie = Thingie.map(1, b=2) + assert thingie.a == 1 + assert thingie.b == 2 + + assert "a" in inspect.signature(Thingie.map).parameters + assert "b" in inspect.signature(Thingie.map).parameters + + assert Thingie.map.__doc__ == Thingie.__init__.__doc__ diff --git a/seaborn/tests/test_distributions.py b/seaborn/tests/test_distributions.py index 4cdd6a746b..a11b2aa971 100644 --- a/seaborn/tests/test_distributions.py +++ b/seaborn/tests/test_distributions.py @@ -1,224 +1,2319 @@ +import itertools +from distutils.version import LooseVersion + import numpy as np -import pandas as pd +import matplotlib as mpl import matplotlib.pyplot as plt +from matplotlib.colors import to_rgb, to_rgba import pytest -import nose.tools as nt -import numpy.testing as npt +from numpy.testing import assert_array_equal, assert_array_almost_equal from .. import distributions as dist - -try: - import statsmodels.nonparametric.api - assert statsmodels.nonparametric.api - _no_statsmodels = False -except ImportError: - _no_statsmodels = True +from ..palettes import ( + color_palette, + light_palette, +) +from .._core import ( + categorical_order, +) +from .._statistics import ( + KDE, + Histogram, + _no_scipy, +) +from ..distributions import ( + _DistributionPlotter, + displot, + distplot, + histplot, + ecdfplot, + kdeplot, + rugplot, +) +from ..axisgrid import FacetGrid +from .._testing import ( + assert_plots_equal, + assert_legends_equal, + assert_colors_equal, +) -class TestKDE(object): +class TestDistPlot(object): rs = np.random.RandomState(0) - x = rs.randn(50) - y = rs.randn(50) - kernel = "gau" - bw = "scott" - gridsize = 128 - clip = (-np.inf, np.inf) - cut = 3 - - def test_scipy_univariate_kde(self): - """Test the univariate KDE estimation with scipy.""" - grid, y = dist._scipy_univariate_kde(self.x, self.bw, self.gridsize, - self.cut, self.clip) - nt.assert_equal(len(grid), self.gridsize) - nt.assert_equal(len(y), self.gridsize) - for bw in ["silverman", .2]: - dist._scipy_univariate_kde(self.x, bw, self.gridsize, - self.cut, self.clip) - - @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels") - def test_statsmodels_univariate_kde(self): - """Test the univariate KDE estimation with statsmodels.""" - grid, y = dist._statsmodels_univariate_kde(self.x, self.kernel, - self.bw, self.gridsize, - self.cut, self.clip) - nt.assert_equal(len(grid), self.gridsize) - nt.assert_equal(len(y), self.gridsize) - for bw in ["silverman", .2]: - dist._statsmodels_univariate_kde(self.x, self.kernel, bw, - self.gridsize, self.cut, - self.clip) - - def test_scipy_bivariate_kde(self): - """Test the bivariate KDE estimation with scipy.""" - clip = [self.clip, self.clip] - x, y, z = dist._scipy_bivariate_kde(self.x, self.y, self.bw, - self.gridsize, self.cut, clip) - nt.assert_equal(x.shape, (self.gridsize, self.gridsize)) - nt.assert_equal(y.shape, (self.gridsize, self.gridsize)) - nt.assert_equal(len(z), self.gridsize) - - # Test a specific bandwidth - clip = [self.clip, self.clip] - x, y, z = dist._scipy_bivariate_kde(self.x, self.y, 1, - self.gridsize, self.cut, clip) - - # Test that we get an error with an invalid bandwidth - with nt.assert_raises(ValueError): - dist._scipy_bivariate_kde(self.x, self.y, (1, 2), - self.gridsize, self.cut, clip) - - @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels") - def test_statsmodels_bivariate_kde(self): - """Test the bivariate KDE estimation with statsmodels.""" - clip = [self.clip, self.clip] - x, y, z = dist._statsmodels_bivariate_kde(self.x, self.y, self.bw, - self.gridsize, - self.cut, clip) - nt.assert_equal(x.shape, (self.gridsize, self.gridsize)) - nt.assert_equal(y.shape, (self.gridsize, self.gridsize)) - nt.assert_equal(len(z), self.gridsize) - - @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels") - def test_statsmodels_kde_cumulative(self): - """Test computation of cumulative KDE.""" - grid, y = dist._statsmodels_univariate_kde(self.x, self.kernel, - self.bw, self.gridsize, - self.cut, self.clip, - cumulative=True) - nt.assert_equal(len(grid), self.gridsize) - nt.assert_equal(len(y), self.gridsize) - # make sure y is monotonically increasing - npt.assert_((np.diff(y) > 0).all()) - - def test_kde_cummulative_2d(self): - """Check error if args indicate bivariate KDE and cumulative.""" - with npt.assert_raises(TypeError): - dist.kdeplot(self.x, data2=self.y, cumulative=True) - - def test_kde_singular(self): + x = rs.randn(100) + + def test_hist_bins(self): + + fd_edges = np.histogram_bin_edges(self.x, "fd") + with pytest.warns(FutureWarning): + ax = distplot(self.x) + for edge, bar in zip(fd_edges, ax.patches): + assert pytest.approx(edge) == bar.get_x() + + plt.close(ax.figure) + n = 25 + n_edges = np.histogram_bin_edges(self.x, n) + with pytest.warns(FutureWarning): + ax = distplot(self.x, bins=n) + for edge, bar in zip(n_edges, ax.patches): + assert pytest.approx(edge) == bar.get_x() + + def test_elements(self): + + with pytest.warns(FutureWarning): + + n = 10 + ax = distplot(self.x, bins=n, + hist=True, kde=False, rug=False, fit=None) + assert len(ax.patches) == 10 + assert len(ax.lines) == 0 + assert len(ax.collections) == 0 + + plt.close(ax.figure) + ax = distplot(self.x, + hist=False, kde=True, rug=False, fit=None) + assert len(ax.patches) == 0 + assert len(ax.lines) == 1 + assert len(ax.collections) == 0 + + plt.close(ax.figure) + ax = distplot(self.x, + hist=False, kde=False, rug=True, fit=None) + assert len(ax.patches) == 0 + assert len(ax.lines) == 0 + assert len(ax.collections) == 1 + + class Norm: + """Dummy object that looks like a scipy RV""" + def fit(self, x): + return () + + def pdf(self, x, *params): + return np.zeros_like(x) + + plt.close(ax.figure) + ax = distplot( + self.x, hist=False, kde=False, rug=False, fit=Norm()) + assert len(ax.patches) == 0 + assert len(ax.lines) == 1 + assert len(ax.collections) == 0 + + def test_distplot_with_nans(self): + + f, (ax1, ax2) = plt.subplots(2) + x_null = np.append(self.x, [np.nan]) + + with pytest.warns(FutureWarning): + distplot(self.x, ax=ax1) + distplot(x_null, ax=ax2) + + line1 = ax1.lines[0] + line2 = ax2.lines[0] + assert np.array_equal(line1.get_xydata(), line2.get_xydata()) + + for bar1, bar2 in zip(ax1.patches, ax2.patches): + assert bar1.get_xy() == bar2.get_xy() + assert bar1.get_height() == bar2.get_height() + + +class SharedAxesLevelTests: + + def test_color(self, long_df, **kwargs): + + ax = plt.figure().subplots() + self.func(data=long_df, x="y", ax=ax, **kwargs) + assert_colors_equal(self.get_last_color(ax, **kwargs), "C0", check_alpha=False) + + ax = plt.figure().subplots() + self.func(data=long_df, x="y", ax=ax, **kwargs) + self.func(data=long_df, x="y", ax=ax, **kwargs) + assert_colors_equal(self.get_last_color(ax, **kwargs), "C1", check_alpha=False) + + ax = plt.figure().subplots() + self.func(data=long_df, x="y", color="C2", ax=ax, **kwargs) + assert_colors_equal(self.get_last_color(ax, **kwargs), "C2", check_alpha=False) + + +class TestRugPlot(SharedAxesLevelTests): + + func = staticmethod(rugplot) + + def get_last_color(self, ax, **kwargs): + + return ax.collections[-1].get_color() + + def assert_rug_equal(self, a, b): + + assert_array_equal(a.get_segments(), b.get_segments()) + + @pytest.mark.parametrize("variable", ["x", "y"]) + def test_long_data(self, long_df, variable): + + vector = long_df[variable] + vectors = [ + variable, vector, np.asarray(vector), vector.to_list(), + ] + + f, ax = plt.subplots() + for vector in vectors: + rugplot(data=long_df, **{variable: vector}) + + for a, b in itertools.product(ax.collections, ax.collections): + self.assert_rug_equal(a, b) + + def test_bivariate_data(self, long_df): + + f, (ax1, ax2) = plt.subplots(ncols=2) + + rugplot(data=long_df, x="x", y="y", ax=ax1) + rugplot(data=long_df, x="x", ax=ax2) + rugplot(data=long_df, y="y", ax=ax2) + + self.assert_rug_equal(ax1.collections[0], ax2.collections[0]) + self.assert_rug_equal(ax1.collections[1], ax2.collections[1]) + + def test_wide_vs_long_data(self, wide_df): + + f, (ax1, ax2) = plt.subplots(ncols=2) + rugplot(data=wide_df, ax=ax1) + for col in wide_df: + rugplot(data=wide_df, x=col, ax=ax2) + + wide_segments = np.sort( + np.array(ax1.collections[0].get_segments()) + ) + long_segments = np.sort( + np.concatenate([c.get_segments() for c in ax2.collections]) + ) + + assert_array_equal(wide_segments, long_segments) + + def test_flat_vector(self, long_df): + + f, ax = plt.subplots() + rugplot(data=long_df["x"]) + rugplot(x=long_df["x"]) + self.assert_rug_equal(*ax.collections) + + def test_datetime_data(self, long_df): + + ax = rugplot(data=long_df["t"]) + vals = np.stack(ax.collections[0].get_segments())[:, 0, 0] + assert_array_equal(vals, mpl.dates.date2num(long_df["t"])) + + def test_empty_data(self): + + ax = rugplot(x=[]) + assert not ax.collections + + def test_a_deprecation(self, flat_series): + + f, ax = plt.subplots() + + with pytest.warns(FutureWarning): + rugplot(a=flat_series) + rugplot(x=flat_series) + + self.assert_rug_equal(*ax.collections) + + @pytest.mark.parametrize("variable", ["x", "y"]) + def test_axis_deprecation(self, flat_series, variable): + + f, ax = plt.subplots() + + with pytest.warns(FutureWarning): + rugplot(flat_series, axis=variable) + rugplot(**{variable: flat_series}) + + self.assert_rug_equal(*ax.collections) + + def test_vertical_deprecation(self, flat_series): + + f, ax = plt.subplots() + + with pytest.warns(FutureWarning): + rugplot(flat_series, vertical=True) + rugplot(y=flat_series) + + self.assert_rug_equal(*ax.collections) + + def test_rug_data(self, flat_array): + + height = .05 + ax = rugplot(x=flat_array, height=height) + segments = np.stack(ax.collections[0].get_segments()) + + n = flat_array.size + assert_array_equal(segments[:, 0, 1], np.zeros(n)) + assert_array_equal(segments[:, 1, 1], np.full(n, height)) + assert_array_equal(segments[:, 1, 0], flat_array) + + def test_rug_colors(self, long_df): + + ax = rugplot(data=long_df, x="x", hue="a") + + order = categorical_order(long_df["a"]) + palette = color_palette() + + expected_colors = np.ones((len(long_df), 4)) + for i, val in enumerate(long_df["a"]): + expected_colors[i, :3] = palette[order.index(val)] + + assert_array_equal(ax.collections[0].get_color(), expected_colors) + + def test_expand_margins(self, flat_array): + + f, ax = plt.subplots() + x1, y1 = ax.margins() + rugplot(x=flat_array, expand_margins=False) + x2, y2 = ax.margins() + assert x1 == x2 + assert y1 == y2 + + f, ax = plt.subplots() + x1, y1 = ax.margins() + height = .05 + rugplot(x=flat_array, height=height) + x2, y2 = ax.margins() + assert x1 == x2 + assert y1 + height * 2 == pytest.approx(y2) + + def test_matplotlib_kwargs(self, flat_series): + + lw = 2 + alpha = .2 + ax = rugplot(y=flat_series, linewidth=lw, alpha=alpha) + rug = ax.collections[0] + assert np.all(rug.get_alpha() == alpha) + assert np.all(rug.get_linewidth() == lw) + + def test_axis_labels(self, flat_series): + + ax = rugplot(x=flat_series) + assert ax.get_xlabel() == flat_series.name + assert not ax.get_ylabel() + + def test_log_scale(self, long_df): + + ax1, ax2 = plt.figure().subplots(2) + + ax2.set_xscale("log") + + rugplot(data=long_df, x="z", ax=ax1) + rugplot(data=long_df, x="z", ax=ax2) + + rug1 = np.stack(ax1.collections[0].get_segments()) + rug2 = np.stack(ax2.collections[0].get_segments()) + + assert_array_almost_equal(rug1, rug2) + + +class TestKDEPlotUnivariate(SharedAxesLevelTests): + + func = staticmethod(kdeplot) + + def get_last_color(self, ax, fill=True): + + if fill: + return ax.collections[-1].get_facecolor() + else: + return ax.lines[-1].get_color() + + @pytest.mark.parametrize("fill", [True, False]) + def test_color(self, long_df, fill): + + super().test_color(long_df, fill=fill) + + if fill: + + ax = plt.figure().subplots() + self.func(data=long_df, x="y", facecolor="C3", fill=True, ax=ax) + assert_colors_equal(self.get_last_color(ax), "C3", check_alpha=False) + + ax = plt.figure().subplots() + self.func(data=long_df, x="y", fc="C4", fill=True, ax=ax) + assert_colors_equal(self.get_last_color(ax), "C4", check_alpha=False) + + @pytest.mark.parametrize( + "variable", ["x", "y"], + ) + def test_long_vectors(self, long_df, variable): + + vector = long_df[variable] + vectors = [ + variable, vector, vector.to_numpy(), vector.to_list(), + ] + + f, ax = plt.subplots() + for vector in vectors: + kdeplot(data=long_df, **{variable: vector}) + + xdata = [l.get_xdata() for l in ax.lines] + for a, b in itertools.product(xdata, xdata): + assert_array_equal(a, b) + + ydata = [l.get_ydata() for l in ax.lines] + for a, b in itertools.product(ydata, ydata): + assert_array_equal(a, b) + + def test_wide_vs_long_data(self, wide_df): + + f, (ax1, ax2) = plt.subplots(ncols=2) + kdeplot(data=wide_df, ax=ax1, common_norm=False, common_grid=False) + for col in wide_df: + kdeplot(data=wide_df, x=col, ax=ax2) + + for l1, l2 in zip(ax1.lines[::-1], ax2.lines): + assert_array_equal(l1.get_xydata(), l2.get_xydata()) + + def test_flat_vector(self, long_df): + + f, ax = plt.subplots() + kdeplot(data=long_df["x"]) + kdeplot(x=long_df["x"]) + assert_array_equal(ax.lines[0].get_xydata(), ax.lines[1].get_xydata()) + + def test_empty_data(self): + + ax = kdeplot(x=[]) + assert not ax.lines + + def test_singular_data(self): with pytest.warns(UserWarning): - ax = dist.kdeplot(np.ones(10)) - line = ax.lines[0] - assert not line.get_xydata().size + ax = kdeplot(x=np.ones(10)) + assert not ax.lines with pytest.warns(UserWarning): - ax = dist.kdeplot(np.ones(10) * np.nan) - line = ax.lines[1] - assert not line.get_xydata().size + ax = kdeplot(x=[5]) + assert not ax.lines + + def test_variable_assignment(self, long_df): - def test_bivariate_kde_series(self): - df = pd.DataFrame({'x': self.x, 'y': self.y}) + f, ax = plt.subplots() + kdeplot(data=long_df, x="x", fill=True) + kdeplot(data=long_df, y="x", fill=True) - ax_series = dist.kdeplot(df.x, df.y) - ax_values = dist.kdeplot(df.x.values, df.y.values) + v0 = ax.collections[0].get_paths()[0].vertices + v1 = ax.collections[1].get_paths()[0].vertices[:, [1, 0]] - nt.assert_equal(len(ax_series.collections), - len(ax_values.collections)) - nt.assert_equal(ax_series.collections[0].get_paths(), - ax_values.collections[0].get_paths()) + assert_array_equal(v0, v1) - def test_bivariate_kde_colorbar(self): + def test_vertical_deprecation(self, long_df): f, ax = plt.subplots() - dist.kdeplot(self.x, self.y, - cbar=True, cbar_kws=dict(label="density"), - ax=ax) - nt.assert_equal(len(f.axes), 2) - nt.assert_equal(f.axes[1].get_ylabel(), "density") + kdeplot(data=long_df, y="x") - def test_legend(self): + with pytest.warns(FutureWarning): + kdeplot(data=long_df, x="x", vertical=True) - f, ax = plt.subplots() - dist.kdeplot(self.x, self.y, label="test1") - line = ax.lines[-1] - assert line.get_label() == "test1" + assert_array_equal(ax.lines[0].get_xydata(), ax.lines[1].get_xydata()) + + def test_bw_deprecation(self, long_df): f, ax = plt.subplots() - dist.kdeplot(self.x, self.y, shade=True, label="test2") - fill = ax.collections[-1] - assert fill.get_label() == "test2" + kdeplot(data=long_df, x="x", bw_method="silverman") + + with pytest.warns(FutureWarning): + kdeplot(data=long_df, x="x", bw="silverman") - def test_contour_color(self): + assert_array_equal(ax.lines[0].get_xydata(), ax.lines[1].get_xydata()) + + def test_kernel_deprecation(self, long_df): - rgb = (.1, .5, .7) f, ax = plt.subplots() + kdeplot(data=long_df, x="x") + + with pytest.warns(UserWarning): + kdeplot(data=long_df, x="x", kernel="epi") + + assert_array_equal(ax.lines[0].get_xydata(), ax.lines[1].get_xydata()) - dist.kdeplot(self.x, self.y, color=rgb) - contour = ax.collections[-1] - assert np.array_equal(contour.get_color()[0, :3], rgb) - low = ax.collections[0].get_color().mean() - high = ax.collections[-1].get_color().mean() - assert low < high + def test_shade_deprecation(self, long_df): f, ax = plt.subplots() - dist.kdeplot(self.x, self.y, shade=True, color=rgb) - contour = ax.collections[-1] - low = ax.collections[0].get_facecolor().mean() - high = ax.collections[-1].get_facecolor().mean() - assert low > high + kdeplot(data=long_df, x="x", shade=True) + kdeplot(data=long_df, x="x", fill=True) + fill1, fill2 = ax.collections + assert_array_equal( + fill1.get_paths()[0].vertices, fill2.get_paths()[0].vertices + ) + @pytest.mark.parametrize("multiple", ["layer", "stack", "fill"]) + def test_hue_colors(self, long_df, multiple): -class TestRugPlot(object): + ax = kdeplot( + data=long_df, x="x", hue="a", + multiple=multiple, + fill=True, legend=False + ) - @pytest.fixture - def list_data(self): - return np.random.randn(20).tolist() + # Note that hue order is reversed in the plot + lines = ax.lines[::-1] + fills = ax.collections[::-1] - @pytest.fixture - def array_data(self): - return np.random.randn(20) + palette = color_palette() - @pytest.fixture - def series_data(self): - return pd.Series(np.random.randn(20)) + for line, fill, color in zip(lines, fills, palette): + assert_colors_equal(line.get_color(), color) + assert_colors_equal(fill.get_facecolor(), to_rgba(color, .25)) - def test_rugplot(self, list_data, array_data, series_data): + def test_hue_stacking(self, long_df): - h = .1 + f, (ax1, ax2) = plt.subplots(ncols=2) - for data in [list_data, array_data, series_data]: + kdeplot( + data=long_df, x="x", hue="a", + multiple="layer", common_grid=True, + legend=False, ax=ax1, + ) + kdeplot( + data=long_df, x="x", hue="a", + multiple="stack", fill=False, + legend=False, ax=ax2, + ) - f, ax = plt.subplots() - dist.rugplot(data, h) - rug, = ax.collections - segments = np.array(rug.get_segments()) + layered_densities = np.stack([ + l.get_ydata() for l in ax1.lines + ]) + stacked_densities = np.stack([ + l.get_ydata() for l in ax2.lines + ]) - assert len(segments) == len(data) - assert np.array_equal(segments[:, 0, 0], data) - assert np.array_equal(segments[:, 1, 0], data) - assert np.array_equal(segments[:, 0, 1], np.zeros_like(data)) - assert np.array_equal(segments[:, 1, 1], np.ones_like(data) * h) + assert_array_equal(layered_densities.cumsum(axis=0), stacked_densities) - plt.close(f) + def test_hue_filling(self, long_df): - f, ax = plt.subplots() - dist.rugplot(data, h, axis="y") - rug, = ax.collections - segments = np.array(rug.get_segments()) + f, (ax1, ax2) = plt.subplots(ncols=2) + + kdeplot( + data=long_df, x="x", hue="a", + multiple="layer", common_grid=True, + legend=False, ax=ax1, + ) + kdeplot( + data=long_df, x="x", hue="a", + multiple="fill", fill=False, + legend=False, ax=ax2, + ) + + layered = np.stack([l.get_ydata() for l in ax1.lines]) + filled = np.stack([l.get_ydata() for l in ax2.lines]) - assert len(segments) == len(data) - assert np.array_equal(segments[:, 0, 1], data) - assert np.array_equal(segments[:, 1, 1], data) - assert np.array_equal(segments[:, 0, 0], np.zeros_like(data)) - assert np.array_equal(segments[:, 1, 0], np.ones_like(data) * h) + assert_array_almost_equal( + (layered / layered.sum(axis=0)).cumsum(axis=0), + filled, + ) - plt.close(f) + @pytest.mark.parametrize("multiple", ["stack", "fill"]) + def test_fill_default(self, long_df, multiple): + + ax = kdeplot( + data=long_df, x="x", hue="a", multiple=multiple, fill=None + ) + + assert len(ax.collections) > 0 + + @pytest.mark.parametrize("multiple", ["layer", "stack", "fill"]) + def test_fill_nondefault(self, long_df, multiple): + + f, (ax1, ax2) = plt.subplots(ncols=2) + + kws = dict(data=long_df, x="x", hue="a") + kdeplot(**kws, multiple=multiple, fill=False, ax=ax1) + kdeplot(**kws, multiple=multiple, fill=True, ax=ax2) + + assert len(ax1.collections) == 0 + assert len(ax2.collections) > 0 + + def test_color_cycle_interaction(self, flat_series): + + color = (.2, 1, .6) f, ax = plt.subplots() - dist.rugplot(data, axis="y") - dist.rugplot(data, vertical=True) - c1, c2 = ax.collections - assert np.array_equal(c1.get_segments(), c2.get_segments()) + kdeplot(flat_series) + kdeplot(flat_series) + assert_colors_equal(ax.lines[0].get_color(), "C0") + assert_colors_equal(ax.lines[1].get_color(), "C1") plt.close(f) f, ax = plt.subplots() - dist.rugplot(data) - dist.rugplot(data, lw=2) - dist.rugplot(data, linewidth=3, alpha=.5) - for c, lw in zip(ax.collections, [1, 2, 3]): - assert np.squeeze(c.get_linewidth()).item() == lw - assert c.get_alpha() == .5 + kdeplot(flat_series, color=color) + kdeplot(flat_series) + assert_colors_equal(ax.lines[0].get_color(), color) + assert_colors_equal(ax.lines[1].get_color(), "C0") plt.close(f) + + f, ax = plt.subplots() + kdeplot(flat_series, fill=True) + kdeplot(flat_series, fill=True) + assert_colors_equal(ax.collections[0].get_facecolor(), to_rgba("C0", .25)) + assert_colors_equal(ax.collections[1].get_facecolor(), to_rgba("C1", .25)) + plt.close(f) + + @pytest.mark.parametrize("fill", [True, False]) + def test_artist_color(self, long_df, fill): + + color = (.2, 1, .6) + alpha = .5 + + f, ax = plt.subplots() + + kdeplot(long_df["x"], fill=fill, color=color) + if fill: + artist_color = ax.collections[-1].get_facecolor().squeeze() + else: + artist_color = ax.lines[-1].get_color() + default_alpha = .25 if fill else 1 + assert_colors_equal(artist_color, to_rgba(color, default_alpha)) + + kdeplot(long_df["x"], fill=fill, color=color, alpha=alpha) + if fill: + artist_color = ax.collections[-1].get_facecolor().squeeze() + else: + artist_color = ax.lines[-1].get_color() + assert_colors_equal(artist_color, to_rgba(color, alpha)) + + def test_datetime_scale(self, long_df): + + f, (ax1, ax2) = plt.subplots(2) + kdeplot(x=long_df["t"], fill=True, ax=ax1) + kdeplot(x=long_df["t"], fill=False, ax=ax2) + assert ax1.get_xlim() == ax2.get_xlim() + + def test_multiple_argument_check(self, long_df): + + with pytest.raises(ValueError, match="`multiple` must be"): + kdeplot(data=long_df, x="x", hue="a", multiple="bad_input") + + def test_cut(self, rng): + + x = rng.normal(0, 3, 1000) + + f, ax = plt.subplots() + kdeplot(x=x, cut=0, legend=False) + + xdata_0 = ax.lines[0].get_xdata() + assert xdata_0.min() == x.min() + assert xdata_0.max() == x.max() + + kdeplot(x=x, cut=2, legend=False) + + xdata_2 = ax.lines[1].get_xdata() + assert xdata_2.min() < xdata_0.min() + assert xdata_2.max() > xdata_0.max() + + assert len(xdata_0) == len(xdata_2) + + def test_clip(self, rng): + + x = rng.normal(0, 3, 1000) + + clip = -1, 1 + ax = kdeplot(x=x, clip=clip) + + xdata = ax.lines[0].get_xdata() + + assert xdata.min() >= clip[0] + assert xdata.max() <= clip[1] + + def test_line_is_density(self, long_df): + + ax = kdeplot(data=long_df, x="x", cut=5) + x, y = ax.lines[0].get_xydata().T + assert integrate(y, x) == pytest.approx(1) + + @pytest.mark.skipif(_no_scipy, reason="Test requires scipy") + def test_cumulative(self, long_df): + + ax = kdeplot(data=long_df, x="x", cut=5, cumulative=True) + y = ax.lines[0].get_ydata() + assert y[0] == pytest.approx(0) + assert y[-1] == pytest.approx(1) + + @pytest.mark.skipif(not _no_scipy, reason="Test requires scipy's absence") + def test_cumulative_requires_scipy(self, long_df): + + with pytest.raises(RuntimeError): + kdeplot(data=long_df, x="x", cut=5, cumulative=True) + + def test_common_norm(self, long_df): + + f, (ax1, ax2) = plt.subplots(ncols=2) + + kdeplot( + data=long_df, x="x", hue="c", common_norm=True, cut=10, ax=ax1 + ) + kdeplot( + data=long_df, x="x", hue="c", common_norm=False, cut=10, ax=ax2 + ) + + total_area = 0 + for line in ax1.lines: + xdata, ydata = line.get_xydata().T + total_area += integrate(ydata, xdata) + assert total_area == pytest.approx(1) + + for line in ax2.lines: + xdata, ydata = line.get_xydata().T + assert integrate(ydata, xdata) == pytest.approx(1) + + def test_common_grid(self, long_df): + + f, (ax1, ax2) = plt.subplots(ncols=2) + + order = "a", "b", "c" + + kdeplot( + data=long_df, x="x", hue="a", hue_order=order, + common_grid=False, cut=0, ax=ax1, + ) + kdeplot( + data=long_df, x="x", hue="a", hue_order=order, + common_grid=True, cut=0, ax=ax2, + ) + + for line, level in zip(ax1.lines[::-1], order): + xdata = line.get_xdata() + assert xdata.min() == long_df.loc[long_df["a"] == level, "x"].min() + assert xdata.max() == long_df.loc[long_df["a"] == level, "x"].max() + + for line in ax2.lines: + xdata = line.get_xdata().T + assert xdata.min() == long_df["x"].min() + assert xdata.max() == long_df["x"].max() + + def test_bw_method(self, long_df): + + f, ax = plt.subplots() + kdeplot(data=long_df, x="x", bw_method=0.2, legend=False) + kdeplot(data=long_df, x="x", bw_method=1.0, legend=False) + kdeplot(data=long_df, x="x", bw_method=3.0, legend=False) + + l1, l2, l3 = ax.lines + + assert ( + np.abs(np.diff(l1.get_ydata())).mean() + > np.abs(np.diff(l2.get_ydata())).mean() + ) + + assert ( + np.abs(np.diff(l2.get_ydata())).mean() + > np.abs(np.diff(l3.get_ydata())).mean() + ) + + def test_bw_adjust(self, long_df): + + f, ax = plt.subplots() + kdeplot(data=long_df, x="x", bw_adjust=0.2, legend=False) + kdeplot(data=long_df, x="x", bw_adjust=1.0, legend=False) + kdeplot(data=long_df, x="x", bw_adjust=3.0, legend=False) + + l1, l2, l3 = ax.lines + + assert ( + np.abs(np.diff(l1.get_ydata())).mean() + > np.abs(np.diff(l2.get_ydata())).mean() + ) + + assert ( + np.abs(np.diff(l2.get_ydata())).mean() + > np.abs(np.diff(l3.get_ydata())).mean() + ) + + def test_log_scale_implicit(self, rng): + + x = rng.lognormal(0, 1, 100) + + f, (ax1, ax2) = plt.subplots(ncols=2) + ax1.set_xscale("log") + + kdeplot(x=x, ax=ax1) + kdeplot(x=x, ax=ax1) + + xdata_log = ax1.lines[0].get_xdata() + assert (xdata_log > 0).all() + assert (np.diff(xdata_log, 2) > 0).all() + assert np.allclose(np.diff(np.log(xdata_log), 2), 0) + + f, ax = plt.subplots() + ax.set_yscale("log") + kdeplot(y=x, ax=ax) + assert_array_equal(ax.lines[0].get_xdata(), ax1.lines[0].get_ydata()) + + def test_log_scale_explicit(self, rng): + + x = rng.lognormal(0, 1, 100) + + f, (ax1, ax2, ax3) = plt.subplots(ncols=3) + + ax1.set_xscale("log") + kdeplot(x=x, ax=ax1) + kdeplot(x=x, log_scale=True, ax=ax2) + kdeplot(x=x, log_scale=10, ax=ax3) + + for ax in f.axes: + assert ax.get_xscale() == "log" + + supports = [ax.lines[0].get_xdata() for ax in f.axes] + for a, b in itertools.product(supports, supports): + assert_array_equal(a, b) + + densities = [ax.lines[0].get_ydata() for ax in f.axes] + for a, b in itertools.product(densities, densities): + assert_array_equal(a, b) + + f, ax = plt.subplots() + kdeplot(y=x, log_scale=True, ax=ax) + assert ax.get_yscale() == "log" + + def test_log_scale_with_hue(self, rng): + + data = rng.lognormal(0, 1, 50), rng.lognormal(0, 2, 100) + ax = kdeplot(data=data, log_scale=True, common_grid=True) + assert_array_equal(ax.lines[0].get_xdata(), ax.lines[1].get_xdata()) + + def test_log_scale_normalization(self, rng): + + x = rng.lognormal(0, 1, 100) + ax = kdeplot(x=x, log_scale=True, cut=10) + xdata, ydata = ax.lines[0].get_xydata().T + integral = integrate(ydata, np.log10(xdata)) + assert integral == pytest.approx(1) + + def test_weights(self): + + x = [1, 2] + weights = [2, 1] + + ax = kdeplot(x=x, weights=weights, bw_method=.1) + + xdata, ydata = ax.lines[0].get_xydata().T + + y1 = ydata[np.abs(xdata - 1).argmin()] + y2 = ydata[np.abs(xdata - 2).argmin()] + + assert y1 == pytest.approx(2 * y2) + + def test_sticky_edges(self, long_df): + + f, (ax1, ax2) = plt.subplots(ncols=2) + + kdeplot(data=long_df, x="x", fill=True, ax=ax1) + assert ax1.collections[0].sticky_edges.y[:] == [0, np.inf] + + kdeplot( + data=long_df, x="x", hue="a", multiple="fill", fill=True, ax=ax2 + ) + assert ax2.collections[0].sticky_edges.y[:] == [0, 1] + + def test_line_kws(self, flat_array): + + lw = 3 + color = (.2, .5, .8) + ax = kdeplot(x=flat_array, linewidth=lw, color=color) + line, = ax.lines + assert line.get_linewidth() == lw + assert_colors_equal(line.get_color(), color) + + def test_input_checking(self, long_df): + + err = "The x variable is categorical," + with pytest.raises(TypeError, match=err): + kdeplot(data=long_df, x="a") + + def test_axis_labels(self, long_df): + + f, (ax1, ax2) = plt.subplots(ncols=2) + + kdeplot(data=long_df, x="x", ax=ax1) + assert ax1.get_xlabel() == "x" + assert ax1.get_ylabel() == "Density" + + kdeplot(data=long_df, y="y", ax=ax2) + assert ax2.get_xlabel() == "Density" + assert ax2.get_ylabel() == "y" + + def test_legend(self, long_df): + + ax = kdeplot(data=long_df, x="x", hue="a") + + assert ax.legend_.get_title().get_text() == "a" + + legend_labels = ax.legend_.get_texts() + order = categorical_order(long_df["a"]) + for label, level in zip(legend_labels, order): + assert label.get_text() == level + + legend_artists = ax.legend_.findobj(mpl.lines.Line2D)[::2] + palette = color_palette() + for artist, color in zip(legend_artists, palette): + assert_colors_equal(artist.get_color(), color) + + ax.clear() + + kdeplot(data=long_df, x="x", hue="a", legend=False) + + assert ax.legend_ is None + + +class TestKDEPlotBivariate: + + def test_long_vectors(self, long_df): + + ax1 = kdeplot(data=long_df, x="x", y="y") + + x = long_df["x"] + x_values = [x, x.to_numpy(), x.to_list()] + + y = long_df["y"] + y_values = [y, y.to_numpy(), y.to_list()] + + for x, y in zip(x_values, y_values): + f, ax2 = plt.subplots() + kdeplot(x=x, y=y, ax=ax2) + + for c1, c2 in zip(ax1.collections, ax2.collections): + assert_array_equal(c1.get_offsets(), c2.get_offsets()) + + def test_singular_data(self): + + with pytest.warns(UserWarning): + ax = dist.kdeplot(x=np.ones(10), y=np.arange(10)) + assert not ax.lines + + with pytest.warns(UserWarning): + ax = dist.kdeplot(x=[5], y=[6]) + assert not ax.lines + + def test_fill_artists(self, long_df): + + for fill in [True, False]: + f, ax = plt.subplots() + kdeplot(data=long_df, x="x", y="y", hue="c", fill=fill) + for c in ax.collections: + if fill: + assert isinstance(c, mpl.collections.PathCollection) + else: + assert isinstance(c, mpl.collections.LineCollection) + + def test_common_norm(self, rng): + + hue = np.repeat(["a", "a", "a", "b"], 40) + x, y = rng.multivariate_normal([0, 0], [(.2, .5), (.5, 2)], len(hue)).T + x[hue == "a"] -= 2 + x[hue == "b"] += 2 + + f, (ax1, ax2) = plt.subplots(ncols=2) + kdeplot(x=x, y=y, hue=hue, common_norm=True, ax=ax1) + kdeplot(x=x, y=y, hue=hue, common_norm=False, ax=ax2) + + n_seg_1 = sum([len(c.get_segments()) > 0 for c in ax1.collections]) + n_seg_2 = sum([len(c.get_segments()) > 0 for c in ax2.collections]) + assert n_seg_2 > n_seg_1 + + def test_log_scale(self, rng): + + x = rng.lognormal(0, 1, 100) + y = rng.uniform(0, 1, 100) + + levels = .2, .5, 1 + + f, ax = plt.subplots() + kdeplot(x=x, y=y, log_scale=True, levels=levels, ax=ax) + assert ax.get_xscale() == "log" + assert ax.get_yscale() == "log" + + f, (ax1, ax2) = plt.subplots(ncols=2) + kdeplot(x=x, y=y, log_scale=(10, False), levels=levels, ax=ax1) + assert ax1.get_xscale() == "log" + assert ax1.get_yscale() == "linear" + + p = _DistributionPlotter() + kde = KDE() + density, (xx, yy) = kde(np.log10(x), y) + levels = p._quantile_to_level(density, levels) + ax2.contour(10 ** xx, yy, density, levels=levels) + + for c1, c2 in zip(ax1.collections, ax2.collections): + assert_array_equal(c1.get_segments(), c2.get_segments()) + + def test_bandwidth(self, rng): + + n = 100 + x, y = rng.multivariate_normal([0, 0], [(.2, .5), (.5, 2)], n).T + + f, (ax1, ax2) = plt.subplots(ncols=2) + + kdeplot(x=x, y=y, ax=ax1) + kdeplot(x=x, y=y, bw_adjust=2, ax=ax2) + + for c1, c2 in zip(ax1.collections, ax2.collections): + seg1, seg2 = c1.get_segments(), c2.get_segments() + if seg1 + seg2: + x1 = seg1[0][:, 0] + x2 = seg2[0][:, 0] + assert np.abs(x2).max() > np.abs(x1).max() + + def test_weights(self, rng): + + import warnings + warnings.simplefilter("error", np.VisibleDeprecationWarning) + + n = 100 + x, y = rng.multivariate_normal([1, 3], [(.2, .5), (.5, 2)], n).T + hue = np.repeat([0, 1], n // 2) + weights = rng.uniform(0, 1, n) + + f, (ax1, ax2) = plt.subplots(ncols=2) + kdeplot(x=x, y=y, hue=hue, ax=ax1) + kdeplot(x=x, y=y, hue=hue, weights=weights, ax=ax2) + + for c1, c2 in zip(ax1.collections, ax2.collections): + if c1.get_segments() and c2.get_segments(): + seg1 = np.concatenate(c1.get_segments(), axis=0) + seg2 = np.concatenate(c2.get_segments(), axis=0) + assert not np.array_equal(seg1, seg2) + + def test_hue_ignores_cmap(self, long_df): + + with pytest.warns(UserWarning, match="cmap parameter ignored"): + ax = kdeplot(data=long_df, x="x", y="y", hue="c", cmap="viridis") + + assert_colors_equal(ax.collections[0].get_color(), "C0") + + def test_contour_line_colors(self, long_df): + + color = (.2, .9, .8, 1) + ax = kdeplot(data=long_df, x="x", y="y", color=color) + + for c in ax.collections: + assert_colors_equal(c.get_color(), color) + + def test_contour_fill_colors(self, long_df): + + n = 6 + color = (.2, .9, .8, 1) + ax = kdeplot( + data=long_df, x="x", y="y", fill=True, color=color, levels=n, + ) + + cmap = light_palette(color, reverse=True, as_cmap=True) + lut = cmap(np.linspace(0, 1, 256)) + for c in ax.collections: + color = c.get_facecolor().squeeze() + assert color in lut + + def test_colorbar(self, long_df): + + ax = kdeplot(data=long_df, x="x", y="y", fill=True, cbar=True) + assert len(ax.figure.axes) == 2 + + def test_levels_and_thresh(self, long_df): + + f, (ax1, ax2) = plt.subplots(ncols=2) + + n = 8 + thresh = .1 + plot_kws = dict(data=long_df, x="x", y="y") + kdeplot(**plot_kws, levels=n, thresh=thresh, ax=ax1) + kdeplot(**plot_kws, levels=np.linspace(thresh, 1, n), ax=ax2) + + for c1, c2 in zip(ax1.collections, ax2.collections): + assert_array_equal(c1.get_segments(), c2.get_segments()) + + with pytest.raises(ValueError): + kdeplot(**plot_kws, levels=[0, 1, 2]) + + ax1.clear() + ax2.clear() + + kdeplot(**plot_kws, levels=n, thresh=None, ax=ax1) + kdeplot(**plot_kws, levels=n, thresh=0, ax=ax2) + + for c1, c2 in zip(ax1.collections, ax2.collections): + assert_array_equal(c1.get_segments(), c2.get_segments()) + for c1, c2 in zip(ax1.collections, ax2.collections): + assert_array_equal(c1.get_facecolors(), c2.get_facecolors()) + + def test_quantile_to_level(self, rng): + + x = rng.uniform(0, 1, 100000) + isoprop = np.linspace(.1, 1, 6) + + levels = _DistributionPlotter()._quantile_to_level(x, isoprop) + for h, p in zip(levels, isoprop): + assert (x[x <= h].sum() / x.sum()) == pytest.approx(p, abs=1e-4) + + def test_input_checking(self, long_df): + + with pytest.raises(TypeError, match="The x variable is categorical,"): + kdeplot(data=long_df, x="a", y="y") + + +class TestHistPlotUnivariate(SharedAxesLevelTests): + + func = staticmethod(histplot) + + def get_last_color(self, ax, element="bars", fill=True): + + if element == "bars": + if fill: + return ax.patches[-1].get_facecolor() + else: + return ax.patches[-1].get_edgecolor() + else: + if fill: + artist = ax.collections[-1] + facecolor = artist.get_facecolor() + edgecolor = artist.get_edgecolor() + assert_colors_equal(facecolor, edgecolor, check_alpha=False) + return facecolor + else: + return ax.lines[-1].get_color() + + @pytest.mark.parametrize( + "element,fill", + itertools.product(["bars", "step", "poly"], [True, False]), + ) + def test_color(self, long_df, element, fill): + + super().test_color(long_df, element=element, fill=fill) + + @pytest.mark.parametrize( + "variable", ["x", "y"], + ) + def test_long_vectors(self, long_df, variable): + + vector = long_df[variable] + vectors = [ + variable, vector, vector.to_numpy(), vector.to_list(), + ] + + f, axs = plt.subplots(3) + for vector, ax in zip(vectors, axs): + histplot(data=long_df, ax=ax, **{variable: vector}) + + bars = [ax.patches for ax in axs] + for a_bars, b_bars in itertools.product(bars, bars): + for a, b in zip(a_bars, b_bars): + assert_array_equal(a.get_height(), b.get_height()) + assert_array_equal(a.get_xy(), b.get_xy()) + + def test_wide_vs_long_data(self, wide_df): + + f, (ax1, ax2) = plt.subplots(2) + + histplot(data=wide_df, ax=ax1, common_bins=False) + + for col in wide_df.columns[::-1]: + histplot(data=wide_df, x=col, ax=ax2) + + for a, b in zip(ax1.patches, ax2.patches): + assert a.get_height() == b.get_height() + assert a.get_xy() == b.get_xy() + + def test_flat_vector(self, long_df): + + f, (ax1, ax2) = plt.subplots(2) + + histplot(data=long_df["x"], ax=ax1) + histplot(data=long_df, x="x", ax=ax2) + + for a, b in zip(ax1.patches, ax2.patches): + assert a.get_height() == b.get_height() + assert a.get_xy() == b.get_xy() + + def test_empty_data(self): + + ax = histplot(x=[]) + assert not ax.patches + + def test_variable_assignment(self, long_df): + + f, (ax1, ax2) = plt.subplots(2) + + histplot(data=long_df, x="x", ax=ax1) + histplot(data=long_df, y="x", ax=ax2) + + for a, b in zip(ax1.patches, ax2.patches): + assert a.get_height() == b.get_width() + + @pytest.mark.parametrize("element", ["bars", "step", "poly"]) + @pytest.mark.parametrize("multiple", ["layer", "dodge", "stack", "fill"]) + def test_hue_fill_colors(self, long_df, multiple, element): + + ax = histplot( + data=long_df, x="x", hue="a", + multiple=multiple, bins=1, + fill=True, element=element, legend=False, + ) + + palette = color_palette() + + if multiple == "layer": + if element == "bars": + a = .5 + else: + a = .25 + else: + a = .75 + + for bar, color in zip(ax.patches[::-1], palette): + assert_colors_equal(bar.get_facecolor(), to_rgba(color, a)) + + for poly, color in zip(ax.collections[::-1], palette): + assert_colors_equal(poly.get_facecolor(), to_rgba(color, a)) + + def test_hue_stack(self, long_df): + + f, (ax1, ax2) = plt.subplots(2) + + n = 10 + + kws = dict(data=long_df, x="x", hue="a", bins=n, element="bars") + + histplot(**kws, multiple="layer", ax=ax1) + histplot(**kws, multiple="stack", ax=ax2) + + layer_heights = np.reshape([b.get_height() for b in ax1.patches], (-1, n)) + stack_heights = np.reshape([b.get_height() for b in ax2.patches], (-1, n)) + assert_array_equal(layer_heights, stack_heights) + + stack_xys = np.reshape([b.get_xy() for b in ax2.patches], (-1, n, 2)) + assert_array_equal( + stack_xys[..., 1] + stack_heights, + stack_heights.cumsum(axis=0), + ) + + def test_hue_fill(self, long_df): + + f, (ax1, ax2) = plt.subplots(2) + + n = 10 + + kws = dict(data=long_df, x="x", hue="a", bins=n, element="bars") + + histplot(**kws, multiple="layer", ax=ax1) + histplot(**kws, multiple="fill", ax=ax2) + + layer_heights = np.reshape([b.get_height() for b in ax1.patches], (-1, n)) + stack_heights = np.reshape([b.get_height() for b in ax2.patches], (-1, n)) + assert_array_almost_equal( + layer_heights / layer_heights.sum(axis=0), stack_heights + ) + + stack_xys = np.reshape([b.get_xy() for b in ax2.patches], (-1, n, 2)) + assert_array_almost_equal( + (stack_xys[..., 1] + stack_heights) / stack_heights.sum(axis=0), + stack_heights.cumsum(axis=0), + ) + + def test_hue_dodge(self, long_df): + + f, (ax1, ax2) = plt.subplots(2) + + bw = 2 + + kws = dict(data=long_df, x="x", hue="c", binwidth=bw, element="bars") + + histplot(**kws, multiple="layer", ax=ax1) + histplot(**kws, multiple="dodge", ax=ax2) + + layer_heights = [b.get_height() for b in ax1.patches] + dodge_heights = [b.get_height() for b in ax2.patches] + assert_array_equal(layer_heights, dodge_heights) + + layer_xs = np.reshape([b.get_x() for b in ax1.patches], (2, -1)) + dodge_xs = np.reshape([b.get_x() for b in ax2.patches], (2, -1)) + assert_array_almost_equal(layer_xs[1], dodge_xs[1]) + assert_array_almost_equal(layer_xs[0], dodge_xs[0] - bw / 2) + + def test_hue_as_numpy_dodged(self, long_df): + # https://github.com/mwaskom/seaborn/issues/2452 + + ax = histplot( + long_df, + x="y", hue=long_df["a"].to_numpy(), + multiple="dodge", bins=1, + ) + # Note hue order reversal + assert ax.patches[1].get_x() < ax.patches[0].get_x() + + def test_multiple_input_check(self, flat_series): + + with pytest.raises(ValueError, match="`multiple` must be"): + histplot(flat_series, multiple="invalid") + + def test_element_input_check(self, flat_series): + + with pytest.raises(ValueError, match="`element` must be"): + histplot(flat_series, element="invalid") + + def test_count_stat(self, flat_series): + + ax = histplot(flat_series, stat="count") + bar_heights = [b.get_height() for b in ax.patches] + assert sum(bar_heights) == len(flat_series) + + def test_density_stat(self, flat_series): + + ax = histplot(flat_series, stat="density") + bar_heights = [b.get_height() for b in ax.patches] + bar_widths = [b.get_width() for b in ax.patches] + assert np.multiply(bar_heights, bar_widths).sum() == pytest.approx(1) + + def test_density_stat_common_norm(self, long_df): + + ax = histplot( + data=long_df, x="x", hue="a", + stat="density", common_norm=True, element="bars", + ) + bar_heights = [b.get_height() for b in ax.patches] + bar_widths = [b.get_width() for b in ax.patches] + assert np.multiply(bar_heights, bar_widths).sum() == pytest.approx(1) + + def test_density_stat_unique_norm(self, long_df): + + n = 10 + ax = histplot( + data=long_df, x="x", hue="a", + stat="density", bins=n, common_norm=False, element="bars", + ) + + bar_groups = ax.patches[:n], ax.patches[-n:] + + for bars in bar_groups: + bar_heights = [b.get_height() for b in bars] + bar_widths = [b.get_width() for b in bars] + bar_areas = np.multiply(bar_heights, bar_widths) + assert bar_areas.sum() == pytest.approx(1) + + def test_probability_stat(self, flat_series): + + ax = histplot(flat_series, stat="probability") + bar_heights = [b.get_height() for b in ax.patches] + assert sum(bar_heights) == pytest.approx(1) + + def test_probability_stat_common_norm(self, long_df): + + ax = histplot( + data=long_df, x="x", hue="a", + stat="probability", common_norm=True, element="bars", + ) + bar_heights = [b.get_height() for b in ax.patches] + assert sum(bar_heights) == pytest.approx(1) + + def test_probability_stat_unique_norm(self, long_df): + + n = 10 + ax = histplot( + data=long_df, x="x", hue="a", + stat="probability", bins=n, common_norm=False, element="bars", + ) + + bar_groups = ax.patches[:n], ax.patches[-n:] + + for bars in bar_groups: + bar_heights = [b.get_height() for b in bars] + assert sum(bar_heights) == pytest.approx(1) + + def test_percent_stat(self, flat_series): + + ax = histplot(flat_series, stat="percent") + bar_heights = [b.get_height() for b in ax.patches] + assert sum(bar_heights) == 100 + + def test_common_bins(self, long_df): + + n = 10 + ax = histplot( + long_df, x="x", hue="a", common_bins=True, bins=n, element="bars", + ) + + bar_groups = ax.patches[:n], ax.patches[-n:] + assert_array_equal( + [b.get_xy() for b in bar_groups[0]], + [b.get_xy() for b in bar_groups[1]] + ) + + def test_unique_bins(self, wide_df): + + ax = histplot(wide_df, common_bins=False, bins=10, element="bars") + + bar_groups = np.split(np.array(ax.patches), len(wide_df.columns)) + + for i, col in enumerate(wide_df.columns[::-1]): + bars = bar_groups[i] + start = bars[0].get_x() + stop = bars[-1].get_x() + bars[-1].get_width() + assert start == wide_df[col].min() + assert stop == wide_df[col].max() + + def test_weights_with_missing(self, missing_df): + + ax = histplot(missing_df, x="x", weights="s", bins=5) + + bar_heights = [bar.get_height() for bar in ax.patches] + total_weight = missing_df[["x", "s"]].dropna()["s"].sum() + assert sum(bar_heights) == pytest.approx(total_weight) + + def test_discrete(self, long_df): + + ax = histplot(long_df, x="s", discrete=True) + + data_min = long_df["s"].min() + data_max = long_df["s"].max() + assert len(ax.patches) == (data_max - data_min + 1) + + for i, bar in enumerate(ax.patches): + assert bar.get_width() == 1 + assert bar.get_x() == (data_min + i - .5) + + def test_discrete_categorical_default(self, long_df): + + ax = histplot(long_df, x="a") + for i, bar in enumerate(ax.patches): + assert bar.get_width() == 1 + + def test_categorical_yaxis_inversion(self, long_df): + + ax = histplot(long_df, y="a") + ymax, ymin = ax.get_ylim() + assert ymax > ymin + + def test_discrete_requires_bars(self, long_df): + + with pytest.raises(ValueError, match="`element` must be 'bars'"): + histplot(long_df, x="s", discrete=True, element="poly") + + @pytest.mark.skipif( + LooseVersion(np.__version__) < "1.17", + reason="Histogram over datetime64 requires numpy >= 1.17", + ) + def test_datetime_scale(self, long_df): + + f, (ax1, ax2) = plt.subplots(2) + histplot(x=long_df["t"], fill=True, ax=ax1) + histplot(x=long_df["t"], fill=False, ax=ax2) + assert ax1.get_xlim() == ax2.get_xlim() + + @pytest.mark.parametrize("stat", ["count", "density", "probability"]) + def test_kde(self, flat_series, stat): + + ax = histplot( + flat_series, kde=True, stat=stat, kde_kws={"cut": 10} + ) + + bar_widths = [b.get_width() for b in ax.patches] + bar_heights = [b.get_height() for b in ax.patches] + hist_area = np.multiply(bar_widths, bar_heights).sum() + + density, = ax.lines + kde_area = integrate(density.get_ydata(), density.get_xdata()) + + assert kde_area == pytest.approx(hist_area) + + @pytest.mark.parametrize("multiple", ["layer", "dodge"]) + @pytest.mark.parametrize("stat", ["count", "density", "probability"]) + def test_kde_with_hue(self, long_df, stat, multiple): + + n = 10 + ax = histplot( + long_df, x="x", hue="c", multiple=multiple, + kde=True, stat=stat, element="bars", + kde_kws={"cut": 10}, bins=n, + ) + + bar_groups = ax.patches[:n], ax.patches[-n:] + + for i, bars in enumerate(bar_groups): + bar_widths = [b.get_width() for b in bars] + bar_heights = [b.get_height() for b in bars] + hist_area = np.multiply(bar_widths, bar_heights).sum() + + x, y = ax.lines[i].get_xydata().T + kde_area = integrate(y, x) + + if multiple == "layer": + assert kde_area == pytest.approx(hist_area) + elif multiple == "dodge": + assert kde_area == pytest.approx(hist_area * 2) + + def test_kde_default_cut(self, flat_series): + + ax = histplot(flat_series, kde=True) + support = ax.lines[0].get_xdata() + assert support.min() == flat_series.min() + assert support.max() == flat_series.max() + + def test_kde_hue(self, long_df): + + n = 10 + ax = histplot(data=long_df, x="x", hue="a", kde=True, bins=n) + + for bar, line in zip(ax.patches[::n], ax.lines): + assert_colors_equal( + bar.get_facecolor(), line.get_color(), check_alpha=False + ) + + def test_kde_yaxis(self, flat_series): + + f, ax = plt.subplots() + histplot(x=flat_series, kde=True) + histplot(y=flat_series, kde=True) + + x, y = ax.lines + assert_array_equal(x.get_xdata(), y.get_ydata()) + assert_array_equal(x.get_ydata(), y.get_xdata()) + + def test_kde_line_kws(self, flat_series): + + lw = 5 + ax = histplot(flat_series, kde=True, line_kws=dict(lw=lw)) + assert ax.lines[0].get_linewidth() == lw + + def test_kde_singular_data(self): + + with pytest.warns(UserWarning): + ax = histplot(x=np.ones(10), kde=True) + assert not ax.lines + + with pytest.warns(UserWarning): + ax = histplot(x=[5], kde=True) + assert not ax.lines + + def test_element_default(self, long_df): + + f, (ax1, ax2) = plt.subplots(2) + histplot(long_df, x="x", ax=ax1) + histplot(long_df, x="x", ax=ax2, element="bars") + assert len(ax1.patches) == len(ax2.patches) + + f, (ax1, ax2) = plt.subplots(2) + histplot(long_df, x="x", hue="a", ax=ax1) + histplot(long_df, x="x", hue="a", ax=ax2, element="bars") + assert len(ax1.patches) == len(ax2.patches) + + def test_bars_no_fill(self, flat_series): + + alpha = .5 + ax = histplot(flat_series, element="bars", fill=False, alpha=alpha) + for bar in ax.patches: + assert bar.get_facecolor() == (0, 0, 0, 0) + assert bar.get_edgecolor()[-1] == alpha + + def test_step_fill(self, flat_series): + + f, (ax1, ax2) = plt.subplots(2) + + n = 10 + histplot(flat_series, element="bars", fill=True, bins=n, ax=ax1) + histplot(flat_series, element="step", fill=True, bins=n, ax=ax2) + + bar_heights = [b.get_height() for b in ax1.patches] + bar_widths = [b.get_width() for b in ax1.patches] + bar_edges = [b.get_x() for b in ax1.patches] + + fill = ax2.collections[0] + x, y = fill.get_paths()[0].vertices[::-1].T + + assert_array_equal(x[1:2 * n:2], bar_edges) + assert_array_equal(y[1:2 * n:2], bar_heights) + + assert x[n * 2] == bar_edges[-1] + bar_widths[-1] + assert y[n * 2] == bar_heights[-1] + + def test_poly_fill(self, flat_series): + + f, (ax1, ax2) = plt.subplots(2) + + n = 10 + histplot(flat_series, element="bars", fill=True, bins=n, ax=ax1) + histplot(flat_series, element="poly", fill=True, bins=n, ax=ax2) + + bar_heights = np.array([b.get_height() for b in ax1.patches]) + bar_widths = np.array([b.get_width() for b in ax1.patches]) + bar_edges = np.array([b.get_x() for b in ax1.patches]) + + fill = ax2.collections[0] + x, y = fill.get_paths()[0].vertices[::-1].T + + assert_array_equal(x[1:n + 1], bar_edges + bar_widths / 2) + assert_array_equal(y[1:n + 1], bar_heights) + + def test_poly_no_fill(self, flat_series): + + f, (ax1, ax2) = plt.subplots(2) + + n = 10 + histplot(flat_series, element="bars", fill=False, bins=n, ax=ax1) + histplot(flat_series, element="poly", fill=False, bins=n, ax=ax2) + + bar_heights = np.array([b.get_height() for b in ax1.patches]) + bar_widths = np.array([b.get_width() for b in ax1.patches]) + bar_edges = np.array([b.get_x() for b in ax1.patches]) + + x, y = ax2.lines[0].get_xydata().T + + assert_array_equal(x, bar_edges + bar_widths / 2) + assert_array_equal(y, bar_heights) + + def test_step_no_fill(self, flat_series): + + f, (ax1, ax2) = plt.subplots(2) + + histplot(flat_series, element="bars", fill=False, ax=ax1) + histplot(flat_series, element="step", fill=False, ax=ax2) + + bar_heights = [b.get_height() for b in ax1.patches] + bar_widths = [b.get_width() for b in ax1.patches] + bar_edges = [b.get_x() for b in ax1.patches] + + x, y = ax2.lines[0].get_xydata().T + + assert_array_equal(x[:-1], bar_edges) + assert_array_equal(y[:-1], bar_heights) + assert x[-1] == bar_edges[-1] + bar_widths[-1] + assert y[-1] == y[-2] + + def test_step_fill_xy(self, flat_series): + + f, ax = plt.subplots() + + histplot(x=flat_series, element="step", fill=True) + histplot(y=flat_series, element="step", fill=True) + + xverts = ax.collections[0].get_paths()[0].vertices + yverts = ax.collections[1].get_paths()[0].vertices + + assert_array_equal(xverts, yverts[:, ::-1]) + + def test_step_no_fill_xy(self, flat_series): + + f, ax = plt.subplots() + + histplot(x=flat_series, element="step", fill=False) + histplot(y=flat_series, element="step", fill=False) + + xline, yline = ax.lines + + assert_array_equal(xline.get_xdata(), yline.get_ydata()) + assert_array_equal(xline.get_ydata(), yline.get_xdata()) + + def test_weighted_histogram(self): + + ax = histplot(x=[0, 1, 2], weights=[1, 2, 3], discrete=True) + + bar_heights = [b.get_height() for b in ax.patches] + assert bar_heights == [1, 2, 3] + + def test_weights_with_auto_bins(self, long_df): + + with pytest.warns(UserWarning): + ax = histplot(long_df, x="x", weights="f") + assert len(ax.patches) == 10 + + def test_shrink(self, long_df): + + f, (ax1, ax2) = plt.subplots(2) + + bw = 2 + shrink = .4 + + histplot(long_df, x="x", binwidth=bw, ax=ax1) + histplot(long_df, x="x", binwidth=bw, shrink=shrink, ax=ax2) + + for p1, p2 in zip(ax1.patches, ax2.patches): + + w1, w2 = p1.get_width(), p2.get_width() + assert w2 == pytest.approx(shrink * w1) + + x1, x2 = p1.get_x(), p2.get_x() + assert (x2 + w2 / 2) == pytest.approx(x1 + w1 / 2) + + def test_log_scale_explicit(self, rng): + + x = rng.lognormal(0, 2, 1000) + ax = histplot(x, log_scale=True, binwidth=1) + + bar_widths = [b.get_width() for b in ax.patches] + steps = np.divide(bar_widths[1:], bar_widths[:-1]) + assert np.allclose(steps, 10) + + def test_log_scale_implicit(self, rng): + + x = rng.lognormal(0, 2, 1000) + + f, ax = plt.subplots() + ax.set_xscale("log") + histplot(x, binwidth=1, ax=ax) + + bar_widths = [b.get_width() for b in ax.patches] + steps = np.divide(bar_widths[1:], bar_widths[:-1]) + assert np.allclose(steps, 10) + + @pytest.mark.parametrize( + "fill", [True, False], + ) + def test_auto_linewidth(self, flat_series, fill): + + get_lw = lambda ax: ax.patches[0].get_linewidth() # noqa: E731 + + kws = dict(element="bars", fill=fill) + + f, (ax1, ax2) = plt.subplots(2) + histplot(flat_series, **kws, bins=10, ax=ax1) + histplot(flat_series, **kws, bins=100, ax=ax2) + assert get_lw(ax1) > get_lw(ax2) + + f, ax1 = plt.subplots(figsize=(10, 5)) + f, ax2 = plt.subplots(figsize=(2, 5)) + histplot(flat_series, **kws, bins=30, ax=ax1) + histplot(flat_series, **kws, bins=30, ax=ax2) + assert get_lw(ax1) > get_lw(ax2) + + f, ax1 = plt.subplots(figsize=(4, 5)) + f, ax2 = plt.subplots(figsize=(4, 5)) + histplot(flat_series, **kws, bins=30, ax=ax1) + histplot(10 ** flat_series, **kws, bins=30, log_scale=True, ax=ax2) + assert get_lw(ax1) == pytest.approx(get_lw(ax2)) + + f, ax1 = plt.subplots(figsize=(4, 5)) + f, ax2 = plt.subplots(figsize=(4, 5)) + histplot(y=[0, 1, 1], **kws, discrete=True, ax=ax1) + histplot(y=["a", "b", "b"], **kws, ax=ax2) + assert get_lw(ax1) == pytest.approx(get_lw(ax2)) + + def test_bar_kwargs(self, flat_series): + + lw = 2 + ec = (1, .2, .9, .5) + ax = histplot(flat_series, binwidth=1, ec=ec, lw=lw) + for bar in ax.patches: + assert_colors_equal(bar.get_edgecolor(), ec) + assert bar.get_linewidth() == lw + + def test_step_fill_kwargs(self, flat_series): + + lw = 2 + ec = (1, .2, .9, .5) + ax = histplot(flat_series, element="step", ec=ec, lw=lw) + poly = ax.collections[0] + assert_colors_equal(poly.get_edgecolor(), ec) + assert poly.get_linewidth() == lw + + def test_step_line_kwargs(self, flat_series): + + lw = 2 + ls = "--" + ax = histplot(flat_series, element="step", fill=False, lw=lw, ls=ls) + line = ax.lines[0] + assert line.get_linewidth() == lw + assert line.get_linestyle() == ls + + +class TestHistPlotBivariate: + + def test_mesh(self, long_df): + + hist = Histogram() + counts, (x_edges, y_edges) = hist(long_df["x"], long_df["y"]) + + ax = histplot(long_df, x="x", y="y") + mesh = ax.collections[0] + mesh_data = mesh.get_array() + + assert_array_equal(mesh_data.data, counts.T.flat) + assert_array_equal(mesh_data.mask, counts.T.flat == 0) + + edges = itertools.product(y_edges[:-1], x_edges[:-1]) + for i, (y, x) in enumerate(edges): + path = mesh.get_paths()[i] + assert path.vertices[0, 0] == x + assert path.vertices[0, 1] == y + + def test_mesh_with_hue(self, long_df): + + ax = histplot(long_df, x="x", y="y", hue="c") + + hist = Histogram() + hist.define_bin_edges(long_df["x"], long_df["y"]) + + for i, sub_df in long_df.groupby("c"): + + mesh = ax.collections[i] + mesh_data = mesh.get_array() + + counts, (x_edges, y_edges) = hist(sub_df["x"], sub_df["y"]) + + assert_array_equal(mesh_data.data, counts.T.flat) + assert_array_equal(mesh_data.mask, counts.T.flat == 0) + + edges = itertools.product(y_edges[:-1], x_edges[:-1]) + for i, (y, x) in enumerate(edges): + path = mesh.get_paths()[i] + assert path.vertices[0, 0] == x + assert path.vertices[0, 1] == y + + def test_mesh_with_hue_unique_bins(self, long_df): + + ax = histplot(long_df, x="x", y="y", hue="c", common_bins=False) + + for i, sub_df in long_df.groupby("c"): + + hist = Histogram() + + mesh = ax.collections[i] + mesh_data = mesh.get_array() + + counts, (x_edges, y_edges) = hist(sub_df["x"], sub_df["y"]) + + assert_array_equal(mesh_data.data, counts.T.flat) + assert_array_equal(mesh_data.mask, counts.T.flat == 0) + + edges = itertools.product(y_edges[:-1], x_edges[:-1]) + for i, (y, x) in enumerate(edges): + path = mesh.get_paths()[i] + assert path.vertices[0, 0] == x + assert path.vertices[0, 1] == y + + def test_mesh_log_scale(self, rng): + + x, y = rng.lognormal(0, 1, (2, 1000)) + hist = Histogram() + counts, (x_edges, y_edges) = hist(np.log10(x), np.log10(y)) + + ax = histplot(x=x, y=y, log_scale=True) + mesh = ax.collections[0] + mesh_data = mesh.get_array() + + assert_array_equal(mesh_data.data, counts.T.flat) + + edges = itertools.product(y_edges[:-1], x_edges[:-1]) + for i, (y_i, x_i) in enumerate(edges): + path = mesh.get_paths()[i] + assert path.vertices[0, 0] == 10 ** x_i + assert path.vertices[0, 1] == 10 ** y_i + + def test_mesh_thresh(self, long_df): + + hist = Histogram() + counts, (x_edges, y_edges) = hist(long_df["x"], long_df["y"]) + + thresh = 5 + ax = histplot(long_df, x="x", y="y", thresh=thresh) + mesh = ax.collections[0] + mesh_data = mesh.get_array() + + assert_array_equal(mesh_data.data, counts.T.flat) + assert_array_equal(mesh_data.mask, (counts <= thresh).T.flat) + + def test_mesh_sticky_edges(self, long_df): + + ax = histplot(long_df, x="x", y="y", thresh=None) + mesh = ax.collections[0] + assert mesh.sticky_edges.x == [long_df["x"].min(), long_df["x"].max()] + assert mesh.sticky_edges.y == [long_df["y"].min(), long_df["y"].max()] + + ax.clear() + ax = histplot(long_df, x="x", y="y") + mesh = ax.collections[0] + assert not mesh.sticky_edges.x + assert not mesh.sticky_edges.y + + def test_mesh_common_norm(self, long_df): + + stat = "density" + ax = histplot( + long_df, x="x", y="y", hue="c", common_norm=True, stat=stat, + ) + + hist = Histogram(stat="density") + hist.define_bin_edges(long_df["x"], long_df["y"]) + + for i, sub_df in long_df.groupby("c"): + + mesh = ax.collections[i] + mesh_data = mesh.get_array() + + density, (x_edges, y_edges) = hist(sub_df["x"], sub_df["y"]) + + scale = len(sub_df) / len(long_df) + assert_array_equal(mesh_data.data, (density * scale).T.flat) + + def test_mesh_unique_norm(self, long_df): + + stat = "density" + ax = histplot( + long_df, x="x", y="y", hue="c", common_norm=False, stat=stat, + ) + + hist = Histogram() + hist.define_bin_edges(long_df["x"], long_df["y"]) + + for i, sub_df in long_df.groupby("c"): + + sub_hist = Histogram(bins=hist.bin_edges, stat=stat) + + mesh = ax.collections[i] + mesh_data = mesh.get_array() + + density, (x_edges, y_edges) = sub_hist(sub_df["x"], sub_df["y"]) + assert_array_equal(mesh_data.data, density.T.flat) + + @pytest.mark.parametrize("stat", ["probability", "percent"]) + def test_mesh_normalization(self, long_df, stat): + + ax = histplot( + long_df, x="x", y="y", stat=stat, + ) + + mesh_data = ax.collections[0].get_array() + expected_sum = {"probability": 1, "percent": 100}[stat] + assert mesh_data.data.sum() == expected_sum + + def test_mesh_colors(self, long_df): + + color = "r" + f, ax = plt.subplots() + histplot( + long_df, x="x", y="y", color=color, + ) + mesh = ax.collections[0] + assert_array_equal( + mesh.get_cmap().colors, + _DistributionPlotter()._cmap_from_color(color).colors, + ) + + f, ax = plt.subplots() + histplot( + long_df, x="x", y="y", hue="c", + ) + colors = color_palette() + for i, mesh in enumerate(ax.collections): + assert_array_equal( + mesh.get_cmap().colors, + _DistributionPlotter()._cmap_from_color(colors[i]).colors, + ) + + def test_color_limits(self, long_df): + + f, (ax1, ax2, ax3) = plt.subplots(3) + kws = dict(data=long_df, x="x", y="y") + hist = Histogram() + counts, _ = hist(long_df["x"], long_df["y"]) + + histplot(**kws, ax=ax1) + assert ax1.collections[0].get_clim() == (0, counts.max()) + + vmax = 10 + histplot(**kws, vmax=vmax, ax=ax2) + counts, _ = hist(long_df["x"], long_df["y"]) + assert ax2.collections[0].get_clim() == (0, vmax) + + pmax = .8 + pthresh = .1 + f = _DistributionPlotter()._quantile_to_level + + histplot(**kws, pmax=pmax, pthresh=pthresh, ax=ax3) + counts, _ = hist(long_df["x"], long_df["y"]) + mesh = ax3.collections[0] + assert mesh.get_clim() == (0, f(counts, pmax)) + assert_array_equal( + mesh.get_array().mask, + (counts <= f(counts, pthresh)).T.flat, + ) + + def test_hue_color_limits(self, long_df): + + _, (ax1, ax2, ax3, ax4) = plt.subplots(4) + kws = dict(data=long_df, x="x", y="y", hue="c", bins=4) + + hist = Histogram(bins=kws["bins"]) + hist.define_bin_edges(long_df["x"], long_df["y"]) + full_counts, _ = hist(long_df["x"], long_df["y"]) + + sub_counts = [] + for _, sub_df in long_df.groupby(kws["hue"]): + c, _ = hist(sub_df["x"], sub_df["y"]) + sub_counts.append(c) + + pmax = .8 + pthresh = .05 + f = _DistributionPlotter()._quantile_to_level + + histplot(**kws, common_norm=True, ax=ax1) + for i, mesh in enumerate(ax1.collections): + assert mesh.get_clim() == (0, full_counts.max()) + + histplot(**kws, common_norm=False, ax=ax2) + for i, mesh in enumerate(ax2.collections): + assert mesh.get_clim() == (0, sub_counts[i].max()) + + histplot(**kws, common_norm=True, pmax=pmax, pthresh=pthresh, ax=ax3) + for i, mesh in enumerate(ax3.collections): + assert mesh.get_clim() == (0, f(full_counts, pmax)) + assert_array_equal( + mesh.get_array().mask, + (sub_counts[i] <= f(full_counts, pthresh)).T.flat, + ) + + histplot(**kws, common_norm=False, pmax=pmax, pthresh=pthresh, ax=ax4) + for i, mesh in enumerate(ax4.collections): + assert mesh.get_clim() == (0, f(sub_counts[i], pmax)) + assert_array_equal( + mesh.get_array().mask, + (sub_counts[i] <= f(sub_counts[i], pthresh)).T.flat, + ) + + def test_colorbar(self, long_df): + + f, ax = plt.subplots() + histplot(long_df, x="x", y="y", cbar=True, ax=ax) + assert len(ax.figure.axes) == 2 + + f, (ax, cax) = plt.subplots(2) + histplot(long_df, x="x", y="y", cbar=True, cbar_ax=cax, ax=ax) + assert len(ax.figure.axes) == 2 + + +class TestECDFPlotUnivariate(SharedAxesLevelTests): + + func = staticmethod(ecdfplot) + + def get_last_color(self, ax): + + return to_rgb(ax.lines[-1].get_color()) + + @pytest.mark.parametrize("variable", ["x", "y"]) + def test_long_vectors(self, long_df, variable): + + vector = long_df[variable] + vectors = [ + variable, vector, vector.to_numpy(), vector.to_list(), + ] + + f, ax = plt.subplots() + for vector in vectors: + ecdfplot(data=long_df, ax=ax, **{variable: vector}) + + xdata = [l.get_xdata() for l in ax.lines] + for a, b in itertools.product(xdata, xdata): + assert_array_equal(a, b) + + ydata = [l.get_ydata() for l in ax.lines] + for a, b in itertools.product(ydata, ydata): + assert_array_equal(a, b) + + def test_hue(self, long_df): + + ax = ecdfplot(long_df, x="x", hue="a") + + for line, color in zip(ax.lines[::-1], color_palette()): + assert_colors_equal(line.get_color(), color) + + def test_line_kwargs(self, long_df): + + color = "r" + ls = "--" + lw = 3 + ax = ecdfplot(long_df, x="x", color=color, ls=ls, lw=lw) + + for line in ax.lines: + assert_colors_equal(line.get_color(), color) + assert line.get_linestyle() == ls + assert line.get_linewidth() == lw + + @pytest.mark.parametrize("data_var", ["x", "y"]) + def test_drawstyle(self, flat_series, data_var): + + ax = ecdfplot(**{data_var: flat_series}) + drawstyles = dict(x="steps-post", y="steps-pre") + assert ax.lines[0].get_drawstyle() == drawstyles[data_var] + + @pytest.mark.parametrize( + "data_var,stat_var", [["x", "y"], ["y", "x"]], + ) + def test_proportion_limits(self, flat_series, data_var, stat_var): + + ax = ecdfplot(**{data_var: flat_series}) + data = getattr(ax.lines[0], f"get_{stat_var}data")() + assert data[0] == 0 + assert data[-1] == 1 + sticky_edges = getattr(ax.lines[0].sticky_edges, stat_var) + assert sticky_edges[:] == [0, 1] + + @pytest.mark.parametrize( + "data_var,stat_var", [["x", "y"], ["y", "x"]], + ) + def test_proportion_limits_complementary(self, flat_series, data_var, stat_var): + + ax = ecdfplot(**{data_var: flat_series}, complementary=True) + data = getattr(ax.lines[0], f"get_{stat_var}data")() + assert data[0] == 1 + assert data[-1] == 0 + sticky_edges = getattr(ax.lines[0].sticky_edges, stat_var) + assert sticky_edges[:] == [0, 1] + + @pytest.mark.parametrize( + "data_var,stat_var", [["x", "y"], ["y", "x"]], + ) + def test_proportion_count(self, flat_series, data_var, stat_var): + + n = len(flat_series) + ax = ecdfplot(**{data_var: flat_series}, stat="count") + data = getattr(ax.lines[0], f"get_{stat_var}data")() + assert data[0] == 0 + assert data[-1] == n + sticky_edges = getattr(ax.lines[0].sticky_edges, stat_var) + assert sticky_edges[:] == [0, n] + + def test_weights(self): + + ax = ecdfplot(x=[1, 2, 3], weights=[1, 1, 2]) + y = ax.lines[0].get_ydata() + assert_array_equal(y, [0, .25, .5, 1]) + + def test_bivariate_error(self, long_df): + + with pytest.raises(NotImplementedError, match="Bivariate ECDF plots"): + ecdfplot(data=long_df, x="x", y="y") + + def test_log_scale(self, long_df): + + ax1, ax2 = plt.figure().subplots(2) + + ecdfplot(data=long_df, x="z", ax=ax1) + ecdfplot(data=long_df, x="z", log_scale=True, ax=ax2) + + # Ignore first point, which either -inf (in linear) or 0 (in log) + line1 = ax1.lines[0].get_xydata()[1:] + line2 = ax2.lines[0].get_xydata()[1:] + + assert_array_almost_equal(line1, line2) + + +class TestDisPlot: + + # TODO probably good to move these utility attributes/methods somewhere else + @pytest.mark.parametrize( + "kwargs", [ + dict(), + dict(x="x"), + dict(x="t"), + dict(x="a"), + dict(x="z", log_scale=True), + dict(x="x", binwidth=4), + dict(x="x", weights="f", bins=5), + dict(x="x", color="green", linewidth=2, binwidth=4), + dict(x="x", hue="a", fill=False), + dict(x="y", hue="a", fill=False), + dict(x="x", hue="a", multiple="stack"), + dict(x="x", hue="a", element="step"), + dict(x="x", hue="a", palette="muted"), + dict(x="x", hue="a", kde=True), + dict(x="x", hue="a", stat="density", common_norm=False), + dict(x="x", y="y"), + ], + ) + def test_versus_single_histplot(self, long_df, kwargs): + + ax = histplot(long_df, **kwargs) + g = displot(long_df, **kwargs) + assert_plots_equal(ax, g.ax) + + if ax.legend_ is not None: + assert_legends_equal(ax.legend_, g._legend) + + if kwargs: + long_df["_"] = "_" + g2 = displot(long_df, col="_", **kwargs) + assert_plots_equal(ax, g2.ax) + + @pytest.mark.parametrize( + "kwargs", [ + dict(), + dict(x="x"), + dict(x="t"), + dict(x="z", log_scale=True), + dict(x="x", bw_adjust=.5), + dict(x="x", weights="f"), + dict(x="x", color="green", linewidth=2), + dict(x="x", hue="a", multiple="stack"), + dict(x="x", hue="a", fill=True), + dict(x="y", hue="a", fill=False), + dict(x="x", hue="a", palette="muted"), + dict(x="x", y="y"), + ], + ) + def test_versus_single_kdeplot(self, long_df, kwargs): + + ax = kdeplot(data=long_df, **kwargs) + g = displot(long_df, kind="kde", **kwargs) + assert_plots_equal(ax, g.ax) + + if ax.legend_ is not None: + assert_legends_equal(ax.legend_, g._legend) + + if kwargs: + long_df["_"] = "_" + g2 = displot(long_df, kind="kde", col="_", **kwargs) + assert_plots_equal(ax, g2.ax) + + @pytest.mark.parametrize( + "kwargs", [ + dict(), + dict(x="x"), + dict(x="t"), + dict(x="z", log_scale=True), + dict(x="x", weights="f"), + dict(y="x"), + dict(x="x", color="green", linewidth=2), + dict(x="x", hue="a", complementary=True), + dict(x="x", hue="a", stat="count"), + dict(x="x", hue="a", palette="muted"), + ], + ) + def test_versus_single_ecdfplot(self, long_df, kwargs): + + ax = ecdfplot(data=long_df, **kwargs) + g = displot(long_df, kind="ecdf", **kwargs) + assert_plots_equal(ax, g.ax) + + if ax.legend_ is not None: + assert_legends_equal(ax.legend_, g._legend) + + if kwargs: + long_df["_"] = "_" + g2 = displot(long_df, kind="ecdf", col="_", **kwargs) + assert_plots_equal(ax, g2.ax) + + @pytest.mark.parametrize( + "kwargs", [ + dict(x="x"), + dict(x="x", y="y"), + dict(x="x", hue="a"), + ] + ) + def test_with_rug(self, long_df, kwargs): + + ax = plt.figure().subplots() + histplot(data=long_df, **kwargs, ax=ax) + rugplot(data=long_df, **kwargs, ax=ax) + + g = displot(long_df, rug=True, **kwargs) + + assert_plots_equal(ax, g.ax, labels=False) + + long_df["_"] = "_" + g2 = displot(long_df, col="_", rug=True, **kwargs) + + assert_plots_equal(ax, g2.ax, labels=False) + + @pytest.mark.parametrize( + "facet_var", ["col", "row"], + ) + def test_facets(self, long_df, facet_var): + + kwargs = {facet_var: "a"} + ax = kdeplot(data=long_df, x="x", hue="a") + g = displot(long_df, x="x", kind="kde", **kwargs) + + legend_texts = ax.legend_.get_texts() + + for i, line in enumerate(ax.lines[::-1]): + facet_ax = g.axes.flat[i] + facet_line = facet_ax.lines[0] + assert_array_equal(line.get_xydata(), facet_line.get_xydata()) + + text = legend_texts[i].get_text() + assert text in facet_ax.get_title() + + @pytest.mark.parametrize("multiple", ["dodge", "stack", "fill"]) + def test_facet_multiple(self, long_df, multiple): + + bins = np.linspace(0, 20, 5) + ax = histplot( + data=long_df[long_df["c"] == 0], + x="x", hue="a", hue_order=["a", "b", "c"], + multiple=multiple, bins=bins, + ) + + g = displot( + data=long_df, x="x", hue="a", col="c", hue_order=["a", "b", "c"], + multiple=multiple, bins=bins, + ) + + assert_plots_equal(ax, g.axes_dict[0]) + + def test_ax_warning(self, long_df): + + ax = plt.figure().subplots() + with pytest.warns(UserWarning, match="`displot` is a figure-level"): + displot(long_df, x="x", ax=ax) + + @pytest.mark.parametrize("key", ["col", "row"]) + def test_array_faceting(self, long_df, key): + + a = long_df["a"].to_numpy() + vals = categorical_order(a) + g = displot(long_df, x="x", **{key: a}) + assert len(g.axes.flat) == len(vals) + for ax, val in zip(g.axes.flat, vals): + assert val in ax.get_title() + + def test_legend(self, long_df): + + g = displot(long_df, x="x", hue="a") + assert g._legend is not None + + def test_empty(self): + + g = displot(x=[], y=[]) + assert isinstance(g, FacetGrid) + + def test_bivariate_ecdf_error(self, long_df): + + with pytest.raises(NotImplementedError): + displot(long_df, x="x", y="y", kind="ecdf") + + def test_bivariate_kde_norm(self, rng): + + x, y = rng.normal(0, 1, (2, 100)) + z = [0] * 80 + [1] * 20 + + g = displot(x=x, y=y, col=z, kind="kde", levels=10) + l1 = sum(bool(c.get_segments()) for c in g.axes.flat[0].collections) + l2 = sum(bool(c.get_segments()) for c in g.axes.flat[1].collections) + assert l1 > l2 + + g = displot(x=x, y=y, col=z, kind="kde", levels=10, common_norm=False) + l1 = sum(bool(c.get_segments()) for c in g.axes.flat[0].collections) + l2 = sum(bool(c.get_segments()) for c in g.axes.flat[1].collections) + assert l1 == l2 + + def test_bivariate_hist_norm(self, rng): + + x, y = rng.normal(0, 1, (2, 100)) + z = [0] * 80 + [1] * 20 + + g = displot(x=x, y=y, col=z, kind="hist") + clim1 = g.axes.flat[0].collections[0].get_clim() + clim2 = g.axes.flat[1].collections[0].get_clim() + assert clim1 == clim2 + + g = displot(x=x, y=y, col=z, kind="hist", common_norm=False) + clim1 = g.axes.flat[0].collections[0].get_clim() + clim2 = g.axes.flat[1].collections[0].get_clim() + assert clim1[1] > clim2[1] + + +def integrate(y, x): + """"Simple numerical integration for testing KDE code.""" + y = np.asarray(y) + x = np.asarray(x) + dx = np.diff(x) + return (dx * y[:-1] + dx * y[1:]).sum() / 2 diff --git a/seaborn/tests/test_docstrings.py b/seaborn/tests/test_docstrings.py new file mode 100644 index 0000000000..ae78d9d5fb --- /dev/null +++ b/seaborn/tests/test_docstrings.py @@ -0,0 +1,58 @@ +from .._docstrings import DocstringComponents + + +EXAMPLE_DICT = dict( + param_a=""" +a : str + The first parameter. + """, +) + + +class ExampleClass: + def example_method(self): + """An example method. + + Parameters + ---------- + a : str + A method parameter. + + """ + + +def example_func(): + """An example function. + + Parameters + ---------- + a : str + A function parameter. + + """ + + +class TestDocstringComponents: + + def test_from_dict(self): + + obj = DocstringComponents(EXAMPLE_DICT) + assert obj.param_a == "a : str\n The first parameter." + + def test_from_nested_components(self): + + obj_inner = DocstringComponents(EXAMPLE_DICT) + obj_outer = DocstringComponents.from_nested_components(inner=obj_inner) + assert obj_outer.inner.param_a == "a : str\n The first parameter." + + def test_from_function(self): + + obj = DocstringComponents.from_function_params(example_func) + assert obj.a == "a : str\n A function parameter." + + def test_from_method(self): + + obj = DocstringComponents.from_function_params( + ExampleClass.example_method + ) + assert obj.a == "a : str\n A method parameter." diff --git a/seaborn/tests/test_matrix.py b/seaborn/tests/test_matrix.py index d3cf890615..74a2e2c5af 100644 --- a/seaborn/tests/test_matrix.py +++ b/seaborn/tests/test_matrix.py @@ -1,35 +1,38 @@ -import itertools import tempfile +import copy import numpy as np import matplotlib as mpl import matplotlib.pyplot as plt import pandas as pd -from scipy.spatial import distance -from scipy.cluster import hierarchy -import nose.tools as nt -import numpy.testing as npt try: - import pandas.testing as pdt + from scipy.spatial import distance + from scipy.cluster import hierarchy + _no_scipy = False except ImportError: - import pandas.util.testing as pdt -import pytest - -from .. import matrix as mat -from .. import color_palette -from ..external.six.moves import range + _no_scipy = True try: import fastcluster - assert fastcluster _no_fastcluster = False except ImportError: _no_fastcluster = True +import numpy.testing as npt +try: + import pandas.testing as pdt +except ImportError: + import pandas.util.testing as pdt +import pytest + +from .. import matrix as mat +from .. import color_palette +from .._testing import assert_colors_equal + -class TestHeatmap(object): +class TestHeatmap: rs = np.random.RandomState(sum(map(ord, "heatmap"))) x_norm = rs.randn(4, 8) @@ -52,8 +55,8 @@ def test_ndarray_input(self): npt.assert_array_equal(p.xticklabels, np.arange(8)) npt.assert_array_equal(p.yticklabels, np.arange(4)) - nt.assert_equal(p.xlabel, "") - nt.assert_equal(p.ylabel, "") + assert p.xlabel == "" + assert p.ylabel == "" def test_df_input(self): @@ -64,8 +67,8 @@ def test_df_input(self): npt.assert_array_equal(p.xticklabels, np.arange(8)) npt.assert_array_equal(p.yticklabels, self.letters.values) - nt.assert_equal(p.xlabel, "") - nt.assert_equal(p.ylabel, "letters") + assert p.xlabel == "" + assert p.ylabel == "letters" def test_df_multindex_input(self): @@ -80,28 +83,49 @@ def test_df_multindex_input(self): combined_tick_labels = ["A-1", "B-2", "C-3", "D-4"] npt.assert_array_equal(p.yticklabels, combined_tick_labels) - nt.assert_equal(p.ylabel, "letter-number") + assert p.ylabel == "letter-number" p = mat._HeatMapper(df.T, **self.default_kws) npt.assert_array_equal(p.xticklabels, combined_tick_labels) - nt.assert_equal(p.xlabel, "letter-number") + assert p.xlabel == "letter-number" - def test_mask_input(self): + @pytest.mark.parametrize("dtype", [float, np.int64, object]) + def test_mask_input(self, dtype): kws = self.default_kws.copy() mask = self.x_norm > 0 kws['mask'] = mask - p = mat._HeatMapper(self.x_norm, **kws) - plot_data = np.ma.masked_where(mask, self.x_norm) + data = self.x_norm.astype(dtype) + p = mat._HeatMapper(data, **kws) + plot_data = np.ma.masked_where(mask, data) npt.assert_array_equal(p.plot_data, plot_data) + def test_mask_limits(self): + """Make sure masked cells are not used to calculate extremes""" + + kws = self.default_kws.copy() + + mask = self.x_norm > 0 + kws['mask'] = mask + p = mat._HeatMapper(self.x_norm, **kws) + + assert p.vmax == np.ma.array(self.x_norm, mask=mask).max() + assert p.vmin == np.ma.array(self.x_norm, mask=mask).min() + + mask = self.x_norm < 0 + kws['mask'] = mask + p = mat._HeatMapper(self.x_norm, **kws) + + assert p.vmin == np.ma.array(self.x_norm, mask=mask).min() + assert p.vmax == np.ma.array(self.x_norm, mask=mask).max() + def test_default_vlims(self): p = mat._HeatMapper(self.df_unif, **self.default_kws) - nt.assert_equal(p.vmin, self.x_unif.min()) - nt.assert_equal(p.vmax, self.x_unif.max()) + assert p.vmin == self.x_unif.min() + assert p.vmax == self.x_unif.max() def test_robust_vlims(self): @@ -109,8 +133,8 @@ def test_robust_vlims(self): kws["robust"] = True p = mat._HeatMapper(self.df_unif, **kws) - nt.assert_equal(p.vmin, np.percentile(self.x_unif, 2)) - nt.assert_equal(p.vmax, np.percentile(self.x_unif, 98)) + assert p.vmin == np.percentile(self.x_unif, 2) + assert p.vmax == np.percentile(self.x_unif, 98) def test_custom_sequential_vlims(self): @@ -119,8 +143,8 @@ def test_custom_sequential_vlims(self): kws["vmax"] = 1 p = mat._HeatMapper(self.df_unif, **kws) - nt.assert_equal(p.vmin, 0) - nt.assert_equal(p.vmax, 1) + assert p.vmin == 0 + assert p.vmax == 1 def test_custom_diverging_vlims(self): @@ -130,8 +154,8 @@ def test_custom_diverging_vlims(self): kws["center"] = 0 p = mat._HeatMapper(self.df_norm, **kws) - nt.assert_equal(p.vmin, -4) - nt.assert_equal(p.vmax, 5) + assert p.vmin == -4 + assert p.vmax == 5 def test_array_with_nans(self): @@ -142,8 +166,8 @@ def test_array_with_nans(self): m1 = mat._HeatMapper(x1, **self.default_kws) m2 = mat._HeatMapper(x2, **self.default_kws) - nt.assert_equal(m1.vmin, m2.vmin) - nt.assert_equal(m1.vmax, m2.vmax) + assert m1.vmin == m2.vmin + assert m1.vmax == m2.vmax def test_mask(self): @@ -164,7 +188,7 @@ def test_custom_cmap(self): kws = self.default_kws.copy() kws["cmap"] = "BuGn" p = mat._HeatMapper(self.df_unif, **kws) - nt.assert_equal(p.cmap, mpl.cm.BuGn) + assert p.cmap == mpl.cm.BuGn def test_centered_vlims(self): @@ -173,8 +197,8 @@ def test_centered_vlims(self): p = mat._HeatMapper(self.df_unif, **kws) - nt.assert_equal(p.vmin, self.df_unif.values.min()) - nt.assert_equal(p.vmax, self.df_unif.values.max()) + assert p.vmin == self.df_unif.values.min() + assert p.vmax == self.df_unif.values.max() def test_default_colors(self): @@ -201,13 +225,52 @@ def test_custom_center_colors(self): fc = ax.collections[0].get_facecolors() npt.assert_array_almost_equal(fc, cmap(vals), 2) + def test_cmap_with_properties(self): + + kws = self.default_kws.copy() + cmap = copy.copy(mpl.cm.get_cmap("BrBG")) + cmap.set_bad("red") + kws["cmap"] = cmap + hm = mat._HeatMapper(self.df_unif, **kws) + npt.assert_array_equal( + cmap(np.ma.masked_invalid([np.nan])), + hm.cmap(np.ma.masked_invalid([np.nan]))) + + kws["center"] = 0.5 + hm = mat._HeatMapper(self.df_unif, **kws) + npt.assert_array_equal( + cmap(np.ma.masked_invalid([np.nan])), + hm.cmap(np.ma.masked_invalid([np.nan]))) + + kws = self.default_kws.copy() + cmap = copy.copy(mpl.cm.get_cmap("BrBG")) + cmap.set_under("red") + kws["cmap"] = cmap + hm = mat._HeatMapper(self.df_unif, **kws) + npt.assert_array_equal(cmap(-np.inf), hm.cmap(-np.inf)) + + kws["center"] = .5 + hm = mat._HeatMapper(self.df_unif, **kws) + npt.assert_array_equal(cmap(-np.inf), hm.cmap(-np.inf)) + + kws = self.default_kws.copy() + cmap = copy.copy(mpl.cm.get_cmap("BrBG")) + cmap.set_over("red") + kws["cmap"] = cmap + hm = mat._HeatMapper(self.df_unif, **kws) + npt.assert_array_equal(cmap(-np.inf), hm.cmap(-np.inf)) + + kws["center"] = .5 + hm = mat._HeatMapper(self.df_unif, **kws) + npt.assert_array_equal(cmap(np.inf), hm.cmap(np.inf)) + def test_tickabels_off(self): kws = self.default_kws.copy() kws['xticklabels'] = False kws['yticklabels'] = False p = mat._HeatMapper(self.df_norm, **kws) - nt.assert_equal(p.xticklabels, []) - nt.assert_equal(p.yticklabels, []) + assert p.xticklabels == [] + assert p.yticklabels == [] def test_custom_ticklabels(self): kws = self.default_kws.copy() @@ -216,8 +279,8 @@ def test_custom_ticklabels(self): kws['xticklabels'] = xticklabels kws['yticklabels'] = yticklabels p = mat._HeatMapper(self.df_norm, **kws) - nt.assert_equal(p.xticklabels, xticklabels) - nt.assert_equal(p.yticklabels, yticklabels) + assert p.xticklabels == xticklabels + assert p.yticklabels == yticklabels def test_custom_ticklabel_interval(self): @@ -240,8 +303,8 @@ def test_heatmap_annotation(self): ax = mat.heatmap(self.df_norm, annot=True, fmt=".1f", annot_kws={"fontsize": 14}) for val, text in zip(self.x_norm.flat, ax.texts): - nt.assert_equal(text.get_text(), "{:.1f}".format(val)) - nt.assert_equal(text.get_fontsize(), 14) + assert text.get_text() == "{:.1f}".format(val) + assert text.get_fontsize() == 14 def test_heatmap_annotation_overwrite_kws(self): @@ -249,9 +312,9 @@ def test_heatmap_annotation_overwrite_kws(self): ax = mat.heatmap(self.df_norm, annot=True, fmt=".1f", annot_kws=annot_kws) for text in ax.texts: - nt.assert_equal(text.get_color(), "0.3") - nt.assert_equal(text.get_ha(), "left") - nt.assert_equal(text.get_va(), "bottom") + assert text.get_color() == "0.3" + assert text.get_ha() == "left" + assert text.get_va() == "bottom" def test_heatmap_annotation_with_mask(self): @@ -261,15 +324,15 @@ def test_heatmap_annotation_with_mask(self): mask = np.isnan(df.values) df_masked = np.ma.masked_where(mask, df) ax = mat.heatmap(df, annot=True, fmt='.1f', mask=mask) - nt.assert_equal(len(df_masked.compressed()), len(ax.texts)) + assert len(df_masked.compressed()) == len(ax.texts) for val, text in zip(df_masked.compressed(), ax.texts): - nt.assert_equal("{:.1f}".format(val), text.get_text()) + assert "{:.1f}".format(val) == text.get_text() def test_heatmap_annotation_mesh_colors(self): ax = mat.heatmap(self.df_norm, annot=True) mesh = ax.collections[0] - nt.assert_equal(len(mesh.get_facecolors()), self.df_norm.values.size) + assert len(mesh.get_facecolors()) == self.df_norm.values.size plt.close("all") @@ -280,30 +343,30 @@ def test_heatmap_annotation_other_data(self): annot_kws={"fontsize": 14}) for val, text in zip(annot_data.values.flat, ax.texts): - nt.assert_equal(text.get_text(), "{:.1f}".format(val)) - nt.assert_equal(text.get_fontsize(), 14) + assert text.get_text() == "{:.1f}".format(val) + assert text.get_fontsize() == 14 def test_heatmap_annotation_with_limited_ticklabels(self): ax = mat.heatmap(self.df_norm, fmt=".2f", annot=True, xticklabels=False, yticklabels=False) for val, text in zip(self.x_norm.flat, ax.texts): - nt.assert_equal(text.get_text(), "{:.2f}".format(val)) + assert text.get_text() == "{:.2f}".format(val) def test_heatmap_cbar(self): f = plt.figure() mat.heatmap(self.df_norm) - nt.assert_equal(len(f.axes), 2) + assert len(f.axes) == 2 plt.close(f) f = plt.figure() mat.heatmap(self.df_norm, cbar=False) - nt.assert_equal(len(f.axes), 1) + assert len(f.axes) == 1 plt.close(f) f, (ax1, ax2) = plt.subplots(2) mat.heatmap(self.df_norm, ax=ax1, cbar_ax=ax2) - nt.assert_equal(len(f.axes), 2) + assert len(f.axes) == 2 plt.close(f) @pytest.mark.xfail(mpl.__version__ == "3.1.1", @@ -313,15 +376,15 @@ def test_heatmap_axes(self): ax = mat.heatmap(self.df_norm) xtl = [int(l.get_text()) for l in ax.get_xticklabels()] - nt.assert_equal(xtl, list(self.df_norm.columns)) + assert xtl == list(self.df_norm.columns) ytl = [l.get_text() for l in ax.get_yticklabels()] - nt.assert_equal(ytl, list(self.df_norm.index)) + assert ytl == list(self.df_norm.index) - nt.assert_equal(ax.get_xlabel(), "") - nt.assert_equal(ax.get_ylabel(), "letters") + assert ax.get_xlabel() == "" + assert ax.get_ylabel() == "letters" - nt.assert_equal(ax.get_xlim(), (0, 8)) - nt.assert_equal(ax.get_ylim(), (4, 0)) + assert ax.get_xlim() == (0, 8) + assert ax.get_ylim() == (4, 0) def test_heatmap_ticklabel_rotation(self): @@ -329,10 +392,10 @@ def test_heatmap_ticklabel_rotation(self): mat.heatmap(self.df_norm, xticklabels=1, yticklabels=1, ax=ax) for t in ax.get_xticklabels(): - nt.assert_equal(t.get_rotation(), 0) + assert t.get_rotation() == 0 for t in ax.get_yticklabels(): - nt.assert_equal(t.get_rotation(), 90) + assert t.get_rotation() == 90 plt.close(f) @@ -344,10 +407,10 @@ def test_heatmap_ticklabel_rotation(self): mat.heatmap(df, xticklabels=1, yticklabels=1, ax=ax) for t in ax.get_xticklabels(): - nt.assert_equal(t.get_rotation(), 90) + assert t.get_rotation() == 90 for t in ax.get_yticklabels(): - nt.assert_equal(t.get_rotation(), 0) + assert t.get_rotation() == 0 plt.close(f) @@ -356,31 +419,34 @@ def test_heatmap_inner_lines(self): c = (0, 0, 1, 1) ax = mat.heatmap(self.df_norm, linewidths=2, linecolor=c) mesh = ax.collections[0] - nt.assert_equal(mesh.get_linewidths()[0], 2) - nt.assert_equal(tuple(mesh.get_edgecolor()[0]), c) + assert mesh.get_linewidths()[0] == 2 + assert tuple(mesh.get_edgecolor()[0]) == c def test_square_aspect(self): ax = mat.heatmap(self.df_norm, square=True) - nt.assert_equal(ax.get_aspect(), "equal") + obs_aspect = ax.get_aspect() + # mpl>3.3 returns 1 for setting "equal" aspect + # so test for the two possible equal outcomes + assert obs_aspect == "equal" or obs_aspect == 1 def test_mask_validation(self): mask = mat._matrix_mask(self.df_norm, None) - nt.assert_equal(mask.shape, self.df_norm.shape) - nt.assert_equal(mask.values.sum(), 0) + assert mask.shape == self.df_norm.shape + assert mask.values.sum() == 0 - with nt.assert_raises(ValueError): + with pytest.raises(ValueError): bad_array_mask = self.rs.randn(3, 6) > 0 mat._matrix_mask(self.df_norm, bad_array_mask) - with nt.assert_raises(ValueError): + with pytest.raises(ValueError): bad_df_mask = pd.DataFrame(self.rs.randn(4, 8) > 0) mat._matrix_mask(self.df_norm, bad_df_mask) def test_missing_data_mask(self): - data = pd.DataFrame(np.arange(4, dtype=np.float).reshape(2, 2)) + data = pd.DataFrame(np.arange(4, dtype=float).reshape(2, 2)) data.loc[0, 0] = np.nan mask = mat._matrix_mask(data, None) npt.assert_array_equal(mask, [[True, False], [False, False]]) @@ -397,31 +463,34 @@ def test_cbar_ticks(self): assert len(ax2.collections) == 2 -class TestDendrogram(object): +@pytest.mark.skipif(_no_scipy, reason="Test requires scipy") +class TestDendrogram: + rs = np.random.RandomState(sum(map(ord, "dendrogram"))) + default_kws = dict(linkage=None, metric='euclidean', method='single', + axis=1, label=True, rotate=False) + x_norm = rs.randn(4, 8) + np.arange(8) x_norm = (x_norm.T + np.arange(4)).T letters = pd.Series(["A", "B", "C", "D", "E", "F", "G", "H"], name="letters") df_norm = pd.DataFrame(x_norm, columns=letters) - try: - import fastcluster - x_norm_linkage = fastcluster.linkage_vector(x_norm.T, - metric='euclidean', - method='single') - except ImportError: - x_norm_distances = distance.pdist(x_norm.T, metric='euclidean') - x_norm_linkage = hierarchy.linkage(x_norm_distances, method='single') - x_norm_dendrogram = hierarchy.dendrogram(x_norm_linkage, no_plot=True, - color_threshold=-np.inf) - x_norm_leaves = x_norm_dendrogram['leaves'] - df_norm_leaves = np.asarray(df_norm.columns[x_norm_leaves]) + if not _no_scipy: + if _no_fastcluster: + x_norm_distances = distance.pdist(x_norm.T, metric='euclidean') + x_norm_linkage = hierarchy.linkage(x_norm_distances, method='single') + else: + x_norm_linkage = fastcluster.linkage_vector(x_norm.T, + metric='euclidean', + method='single') - default_kws = dict(linkage=None, metric='euclidean', method='single', - axis=1, label=True, rotate=False) + x_norm_dendrogram = hierarchy.dendrogram(x_norm_linkage, no_plot=True, + color_threshold=-np.inf) + x_norm_leaves = x_norm_dendrogram['leaves'] + df_norm_leaves = np.asarray(df_norm.columns[x_norm_leaves]) def test_ndarray_input(self): p = mat._DendrogramPlotter(self.x_norm, **self.default_kws) @@ -429,15 +498,15 @@ def test_ndarray_input(self): pdt.assert_frame_equal(p.data.T, pd.DataFrame(self.x_norm)) npt.assert_array_equal(p.linkage, self.x_norm_linkage) - nt.assert_dict_equal(p.dendrogram, self.x_norm_dendrogram) + assert p.dendrogram == self.x_norm_dendrogram npt.assert_array_equal(p.reordered_ind, self.x_norm_leaves) npt.assert_array_equal(p.xticklabels, self.x_norm_leaves) npt.assert_array_equal(p.yticklabels, []) - nt.assert_equal(p.xlabel, None) - nt.assert_equal(p.ylabel, '') + assert p.xlabel is None + assert p.ylabel == '' def test_df_input(self): p = mat._DendrogramPlotter(self.df_norm, **self.default_kws) @@ -445,15 +514,15 @@ def test_df_input(self): pdt.assert_frame_equal(p.data.T, self.df_norm) npt.assert_array_equal(p.linkage, self.x_norm_linkage) - nt.assert_dict_equal(p.dendrogram, self.x_norm_dendrogram) + assert p.dendrogram == self.x_norm_dendrogram npt.assert_array_equal(p.xticklabels, np.asarray(self.df_norm.columns)[ self.x_norm_leaves]) npt.assert_array_equal(p.yticklabels, []) - nt.assert_equal(p.xlabel, 'letters') - nt.assert_equal(p.ylabel, '') + assert p.xlabel == 'letters' + assert p.ylabel == '' def test_df_multindex_input(self): @@ -472,7 +541,7 @@ def test_df_multindex_input(self): xticklabels = [xticklabels[i] for i in p.reordered_ind] npt.assert_array_equal(p.xticklabels, xticklabels) npt.assert_array_equal(p.yticklabels, []) - nt.assert_equal(p.xlabel, "letter-number") + assert p.xlabel == "letter-number" def test_axis0_input(self): kws = self.default_kws.copy() @@ -483,13 +552,13 @@ def test_axis0_input(self): pdt.assert_frame_equal(p.data, self.df_norm.T) npt.assert_array_equal(p.linkage, self.x_norm_linkage) - nt.assert_dict_equal(p.dendrogram, self.x_norm_dendrogram) + assert p.dendrogram == self.x_norm_dendrogram npt.assert_array_equal(p.xticklabels, self.df_norm_leaves) npt.assert_array_equal(p.yticklabels, []) - nt.assert_equal(p.xlabel, 'letters') - nt.assert_equal(p.ylabel, '') + assert p.xlabel == 'letters' + assert p.ylabel == '' def test_rotate_input(self): kws = self.default_kws.copy() @@ -501,8 +570,8 @@ def test_rotate_input(self): npt.assert_array_equal(p.xticklabels, []) npt.assert_array_equal(p.yticklabels, self.df_norm_leaves) - nt.assert_equal(p.xlabel, '') - nt.assert_equal(p.ylabel, 'letters') + assert p.xlabel == '' + assert p.ylabel == 'letters' def test_rotate_axis0_input(self): kws = self.default_kws.copy() @@ -529,18 +598,18 @@ def test_custom_linkage(self): p = mat._DendrogramPlotter(self.df_norm, **kws) npt.assert_array_equal(p.linkage, linkage) - nt.assert_dict_equal(p.dendrogram, dendrogram) + assert p.dendrogram == dendrogram def test_label_false(self): kws = self.default_kws.copy() kws['label'] = False p = mat._DendrogramPlotter(self.df_norm, **kws) - nt.assert_equal(p.xticks, []) - nt.assert_equal(p.yticks, []) - nt.assert_equal(p.xticklabels, []) - nt.assert_equal(p.yticklabels, []) - nt.assert_equal(p.xlabel, "") - nt.assert_equal(p.ylabel, "") + assert p.xticks == [] + assert p.yticks == [] + assert p.xticklabels == [] + assert p.yticklabels == [] + assert p.xlabel == "" + assert p.ylabel == "" def test_linkage_scipy(self): p = mat._DendrogramPlotter(self.x_norm, **self.default_kws) @@ -587,11 +656,10 @@ def test_dendrogram_plot(self): # 10 comes from _plot_dendrogram in scipy.cluster.hierarchy xmax = len(d.reordered_ind) * 10 - nt.assert_equal(xlim[0], 0) - nt.assert_equal(xlim[1], xmax) + assert xlim[0] == 0 + assert xlim[1] == xmax - nt.assert_equal(len(ax.collections[0].get_paths()), - len(d.dependent_coord)) + assert len(ax.collections[0].get_paths()) == len(d.dependent_coord) @pytest.mark.xfail(mpl.__version__ == "3.1.1", reason="matplotlib 3.1.1 bug") @@ -609,15 +677,15 @@ def test_dendrogram_rotate(self): # Since y axis is inverted, ylim is (80, 0) # and therefore not (0, 80) as usual: - nt.assert_equal(ylim[1], 0) - nt.assert_equal(ylim[0], ymax) + assert ylim[1] == 0 + assert ylim[0] == ymax def test_dendrogram_ticklabel_rotation(self): f, ax = plt.subplots(figsize=(2, 2)) mat.dendrogram(self.df_norm, ax=ax) for t in ax.get_xticklabels(): - nt.assert_equal(t.get_rotation(), 0) + assert t.get_rotation() == 0 plt.close(f) @@ -629,18 +697,20 @@ def test_dendrogram_ticklabel_rotation(self): mat.dendrogram(df, ax=ax) for t in ax.get_xticklabels(): - nt.assert_equal(t.get_rotation(), 90) + assert t.get_rotation() == 90 plt.close(f) f, ax = plt.subplots(figsize=(2, 2)) mat.dendrogram(df.T, axis=0, rotate=True) for t in ax.get_yticklabels(): - nt.assert_equal(t.get_rotation(), 0) + assert t.get_rotation() == 0 plt.close(f) -class TestClustermap(object): +@pytest.mark.skipif(_no_scipy, reason="Test requires scipy") +class TestClustermap: + rs = np.random.RandomState(sum(map(ord, "clustermap"))) x_norm = rs.randn(4, 8) + np.arange(8) @@ -649,41 +719,45 @@ class TestClustermap(object): name="letters") df_norm = pd.DataFrame(x_norm, columns=letters) - try: - import fastcluster - - x_norm_linkage = fastcluster.linkage_vector(x_norm.T, - metric='euclidean', - method='single') - except ImportError: - x_norm_distances = distance.pdist(x_norm.T, metric='euclidean') - x_norm_linkage = hierarchy.linkage(x_norm_distances, method='single') - x_norm_dendrogram = hierarchy.dendrogram(x_norm_linkage, no_plot=True, - color_threshold=-np.inf) - x_norm_leaves = x_norm_dendrogram['leaves'] - df_norm_leaves = np.asarray(df_norm.columns[x_norm_leaves]) default_kws = dict(pivot_kws=None, z_score=None, standard_scale=None, - figsize=None, row_colors=None, col_colors=None) + figsize=(10, 10), row_colors=None, col_colors=None, + dendrogram_ratio=.2, colors_ratio=.03, + cbar_pos=(0, .8, .05, .2)) default_plot_kws = dict(metric='euclidean', method='average', colorbar_kws=None, row_cluster=True, col_cluster=True, - row_linkage=None, col_linkage=None) + row_linkage=None, col_linkage=None, + tree_kws=None) row_colors = color_palette('Set2', df_norm.shape[0]) col_colors = color_palette('Dark2', df_norm.shape[1]) + if not _no_scipy: + if _no_fastcluster: + x_norm_distances = distance.pdist(x_norm.T, metric='euclidean') + x_norm_linkage = hierarchy.linkage(x_norm_distances, method='single') + else: + x_norm_linkage = fastcluster.linkage_vector(x_norm.T, + metric='euclidean', + method='single') + + x_norm_dendrogram = hierarchy.dendrogram(x_norm_linkage, no_plot=True, + color_threshold=-np.inf) + x_norm_leaves = x_norm_dendrogram['leaves'] + df_norm_leaves = np.asarray(df_norm.columns[x_norm_leaves]) + def test_ndarray_input(self): - cm = mat.ClusterGrid(self.x_norm, **self.default_kws) - pdt.assert_frame_equal(cm.data, pd.DataFrame(self.x_norm)) - nt.assert_equal(len(cm.fig.axes), 4) - nt.assert_equal(cm.ax_row_colors, None) - nt.assert_equal(cm.ax_col_colors, None) + cg = mat.ClusterGrid(self.x_norm, **self.default_kws) + pdt.assert_frame_equal(cg.data, pd.DataFrame(self.x_norm)) + assert len(cg.fig.axes) == 4 + assert cg.ax_row_colors is None + assert cg.ax_col_colors is None def test_df_input(self): - cm = mat.ClusterGrid(self.df_norm, **self.default_kws) - pdt.assert_frame_equal(cm.data, self.df_norm) + cg = mat.ClusterGrid(self.df_norm, **self.default_kws) + pdt.assert_frame_equal(cg.data, self.df_norm) def test_corr_df_input(self): df = self.df_norm.corr() @@ -700,9 +774,9 @@ def test_pivot_input(self): kws = self.default_kws.copy() kws['pivot_kws'] = dict(index='numbers', columns='letters', values='value') - cm = mat.ClusterGrid(df_long, **kws) + cg = mat.ClusterGrid(df_long, **kws) - pdt.assert_frame_equal(cm.data2d, df_norm) + pdt.assert_frame_equal(cg.data2d, df_norm) def test_colors_input(self): kws = self.default_kws.copy() @@ -710,11 +784,31 @@ def test_colors_input(self): kws['row_colors'] = self.row_colors kws['col_colors'] = self.col_colors - cm = mat.ClusterGrid(self.df_norm, **kws) - npt.assert_array_equal(cm.row_colors, self.row_colors) - npt.assert_array_equal(cm.col_colors, self.col_colors) + cg = mat.ClusterGrid(self.df_norm, **kws) + npt.assert_array_equal(cg.row_colors, self.row_colors) + npt.assert_array_equal(cg.col_colors, self.col_colors) + + assert len(cg.fig.axes) == 6 + + def test_categorical_colors_input(self): + kws = self.default_kws.copy() + + row_colors = pd.Series(self.row_colors, dtype="category") + col_colors = pd.Series( + self.col_colors, dtype="category", index=self.df_norm.columns + ) + + kws['row_colors'] = row_colors + kws['col_colors'] = col_colors - nt.assert_equal(len(cm.fig.axes), 6) + exp_row_colors = list(map(mpl.colors.to_rgb, row_colors)) + exp_col_colors = list(map(mpl.colors.to_rgb, col_colors)) + + cg = mat.ClusterGrid(self.df_norm, **kws) + npt.assert_array_equal(cg.row_colors, exp_row_colors) + npt.assert_array_equal(cg.col_colors, exp_col_colors) + + assert len(cg.fig.axes) == 6 def test_nested_colors_input(self): kws = self.default_kws.copy() @@ -728,7 +822,7 @@ def test_nested_colors_input(self): npt.assert_array_equal(cm.row_colors, row_colors) npt.assert_array_equal(cm.col_colors, col_colors) - nt.assert_equal(len(cm.fig.axes), 6) + assert len(cm.fig.axes) == 6 def test_colors_input_custom_cmap(self): kws = self.default_kws.copy() @@ -737,11 +831,11 @@ def test_colors_input_custom_cmap(self): kws['row_colors'] = self.row_colors kws['col_colors'] = self.col_colors - cm = mat.clustermap(self.df_norm, **kws) - npt.assert_array_equal(cm.row_colors, self.row_colors) - npt.assert_array_equal(cm.col_colors, self.col_colors) + cg = mat.clustermap(self.df_norm, **kws) + npt.assert_array_equal(cg.row_colors, self.row_colors) + npt.assert_array_equal(cg.col_colors, self.col_colors) - nt.assert_equal(len(cm.fig.axes), 6) + assert len(cg.fig.axes) == 6 def test_z_score(self): df = self.df_norm.copy() @@ -749,8 +843,8 @@ def test_z_score(self): kws = self.default_kws.copy() kws['z_score'] = 1 - cm = mat.ClusterGrid(self.df_norm, **kws) - pdt.assert_frame_equal(cm.data2d, df) + cg = mat.ClusterGrid(self.df_norm, **kws) + pdt.assert_frame_equal(cg.data2d, df) def test_z_score_axis0(self): df = self.df_norm.copy() @@ -760,8 +854,8 @@ def test_z_score_axis0(self): kws = self.default_kws.copy() kws['z_score'] = 0 - cm = mat.ClusterGrid(self.df_norm, **kws) - pdt.assert_frame_equal(cm.data2d, df) + cg = mat.ClusterGrid(self.df_norm, **kws) + pdt.assert_frame_equal(cg.data2d, df) def test_standard_scale(self): df = self.df_norm.copy() @@ -769,8 +863,8 @@ def test_standard_scale(self): kws = self.default_kws.copy() kws['standard_scale'] = 1 - cm = mat.ClusterGrid(self.df_norm, **kws) - pdt.assert_frame_equal(cm.data2d, df) + cg = mat.ClusterGrid(self.df_norm, **kws) + pdt.assert_frame_equal(cg.data2d, df) def test_standard_scale_axis0(self): df = self.df_norm.copy() @@ -780,75 +874,65 @@ def test_standard_scale_axis0(self): kws = self.default_kws.copy() kws['standard_scale'] = 0 - cm = mat.ClusterGrid(self.df_norm, **kws) - pdt.assert_frame_equal(cm.data2d, df) + cg = mat.ClusterGrid(self.df_norm, **kws) + pdt.assert_frame_equal(cg.data2d, df) def test_z_score_standard_scale(self): kws = self.default_kws.copy() kws['z_score'] = True kws['standard_scale'] = True - with nt.assert_raises(ValueError): + with pytest.raises(ValueError): mat.ClusterGrid(self.df_norm, **kws) def test_color_list_to_matrix_and_cmap(self): + # Note this uses the attribute named col_colors but tests row colors matrix, cmap = mat.ClusterGrid.color_list_to_matrix_and_cmap( - self.col_colors, self.x_norm_leaves) - - colors_set = set(self.col_colors) - col_to_value = dict((col, i) for i, col in enumerate(colors_set)) - matrix_test = np.array([col_to_value[col] for col in - self.col_colors])[self.x_norm_leaves] - shape = len(self.col_colors), 1 - matrix_test = matrix_test.reshape(shape) - cmap_test = mpl.colors.ListedColormap(colors_set) - npt.assert_array_equal(matrix, matrix_test) - npt.assert_array_equal(cmap.colors, cmap_test.colors) + self.col_colors, self.x_norm_leaves, axis=0) + + for i, leaf in enumerate(self.x_norm_leaves): + color = self.col_colors[leaf] + assert_colors_equal(cmap(matrix[i, 0]), color) def test_nested_color_list_to_matrix_and_cmap(self): - colors = [self.col_colors, self.col_colors] + # Note this uses the attribute named col_colors but tests row colors + colors = [self.col_colors, self.col_colors[::-1]] matrix, cmap = mat.ClusterGrid.color_list_to_matrix_and_cmap( - colors, self.x_norm_leaves) - - all_colors = set(itertools.chain(*colors)) - color_to_value = dict((col, i) for i, col in enumerate(all_colors)) - matrix_test = np.array( - [color_to_value[c] for color in colors for c in color]) - shape = len(colors), len(colors[0]) - matrix_test = matrix_test.reshape(shape) - matrix_test = matrix_test[:, self.x_norm_leaves] - matrix_test = matrix_test.T + colors, self.x_norm_leaves, axis=0) - cmap_test = mpl.colors.ListedColormap(all_colors) - npt.assert_array_equal(matrix, matrix_test) - npt.assert_array_equal(cmap.colors, cmap_test.colors) + for i, leaf in enumerate(self.x_norm_leaves): + for j, color_row in enumerate(colors): + color = color_row[leaf] + assert_colors_equal(cmap(matrix[i, j]), color) def test_color_list_to_matrix_and_cmap_axis1(self): matrix, cmap = mat.ClusterGrid.color_list_to_matrix_and_cmap( self.col_colors, self.x_norm_leaves, axis=1) - colors_set = set(self.col_colors) - col_to_value = dict((col, i) for i, col in enumerate(colors_set)) - matrix_test = np.array([col_to_value[col] for col in - self.col_colors])[self.x_norm_leaves] - shape = 1, len(self.col_colors) - matrix_test = matrix_test.reshape(shape) - cmap_test = mpl.colors.ListedColormap(colors_set) - npt.assert_array_equal(matrix, matrix_test) - npt.assert_array_equal(cmap.colors, cmap_test.colors) + for j, leaf in enumerate(self.x_norm_leaves): + color = self.col_colors[leaf] + assert_colors_equal(cmap(matrix[0, j]), color) + + def test_color_list_to_matrix_and_cmap_different_sizes(self): + colors = [self.col_colors, self.col_colors * 2] + with pytest.raises(ValueError): + matrix, cmap = mat.ClusterGrid.color_list_to_matrix_and_cmap( + colors, self.x_norm_leaves, axis=1) def test_savefig(self): # Not sure if this is the right way to test.... - cm = mat.ClusterGrid(self.df_norm, **self.default_kws) - cm.plot(**self.default_plot_kws) - cm.savefig(tempfile.NamedTemporaryFile(), format='png') + cg = mat.ClusterGrid(self.df_norm, **self.default_kws) + cg.plot(**self.default_plot_kws) + cg.savefig(tempfile.NamedTemporaryFile(), format='png') def test_plot_dendrograms(self): cm = mat.clustermap(self.df_norm, **self.default_kws) - nt.assert_equal(len(cm.ax_row_dendrogram.collections[0].get_paths()), - len(cm.dendrogram_row.independent_coord)) - nt.assert_equal(len(cm.ax_col_dendrogram.collections[0].get_paths()), - len(cm.dendrogram_col.independent_coord)) + assert len(cm.ax_row_dendrogram.collections[0].get_paths()) == len( + cm.dendrogram_row.independent_coord + ) + assert len(cm.ax_col_dendrogram.collections[0].get_paths()) == len( + cm.dendrogram_col.independent_coord + ) data2d = self.df_norm.iloc[cm.dendrogram_row.reordered_ind, cm.dendrogram_col.reordered_ind] pdt.assert_frame_equal(cm.data2d, data2d) @@ -859,13 +943,13 @@ def test_cluster_false(self): kws['col_cluster'] = False cm = mat.clustermap(self.df_norm, **kws) - nt.assert_equal(len(cm.ax_row_dendrogram.lines), 0) - nt.assert_equal(len(cm.ax_col_dendrogram.lines), 0) + assert len(cm.ax_row_dendrogram.lines) == 0 + assert len(cm.ax_col_dendrogram.lines) == 0 - nt.assert_equal(len(cm.ax_row_dendrogram.get_xticks()), 0) - nt.assert_equal(len(cm.ax_row_dendrogram.get_yticks()), 0) - nt.assert_equal(len(cm.ax_col_dendrogram.get_xticks()), 0) - nt.assert_equal(len(cm.ax_col_dendrogram.get_yticks()), 0) + assert len(cm.ax_row_dendrogram.get_xticks()) == 0 + assert len(cm.ax_row_dendrogram.get_yticks()) == 0 + assert len(cm.ax_col_dendrogram.get_xticks()) == 0 + assert len(cm.ax_col_dendrogram.get_yticks()) == 0 pdt.assert_frame_equal(cm.data2d, self.df_norm) @@ -876,8 +960,8 @@ def test_row_col_colors(self): cm = mat.clustermap(self.df_norm, **kws) - nt.assert_equal(len(cm.ax_row_colors.collections), 1) - nt.assert_equal(len(cm.ax_col_colors.collections), 1) + assert len(cm.ax_row_colors.collections) == 1 + assert len(cm.ax_col_colors.collections) == 1 def test_cluster_false_row_col_colors(self): kws = self.default_kws.copy() @@ -887,15 +971,15 @@ def test_cluster_false_row_col_colors(self): kws['col_colors'] = self.col_colors cm = mat.clustermap(self.df_norm, **kws) - nt.assert_equal(len(cm.ax_row_dendrogram.lines), 0) - nt.assert_equal(len(cm.ax_col_dendrogram.lines), 0) + assert len(cm.ax_row_dendrogram.lines) == 0 + assert len(cm.ax_col_dendrogram.lines) == 0 - nt.assert_equal(len(cm.ax_row_dendrogram.get_xticks()), 0) - nt.assert_equal(len(cm.ax_row_dendrogram.get_yticks()), 0) - nt.assert_equal(len(cm.ax_col_dendrogram.get_xticks()), 0) - nt.assert_equal(len(cm.ax_col_dendrogram.get_yticks()), 0) - nt.assert_equal(len(cm.ax_row_colors.collections), 1) - nt.assert_equal(len(cm.ax_col_colors.collections), 1) + assert len(cm.ax_row_dendrogram.get_xticks()) == 0 + assert len(cm.ax_row_dendrogram.get_yticks()) == 0 + assert len(cm.ax_col_dendrogram.get_xticks()) == 0 + assert len(cm.ax_col_dendrogram.get_yticks()) == 0 + assert len(cm.ax_row_colors.collections) == 1 + assert len(cm.ax_col_colors.collections) == 1 pdt.assert_frame_equal(cm.data2d, self.df_norm) @@ -914,13 +998,13 @@ def test_row_col_colors_df(self): row_labels = [l.get_text() for l in cm.ax_row_colors.get_xticklabels()] - nt.assert_equal(cm.row_color_labels, ['row_1', 'row_2']) - nt.assert_equal(row_labels, cm.row_color_labels) + assert cm.row_color_labels == ['row_1', 'row_2'] + assert row_labels == cm.row_color_labels col_labels = [l.get_text() for l in cm.ax_col_colors.get_yticklabels()] - nt.assert_equal(cm.col_color_labels, ['col_1', 'col_2']) - nt.assert_equal(col_labels, cm.col_color_labels) + assert cm.col_color_labels == ['col_1', 'col_2'] + assert col_labels == cm.col_color_labels def test_row_col_colors_df_shuffled(self): # Tests if colors are properly matched, even if given in wrong order @@ -942,8 +1026,8 @@ def test_row_col_colors_df_shuffled(self): kws['col_colors'] = col_colors.loc[shuffled_cols] cm = mat.clustermap(self.df_norm, **kws) - nt.assert_equal(list(cm.col_colors)[0], list(self.col_colors)) - nt.assert_equal(list(cm.row_colors)[0], list(self.row_colors)) + assert list(cm.col_colors)[0] == list(self.col_colors) + assert list(cm.row_colors)[0] == list(self.row_colors) def test_row_col_colors_df_missing(self): kws = self.default_kws.copy() @@ -957,10 +1041,8 @@ def test_row_col_colors_df_missing(self): cm = mat.clustermap(self.df_norm, **kws) - nt.assert_equal(list(cm.col_colors)[0], - [(1.0, 1.0, 1.0)] + list(self.col_colors[1:])) - nt.assert_equal(list(cm.row_colors)[0], - [(1.0, 1.0, 1.0)] + list(self.row_colors[1:])) + assert list(cm.col_colors)[0] == [(1.0, 1.0, 1.0)] + list(self.col_colors[1:]) + assert list(cm.row_colors)[0] == [(1.0, 1.0, 1.0)] + list(self.row_colors[1:]) def test_row_col_colors_df_one_axis(self): # Test case with only row annotation. @@ -974,10 +1056,10 @@ def test_row_col_colors_df_one_axis(self): row_labels = [l.get_text() for l in cm1.ax_row_colors.get_xticklabels()] - nt.assert_equal(cm1.row_color_labels, ['row_1', 'row_2']) - nt.assert_equal(row_labels, cm1.row_color_labels) + assert cm1.row_color_labels == ['row_1', 'row_2'] + assert row_labels == cm1.row_color_labels - # Test case with onl col annotation. + # Test case with only col annotation. kws2 = self.default_kws.copy() kws2['col_colors'] = pd.DataFrame({'col_1': list(self.col_colors), 'col_2': list(self.col_colors)}, @@ -988,8 +1070,8 @@ def test_row_col_colors_df_one_axis(self): col_labels = [l.get_text() for l in cm2.ax_col_colors.get_yticklabels()] - nt.assert_equal(cm2.col_color_labels, ['col_1', 'col_2']) - nt.assert_equal(col_labels, cm2.col_color_labels) + assert cm2.col_color_labels == ['col_1', 'col_2'] + assert col_labels == cm2.col_color_labels def test_row_col_colors_series(self): kws = self.default_kws.copy() @@ -1000,15 +1082,13 @@ def test_row_col_colors_series(self): cm = mat.clustermap(self.df_norm, **kws) - row_labels = [l.get_text() for l in - cm.ax_row_colors.get_xticklabels()] - nt.assert_equal(cm.row_color_labels, ['row_annot']) - nt.assert_equal(row_labels, cm.row_color_labels) + row_labels = [l.get_text() for l in cm.ax_row_colors.get_xticklabels()] + assert cm.row_color_labels == ['row_annot'] + assert row_labels == cm.row_color_labels - col_labels = [l.get_text() for l in - cm.ax_col_colors.get_yticklabels()] - nt.assert_equal(cm.col_color_labels, ['col_annot']) - nt.assert_equal(col_labels, cm.col_color_labels) + col_labels = [l.get_text() for l in cm.ax_col_colors.get_yticklabels()] + assert cm.col_color_labels == ['col_annot'] + assert col_labels == cm.col_color_labels def test_row_col_colors_series_shuffled(self): # Tests if colors are properly matched, even if given in wrong order @@ -1031,8 +1111,8 @@ def test_row_col_colors_series_shuffled(self): cm = mat.clustermap(self.df_norm, **kws) - nt.assert_equal(list(cm.col_colors), list(self.col_colors)) - nt.assert_equal(list(cm.row_colors), list(self.row_colors)) + assert list(cm.col_colors) == list(self.col_colors) + assert list(cm.row_colors) == list(self.row_colors) def test_row_col_colors_series_missing(self): kws = self.default_kws.copy() @@ -1045,10 +1125,8 @@ def test_row_col_colors_series_missing(self): kws['col_colors'] = col_colors.drop(self.df_norm.columns[0]) cm = mat.clustermap(self.df_norm, **kws) - nt.assert_equal(list(cm.col_colors), - [(1.0, 1.0, 1.0)] + list(self.col_colors[1:])) - nt.assert_equal(list(cm.row_colors), - [(1.0, 1.0, 1.0)] + list(self.row_colors[1:])) + assert list(cm.col_colors) == [(1.0, 1.0, 1.0)] + list(self.col_colors[1:]) + assert list(cm.row_colors) == [(1.0, 1.0, 1.0)] + list(self.row_colors[1:]) def test_row_col_colors_ignore_heatmap_kwargs(self): @@ -1069,6 +1147,22 @@ def test_row_col_colors_ignore_heatmap_kwargs(self): g.ax_col_colors.collections[0].get_facecolors()[:, :3] ) + def test_row_col_colors_raise_on_mixed_index_types(self): + + row_colors = pd.Series( + list(self.row_colors), name="row_annot", index=self.df_norm.index + ) + + col_colors = pd.Series( + list(self.col_colors), name="col_annot", index=self.df_norm.columns + ) + + with pytest.raises(TypeError): + mat.clustermap(self.x_norm, row_colors=row_colors) + + with pytest.raises(TypeError): + mat.clustermap(self.x_norm, col_colors=col_colors) + def test_mask_reorganization(self): kws = self.default_kws.copy() @@ -1114,5 +1208,129 @@ def test_noticklabels(self): xtl_actual = [t.get_text() for t in g.ax_heatmap.get_xticklabels()] ytl_actual = [t.get_text() for t in g.ax_heatmap.get_yticklabels()] - nt.assert_equal(xtl_actual, []) - nt.assert_equal(ytl_actual, []) + assert xtl_actual == [] + assert ytl_actual == [] + + def test_size_ratios(self): + + # The way that wspace/hspace work in GridSpec, the mapping from input + # ratio to actual width/height of each axes is complicated, so this + # test is just going to assert comparative relationships + + kws1 = self.default_kws.copy() + kws1.update(dendrogram_ratio=.2, colors_ratio=.03, + col_colors=self.col_colors, row_colors=self.row_colors) + + kws2 = kws1.copy() + kws2.update(dendrogram_ratio=.3, colors_ratio=.05) + + g1 = mat.clustermap(self.df_norm, **kws1) + g2 = mat.clustermap(self.df_norm, **kws2) + + assert (g2.ax_col_dendrogram.get_position().height + > g1.ax_col_dendrogram.get_position().height) + + assert (g2.ax_col_colors.get_position().height + > g1.ax_col_colors.get_position().height) + + assert (g2.ax_heatmap.get_position().height + < g1.ax_heatmap.get_position().height) + + assert (g2.ax_row_dendrogram.get_position().width + > g1.ax_row_dendrogram.get_position().width) + + assert (g2.ax_row_colors.get_position().width + > g1.ax_row_colors.get_position().width) + + assert (g2.ax_heatmap.get_position().width + < g1.ax_heatmap.get_position().width) + + kws1 = self.default_kws.copy() + kws1.update(col_colors=self.col_colors) + kws2 = kws1.copy() + kws2.update(col_colors=[self.col_colors, self.col_colors]) + + g1 = mat.clustermap(self.df_norm, **kws1) + g2 = mat.clustermap(self.df_norm, **kws2) + + assert (g2.ax_col_colors.get_position().height + > g1.ax_col_colors.get_position().height) + + kws1 = self.default_kws.copy() + kws1.update(dendrogram_ratio=(.2, .2)) + + kws2 = kws1.copy() + kws2.update(dendrogram_ratio=(.2, .3)) + + g1 = mat.clustermap(self.df_norm, **kws1) + g2 = mat.clustermap(self.df_norm, **kws2) + + # Fails on pinned matplotlib? + # assert (g2.ax_row_dendrogram.get_position().width + # == g1.ax_row_dendrogram.get_position().width) + assert g1.gs.get_width_ratios() == g2.gs.get_width_ratios() + + assert (g2.ax_col_dendrogram.get_position().height + > g1.ax_col_dendrogram.get_position().height) + + def test_cbar_pos(self): + + kws = self.default_kws.copy() + kws["cbar_pos"] = (.2, .1, .4, .3) + + g = mat.clustermap(self.df_norm, **kws) + pos = g.ax_cbar.get_position() + assert pytest.approx(tuple(pos.p0)) == kws["cbar_pos"][:2] + assert pytest.approx(pos.width) == kws["cbar_pos"][2] + assert pytest.approx(pos.height) == kws["cbar_pos"][3] + + kws["cbar_pos"] = None + g = mat.clustermap(self.df_norm, **kws) + assert g.ax_cbar is None + + def test_square_warning(self): + + kws = self.default_kws.copy() + g1 = mat.clustermap(self.df_norm, **kws) + + with pytest.warns(UserWarning): + kws["square"] = True + g2 = mat.clustermap(self.df_norm, **kws) + + g1_shape = g1.ax_heatmap.get_position().get_points() + g2_shape = g2.ax_heatmap.get_position().get_points() + assert np.array_equal(g1_shape, g2_shape) + + def test_clustermap_annotation(self): + + g = mat.clustermap(self.df_norm, annot=True, fmt=".1f") + for val, text in zip(np.asarray(g.data2d).flat, g.ax_heatmap.texts): + assert text.get_text() == "{:.1f}".format(val) + + g = mat.clustermap(self.df_norm, annot=self.df_norm, fmt=".1f") + for val, text in zip(np.asarray(g.data2d).flat, g.ax_heatmap.texts): + assert text.get_text() == "{:.1f}".format(val) + + def test_tree_kws(self): + + rgb = (1, .5, .2) + g = mat.clustermap(self.df_norm, tree_kws=dict(color=rgb)) + for ax in [g.ax_col_dendrogram, g.ax_row_dendrogram]: + tree, = ax.collections + assert tuple(tree.get_color().squeeze())[:3] == rgb + + +if _no_scipy: + + def test_required_scipy_errors(): + + x = np.random.normal(0, 1, (10, 10)) + + with pytest.raises(RuntimeError): + mat.clustermap(x) + + with pytest.raises(RuntimeError): + mat.ClusterGrid(x) + + with pytest.raises(RuntimeError): + mat.dendrogram(x) diff --git a/seaborn/tests/test_miscplot.py b/seaborn/tests/test_miscplot.py index 45f505d96e..323cdf7337 100644 --- a/seaborn/tests/test_miscplot.py +++ b/seaborn/tests/test_miscplot.py @@ -1,32 +1,31 @@ -import nose.tools as nt import matplotlib.pyplot as plt from .. import miscplot as misc from ..palettes import color_palette -from ..utils import _network +from .test_utils import _network -class TestPalPlot(object): +class TestPalPlot: """Test the function that visualizes a color palette.""" def test_palplot_size(self): pal4 = color_palette("husl", 4) misc.palplot(pal4) size4 = plt.gcf().get_size_inches() - nt.assert_equal(tuple(size4), (4, 1)) + assert tuple(size4) == (4, 1) pal5 = color_palette("husl", 5) misc.palplot(pal5) size5 = plt.gcf().get_size_inches() - nt.assert_equal(tuple(size5), (5, 1)) + assert tuple(size5) == (5, 1) palbig = color_palette("husl", 3) misc.palplot(palbig, 2) sizebig = plt.gcf().get_size_inches() - nt.assert_equal(tuple(sizebig), (6, 2)) + assert tuple(sizebig) == (6, 2) -class TestDogPlot(object): +class TestDogPlot: @_network(url="https://github.com/mwaskom/seaborn-data") def test_dogplot(self): diff --git a/seaborn/tests/test_palettes.py b/seaborn/tests/test_palettes.py index ead327740c..152a2448fb 100644 --- a/seaborn/tests/test_palettes.py +++ b/seaborn/tests/test_palettes.py @@ -3,19 +3,14 @@ import matplotlib as mpl import pytest -import nose.tools as nt import numpy.testing as npt -import matplotlib.pyplot as plt from .. import palettes, utils, rcmod from ..external import husl from ..colors import xkcd_rgb, crayons -from distutils.version import LooseVersion -mpl_ge_150 = LooseVersion(mpl.__version__) >= '1.5.0' - -class TestColorPalettes(object): +class TestColorPalettes: def test_current_palette(self): @@ -30,9 +25,9 @@ def test_palette_context(self): context_pal = palettes.color_palette("muted") with palettes.color_palette(context_pal): - nt.assert_equal(utils.get_color_cycle(), context_pal) + assert utils.get_color_cycle() == context_pal - nt.assert_equal(utils.get_color_cycle(), default_pal) + assert utils.get_color_cycle() == default_pal def test_big_palette_context(self): @@ -41,9 +36,9 @@ def test_big_palette_context(self): rcmod.set_palette(original_pal) with palettes.color_palette(context_pal, 10): - nt.assert_equal(utils.get_color_cycle(), context_pal) + assert utils.get_color_cycle() == context_pal - nt.assert_equal(utils.get_color_cycle(), original_pal) + assert utils.get_color_cycle() == original_pal # Reset default rcmod.set() @@ -76,21 +71,35 @@ def test_seaborn_palettes(self): def test_hls_palette(self): - hls_pal1 = palettes.hls_palette() - hls_pal2 = palettes.color_palette("hls") - npt.assert_array_equal(hls_pal1, hls_pal2) + pal1 = palettes.hls_palette() + pal2 = palettes.color_palette("hls") + npt.assert_array_equal(pal1, pal2) + + cmap1 = palettes.hls_palette(as_cmap=True) + cmap2 = palettes.color_palette("hls", as_cmap=True) + npt.assert_array_equal(cmap1([.2, .8]), cmap2([.2, .8])) def test_husl_palette(self): - husl_pal1 = palettes.husl_palette() - husl_pal2 = palettes.color_palette("husl") - npt.assert_array_equal(husl_pal1, husl_pal2) + pal1 = palettes.husl_palette() + pal2 = palettes.color_palette("husl") + npt.assert_array_equal(pal1, pal2) + + cmap1 = palettes.husl_palette(as_cmap=True) + cmap2 = palettes.color_palette("husl", as_cmap=True) + npt.assert_array_equal(cmap1([.2, .8]), cmap2([.2, .8])) def test_mpl_palette(self): - mpl_pal1 = palettes.mpl_palette("Reds") - mpl_pal2 = palettes.color_palette("Reds") - npt.assert_array_equal(mpl_pal1, mpl_pal2) + pal1 = palettes.mpl_palette("Reds") + pal2 = palettes.color_palette("Reds") + npt.assert_array_equal(pal1, pal2) + + cmap1 = mpl.cm.get_cmap("Reds") + cmap2 = palettes.mpl_palette("Reds", as_cmap=True) + cmap3 = palettes.color_palette("Reds", as_cmap=True) + npt.assert_array_equal(cmap1, cmap2) + npt.assert_array_equal(cmap1, cmap3) def test_mpl_dark_palette(self): @@ -98,20 +107,24 @@ def test_mpl_dark_palette(self): mpl_pal2 = palettes.color_palette("Blues_d") npt.assert_array_equal(mpl_pal1, mpl_pal2) + mpl_pal1 = palettes.mpl_palette("Blues_r_d") + mpl_pal2 = palettes.color_palette("Blues_r_d") + npt.assert_array_equal(mpl_pal1, mpl_pal2) + def test_bad_palette_name(self): - with nt.assert_raises(ValueError): + with pytest.raises(ValueError): palettes.color_palette("IAmNotAPalette") def test_terrible_palette_name(self): - with nt.assert_raises(ValueError): + with pytest.raises(ValueError): palettes.color_palette("jet") def test_bad_palette_colors(self): pal = ["red", "blue", "iamnotacolor"] - with nt.assert_raises(ValueError): + with pytest.raises(ValueError): palettes.color_palette(pal) def test_palette_desat(self): @@ -126,16 +139,16 @@ def test_palette_is_list_of_tuples(self): pal_in = np.array(["red", "blue", "green"]) pal_out = palettes.color_palette(pal_in, 3) - nt.assert_is_instance(pal_out, list) - nt.assert_is_instance(pal_out[0], tuple) - nt.assert_is_instance(pal_out[0][0], float) - nt.assert_equal(len(pal_out[0]), 3) + assert isinstance(pal_out, list) + assert isinstance(pal_out[0], tuple) + assert isinstance(pal_out[0][0], float) + assert len(pal_out[0]) == 3 def test_palette_cycles(self): deep = palettes.color_palette("deep6") double_deep = palettes.color_palette("deep6", 12) - nt.assert_equal(double_deep, deep + deep) + assert double_deep == deep + deep def test_hls_values(self): @@ -175,11 +188,11 @@ def test_cbrewer_qual(self): pal_short = palettes.mpl_palette("Set1", 4) pal_long = palettes.mpl_palette("Set1", 6) - nt.assert_equal(pal_short, pal_long[:4]) + assert pal_short == pal_long[:4] pal_full = palettes.mpl_palette("Set2", 8) pal_long = palettes.mpl_palette("Set2", 10) - nt.assert_equal(pal_full, pal_long[:8]) + assert pal_full == pal_long[:8] def test_mpl_reversal(self): @@ -192,51 +205,113 @@ def test_rgb_from_hls(self): color = .5, .8, .4 rgb_got = palettes._color_to_rgb(color, "hls") rgb_want = colorsys.hls_to_rgb(*color) - nt.assert_equal(rgb_got, rgb_want) + assert rgb_got == rgb_want def test_rgb_from_husl(self): color = 120, 50, 40 rgb_got = palettes._color_to_rgb(color, "husl") - rgb_want = husl.husl_to_rgb(*color) - nt.assert_equal(rgb_got, rgb_want) + rgb_want = tuple(husl.husl_to_rgb(*color)) + assert rgb_got == rgb_want + + for h in range(0, 360): + color = h, 100, 100 + rgb = palettes._color_to_rgb(color, "husl") + assert min(rgb) >= 0 + assert max(rgb) <= 1 def test_rgb_from_xkcd(self): color = "dull red" rgb_got = palettes._color_to_rgb(color, "xkcd") - rgb_want = xkcd_rgb[color] - nt.assert_equal(rgb_got, rgb_want) + rgb_want = mpl.colors.to_rgb(xkcd_rgb[color]) + assert rgb_got == rgb_want def test_light_palette(self): - pal_forward = palettes.light_palette("red") - pal_reverse = palettes.light_palette("red", reverse=True) - npt.assert_array_almost_equal(pal_forward, pal_reverse[::-1]) + n = 4 + pal_forward = palettes.light_palette("red", n) + pal_reverse = palettes.light_palette("red", n, reverse=True) + assert np.allclose(pal_forward, pal_reverse[::-1]) + + red = mpl.colors.colorConverter.to_rgb("red") + assert pal_forward[-1] == red - red = tuple(mpl.colors.colorConverter.to_rgba("red")) - nt.assert_equal(tuple(pal_forward[-1]), red) + pal_f_from_string = palettes.color_palette("light:red", n) + assert pal_forward[3] == pal_f_from_string[3] + + pal_r_from_string = palettes.color_palette("light:red_r", n) + assert pal_reverse[3] == pal_r_from_string[3] pal_cmap = palettes.light_palette("blue", as_cmap=True) - nt.assert_is_instance(pal_cmap, mpl.colors.LinearSegmentedColormap) + assert isinstance(pal_cmap, mpl.colors.LinearSegmentedColormap) + + pal_cmap_from_string = palettes.color_palette("light:blue", as_cmap=True) + assert pal_cmap(.8) == pal_cmap_from_string(.8) + + pal_cmap = palettes.light_palette("blue", as_cmap=True, reverse=True) + pal_cmap_from_string = palettes.color_palette("light:blue_r", as_cmap=True) + assert pal_cmap(.8) == pal_cmap_from_string(.8) def test_dark_palette(self): - pal_forward = palettes.dark_palette("red") - pal_reverse = palettes.dark_palette("red", reverse=True) - npt.assert_array_almost_equal(pal_forward, pal_reverse[::-1]) + n = 4 + pal_forward = palettes.dark_palette("red", n) + pal_reverse = palettes.dark_palette("red", n, reverse=True) + assert np.allclose(pal_forward, pal_reverse[::-1]) - red = tuple(mpl.colors.colorConverter.to_rgba("red")) - nt.assert_equal(tuple(pal_forward[-1]), red) + red = mpl.colors.colorConverter.to_rgb("red") + assert pal_forward[-1] == red + + pal_f_from_string = palettes.color_palette("dark:red", n) + assert pal_forward[3] == pal_f_from_string[3] + + pal_r_from_string = palettes.color_palette("dark:red_r", n) + assert pal_reverse[3] == pal_r_from_string[3] pal_cmap = palettes.dark_palette("blue", as_cmap=True) - nt.assert_is_instance(pal_cmap, mpl.colors.LinearSegmentedColormap) + assert isinstance(pal_cmap, mpl.colors.LinearSegmentedColormap) + + pal_cmap_from_string = palettes.color_palette("dark:blue", as_cmap=True) + assert pal_cmap(.8) == pal_cmap_from_string(.8) + + pal_cmap = palettes.dark_palette("blue", as_cmap=True, reverse=True) + pal_cmap_from_string = palettes.color_palette("dark:blue_r", as_cmap=True) + assert pal_cmap(.8) == pal_cmap_from_string(.8) + + def test_diverging_palette(self): + + h_neg, h_pos = 100, 200 + sat, lum = 70, 50 + args = h_neg, h_pos, sat, lum + + n = 12 + pal = palettes.diverging_palette(*args, n=n) + neg_pal = palettes.light_palette((h_neg, sat, lum), int(n // 2), + input="husl") + pos_pal = palettes.light_palette((h_pos, sat, lum), int(n // 2), + input="husl") + assert len(pal) == n + assert pal[0] == neg_pal[-1] + assert pal[-1] == pos_pal[-1] + + pal_dark = palettes.diverging_palette(*args, n=n, center="dark") + assert np.mean(pal[int(n / 2)]) > np.mean(pal_dark[int(n / 2)]) + + pal_cmap = palettes.diverging_palette(*args, as_cmap=True) + assert isinstance(pal_cmap, mpl.colors.LinearSegmentedColormap) def test_blend_palette(self): colors = ["red", "yellow", "white"] pal_cmap = palettes.blend_palette(colors, as_cmap=True) - nt.assert_is_instance(pal_cmap, mpl.colors.LinearSegmentedColormap) + assert isinstance(pal_cmap, mpl.colors.LinearSegmentedColormap) + + colors = ["red", "blue"] + pal = palettes.blend_palette(colors) + pal_str = "blend:" + ",".join(colors) + pal_from_str = palettes.color_palette(pal_str) + assert pal == pal_from_str def test_cubehelix_against_matplotlib(self): @@ -246,24 +321,24 @@ def test_cubehelix_against_matplotlib(self): sns_pal = palettes.cubehelix_palette(8, start=0.5, rot=-1.5, hue=1, dark=0, light=1, reverse=True) - nt.assert_list_equal(sns_pal, mpl_pal) + assert sns_pal == mpl_pal def test_cubehelix_n_colors(self): for n in [3, 5, 8]: pal = palettes.cubehelix_palette(n) - nt.assert_equal(len(pal), n) + assert len(pal) == n def test_cubehelix_reverse(self): pal_forward = palettes.cubehelix_palette() pal_reverse = palettes.cubehelix_palette(reverse=True) - nt.assert_list_equal(pal_forward, pal_reverse[::-1]) + assert pal_forward == pal_reverse[::-1] def test_cubehelix_cmap(self): cmap = palettes.cubehelix_palette(as_cmap=True) - nt.assert_is_instance(cmap, mpl.colors.ListedColormap) + assert isinstance(cmap, mpl.colors.ListedColormap) pal = palettes.cubehelix_palette() x = np.linspace(0, 1, 6) npt.assert_array_equal(cmap(x)[:, :3], pal) @@ -272,7 +347,7 @@ def test_cubehelix_cmap(self): x = np.linspace(0, 1, 6) pal_forward = cmap(x).tolist() pal_reverse = cmap_rev(x[::-1]).tolist() - nt.assert_list_equal(pal_forward, pal_reverse) + assert pal_forward == pal_reverse def test_cubehelix_code(self): @@ -295,13 +370,17 @@ def test_cubehelix_code(self): pal2 = color_palette(cubehelix_palette(6, reverse=True)) assert pal1 == pal2 + pal1 = color_palette("ch:_r", as_cmap=True) + pal2 = cubehelix_palette(6, reverse=True, as_cmap=True) + assert pal1(.5) == pal2(.5) + def test_xkcd_palette(self): names = list(xkcd_rgb.keys())[10:15] colors = palettes.xkcd_palette(names) for name, color in zip(names, colors): as_hex = mpl.colors.rgb2hex(color) - nt.assert_equal(as_hex, xkcd_rgb[name]) + assert as_hex == xkcd_rgb[name] def test_crayon_palette(self): @@ -309,7 +388,7 @@ def test_crayon_palette(self): colors = palettes.crayon_palette(names) for name, color in zip(names, colors): as_hex = mpl.colors.rgb2hex(color) - nt.assert_equal(as_hex, crayons[name].lower()) + assert as_hex == crayons[name].lower() def test_color_codes(self): @@ -318,7 +397,7 @@ def test_color_codes(self): for code, color in zip("bgrmyck", colors): rgb_want = mpl.colors.colorConverter.to_rgb(color) rgb_got = mpl.colors.colorConverter.to_rgb(code) - nt.assert_equal(rgb_want, rgb_got) + assert rgb_want == rgb_got palettes.set_color_codes("reset") with pytest.raises(ValueError): @@ -328,19 +407,17 @@ def test_as_hex(self): pal = palettes.color_palette("deep") for rgb, hex in zip(pal, pal.as_hex()): - nt.assert_equal(mpl.colors.rgb2hex(rgb), hex) + assert mpl.colors.rgb2hex(rgb) == hex def test_preserved_palette_length(self): pal_in = palettes.color_palette("Set1", 10) pal_out = palettes.color_palette(pal_in) - nt.assert_equal(pal_in, pal_out) + assert pal_in == pal_out - def test_get_color_cycle(self): + def test_html_rep(self): - if mpl_ge_150: - colors = [(1., 0., 0.), (0, 1., 0.)] - prop_cycle = plt.cycler(color=colors) - with plt.rc_context({"axes.prop_cycle": prop_cycle}): - result = utils.get_color_cycle() - assert result == colors + pal = palettes.color_palette() + html = pal._repr_html_() + for color in pal.as_hex(): + assert color in html diff --git a/seaborn/tests/test_rcmod.py b/seaborn/tests/test_rcmod.py index 290d171f80..e0590dd0b1 100644 --- a/seaborn/tests/test_rcmod.py +++ b/seaborn/tests/test_rcmod.py @@ -1,15 +1,14 @@ +import pytest import numpy as np import matplotlib as mpl -from distutils.version import LooseVersion -import nose import matplotlib.pyplot as plt -import nose.tools as nt import numpy.testing as npt from .. import rcmod, palettes, utils +from ..conftest import has_verdana -class RCParamTester(object): +class RCParamTester: def flatten_list(self, orig_list): @@ -20,10 +19,28 @@ def flatten_list(self, orig_list): def assert_rc_params(self, params): for k, v in params.items(): + # Various subtle issues in matplotlib lead to unexpected + # values for the backend rcParam, which isn't relevant here + if k == "backend": + continue if isinstance(v, np.ndarray): npt.assert_array_equal(mpl.rcParams[k], v) else: - nt.assert_equal((k, mpl.rcParams[k]), (k, v)) + assert mpl.rcParams[k] == v + + def assert_rc_params_equal(self, params1, params2): + + for key, v1 in params1.items(): + # Various subtle issues in matplotlib lead to unexpected + # values for the backend rcParam, which isn't relevant here + if key == "backend": + continue + + v2 = params2[key] + if isinstance(v1, np.ndarray): + npt.assert_array_equal(v1, v2) + else: + assert v1 == v2 class TestAxesStyle(RCParamTester): @@ -39,19 +56,19 @@ def test_key_usage(self): _style_keys = set(rcmod._style_keys) for style in self.styles: - nt.assert_true(not set(rcmod.axes_style(style)) ^ _style_keys) + assert not set(rcmod.axes_style(style)) ^ _style_keys def test_bad_style(self): - with nt.assert_raises(ValueError): + with pytest.raises(ValueError): rcmod.axes_style("i_am_not_a_style") def test_rc_override(self): rc = {"axes.facecolor": "blue", "foo.notaparam": "bar"} out = rcmod.axes_style("darkgrid", rc) - nt.assert_equal(out["axes.facecolor"], "blue") - nt.assert_not_in("foo.notaparam", out) + assert out["axes.facecolor"] == "blue" + assert "foo.notaparam" not in out def test_set_style(self): @@ -79,58 +96,61 @@ def func(): def test_style_context_independence(self): - nt.assert_true(set(rcmod._style_keys) ^ set(rcmod._context_keys)) + assert set(rcmod._style_keys) ^ set(rcmod._context_keys) def test_set_rc(self): - rcmod.set(rc={"lines.linewidth": 4}) - nt.assert_equal(mpl.rcParams["lines.linewidth"], 4) - rcmod.set() + rcmod.set_theme(rc={"lines.linewidth": 4}) + assert mpl.rcParams["lines.linewidth"] == 4 + rcmod.set_theme() def test_set_with_palette(self): rcmod.reset_orig() - rcmod.set(palette="deep") + rcmod.set_theme(palette="deep") assert utils.get_color_cycle() == palettes.color_palette("deep", 10) rcmod.reset_orig() - rcmod.set(palette="deep", color_codes=False) + rcmod.set_theme(palette="deep", color_codes=False) assert utils.get_color_cycle() == palettes.color_palette("deep", 10) rcmod.reset_orig() pal = palettes.color_palette("deep") - rcmod.set(palette=pal) + rcmod.set_theme(palette=pal) assert utils.get_color_cycle() == palettes.color_palette("deep", 10) rcmod.reset_orig() - rcmod.set(palette=pal, color_codes=False) + rcmod.set_theme(palette=pal, color_codes=False) assert utils.get_color_cycle() == palettes.color_palette("deep", 10) rcmod.reset_orig() - rcmod.set() + rcmod.set_theme() def test_reset_defaults(self): - # Changes to the rc parameters make this test hard to manage - # on older versions of matplotlib, so we'll skip it - if LooseVersion(mpl.__version__) < LooseVersion("1.3"): - raise nose.SkipTest - rcmod.reset_defaults() self.assert_rc_params(mpl.rcParamsDefault) - rcmod.set() + rcmod.set_theme() def test_reset_orig(self): - # Changes to the rc parameters make this test hard to manage - # on older versions of matplotlib, so we'll skip it - if LooseVersion(mpl.__version__) < LooseVersion("1.3"): - raise nose.SkipTest - rcmod.reset_orig() self.assert_rc_params(mpl.rcParamsOrig) - rcmod.set() + rcmod.set_theme() + + def test_set_is_alias(self): + + rcmod.set_theme(context="paper", style="white") + params1 = mpl.rcParams.copy() + rcmod.reset_orig() + + rcmod.set_theme(context="paper", style="white") + params2 = mpl.rcParams.copy() + + self.assert_rc_params_equal(params1, params2) + + rcmod.set_theme() class TestPlottingContext(RCParamTester): @@ -147,11 +167,11 @@ def test_key_usage(self): _context_keys = set(rcmod._context_keys) for context in self.contexts: missing = set(rcmod.plotting_context(context)) ^ _context_keys - nt.assert_true(not missing) + assert not missing def test_bad_context(self): - with nt.assert_raises(ValueError): + with pytest.raises(ValueError): rcmod.plotting_context("i_am_not_a_context") def test_font_scale(self): @@ -159,19 +179,23 @@ def test_font_scale(self): notebook_ref = rcmod.plotting_context("notebook") notebook_big = rcmod.plotting_context("notebook", 2) - font_keys = ["axes.labelsize", "axes.titlesize", "legend.fontsize", - "xtick.labelsize", "ytick.labelsize", "font.size"] + font_keys = [ + "font.size", + "axes.labelsize", "axes.titlesize", + "xtick.labelsize", "ytick.labelsize", + "legend.fontsize", "legend.title_fontsize", + ] for k in font_keys: - nt.assert_equal(notebook_ref[k] * 2, notebook_big[k]) + assert notebook_ref[k] * 2 == notebook_big[k] def test_rc_override(self): key, val = "grid.linewidth", 5 rc = {key: val, "foo": "bar"} out = rcmod.plotting_context("talk", rc=rc) - nt.assert_equal(out[key], val) - nt.assert_not_in("foo", out) + assert out[key] == val + assert "foo" not in out def test_set_context(self): @@ -198,7 +222,7 @@ def func(): self.assert_rc_params(orig_params) -class TestPalette(object): +class TestPalette: def test_set_palette(self): @@ -215,79 +239,42 @@ def test_set_palette(self): assert utils.get_color_cycle() == palettes.color_palette("Set2", 8) -class TestFonts(object): +class TestFonts: + _no_verdana = not has_verdana() + + @pytest.mark.skipif(_no_verdana, reason="Verdana font is not present") def test_set_font(self): - rcmod.set(font="Verdana") + rcmod.set_theme(font="Verdana") _, ax = plt.subplots() ax.set_xlabel("foo") - try: - nt.assert_equal(ax.xaxis.label.get_fontname(), - "Verdana") - except AssertionError: - if has_verdana(): - raise - else: - raise nose.SkipTest("Verdana font is not present") - finally: - rcmod.set() + assert ax.xaxis.label.get_fontname() == "Verdana" + + rcmod.set_theme() def test_set_serif_font(self): - rcmod.set(font="serif") + rcmod.set_theme(font="serif") _, ax = plt.subplots() ax.set_xlabel("foo") - nt.assert_in(ax.xaxis.label.get_fontname(), - mpl.rcParams["font.serif"]) + assert ax.xaxis.label.get_fontname() in mpl.rcParams["font.serif"] - rcmod.set() + rcmod.set_theme() + @pytest.mark.skipif(_no_verdana, reason="Verdana font is not present") def test_different_sans_serif(self): - if LooseVersion(mpl.__version__) < LooseVersion("1.4"): - raise nose.SkipTest - - rcmod.set() + rcmod.set_theme() rcmod.set_style(rc={"font.sans-serif": ["Verdana"]}) _, ax = plt.subplots() ax.set_xlabel("foo") - try: - nt.assert_equal(ax.xaxis.label.get_fontname(), - "Verdana") - except AssertionError: - if has_verdana(): - raise - else: - raise nose.SkipTest("Verdana font is not present") - finally: - rcmod.set() - - -def has_verdana(): - """Helper to verify if Verdana font is present""" - # This import is relatively lengthy, so to prevent its import for - # testing other tests in this module not requiring this knowledge, - # import font_manager here - import matplotlib.font_manager as mplfm - try: - verdana_font = mplfm.findfont('Verdana', fallback_to_default=False) - except: # noqa - # if https://github.com/matplotlib/matplotlib/pull/3435 - # gets accepted - return False - # otherwise check if not matching the logic for a 'default' one - try: - unlikely_font = mplfm.findfont("very_unlikely_to_exist1234", - fallback_to_default=False) - except: # noqa - # if matched verdana but not unlikely, Verdana must exist - return True - # otherwise -- if they match, must be the same default - return verdana_font != unlikely_font + assert ax.xaxis.label.get_fontname() == "Verdana" + + rcmod.set_theme() diff --git a/seaborn/tests/test_regression.py b/seaborn/tests/test_regression.py index 0ac547dcd5..ee471cff68 100644 --- a/seaborn/tests/test_regression.py +++ b/seaborn/tests/test_regression.py @@ -4,13 +4,11 @@ import pandas as pd import pytest -import nose.tools as nt import numpy.testing as npt try: import pandas.testing as pdt except ImportError: import pandas.util.testing as pdt -from nose import SkipTest try: import statsmodels.regression.linear_model as smlm @@ -24,7 +22,7 @@ rs = np.random.RandomState(0) -class TestLinearPlotter(object): +class TestLinearPlotter: rs = np.random.RandomState(77) df = pd.DataFrame(dict(x=rs.normal(size=60), @@ -49,7 +47,7 @@ def test_establish_variables_from_series(self): p.establish_variables(None, x=self.df.x, y=self.df.y) pdt.assert_series_equal(p.x, self.df.x) pdt.assert_series_equal(p.y, self.df.y) - nt.assert_is(p.data, None) + assert p.data is None def test_establish_variables_from_array(self): @@ -59,7 +57,7 @@ def test_establish_variables_from_array(self): y=self.df.y.values) npt.assert_array_equal(p.x, self.df.x) npt.assert_array_equal(p.y, self.df.y) - nt.assert_is(p.data, None) + assert p.data is None def test_establish_variables_from_lists(self): @@ -69,7 +67,7 @@ def test_establish_variables_from_lists(self): y=self.df.y.values.tolist()) npt.assert_array_equal(p.x, self.df.x) npt.assert_array_equal(p.y, self.df.y) - nt.assert_is(p.data, None) + assert p.data is None def test_establish_variables_from_mix(self): @@ -82,7 +80,7 @@ def test_establish_variables_from_mix(self): def test_establish_variables_from_bad(self): p = lm._LinearPlotter() - with nt.assert_raises(ValueError): + with pytest.raises(ValueError): p.establish_variables(None, x="x", y=self.df.y) def test_dropna(self): @@ -98,7 +96,7 @@ def test_dropna(self): pdt.assert_series_equal(p.y_na, self.df.y_na[mask]) -class TestRegressionPlotter(object): +class TestRegressionPlotter: rs = np.random.RandomState(49) @@ -137,7 +135,7 @@ def test_variables_from_series(self): npt.assert_array_equal(p.x, self.df.x) npt.assert_array_equal(p.y, self.df.y) npt.assert_array_equal(p.units, self.df.s) - nt.assert_is(p.data, None) + assert p.data is None def test_variables_from_mix(self): @@ -147,27 +145,44 @@ def test_variables_from_mix(self): npt.assert_array_equal(p.y, self.df.y + 1) pdt.assert_frame_equal(p.data, self.df) + def test_variables_must_be_1d(self): + + array_2d = np.random.randn(20, 2) + array_1d = np.random.randn(20) + with pytest.raises(ValueError): + lm._RegressionPlotter(array_2d, array_1d) + with pytest.raises(ValueError): + lm._RegressionPlotter(array_1d, array_2d) + def test_dropna(self): p = lm._RegressionPlotter("x", "y_na", data=self.df) - nt.assert_equal(len(p.x), pd.notnull(self.df.y_na).sum()) + assert len(p.x) == pd.notnull(self.df.y_na).sum() p = lm._RegressionPlotter("x", "y_na", data=self.df, dropna=False) - nt.assert_equal(len(p.x), len(self.df.y_na)) + assert len(p.x) == len(self.df.y_na) + + @pytest.mark.parametrize("x,y", + [([1.5], [2]), + (np.array([1.5]), np.array([2])), + (pd.Series(1.5), pd.Series(2))]) + def test_singleton(self, x, y): + p = lm._RegressionPlotter(x, y) + assert not p.fit_reg def test_ci(self): p = lm._RegressionPlotter("x", "y", data=self.df, ci=95) - nt.assert_equal(p.ci, 95) - nt.assert_equal(p.x_ci, 95) + assert p.ci == 95 + assert p.x_ci == 95 p = lm._RegressionPlotter("x", "y", data=self.df, ci=95, x_ci=68) - nt.assert_equal(p.ci, 95) - nt.assert_equal(p.x_ci, 68) + assert p.ci == 95 + assert p.x_ci == 68 p = lm._RegressionPlotter("x", "y", data=self.df, ci=95, x_ci="sd") - nt.assert_equal(p.ci, 95) - nt.assert_equal(p.x_ci, "sd") + assert p.ci == 95 + assert p.x_ci == "sd" @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels") def test_fast_regression(self): @@ -207,9 +222,9 @@ def test_regress_logx(self): yhat_lin, _ = p.fit_fast(grid) yhat_log, _ = p.fit_logx(grid) - nt.assert_greater(yhat_lin[0], yhat_log[0]) - nt.assert_greater(yhat_log[20], yhat_lin[20]) - nt.assert_greater(yhat_lin[90], yhat_log[90]) + assert yhat_lin[0] > yhat_log[0] + assert yhat_log[20] > yhat_lin[20] + assert yhat_lin[90] > yhat_log[90] @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels") def test_regress_n_boot(self): @@ -236,15 +251,27 @@ def test_regress_without_bootstrap(self): # Fast (linear algebra) version _, boots_fast = p.fit_fast(self.grid) - nt.assert_is(boots_fast, None) + assert boots_fast is None # Slower (np.polyfit) version _, boots_poly = p.fit_poly(self.grid, 1) - nt.assert_is(boots_poly, None) + assert boots_poly is None # Slowest (statsmodels) version _, boots_smod = p.fit_statsmodels(self.grid, smlm.OLS) - nt.assert_is(boots_smod, None) + assert boots_smod is None + + def test_regress_bootstrap_seed(self): + + seed = 200 + p1 = lm._RegressionPlotter("x", "y", data=self.df, + n_boot=self.n_boot, seed=seed) + p2 = lm._RegressionPlotter("x", "y", data=self.df, + n_boot=self.n_boot, seed=seed) + + _, boots1 = p1.fit_fast(self.grid) + _, boots2 = p2.fit_fast(self.grid) + npt.assert_array_equal(boots1, boots2) def test_numeric_bins(self): @@ -263,10 +290,8 @@ def test_bin_results(self): p = lm._RegressionPlotter(self.df.x, self.df.y) x_binned, bins = p.bin_predictor(self.bins_given) - nt.assert_greater(self.df.x[x_binned == 0].min(), - self.df.x[x_binned == -1].max()) - nt.assert_greater(self.df.x[x_binned == 1].min(), - self.df.x[x_binned == 0].max()) + assert self.df.x[x_binned == 0].min() > self.df.x[x_binned == -1].max() + assert self.df.x[x_binned == 1].min() > self.df.x[x_binned == 0].max() def test_scatter_data(self): @@ -282,7 +307,7 @@ def test_scatter_data(self): p = lm._RegressionPlotter(self.df.d, self.df.y, x_jitter=.1) x, y = p.scatter_data - nt.assert_true((x != self.df.d).any()) + assert (x != self.df.d).any() npt.assert_array_less(np.abs(self.df.d - x), np.repeat(.1, len(x))) npt.assert_array_equal(y, self.df.y) @@ -304,15 +329,14 @@ def test_estimate_data(self): def test_estimate_cis(self): - # set known good seed to avoid the test stochastically failing - np.random.seed(123) + seed = 123 p = lm._RegressionPlotter(self.df.d, self.df.y, - x_estimator=np.mean, ci=95) + x_estimator=np.mean, ci=95, seed=seed) _, _, ci_big = p.estimate_data p = lm._RegressionPlotter(self.df.d, self.df.y, - x_estimator=np.mean, ci=50) + x_estimator=np.mean, ci=50, seed=seed) _, _, ci_wee = p.estimate_data npt.assert_array_less(np.diff(ci_wee), np.diff(ci_big)) @@ -324,14 +348,14 @@ def test_estimate_cis(self): def test_estimate_units(self): # Seed the RNG locally - np.random.seed(345) + seed = 345 p = lm._RegressionPlotter("x", "y", data=self.df, - units="s", x_bins=3) + units="s", seed=seed, x_bins=3) _, _, ci_big = p.estimate_data ci_big = np.diff(ci_big, axis=1) - p = lm._RegressionPlotter("x", "y", data=self.df, x_bins=3) + p = lm._RegressionPlotter("x", "y", data=self.df, seed=seed, x_bins=3) _, _, ci_wee = p.estimate_data ci_wee = np.diff(ci_wee, axis=1) @@ -348,11 +372,17 @@ def test_partial(self): p = lm._RegressionPlotter(y, z, y_partial=x) _, r_semipartial = np.corrcoef(p.x, p.y)[0] - nt.assert_less(r_semipartial, r_orig) + assert r_semipartial < r_orig p = lm._RegressionPlotter(y, z, x_partial=x, y_partial=x) _, r_partial = np.corrcoef(p.x, p.y)[0] - nt.assert_less(r_partial, r_orig) + assert r_partial < r_orig + + x = pd.Series(x) + y = pd.Series(y) + p = lm._RegressionPlotter(y, z, x_partial=x, y_partial=x) + _, r_partial = np.corrcoef(p.x, p.y)[0] + assert r_partial < r_orig @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels") def test_logistic_regression(self): @@ -371,7 +401,7 @@ def test_logistic_perfect_separation(self): logistic=True, n_boot=10) with np.errstate(all="ignore"): _, yhat, _ = p.fit_regression(x_range=(-3, 3)) - nt.assert_true(np.isnan(yhat).all()) + assert np.isnan(yhat).all() @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels") def test_robust_regression(self): @@ -384,7 +414,7 @@ def test_robust_regression(self): robust=True, n_boot=self.n_boot) _, robust_yhat, _ = p_robust.fit_regression(x_range=(-3, 3)) - nt.assert_equal(len(ols_yhat), len(robust_yhat)) + assert len(ols_yhat) == len(robust_yhat) @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels") def test_lowess_regression(self): @@ -392,16 +422,16 @@ def test_lowess_regression(self): p = lm._RegressionPlotter("x", "y", data=self.df, lowess=True) grid, yhat, err_bands = p.fit_regression(x_range=(-3, 3)) - nt.assert_equal(len(grid), len(yhat)) - nt.assert_is(err_bands, None) + assert len(grid) == len(yhat) + assert err_bands is None def test_regression_options(self): - with nt.assert_raises(ValueError): + with pytest.raises(ValueError): lm._RegressionPlotter("x", "y", data=self.df, lowess=True, order=2) - with nt.assert_raises(ValueError): + with pytest.raises(ValueError): lm._RegressionPlotter("x", "y", data=self.df, lowess=True, logistic=True) @@ -412,16 +442,16 @@ def test_regression_limits(self): p = lm._RegressionPlotter("x", "y", data=self.df) grid, _, _ = p.fit_regression(ax) xlim = ax.get_xlim() - nt.assert_equal(grid.min(), xlim[0]) - nt.assert_equal(grid.max(), xlim[1]) + assert grid.min() == xlim[0] + assert grid.max() == xlim[1] p = lm._RegressionPlotter("x", "y", data=self.df, truncate=True) grid, _, _ = p.fit_regression() - nt.assert_equal(grid.min(), self.df.x.min()) - nt.assert_equal(grid.max(), self.df.x.max()) + assert grid.min() == self.df.x.min() + assert grid.max() == self.df.x.max() -class TestRegressionPlots(object): +class TestRegressionPlots: rs = np.random.RandomState(56) df = pd.DataFrame(dict(x=rs.randn(90), @@ -436,9 +466,9 @@ class TestRegressionPlots(object): def test_regplot_basic(self): f, ax = plt.subplots() - lm.regplot("x", "y", self.df) - nt.assert_equal(len(ax.lines), 1) - nt.assert_equal(len(ax.collections), 2) + lm.regplot(x="x", y="y", data=self.df) + assert len(ax.lines) == 1 + assert len(ax.collections) == 2 x, y = ax.collections[0].get_offsets().T npt.assert_array_equal(x, self.df.x) @@ -447,59 +477,68 @@ def test_regplot_basic(self): def test_regplot_selective(self): f, ax = plt.subplots() - ax = lm.regplot("x", "y", self.df, scatter=False, ax=ax) - nt.assert_equal(len(ax.lines), 1) - nt.assert_equal(len(ax.collections), 1) + ax = lm.regplot(x="x", y="y", data=self.df, scatter=False, ax=ax) + assert len(ax.lines) == 1 + assert len(ax.collections) == 1 ax.clear() f, ax = plt.subplots() - ax = lm.regplot("x", "y", self.df, fit_reg=False) - nt.assert_equal(len(ax.lines), 0) - nt.assert_equal(len(ax.collections), 1) + ax = lm.regplot(x="x", y="y", data=self.df, fit_reg=False) + assert len(ax.lines) == 0 + assert len(ax.collections) == 1 ax.clear() f, ax = plt.subplots() - ax = lm.regplot("x", "y", self.df, ci=None) - nt.assert_equal(len(ax.lines), 1) - nt.assert_equal(len(ax.collections), 1) + ax = lm.regplot(x="x", y="y", data=self.df, ci=None) + assert len(ax.lines) == 1 + assert len(ax.collections) == 1 ax.clear() def test_regplot_scatter_kws_alpha(self): f, ax = plt.subplots() color = np.array([[0.3, 0.8, 0.5, 0.5]]) - ax = lm.regplot("x", "y", self.df, scatter_kws={'color': color}) - nt.assert_is(ax.collections[0]._alpha, None) - nt.assert_equal(ax.collections[0]._facecolors[0, 3], 0.5) + ax = lm.regplot(x="x", y="y", data=self.df, + scatter_kws={'color': color}) + assert ax.collections[0]._alpha is None + assert ax.collections[0]._facecolors[0, 3] == 0.5 f, ax = plt.subplots() color = np.array([[0.3, 0.8, 0.5]]) - ax = lm.regplot("x", "y", self.df, scatter_kws={'color': color}) - nt.assert_equal(ax.collections[0]._alpha, 0.8) + ax = lm.regplot(x="x", y="y", data=self.df, + scatter_kws={'color': color}) + assert ax.collections[0]._alpha == 0.8 f, ax = plt.subplots() color = np.array([[0.3, 0.8, 0.5]]) - ax = lm.regplot("x", "y", self.df, scatter_kws={'color': color, - 'alpha': 0.4}) - nt.assert_equal(ax.collections[0]._alpha, 0.4) + ax = lm.regplot(x="x", y="y", data=self.df, + scatter_kws={'color': color, 'alpha': 0.4}) + assert ax.collections[0]._alpha == 0.4 f, ax = plt.subplots() color = 'r' - ax = lm.regplot("x", "y", self.df, scatter_kws={'color': color}) - nt.assert_equal(ax.collections[0]._alpha, 0.8) + ax = lm.regplot(x="x", y="y", data=self.df, + scatter_kws={'color': color}) + assert ax.collections[0]._alpha == 0.8 def test_regplot_binned(self): - ax = lm.regplot("x", "y", self.df, x_bins=5) - nt.assert_equal(len(ax.lines), 6) - nt.assert_equal(len(ax.collections), 2) + ax = lm.regplot(x="x", y="y", data=self.df, x_bins=5) + assert len(ax.lines) == 6 + assert len(ax.collections) == 2 + + def test_lmplot_no_data(self): + + with pytest.raises(TypeError): + # keyword argument `data` is required + lm.lmplot(x="x", y="y") def test_lmplot_basic(self): - g = lm.lmplot("x", "y", self.df) + g = lm.lmplot(x="x", y="y", data=self.df) ax = g.axes[0, 0] - nt.assert_equal(len(ax.lines), 1) - nt.assert_equal(len(ax.collections), 2) + assert len(ax.lines) == 1 + assert len(ax.collections) == 2 x, y = ax.collections[0].get_offsets().T npt.assert_array_equal(x, self.df.x) @@ -507,53 +546,50 @@ def test_lmplot_basic(self): def test_lmplot_hue(self): - g = lm.lmplot("x", "y", data=self.df, hue="h") + g = lm.lmplot(x="x", y="y", data=self.df, hue="h") ax = g.axes[0, 0] - nt.assert_equal(len(ax.lines), 2) - nt.assert_equal(len(ax.collections), 4) + assert len(ax.lines) == 2 + assert len(ax.collections) == 4 def test_lmplot_markers(self): - g1 = lm.lmplot("x", "y", data=self.df, hue="h", markers="s") - nt.assert_equal(g1.hue_kws, {"marker": ["s", "s"]}) + g1 = lm.lmplot(x="x", y="y", data=self.df, hue="h", markers="s") + assert g1.hue_kws == {"marker": ["s", "s"]} - g2 = lm.lmplot("x", "y", data=self.df, hue="h", markers=["o", "s"]) - nt.assert_equal(g2.hue_kws, {"marker": ["o", "s"]}) + g2 = lm.lmplot(x="x", y="y", data=self.df, hue="h", markers=["o", "s"]) + assert g2.hue_kws == {"marker": ["o", "s"]} - with nt.assert_raises(ValueError): - lm.lmplot("x", "y", data=self.df, hue="h", markers=["o", "s", "d"]) + with pytest.raises(ValueError): + lm.lmplot(x="x", y="y", data=self.df, hue="h", + markers=["o", "s", "d"]) def test_lmplot_marker_linewidths(self): - if mpl.__version__ == "1.4.2": - raise SkipTest - - g = lm.lmplot("x", "y", data=self.df, hue="h", + g = lm.lmplot(x="x", y="y", data=self.df, hue="h", fit_reg=False, markers=["o", "+"]) c = g.axes[0, 0].collections - nt.assert_equal(c[1].get_linewidths()[0], - mpl.rcParams["lines.linewidth"]) + assert c[1].get_linewidths()[0] == mpl.rcParams["lines.linewidth"] def test_lmplot_facets(self): - g = lm.lmplot("x", "y", data=self.df, row="g", col="h") - nt.assert_equal(g.axes.shape, (3, 2)) + g = lm.lmplot(x="x", y="y", data=self.df, row="g", col="h") + assert g.axes.shape == (3, 2) - g = lm.lmplot("x", "y", data=self.df, col="u", col_wrap=4) - nt.assert_equal(g.axes.shape, (6,)) + g = lm.lmplot(x="x", y="y", data=self.df, col="u", col_wrap=4) + assert g.axes.shape == (6,) - g = lm.lmplot("x", "y", data=self.df, hue="h", col="u") - nt.assert_equal(g.axes.shape, (1, 6)) + g = lm.lmplot(x="x", y="y", data=self.df, hue="h", col="u") + assert g.axes.shape == (1, 6) def test_lmplot_hue_col_nolegend(self): - g = lm.lmplot("x", "y", data=self.df, col="h", hue="h") - nt.assert_is(g._legend, None) + g = lm.lmplot(x="x", y="y", data=self.df, col="h", hue="h") + assert g._legend is None def test_lmplot_scatter_kws(self): - g = lm.lmplot("x", "y", hue="h", data=self.df, ci=None) + g = lm.lmplot(x="x", y="y", hue="h", data=self.df, ci=None) red_scatter, blue_scatter = g.axes[0, 0].collections red, blue = color_palette(n_colors=2) @@ -563,7 +599,7 @@ def test_lmplot_scatter_kws(self): def test_residplot(self): x, y = self.df.x, self.df.y - ax = lm.residplot(x, y) + ax = lm.residplot(x=x, y=y) resid = y - np.polyval(np.polyfit(x, y, 1), x) x_plot, y_plot = ax.collections[0].get_offsets().T @@ -574,8 +610,8 @@ def test_residplot(self): @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels") def test_residplot_lowess(self): - ax = lm.residplot("x", "y", self.df, lowess=True) - nt.assert_equal(len(ax.lines), 2) + ax = lm.residplot(x="x", y="y", data=self.df, lowess=True) + assert len(ax.lines) == 2 x, y = ax.lines[1].get_xydata().T npt.assert_array_equal(x, np.sort(self.df.x)) @@ -583,7 +619,16 @@ def test_residplot_lowess(self): def test_three_point_colors(self): x, y = np.random.randn(2, 3) - ax = lm.regplot(x, y, color=(1, 0, 0)) + ax = lm.regplot(x=x, y=y, color=(1, 0, 0)) color = ax.collections[0].get_facecolors() npt.assert_almost_equal(color[0, :3], (1, 0, 0)) + + def test_regplot_xlim(self): + + f, ax = plt.subplots() + x, y1, y2 = np.random.randn(3, 50) + lm.regplot(x=x, y=y1, truncate=False) + lm.regplot(x=x, y=y2, truncate=False) + line1, line2 = ax.lines + assert np.array_equal(line1.get_xdata(), line2.get_xdata()) diff --git a/seaborn/tests/test_relational.py b/seaborn/tests/test_relational.py index 874c1d901d..5fb0829e35 100644 --- a/seaborn/tests/test_relational.py +++ b/seaborn/tests/test_relational.py @@ -1,16 +1,49 @@ -from __future__ import division +from distutils.version import LooseVersion from itertools import product import numpy as np -import pandas as pd import matplotlib as mpl import matplotlib.pyplot as plt +from matplotlib.colors import same_color, to_rgba + import pytest -from .. import relational as rel -from ..palettes import color_palette -from ..utils import categorical_order, sort_df +from numpy.testing import assert_array_equal +from ..palettes import color_palette -class TestRelationalPlotter(object): +from ..relational import ( + _RelationalPlotter, + _LinePlotter, + _ScatterPlotter, + relplot, + lineplot, + scatterplot +) + +from ..utils import _draw_figure +from .._testing import assert_plots_equal + + +@pytest.fixture(params=[ + dict(x="x", y="y"), + dict(x="t", y="y"), + dict(x="a", y="y"), + dict(x="x", y="y", hue="y"), + dict(x="x", y="y", hue="a"), + dict(x="x", y="y", size="a"), + dict(x="x", y="y", style="a"), + dict(x="x", y="y", hue="s"), + dict(x="x", y="y", size="s"), + dict(x="x", y="y", style="s"), + dict(x="x", y="y", hue="a", style="a"), + dict(x="x", y="y", hue="a", size="b", style="b"), +]) +def long_semantics(request): + return request.param + + +class Helpers: + + # TODO Better place for these? def scatter_rgbs(self, collections): rgbs = [] @@ -19,989 +52,724 @@ def scatter_rgbs(self, collections): rgbs.append(rgb) return rgbs - def colors_equal(self, *args): - - equal = True - for c1, c2 in zip(*args): - c1 = mpl.colors.colorConverter.to_rgb(np.squeeze(c1)) - c2 = mpl.colors.colorConverter.to_rgb(np.squeeze(c1)) - equal &= c1 == c2 - return equal - def paths_equal(self, *args): - equal = True + equal = all([len(a) == len(args[0]) for a in args]) + for p1, p2 in zip(*args): equal &= np.array_equal(p1.vertices, p2.vertices) equal &= np.array_equal(p1.codes, p2.codes) return equal - @pytest.fixture - def wide_df(self): - - columns = list("abc") - index = pd.Int64Index(np.arange(10, 50, 2), name="wide_index") - values = np.random.randn(len(index), len(columns)) - return pd.DataFrame(values, index=index, columns=columns) - - @pytest.fixture - def wide_array(self): - return np.random.randn(20, 3) +class SharedAxesLevelTests: - @pytest.fixture - def flat_array(self): + def test_color(self, long_df): - return np.random.randn(20) + ax = plt.figure().subplots() + self.func(data=long_df, x="x", y="y", ax=ax) + assert self.get_last_color(ax) == to_rgba("C0") - @pytest.fixture - def flat_series(self): + ax = plt.figure().subplots() + self.func(data=long_df, x="x", y="y", ax=ax) + self.func(data=long_df, x="x", y="y", ax=ax) + assert self.get_last_color(ax) == to_rgba("C1") - index = pd.Int64Index(np.arange(10, 30), name="t") - return pd.Series(np.random.randn(20), index, name="s") + ax = plt.figure().subplots() + self.func(data=long_df, x="x", y="y", color="C2", ax=ax) + assert self.get_last_color(ax) == to_rgba("C2") - @pytest.fixture - def wide_list(self): + ax = plt.figure().subplots() + self.func(data=long_df, x="x", y="y", c="C2", ax=ax) + assert self.get_last_color(ax) == to_rgba("C2") - return [np.random.randn(20), np.random.randn(10)] - @pytest.fixture - def wide_list_of_series(self): +class TestRelationalPlotter(Helpers): - return [pd.Series(np.random.randn(20), np.arange(20), name="a"), - pd.Series(np.random.randn(10), np.arange(5, 15), name="b")] + def test_wide_df_variables(self, wide_df): - @pytest.fixture - def long_df(self): + p = _RelationalPlotter() + p.assign_variables(data=wide_df) + assert p.input_format == "wide" + assert list(p.variables) == ["x", "y", "hue", "style"] + assert len(p.plot_data) == np.product(wide_df.shape) - n = 100 - rs = np.random.RandomState() - df = pd.DataFrame(dict( - x=rs.randint(0, 20, n), - y=rs.randn(n), - a=np.take(list("abc"), rs.randint(0, 3, n)), - b=np.take(list("mnop"), rs.randint(0, 4, n)), - c=np.take(list([0, 1]), rs.randint(0, 2, n)), - s=np.take([2, 4, 8], rs.randint(0, 3, n)), - )) - df["s_cat"] = df["s"].astype("category") - return df + x = p.plot_data["x"] + expected_x = np.tile(wide_df.index, wide_df.shape[1]) + assert_array_equal(x, expected_x) - @pytest.fixture - def repeated_df(self): + y = p.plot_data["y"] + expected_y = wide_df.to_numpy().ravel(order="f") + assert_array_equal(y, expected_y) - n = 100 - rs = np.random.RandomState() - return pd.DataFrame(dict( - x=np.tile(np.arange(n // 2), 2), - y=rs.randn(n), - a=np.take(list("abc"), rs.randint(0, 3, n)), - u=np.repeat(np.arange(2), n // 2), - )) + hue = p.plot_data["hue"] + expected_hue = np.repeat(wide_df.columns.to_numpy(), wide_df.shape[0]) + assert_array_equal(hue, expected_hue) - @pytest.fixture - def missing_df(self): + style = p.plot_data["style"] + expected_style = expected_hue + assert_array_equal(style, expected_style) - n = 100 - rs = np.random.RandomState() - df = pd.DataFrame(dict( - x=rs.randint(0, 20, n), - y=rs.randn(n), - a=np.take(list("abc"), rs.randint(0, 3, n)), - b=np.take(list("mnop"), rs.randint(0, 4, n)), - s=np.take([2, 4, 8], rs.randint(0, 3, n)), - )) - for col in df: - idx = rs.permutation(df.index)[:10] - df.loc[idx, col] = np.nan - return df + assert p.variables["x"] == wide_df.index.name + assert p.variables["y"] is None + assert p.variables["hue"] == wide_df.columns.name + assert p.variables["style"] == wide_df.columns.name - @pytest.fixture - def null_column(self): + def test_wide_df_with_nonnumeric_variables(self, long_df): - return pd.Series(index=np.arange(20)) + p = _RelationalPlotter() + p.assign_variables(data=long_df) + assert p.input_format == "wide" + assert list(p.variables) == ["x", "y", "hue", "style"] - def test_wide_df_variables(self, wide_df): + numeric_df = long_df.select_dtypes("number") - p = rel._RelationalPlotter() - p.establish_variables(data=wide_df) - assert p.input_format == "wide" - assert p.semantics == ["x", "y", "hue", "style"] - assert len(p.plot_data) == np.product(wide_df.shape) + assert len(p.plot_data) == np.product(numeric_df.shape) x = p.plot_data["x"] - expected_x = np.tile(wide_df.index, wide_df.shape[1]) - assert np.array_equal(x, expected_x) + expected_x = np.tile(numeric_df.index, numeric_df.shape[1]) + assert_array_equal(x, expected_x) y = p.plot_data["y"] - expected_y = wide_df.values.ravel(order="f") - assert np.array_equal(y, expected_y) + expected_y = numeric_df.to_numpy().ravel(order="f") + assert_array_equal(y, expected_y) hue = p.plot_data["hue"] - expected_hue = np.repeat(wide_df.columns.values, wide_df.shape[0]) - assert np.array_equal(hue, expected_hue) + expected_hue = np.repeat( + numeric_df.columns.to_numpy(), numeric_df.shape[0] + ) + assert_array_equal(hue, expected_hue) style = p.plot_data["style"] expected_style = expected_hue - assert np.array_equal(style, expected_style) - - assert p.plot_data["size"].isnull().all() - - assert p.x_label == wide_df.index.name - assert p.y_label is None - assert p.hue_label == wide_df.columns.name - assert p.size_label is None - assert p.style_label == wide_df.columns.name + assert_array_equal(style, expected_style) - def test_wide_df_variables_check(self, wide_df): - - p = rel._RelationalPlotter() - wide_df = wide_df.copy() - wide_df.loc[:, "not_numeric"] = "a" - with pytest.raises(ValueError): - p.establish_variables(data=wide_df) + assert p.variables["x"] == numeric_df.index.name + assert p.variables["y"] is None + assert p.variables["hue"] == numeric_df.columns.name + assert p.variables["style"] == numeric_df.columns.name def test_wide_array_variables(self, wide_array): - p = rel._RelationalPlotter() - p.establish_variables(data=wide_array) + p = _RelationalPlotter() + p.assign_variables(data=wide_array) assert p.input_format == "wide" - assert p.semantics == ["x", "y", "hue", "style"] + assert list(p.variables) == ["x", "y", "hue", "style"] assert len(p.plot_data) == np.product(wide_array.shape) nrow, ncol = wide_array.shape x = p.plot_data["x"] expected_x = np.tile(np.arange(nrow), ncol) - assert np.array_equal(x, expected_x) + assert_array_equal(x, expected_x) y = p.plot_data["y"] expected_y = wide_array.ravel(order="f") - assert np.array_equal(y, expected_y) + assert_array_equal(y, expected_y) hue = p.plot_data["hue"] expected_hue = np.repeat(np.arange(ncol), nrow) - assert np.array_equal(hue, expected_hue) + assert_array_equal(hue, expected_hue) style = p.plot_data["style"] expected_style = expected_hue - assert np.array_equal(style, expected_style) - - assert p.plot_data["size"].isnull().all() + assert_array_equal(style, expected_style) - assert p.x_label is None - assert p.y_label is None - assert p.hue_label is None - assert p.size_label is None - assert p.style_label is None + assert p.variables["x"] is None + assert p.variables["y"] is None + assert p.variables["hue"] is None + assert p.variables["style"] is None def test_flat_array_variables(self, flat_array): - p = rel._RelationalPlotter() - p.establish_variables(data=flat_array) + p = _RelationalPlotter() + p.assign_variables(data=flat_array) assert p.input_format == "wide" - assert p.semantics == ["x", "y"] + assert list(p.variables) == ["x", "y"] assert len(p.plot_data) == np.product(flat_array.shape) x = p.plot_data["x"] expected_x = np.arange(flat_array.shape[0]) - assert np.array_equal(x, expected_x) + assert_array_equal(x, expected_x) y = p.plot_data["y"] expected_y = flat_array - assert np.array_equal(y, expected_y) + assert_array_equal(y, expected_y) + + assert p.variables["x"] is None + assert p.variables["y"] is None + + def test_flat_list_variables(self, flat_list): + + p = _RelationalPlotter() + p.assign_variables(data=flat_list) + assert p.input_format == "wide" + assert list(p.variables) == ["x", "y"] + assert len(p.plot_data) == len(flat_list) + + x = p.plot_data["x"] + expected_x = np.arange(len(flat_list)) + assert_array_equal(x, expected_x) - assert p.plot_data["hue"].isnull().all() - assert p.plot_data["style"].isnull().all() - assert p.plot_data["size"].isnull().all() + y = p.plot_data["y"] + expected_y = flat_list + assert_array_equal(y, expected_y) - assert p.x_label is None - assert p.y_label is None - assert p.hue_label is None - assert p.size_label is None - assert p.style_label is None + assert p.variables["x"] is None + assert p.variables["y"] is None def test_flat_series_variables(self, flat_series): - p = rel._RelationalPlotter() - p.establish_variables(data=flat_series) + p = _RelationalPlotter() + p.assign_variables(data=flat_series) assert p.input_format == "wide" - assert p.semantics == ["x", "y"] + assert list(p.variables) == ["x", "y"] assert len(p.plot_data) == len(flat_series) x = p.plot_data["x"] expected_x = flat_series.index - assert np.array_equal(x, expected_x) + assert_array_equal(x, expected_x) y = p.plot_data["y"] expected_y = flat_series - assert np.array_equal(y, expected_y) + assert_array_equal(y, expected_y) - assert p.x_label is None - assert p.y_label is None - assert p.hue_label is None - assert p.size_label is None - assert p.style_label is None + assert p.variables["x"] is flat_series.index.name + assert p.variables["y"] is flat_series.name - def test_wide_list_variables(self, wide_list): + def test_wide_list_of_series_variables(self, wide_list_of_series): - p = rel._RelationalPlotter() - p.establish_variables(data=wide_list) + p = _RelationalPlotter() + p.assign_variables(data=wide_list_of_series) assert p.input_format == "wide" - assert p.semantics == ["x", "y", "hue", "style"] - assert len(p.plot_data) == sum(len(l) for l in wide_list) + assert list(p.variables) == ["x", "y", "hue", "style"] + + chunks = len(wide_list_of_series) + chunk_size = max(len(l) for l in wide_list_of_series) + + assert len(p.plot_data) == chunks * chunk_size + + index_union = np.unique( + np.concatenate([s.index for s in wide_list_of_series]) + ) x = p.plot_data["x"] - expected_x = np.concatenate([np.arange(len(l)) for l in wide_list]) - assert np.array_equal(x, expected_x) + expected_x = np.tile(index_union, chunks) + assert_array_equal(x, expected_x) y = p.plot_data["y"] - expected_y = np.concatenate(wide_list) - assert np.array_equal(y, expected_y) + expected_y = np.concatenate([ + s.reindex(index_union) for s in wide_list_of_series + ]) + assert_array_equal(y, expected_y) hue = p.plot_data["hue"] - expected_hue = np.concatenate([ - np.ones_like(l) * i for i, l in enumerate(wide_list) - ]) - assert np.array_equal(hue, expected_hue) + series_names = [s.name for s in wide_list_of_series] + expected_hue = np.repeat(series_names, chunk_size) + assert_array_equal(hue, expected_hue) style = p.plot_data["style"] expected_style = expected_hue - assert np.array_equal(style, expected_style) + assert_array_equal(style, expected_style) - assert p.plot_data["size"].isnull().all() + assert p.variables["x"] is None + assert p.variables["y"] is None + assert p.variables["hue"] is None + assert p.variables["style"] is None - assert p.x_label is None - assert p.y_label is None - assert p.hue_label is None - assert p.size_label is None - assert p.style_label is None + def test_wide_list_of_arrays_variables(self, wide_list_of_arrays): - def test_wide_list_of_series_variables(self, wide_list_of_series): - - p = rel._RelationalPlotter() - p.establish_variables(data=wide_list_of_series) + p = _RelationalPlotter() + p.assign_variables(data=wide_list_of_arrays) assert p.input_format == "wide" - assert p.semantics == ["x", "y", "hue", "style"] - assert len(p.plot_data) == sum(len(l) for l in wide_list_of_series) + assert list(p.variables) == ["x", "y", "hue", "style"] + + chunks = len(wide_list_of_arrays) + chunk_size = max(len(l) for l in wide_list_of_arrays) + + assert len(p.plot_data) == chunks * chunk_size x = p.plot_data["x"] - expected_x = np.concatenate([s.index for s in wide_list_of_series]) - assert np.array_equal(x, expected_x) + expected_x = np.tile(np.arange(chunk_size), chunks) + assert_array_equal(x, expected_x) - y = p.plot_data["y"] - expected_y = np.concatenate(wide_list_of_series) - assert np.array_equal(y, expected_y) + y = p.plot_data["y"].dropna() + expected_y = np.concatenate(wide_list_of_arrays) + assert_array_equal(y, expected_y) hue = p.plot_data["hue"] - expected_hue = np.concatenate([ - np.full(len(s), s.name, object) for s in wide_list_of_series - ]) - assert np.array_equal(hue, expected_hue) + expected_hue = np.repeat(np.arange(chunks), chunk_size) + assert_array_equal(hue, expected_hue) style = p.plot_data["style"] expected_style = expected_hue - assert np.array_equal(style, expected_style) - - assert p.plot_data["size"].isnull().all() - - assert p.x_label is None - assert p.y_label is None - assert p.hue_label is None - assert p.size_label is None - assert p.style_label is None - - def test_long_df(self, long_df): - - p = rel._RelationalPlotter() - p.establish_variables(x="x", y="y", data=long_df) - assert p.input_format == "long" - assert p.semantics == ["x", "y"] - - assert np.array_equal(p.plot_data["x"], long_df["x"]) - assert np.array_equal(p.plot_data["y"], long_df["y"]) - for col in ["hue", "style", "size"]: - assert p.plot_data[col].isnull().all() - assert (p.x_label, p.y_label) == ("x", "y") - assert p.hue_label is None - assert p.size_label is None - assert p.style_label is None - - p.establish_variables(x=long_df.x, y="y", data=long_df) - assert p.semantics == ["x", "y"] - assert np.array_equal(p.plot_data["x"], long_df["x"]) - assert np.array_equal(p.plot_data["y"], long_df["y"]) - assert (p.x_label, p.y_label) == ("x", "y") - - p.establish_variables(x="x", y=long_df.y, data=long_df) - assert p.semantics == ["x", "y"] - assert np.array_equal(p.plot_data["x"], long_df["x"]) - assert np.array_equal(p.plot_data["y"], long_df["y"]) - assert (p.x_label, p.y_label) == ("x", "y") - - p.establish_variables(x="x", y="y", hue="a", data=long_df) - assert p.semantics == ["x", "y", "hue"] - assert np.array_equal(p.plot_data["hue"], long_df["a"]) - for col in ["style", "size"]: - assert p.plot_data[col].isnull().all() - assert p.hue_label == "a" - assert p.size_label is None and p.style_label is None - - p.establish_variables(x="x", y="y", hue="a", style="a", data=long_df) - assert p.semantics == ["x", "y", "hue", "style"] - assert np.array_equal(p.plot_data["hue"], long_df["a"]) - assert np.array_equal(p.plot_data["style"], long_df["a"]) - assert p.plot_data["size"].isnull().all() - assert p.hue_label == p.style_label == "a" - assert p.size_label is None - - p.establish_variables(x="x", y="y", hue="a", style="b", data=long_df) - assert p.semantics == ["x", "y", "hue", "style"] - assert np.array_equal(p.plot_data["hue"], long_df["a"]) - assert np.array_equal(p.plot_data["style"], long_df["b"]) - assert p.plot_data["size"].isnull().all() - - p.establish_variables(x="x", y="y", size="y", data=long_df) - assert p.semantics == ["x", "y", "size"] - assert np.array_equal(p.plot_data["size"], long_df["y"]) - assert p.size_label == "y" - assert p.hue_label is None and p.style_label is None - - def test_bad_input(self, long_df): - - p = rel._RelationalPlotter() + assert_array_equal(style, expected_style) - with pytest.raises(ValueError): - p.establish_variables(x=long_df.x) + assert p.variables["x"] is None + assert p.variables["y"] is None + assert p.variables["hue"] is None + assert p.variables["style"] is None - with pytest.raises(ValueError): - p.establish_variables(y=long_df.y) + def test_wide_list_of_list_variables(self, wide_list_of_lists): - with pytest.raises(ValueError): - p.establish_variables(x="not_in_df", data=long_df) + p = _RelationalPlotter() + p.assign_variables(data=wide_list_of_lists) + assert p.input_format == "wide" + assert list(p.variables) == ["x", "y", "hue", "style"] - with pytest.raises(ValueError): - p.establish_variables(x="x", y="not_in_df", data=long_df) + chunks = len(wide_list_of_lists) + chunk_size = max(len(l) for l in wide_list_of_lists) - with pytest.raises(ValueError): - p.establish_variables(x="x", y="not_in_df", data=long_df) + assert len(p.plot_data) == chunks * chunk_size - def test_empty_input(self): + x = p.plot_data["x"] + expected_x = np.tile(np.arange(chunk_size), chunks) + assert_array_equal(x, expected_x) - p = rel._RelationalPlotter() + y = p.plot_data["y"].dropna() + expected_y = np.concatenate(wide_list_of_lists) + assert_array_equal(y, expected_y) - p.establish_variables(data=[]) - p.establish_variables(data=np.array([])) - p.establish_variables(data=pd.DataFrame()) - p.establish_variables(x=[], y=[]) + hue = p.plot_data["hue"] + expected_hue = np.repeat(np.arange(chunks), chunk_size) + assert_array_equal(hue, expected_hue) - def test_units(self, repeated_df): + style = p.plot_data["style"] + expected_style = expected_hue + assert_array_equal(style, expected_style) - p = rel._RelationalPlotter() - p.establish_variables(x="x", y="y", units="u", data=repeated_df) - assert np.array_equal(p.plot_data["units"], repeated_df["u"]) + assert p.variables["x"] is None + assert p.variables["y"] is None + assert p.variables["hue"] is None + assert p.variables["style"] is None - def test_parse_hue_null(self, wide_df, null_column): + def test_wide_dict_of_series_variables(self, wide_dict_of_series): - p = rel._LinePlotter(data=wide_df) - p.parse_hue(null_column, "Blues", None, None) - assert p.hue_levels == [None] - assert p.palette == {} - assert p.hue_type is None - assert p.cmap is None + p = _RelationalPlotter() + p.assign_variables(data=wide_dict_of_series) + assert p.input_format == "wide" + assert list(p.variables) == ["x", "y", "hue", "style"] - def test_parse_hue_categorical(self, wide_df, long_df): + chunks = len(wide_dict_of_series) + chunk_size = max(len(l) for l in wide_dict_of_series.values()) - p = rel._LinePlotter(data=wide_df) - assert p.hue_levels == wide_df.columns.tolist() - assert p.hue_type is "categorical" - assert p.cmap is None + assert len(p.plot_data) == chunks * chunk_size - # Test named palette - palette = "Blues" - expected_colors = color_palette(palette, wide_df.shape[1]) - expected_palette = dict(zip(wide_df.columns, expected_colors)) - p.parse_hue(p.plot_data.hue, palette, None, None) - assert p.palette == expected_palette + x = p.plot_data["x"] + expected_x = np.tile(np.arange(chunk_size), chunks) + assert_array_equal(x, expected_x) - # Test list palette - palette = color_palette("Reds", wide_df.shape[1]) - p.parse_hue(p.plot_data.hue, palette, None, None) - expected_palette = dict(zip(wide_df.columns, palette)) - assert p.palette == expected_palette + y = p.plot_data["y"].dropna() + expected_y = np.concatenate(list(wide_dict_of_series.values())) + assert_array_equal(y, expected_y) - # Test dict palette - colors = color_palette("Set1", 8) - palette = dict(zip(wide_df.columns, colors)) - p.parse_hue(p.plot_data.hue, palette, None, None) - assert p.palette == palette + hue = p.plot_data["hue"] + expected_hue = np.repeat(list(wide_dict_of_series), chunk_size) + assert_array_equal(hue, expected_hue) - # Test dict with missing keys - palette = dict(zip(wide_df.columns[:-1], colors)) - with pytest.raises(ValueError): - p.parse_hue(p.plot_data.hue, palette, None, None) + style = p.plot_data["style"] + expected_style = expected_hue + assert_array_equal(style, expected_style) - # Test list with wrong number of colors - palette = colors[:-1] - with pytest.raises(ValueError): - p.parse_hue(p.plot_data.hue, palette, None, None) - - # Test hue order - hue_order = ["a", "c", "d"] - p.parse_hue(p.plot_data.hue, None, hue_order, None) - assert p.hue_levels == hue_order - - # Test long data - p = rel._LinePlotter(x="x", y="y", hue="a", data=long_df) - assert p.hue_levels == categorical_order(long_df.a) - assert p.hue_type is "categorical" - assert p.cmap is None - - # Test default palette - p.parse_hue(p.plot_data.hue, None, None, None) - hue_levels = categorical_order(long_df.a) - expected_colors = color_palette(n_colors=len(hue_levels)) - expected_palette = dict(zip(hue_levels, expected_colors)) - assert p.palette == expected_palette - - # Test default palette with many levels - levels = pd.Series(list("abcdefghijklmnopqrstuvwxyz")) - p.parse_hue(levels, None, None, None) - expected_colors = color_palette("husl", n_colors=len(levels)) - expected_palette = dict(zip(levels, expected_colors)) - assert p.palette == expected_palette - - # Test binary data - p = rel._LinePlotter(x="x", y="y", hue="c", data=long_df) - assert p.hue_levels == [0, 1] - assert p.hue_type is "categorical" - - # Test numeric data with category type - p = rel._LinePlotter(x="x", y="y", hue="s_cat", data=long_df) - assert p.hue_levels == categorical_order(long_df.s_cat) - assert p.hue_type is "categorical" - assert p.cmap is None - - def test_parse_hue_numeric(self, long_df): - - p = rel._LinePlotter(x="x", y="y", hue="s", data=long_df) - hue_levels = list(np.sort(long_df.s.unique())) - assert p.hue_levels == hue_levels - assert p.hue_type is "numeric" - assert p.cmap.name == "seaborn_cubehelix" - - # Test named colormap - palette = "Purples" - p.parse_hue(p.plot_data.hue, palette, None, None) - assert p.cmap is mpl.cm.get_cmap(palette) - - # Test colormap object - palette = mpl.cm.get_cmap("Greens") - p.parse_hue(p.plot_data.hue, palette, None, None) - assert p.cmap is palette - - # Test cubehelix shorthand - palette = "ch:2,0,light=.2" - p.parse_hue(p.plot_data.hue, palette, None, None) - assert isinstance(p.cmap, mpl.colors.ListedColormap) - - # Test default hue limits - p.parse_hue(p.plot_data.hue, None, None, None) - assert p.hue_limits == (p.plot_data.hue.min(), p.plot_data.hue.max()) - - # Test specified hue limits - hue_norm = 1, 4 - p.parse_hue(p.plot_data.hue, None, None, hue_norm) - assert p.hue_limits == hue_norm - assert isinstance(p.hue_norm, mpl.colors.Normalize) - assert p.hue_norm.vmin == hue_norm[0] - assert p.hue_norm.vmax == hue_norm[1] - - # Test Normalize object - hue_norm = mpl.colors.PowerNorm(2, vmin=1, vmax=10) - p.parse_hue(p.plot_data.hue, None, None, hue_norm) - assert p.hue_limits == (hue_norm.vmin, hue_norm.vmax) - assert p.hue_norm is hue_norm - - # Test default colormap values - hmin, hmax = p.plot_data.hue.min(), p.plot_data.hue.max() - p.parse_hue(p.plot_data.hue, None, None, None) - assert p.palette[hmin] == pytest.approx(p.cmap(0.0)) - assert p.palette[hmax] == pytest.approx(p.cmap(1.0)) - - # Test specified colormap values - hue_norm = hmin - 1, hmax - 1 - p.parse_hue(p.plot_data.hue, None, None, hue_norm) - norm_min = (hmin - hue_norm[0]) / (hue_norm[1] - hue_norm[0]) - assert p.palette[hmin] == pytest.approx(p.cmap(norm_min)) - assert p.palette[hmax] == pytest.approx(p.cmap(1.0)) - - # Test list of colors - hue_levels = list(np.sort(long_df.s.unique())) - palette = color_palette("Blues", len(hue_levels)) - p.parse_hue(p.plot_data.hue, palette, None, None) - assert p.palette == dict(zip(hue_levels, palette)) - - palette = color_palette("Blues", len(hue_levels) + 1) - with pytest.raises(ValueError): - p.parse_hue(p.plot_data.hue, palette, None, None) + assert p.variables["x"] is None + assert p.variables["y"] is None + assert p.variables["hue"] is None + assert p.variables["style"] is None - # Test dictionary of colors - palette = dict(zip(hue_levels, color_palette("Reds"))) - p.parse_hue(p.plot_data.hue, palette, None, None) - assert p.palette == palette + def test_wide_dict_of_arrays_variables(self, wide_dict_of_arrays): - palette.pop(hue_levels[0]) - with pytest.raises(ValueError): - p.parse_hue(p.plot_data.hue, palette, None, None) + p = _RelationalPlotter() + p.assign_variables(data=wide_dict_of_arrays) + assert p.input_format == "wide" + assert list(p.variables) == ["x", "y", "hue", "style"] - # Test invalid palette - palette = "not_a_valid_palette" - with pytest.raises(ValueError): - p.parse_hue(p.plot_data.hue, palette, None, None) + chunks = len(wide_dict_of_arrays) + chunk_size = max(len(l) for l in wide_dict_of_arrays.values()) - # Test bad norm argument - hue_norm = "not a norm" - with pytest.raises(ValueError): - p.parse_hue(p.plot_data.hue, None, None, hue_norm) - - def test_parse_size(self, long_df): - - p = rel._LinePlotter(x="x", y="y", size="s", data=long_df) - - # Test default size limits and range - default_linewidth = mpl.rcParams["lines.linewidth"] - default_limits = p.plot_data["size"].min(), p.plot_data["size"].max() - default_range = .5 * default_linewidth, 2 * default_linewidth - p.parse_size(p.plot_data["size"], None, None, None) - assert p.size_limits == default_limits - size_range = min(p.sizes.values()), max(p.sizes.values()) - assert size_range == default_range - - # Test specified size limits - size_limits = (1, 5) - p.parse_size(p.plot_data["size"], None, None, size_limits) - assert p.size_limits == size_limits - - # Test specified size range - sizes = (.1, .5) - p.parse_size(p.plot_data["size"], sizes, None, None) - assert p.size_limits == default_limits - - # Test size values with normalization range - sizes = (1, 5) - size_norm = (1, 10) - p.parse_size(p.plot_data["size"], sizes, None, size_norm) - normalize = mpl.colors.Normalize(*size_norm, clip=True) - for level, width in p.sizes.items(): - assert width == sizes[0] + (sizes[1] - sizes[0]) * normalize(level) - - # Test size values with normalization object - sizes = (1, 5) - size_norm = mpl.colors.LogNorm(1, 10, clip=False) - p.parse_size(p.plot_data["size"], sizes, None, size_norm) - assert p.size_norm.clip - for level, width in p.sizes.items(): - assert width == sizes[0] + (sizes[1] - sizes[0]) * size_norm(level) - - # Test specified size order - var = "a" - levels = long_df[var].unique() - sizes = [1, 4, 6] - size_order = [levels[1], levels[2], levels[0]] - p = rel._LinePlotter(x="x", y="y", size=var, data=long_df) - p.parse_size(p.plot_data["size"], sizes, size_order, None) - assert p.sizes == dict(zip(size_order, sizes)) - - # Test list of sizes - var = "a" - levels = categorical_order(long_df[var]) - sizes = list(np.random.rand(len(levels))) - p = rel._LinePlotter(x="x", y="y", size=var, data=long_df) - p.parse_size(p.plot_data["size"], sizes, None, None) - assert p.sizes == dict(zip(levels, sizes)) - - # Test dict of sizes - var = "a" - levels = categorical_order(long_df[var]) - sizes = dict(zip(levels, np.random.rand(len(levels)))) - p = rel._LinePlotter(x="x", y="y", size=var, data=long_df) - p.parse_size(p.plot_data["size"], sizes, None, None) - assert p.sizes == sizes - - # Test sizes list with wrong length - sizes = list(np.random.rand(len(levels) + 1)) - with pytest.raises(ValueError): - p.parse_size(p.plot_data["size"], sizes, None, None) + assert len(p.plot_data) == chunks * chunk_size - # Test sizes dict with missing levels - sizes = dict(zip(levels, np.random.rand(len(levels) - 1))) - with pytest.raises(ValueError): - p.parse_size(p.plot_data["size"], sizes, None, None) + x = p.plot_data["x"] + expected_x = np.tile(np.arange(chunk_size), chunks) + assert_array_equal(x, expected_x) - # Test bad sizes argument - sizes = "bad_size" - with pytest.raises(ValueError): - p.parse_size(p.plot_data["size"], sizes, None, None) + y = p.plot_data["y"].dropna() + expected_y = np.concatenate(list(wide_dict_of_arrays.values())) + assert_array_equal(y, expected_y) - # Test bad norm argument - size_norm = "not a norm" - p = rel._LinePlotter(x="x", y="y", size="s", data=long_df) - with pytest.raises(ValueError): - p.parse_size(p.plot_data["size"], None, None, size_norm) - - def test_parse_style(self, long_df): - - p = rel._LinePlotter(x="x", y="y", style="a", data=long_df) - - # Test defaults - markers, dashes = True, True - p.parse_style(p.plot_data["style"], markers, dashes, None) - assert p.markers == dict(zip(p.style_levels, p.default_markers)) - assert p.dashes == dict(zip(p.style_levels, p.default_dashes)) - - # Test lists - markers, dashes = ["o", "s", "d"], [(1, 0), (1, 1), (2, 1, 3, 1)] - p.parse_style(p.plot_data["style"], markers, dashes, None) - assert p.markers == dict(zip(p.style_levels, markers)) - assert p.dashes == dict(zip(p.style_levels, dashes)) - - # Test dicts - markers = dict(zip(p.style_levels, markers)) - dashes = dict(zip(p.style_levels, dashes)) - p.parse_style(p.plot_data["style"], markers, dashes, None) - assert p.markers == markers - assert p.dashes == dashes - - # Test style order with defaults - style_order = np.take(p.style_levels, [1, 2, 0]) - markers = dashes = True - p.parse_style(p.plot_data["style"], markers, dashes, style_order) - assert p.markers == dict(zip(style_order, p.default_markers)) - assert p.dashes == dict(zip(style_order, p.default_dashes)) - - # Test too many levels with style lists - markers, dashes = ["o", "s"], False - with pytest.raises(ValueError): - p.parse_style(p.plot_data["style"], markers, dashes, None) + hue = p.plot_data["hue"] + expected_hue = np.repeat(list(wide_dict_of_arrays), chunk_size) + assert_array_equal(hue, expected_hue) - markers, dashes = False, [(2, 1)] - with pytest.raises(ValueError): - p.parse_style(p.plot_data["style"], markers, dashes, None) + style = p.plot_data["style"] + expected_style = expected_hue + assert_array_equal(style, expected_style) - # Test too many levels with style dicts - markers, dashes = {"a": "o", "b": "s"}, False - with pytest.raises(ValueError): - p.parse_style(p.plot_data["style"], markers, dashes, None) + assert p.variables["x"] is None + assert p.variables["y"] is None + assert p.variables["hue"] is None + assert p.variables["style"] is None - markers, dashes = False, {"a": (1, 0), "b": (2, 1)} - with pytest.raises(ValueError): - p.parse_style(p.plot_data["style"], markers, dashes, None) + def test_wide_dict_of_lists_variables(self, wide_dict_of_lists): - # Test mixture of filled and unfilled markers - markers, dashes = ["o", "x", "s"], None - with pytest.raises(ValueError): - p.parse_style(p.plot_data["style"], markers, dashes, None) + p = _RelationalPlotter() + p.assign_variables(data=wide_dict_of_lists) + assert p.input_format == "wide" + assert list(p.variables) == ["x", "y", "hue", "style"] - def test_subset_data_quantities(self, long_df): + chunks = len(wide_dict_of_lists) + chunk_size = max(len(l) for l in wide_dict_of_lists.values()) - p = rel._LinePlotter(x="x", y="y", data=long_df) - assert len(list(p.subset_data())) == 1 + assert len(p.plot_data) == chunks * chunk_size - # -- + x = p.plot_data["x"] + expected_x = np.tile(np.arange(chunk_size), chunks) + assert_array_equal(x, expected_x) - var = "a" - n_subsets = len(long_df[var].unique()) + y = p.plot_data["y"].dropna() + expected_y = np.concatenate(list(wide_dict_of_lists.values())) + assert_array_equal(y, expected_y) - p = rel._LinePlotter(x="x", y="y", hue=var, data=long_df) - assert len(list(p.subset_data())) == n_subsets + hue = p.plot_data["hue"] + expected_hue = np.repeat(list(wide_dict_of_lists), chunk_size) + assert_array_equal(hue, expected_hue) + + style = p.plot_data["style"] + expected_style = expected_hue + assert_array_equal(style, expected_style) - p = rel._LinePlotter(x="x", y="y", style=var, data=long_df) - assert len(list(p.subset_data())) == n_subsets + assert p.variables["x"] is None + assert p.variables["y"] is None + assert p.variables["hue"] is None + assert p.variables["style"] is None - n_subsets = len(long_df[var].unique()) + def test_relplot_simple(self, long_df): - p = rel._LinePlotter(x="x", y="y", size=var, data=long_df) - assert len(list(p.subset_data())) == n_subsets + g = relplot(data=long_df, x="x", y="y", kind="scatter") + x, y = g.ax.collections[0].get_offsets().T + assert_array_equal(x, long_df["x"]) + assert_array_equal(y, long_df["y"]) - # -- + g = relplot(data=long_df, x="x", y="y", kind="line") + x, y = g.ax.lines[0].get_xydata().T + expected = long_df.groupby("x").y.mean() + assert_array_equal(x, expected.index) + assert y == pytest.approx(expected.values) - var = "a" - n_subsets = len(long_df[var].unique()) + with pytest.raises(ValueError): + g = relplot(data=long_df, x="x", y="y", kind="not_a_kind") - p = rel._LinePlotter(x="x", y="y", hue=var, style=var, data=long_df) - assert len(list(p.subset_data())) == n_subsets + def test_relplot_complex(self, long_df): - # -- + for sem in ["hue", "size", "style"]: + g = relplot(data=long_df, x="x", y="y", **{sem: "a"}) + x, y = g.ax.collections[0].get_offsets().T + assert_array_equal(x, long_df["x"]) + assert_array_equal(y, long_df["y"]) - var1, var2 = "a", "s" - n_subsets = len(set(list(map(tuple, long_df[[var1, var2]].values)))) + for sem in ["hue", "size", "style"]: + g = relplot( + data=long_df, x="x", y="y", col="c", **{sem: "a"} + ) + grouped = long_df.groupby("c") + for (_, grp_df), ax in zip(grouped, g.axes.flat): + x, y = ax.collections[0].get_offsets().T + assert_array_equal(x, grp_df["x"]) + assert_array_equal(y, grp_df["y"]) - p = rel._LinePlotter(x="x", y="y", hue=var1, style=var2, - data=long_df) - assert len(list(p.subset_data())) == n_subsets + for sem in ["size", "style"]: + g = relplot( + data=long_df, x="x", y="y", hue="b", col="c", **{sem: "a"} + ) + grouped = long_df.groupby("c") + for (_, grp_df), ax in zip(grouped, g.axes.flat): + x, y = ax.collections[0].get_offsets().T + assert_array_equal(x, grp_df["x"]) + assert_array_equal(y, grp_df["y"]) - p = rel._LinePlotter(x="x", y="y", hue=var1, size=var2, style=var1, - data=long_df) - assert len(list(p.subset_data())) == n_subsets + for sem in ["hue", "size", "style"]: + g = relplot( + data=long_df.sort_values(["c", "b"]), + x="x", y="y", col="b", row="c", **{sem: "a"} + ) + grouped = long_df.groupby(["c", "b"]) + for (_, grp_df), ax in zip(grouped, g.axes.flat): + x, y = ax.collections[0].get_offsets().T + assert_array_equal(x, grp_df["x"]) + assert_array_equal(y, grp_df["y"]) + + @pytest.mark.parametrize( + "vector_type", + ["series", "numpy", "list"], + ) + def test_relplot_vectors(self, long_df, vector_type): + + semantics = dict(x="x", y="y", hue="f", col="c") + kws = {key: long_df[val] for key, val in semantics.items()} + g = relplot(data=long_df, **kws) + grouped = long_df.groupby("c") + for (_, grp_df), ax in zip(grouped, g.axes.flat): + x, y = ax.collections[0].get_offsets().T + assert_array_equal(x, grp_df["x"]) + assert_array_equal(y, grp_df["y"]) - # -- + def test_relplot_wide(self, wide_df): - var1, var2, var3 = "a", "s", "b" - cols = [var1, var2, var3] - n_subsets = len(set(list(map(tuple, long_df[cols].values)))) + g = relplot(data=wide_df) + x, y = g.ax.collections[0].get_offsets().T + assert_array_equal(y, wide_df.to_numpy().T.ravel()) - p = rel._LinePlotter(x="x", y="y", hue=var1, size=var2, style=var3, - data=long_df) - assert len(list(p.subset_data())) == n_subsets + def test_relplot_hues(self, long_df): - def test_subset_data_keys(self, long_df): + palette = ["r", "b", "g"] + g = relplot( + x="x", y="y", hue="a", style="b", col="c", + palette=palette, data=long_df + ) - p = rel._LinePlotter(x="x", y="y", data=long_df) - for (hue, size, style), _ in p.subset_data(): - assert hue is None - assert size is None - assert style is None + palette = dict(zip(long_df["a"].unique(), palette)) + grouped = long_df.groupby("c") + for (_, grp_df), ax in zip(grouped, g.axes.flat): + points = ax.collections[0] + expected_hues = [palette[val] for val in grp_df["a"]] + assert same_color(points.get_facecolors(), expected_hues) - # -- + def test_relplot_sizes(self, long_df): - var = "a" + sizes = [5, 12, 7] + g = relplot( + data=long_df, + x="x", y="y", size="a", hue="b", col="c", + sizes=sizes, + ) - p = rel._LinePlotter(x="x", y="y", hue=var, data=long_df) - for (hue, size, style), _ in p.subset_data(): - assert hue in long_df[var].values - assert size is None - assert style is None + sizes = dict(zip(long_df["a"].unique(), sizes)) + grouped = long_df.groupby("c") + for (_, grp_df), ax in zip(grouped, g.axes.flat): + points = ax.collections[0] + expected_sizes = [sizes[val] for val in grp_df["a"]] + assert_array_equal(points.get_sizes(), expected_sizes) - p = rel._LinePlotter(x="x", y="y", style=var, data=long_df) - for (hue, size, style), _ in p.subset_data(): - assert hue is None - assert size is None - assert style in long_df[var].values + def test_relplot_styles(self, long_df): - p = rel._LinePlotter(x="x", y="y", hue=var, style=var, data=long_df) - for (hue, size, style), _ in p.subset_data(): - assert hue in long_df[var].values - assert size is None - assert style in long_df[var].values + markers = ["o", "d", "s"] + g = relplot( + data=long_df, + x="x", y="y", style="a", hue="b", col="c", + markers=markers, + ) - p = rel._LinePlotter(x="x", y="y", size=var, data=long_df) - for (hue, size, style), _ in p.subset_data(): - assert hue is None - assert size in long_df[var].values - assert style is None + paths = [] + for m in markers: + m = mpl.markers.MarkerStyle(m) + paths.append(m.get_path().transformed(m.get_transform())) + paths = dict(zip(long_df["a"].unique(), paths)) - # -- + grouped = long_df.groupby("c") + for (_, grp_df), ax in zip(grouped, g.axes.flat): + points = ax.collections[0] + expected_paths = [paths[val] for val in grp_df["a"]] + assert self.paths_equal(points.get_paths(), expected_paths) - var1, var2 = "a", "s" + def test_relplot_stringy_numerics(self, long_df): - p = rel._LinePlotter(x="x", y="y", hue=var1, size=var2, data=long_df) - for (hue, size, style), _ in p.subset_data(): - assert hue in long_df[var1].values - assert size in long_df[var2].values - assert style is None + long_df["x_str"] = long_df["x"].astype(str) - def test_subset_data_values(self, long_df): + g = relplot(data=long_df, x="x", y="y", hue="x_str") + points = g.ax.collections[0] + xys = points.get_offsets() + mask = np.ma.getmask(xys) + assert not mask.any() + assert_array_equal(xys, long_df[["x", "y"]]) - p = rel._LinePlotter(x="x", y="y", data=long_df) - _, data = next(p.subset_data()) - expected = sort_df(p.plot_data.loc[:, ["x", "y"]], ["x", "y"]) - assert np.array_equal(data.values, expected) + g = relplot(data=long_df, x="x", y="y", size="x_str") + points = g.ax.collections[0] + xys = points.get_offsets() + mask = np.ma.getmask(xys) + assert not mask.any() + assert_array_equal(xys, long_df[["x", "y"]]) - p = rel._LinePlotter(x="x", y="y", data=long_df, sort=False) - _, data = next(p.subset_data()) - expected = p.plot_data.loc[:, ["x", "y"]] - assert np.array_equal(data.values, expected) + def test_relplot_legend(self, long_df): - p = rel._LinePlotter(x="x", y="y", hue="a", data=long_df) - for (hue, _, _), data in p.subset_data(): - rows = p.plot_data["hue"] == hue - cols = ["x", "y"] - expected = sort_df(p.plot_data.loc[rows, cols], cols) - assert np.array_equal(data.values, expected.values) + g = relplot(data=long_df, x="x", y="y") + assert g._legend is None - p = rel._LinePlotter(x="x", y="y", hue="a", data=long_df, sort=False) - for (hue, _, _), data in p.subset_data(): - rows = p.plot_data["hue"] == hue - cols = ["x", "y"] - expected = p.plot_data.loc[rows, cols] - assert np.array_equal(data.values, expected.values) + g = relplot(data=long_df, x="x", y="y", hue="a") + texts = [t.get_text() for t in g._legend.texts] + expected_texts = long_df["a"].unique() + assert_array_equal(texts, expected_texts) - p = rel._LinePlotter(x="x", y="y", hue="a", style="a", data=long_df) - for (hue, _, _), data in p.subset_data(): - rows = p.plot_data["hue"] == hue - cols = ["x", "y"] - expected = sort_df(p.plot_data.loc[rows, cols], cols) - assert np.array_equal(data.values, expected.values) + g = relplot(data=long_df, x="x", y="y", hue="s", size="s") + texts = [t.get_text() for t in g._legend.texts] + assert_array_equal(texts, np.sort(texts)) - p = rel._LinePlotter(x="x", y="y", hue="a", size="s", data=long_df) - for (hue, size, _), data in p.subset_data(): - rows = (p.plot_data["hue"] == hue) & (p.plot_data["size"] == size) - cols = ["x", "y"] - expected = sort_df(p.plot_data.loc[rows, cols], cols) - assert np.array_equal(data.values, expected.values) + g = relplot(data=long_df, x="x", y="y", hue="a", legend=False) + assert g._legend is None + palette = color_palette("deep", len(long_df["b"].unique())) + a_like_b = dict(zip(long_df["a"].unique(), long_df["b"].unique())) + long_df["a_like_b"] = long_df["a"].map(a_like_b) + g = relplot( + data=long_df, + x="x", y="y", hue="b", style="a_like_b", + palette=palette, kind="line", estimator=None, + ) + lines = g._legend.get_lines()[1:] # Chop off title dummy + for line, color in zip(lines, palette): + assert line.get_color() == color -class TestLinePlotter(TestRelationalPlotter): + def test_ax_kwarg_removal(self, long_df): - def test_aggregate(self, long_df): + f, ax = plt.subplots() + with pytest.warns(UserWarning): + g = relplot(data=long_df, x="x", y="y", ax=ax) + assert len(ax.collections) == 0 + assert len(g.ax.collections) > 0 - p = rel._LinePlotter(x="x", y="y", data=long_df) - p.n_boot = 10000 - p.sort = False - x = pd.Series(np.tile([1, 2], 100)) - y = pd.Series(np.random.randn(200)) - y_mean = y.groupby(x).mean() +class TestLinePlotter(SharedAxesLevelTests, Helpers): - def sem(x): - return np.std(x) / np.sqrt(len(x)) + func = staticmethod(lineplot) - y_sem = y.groupby(x).apply(sem) - y_cis = pd.DataFrame(dict(low=y_mean - y_sem, - high=y_mean + y_sem), - columns=["low", "high"]) + def get_last_color(self, ax): - p.ci = 68 - p.estimator = "mean" - index, est, cis = p.aggregate(y, x) - assert np.array_equal(index.values, x.unique()) - assert est.index.equals(index) - assert est.values == pytest.approx(y_mean.values) - assert cis.values == pytest.approx(y_cis.values, 4) - assert list(cis.columns) == ["low", "high"] - - p.estimator = np.mean - index, est, cis = p.aggregate(y, x) - assert np.array_equal(index.values, x.unique()) - assert est.index.equals(index) - assert est.values == pytest.approx(y_mean.values) - assert cis.values == pytest.approx(y_cis.values, 4) - assert list(cis.columns) == ["low", "high"] - - y_std = y.groupby(x).std() - y_cis = pd.DataFrame(dict(low=y_mean - y_std, - high=y_mean + y_std), - columns=["low", "high"]) - - p.ci = "sd" - index, est, cis = p.aggregate(y, x) - assert np.array_equal(index.values, x.unique()) - assert est.index.equals(index) - assert est.values == pytest.approx(y_mean.values) - assert cis.values == pytest.approx(y_cis.values) - assert list(cis.columns) == ["low", "high"] - - p.ci = None - index, est, cis = p.aggregate(y, x) - assert cis is None - - p.ci = 68 - x, y = pd.Series([1, 2, 3]), pd.Series([4, 3, 2]) - index, est, cis = p.aggregate(y, x) - assert np.array_equal(index.values, x) - assert np.array_equal(est.values, y) - assert cis is None - - x, y = pd.Series([1, 1, 2]), pd.Series([2, 3, 4]) - index, est, cis = p.aggregate(y, x) - assert cis.loc[2].isnull().all() + return to_rgba(ax.lines[-1].get_color()) def test_legend_data(self, long_df): f, ax = plt.subplots() - p = rel._LinePlotter(x="x", y="y", data=long_df, legend="full") + p = _LinePlotter( + data=long_df, + variables=dict(x="x", y="y"), + legend="full" + ) p.add_legend_data(ax) - handles, _ = ax.get_legend_handles_labels() + handles, labels = ax.get_legend_handles_labels() assert handles == [] # -- ax.clear() - p = rel._LinePlotter(x="x", y="y", hue="a", data=long_df, - legend="full") + p = _LinePlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a"), + legend="full", + ) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() colors = [h.get_color() for h in handles] - assert labels == ["a"] + p.hue_levels - assert colors == ["w"] + [p.palette[l] for l in p.hue_levels] + assert labels == p._hue_map.levels + assert colors == p._hue_map(p._hue_map.levels) # -- ax.clear() - p = rel._LinePlotter(x="x", y="y", hue="a", style="a", - markers=True, legend="full", data=long_df) + p = _LinePlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a", style="a"), + legend="full", + ) + p.map_style(markers=True) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() colors = [h.get_color() for h in handles] markers = [h.get_marker() for h in handles] - assert labels == ["a"] + p.hue_levels == ["a"] + p.style_levels - assert colors == ["w"] + [p.palette[l] for l in p.hue_levels] - assert markers == [""] + [p.markers[l] for l in p.style_levels] + assert labels == p._hue_map.levels + assert labels == p._style_map.levels + assert colors == p._hue_map(p._hue_map.levels) + assert markers == p._style_map(p._style_map.levels, "marker") # -- ax.clear() - p = rel._LinePlotter(x="x", y="y", hue="a", style="b", - markers=True, legend="full", data=long_df) + p = _LinePlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a", style="b"), + legend="full", + ) + p.map_style(markers=True) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() colors = [h.get_color() for h in handles] markers = [h.get_marker() for h in handles] - expected_colors = (["w"] + [p.palette[l] for l in p.hue_levels] - + ["w"] + [".2" for _ in p.style_levels]) - expected_markers = ([""] + ["None" for _ in p.hue_levels] - + [""] + [p.markers[l] for l in p.style_levels]) - assert labels == ["a"] + p.hue_levels + ["b"] + p.style_levels + expected_labels = ( + ["a"] + + p._hue_map.levels + + ["b"] + p._style_map.levels + ) + expected_colors = ( + ["w"] + p._hue_map(p._hue_map.levels) + + ["w"] + [".2" for _ in p._style_map.levels] + ) + expected_markers = ( + [""] + ["None" for _ in p._hue_map.levels] + + [""] + p._style_map(p._style_map.levels, "marker") + ) + assert labels == expected_labels assert colors == expected_colors assert markers == expected_markers # -- ax.clear() - p = rel._LinePlotter(x="x", y="y", hue="a", size="a", data=long_df, - legend="full") + p = _LinePlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a", size="a"), + legend="full" + ) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() colors = [h.get_color() for h in handles] widths = [h.get_linewidth() for h in handles] - assert labels == ["a"] + p.hue_levels == ["a"] + p.size_levels - assert colors == ["w"] + [p.palette[l] for l in p.hue_levels] - assert widths == [0] + [p.sizes[l] for l in p.size_levels] + assert labels == p._hue_map.levels + assert labels == p._size_map.levels + assert colors == p._hue_map(p._hue_map.levels) + assert widths == p._size_map(p._size_map.levels) # -- x, y = np.random.randn(2, 40) z = np.tile(np.arange(20), 2) - p = rel._LinePlotter(x=x, y=y, hue=z) + p = _LinePlotter(variables=dict(x=x, y=y, hue=z)) ax.clear() p.legend = "full" p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() - assert labels == [str(l) for l in p.hue_levels] + assert labels == [str(l) for l in p._hue_map.levels] ax.clear() p.legend = "brief" p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() - assert len(labels) == 4 + assert len(labels) < len(p._hue_map.levels) - p = rel._LinePlotter(x=x, y=y, size=z) + p = _LinePlotter(variables=dict(x=x, y=y, size=z)) ax.clear() p.legend = "full" p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() - assert labels == [str(l) for l in p.size_levels] + assert labels == [str(l) for l in p._size_map.levels] ax.clear() p.legend = "brief" p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() - assert len(labels) == 4 + assert len(labels) < len(p._size_map.levels) + + ax.clear() + p.legend = "auto" + p.add_legend_data(ax) + handles, labels = ax.get_legend_handles_labels() + assert len(labels) < len(p._size_map.levels) + + ax.clear() + p.legend = True + p.add_legend_data(ax) + handles, labels = ax.get_legend_handles_labels() + assert len(labels) < len(p._size_map.levels) ax.clear() p.legend = "bad_value" @@ -1009,31 +777,71 @@ def test_legend_data(self, long_df): p.add_legend_data(ax) ax.clear() - p = rel._LinePlotter(x=x, y=y, hue=z, - hue_norm=mpl.colors.LogNorm(), - legend="brief") + p = _LinePlotter( + variables=dict(x=x, y=y, hue=z + 1), + legend="brief" + ) + p.map_hue(norm=mpl.colors.LogNorm()), + p.add_legend_data(ax) + handles, labels = ax.get_legend_handles_labels() + assert float(labels[1]) / float(labels[0]) == 10 + + ax.clear() + p = _LinePlotter( + variables=dict(x=x, y=y, hue=z % 2), + legend="auto" + ) + p.map_hue(norm=mpl.colors.LogNorm()), + p.add_legend_data(ax) + handles, labels = ax.get_legend_handles_labels() + assert labels == ["0", "1"] + + ax.clear() + p = _LinePlotter( + variables=dict(x=x, y=y, size=z + 1), + legend="brief" + ) + p.map_size(norm=mpl.colors.LogNorm()) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() - assert float(labels[2]) / float(labels[1]) == 10 + assert float(labels[1]) / float(labels[0]) == 10 ax.clear() - p = rel._LinePlotter(x=x, y=y, size=z, - size_norm=mpl.colors.LogNorm(), - legend="brief") + p = _LinePlotter( + data=long_df, + variables=dict(x="x", y="y", hue="f"), + legend="brief", + ) p.add_legend_data(ax) + expected_labels = ['0.20', '0.22', '0.24', '0.26', '0.28'] handles, labels = ax.get_legend_handles_labels() - assert float(labels[2]) / float(labels[1]) == 10 + assert labels == expected_labels + + ax.clear() + p = _LinePlotter( + data=long_df, + variables=dict(x="x", y="y", size="f"), + legend="brief", + ) + p.add_legend_data(ax) + expected_levels = ['0.20', '0.22', '0.24', '0.26', '0.28'] + handles, labels = ax.get_legend_handles_labels() + assert labels == expected_levels def test_plot(self, long_df, repeated_df): f, ax = plt.subplots() - p = rel._LinePlotter(x="x", y="y", data=long_df, - sort=False, estimator=None) + p = _LinePlotter( + data=long_df, + variables=dict(x="x", y="y"), + sort=False, + estimator=None + ) p.plot(ax, {}) line, = ax.lines - assert np.array_equal(line.get_xdata(), long_df.x.values) - assert np.array_equal(line.get_ydata(), long_df.y.values) + assert_array_equal(line.get_xdata(), long_df.x.to_numpy()) + assert_array_equal(line.get_ydata(), long_df.y.to_numpy()) ax.clear() p.plot(ax, {"color": "k", "label": "test"}) @@ -1041,95 +849,143 @@ def test_plot(self, long_df, repeated_df): assert line.get_color() == "k" assert line.get_label() == "test" - p = rel._LinePlotter(x="x", y="y", data=long_df, - sort=True, estimator=None) + p = _LinePlotter( + data=long_df, + variables=dict(x="x", y="y"), + sort=True, estimator=None + ) ax.clear() p.plot(ax, {}) line, = ax.lines - sorted_data = sort_df(long_df, ["x", "y"]) - assert np.array_equal(line.get_xdata(), sorted_data.x.values) - assert np.array_equal(line.get_ydata(), sorted_data.y.values) + sorted_data = long_df.sort_values(["x", "y"]) + assert_array_equal(line.get_xdata(), sorted_data.x.to_numpy()) + assert_array_equal(line.get_ydata(), sorted_data.y.to_numpy()) - p = rel._LinePlotter(x="x", y="y", hue="a", data=long_df) + p = _LinePlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a"), + ) ax.clear() p.plot(ax, {}) - assert len(ax.lines) == len(p.hue_levels) - for line, level in zip(ax.lines, p.hue_levels): - assert line.get_color() == p.palette[level] + assert len(ax.lines) == len(p._hue_map.levels) + for line, level in zip(ax.lines, p._hue_map.levels): + assert line.get_color() == p._hue_map(level) - p = rel._LinePlotter(x="x", y="y", size="a", data=long_df) + p = _LinePlotter( + data=long_df, + variables=dict(x="x", y="y", size="a"), + ) ax.clear() p.plot(ax, {}) - assert len(ax.lines) == len(p.size_levels) - for line, level in zip(ax.lines, p.size_levels): - assert line.get_linewidth() == p.sizes[level] + assert len(ax.lines) == len(p._size_map.levels) + for line, level in zip(ax.lines, p._size_map.levels): + assert line.get_linewidth() == p._size_map(level) - p = rel._LinePlotter(x="x", y="y", hue="a", style="a", - markers=True, data=long_df) + p = _LinePlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a", style="a"), + ) + p.map_style(markers=True) ax.clear() p.plot(ax, {}) - assert len(ax.lines) == len(p.hue_levels) == len(p.style_levels) - for line, level in zip(ax.lines, p.hue_levels): - assert line.get_color() == p.palette[level] - assert line.get_marker() == p.markers[level] + assert len(ax.lines) == len(p._hue_map.levels) + assert len(ax.lines) == len(p._style_map.levels) + for line, level in zip(ax.lines, p._hue_map.levels): + assert line.get_color() == p._hue_map(level) + assert line.get_marker() == p._style_map(level, "marker") - p = rel._LinePlotter(x="x", y="y", hue="a", style="b", - markers=True, data=long_df) + p = _LinePlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a", style="b"), + ) + p.map_style(markers=True) ax.clear() p.plot(ax, {}) - levels = product(p.hue_levels, p.style_levels) - assert len(ax.lines) == (len(p.hue_levels) * len(p.style_levels)) + levels = product(p._hue_map.levels, p._style_map.levels) + expected_line_count = len(p._hue_map.levels) * len(p._style_map.levels) + assert len(ax.lines) == expected_line_count for line, (hue, style) in zip(ax.lines, levels): - assert line.get_color() == p.palette[hue] - assert line.get_marker() == p.markers[style] + assert line.get_color() == p._hue_map(hue) + assert line.get_marker() == p._style_map(style, "marker") - p = rel._LinePlotter(x="x", y="y", data=long_df, - estimator="mean", err_style="band", ci="sd", - sort=True) + p = _LinePlotter( + data=long_df, + variables=dict(x="x", y="y"), + estimator="mean", err_style="band", errorbar="sd", sort=True + ) ax.clear() p.plot(ax, {}) line, = ax.lines expected_data = long_df.groupby("x").y.mean() - assert np.array_equal(line.get_xdata(), expected_data.index.values) - assert np.allclose(line.get_ydata(), expected_data.values) + assert_array_equal(line.get_xdata(), expected_data.index.to_numpy()) + assert np.allclose(line.get_ydata(), expected_data.to_numpy()) assert len(ax.collections) == 1 - p = rel._LinePlotter(x="x", y="y", hue="a", data=long_df, - estimator="mean", err_style="band", ci="sd") + # Test that nans do not propagate to means or CIs + p = _LinePlotter( + variables=dict( + x=[1, 1, 1, 2, 2, 2, 3, 3, 3], + y=[1, 2, 3, 3, np.nan, 5, 4, 5, 6], + ), + estimator="mean", err_style="band", errorbar="ci", n_boot=100, sort=True, + ) ax.clear() p.plot(ax, {}) - assert len(ax.lines) == len(ax.collections) == len(p.hue_levels) + line, = ax.lines + assert line.get_xdata().tolist() == [1, 2, 3] + err_band = ax.collections[0].get_paths() + assert len(err_band) == 1 + assert len(err_band[0].vertices) == 9 + + p = _LinePlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a"), + estimator="mean", err_style="band", errorbar="sd" + ) + + ax.clear() + p.plot(ax, {}) + assert len(ax.lines) == len(ax.collections) == len(p._hue_map.levels) for c in ax.collections: assert isinstance(c, mpl.collections.PolyCollection) - p = rel._LinePlotter(x="x", y="y", hue="a", data=long_df, - estimator="mean", err_style="bars", ci="sd") + p = _LinePlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a"), + estimator="mean", err_style="bars", errorbar="sd" + ) ax.clear() p.plot(ax, {}) - # assert len(ax.lines) / 2 == len(ax.collections) == len(p.hue_levels) - # The lines are different on mpl 1.4 but I can't install to debug - assert len(ax.collections) == len(p.hue_levels) + n_lines = len(ax.lines) + assert n_lines / 2 == len(ax.collections) == len(p._hue_map.levels) + assert len(ax.collections) == len(p._hue_map.levels) for c in ax.collections: assert isinstance(c, mpl.collections.LineCollection) - p = rel._LinePlotter(x="x", y="y", data=repeated_df, - units="u", estimator=None) + p = _LinePlotter( + data=repeated_df, + variables=dict(x="x", y="y", units="u"), + estimator=None + ) ax.clear() p.plot(ax, {}) n_units = len(repeated_df["u"].unique()) assert len(ax.lines) == n_units - p = rel._LinePlotter(x="x", y="y", hue="a", data=repeated_df, - units="u", estimator=None) + p = _LinePlotter( + data=repeated_df, + variables=dict(x="x", y="y", hue="a", units="u"), + estimator=None + ) ax.clear() p.plot(ax, {}) @@ -1140,16 +996,22 @@ def test_plot(self, long_df, repeated_df): with pytest.raises(ValueError): p.plot(ax, {}) - p = rel._LinePlotter(x="x", y="y", hue="a", data=long_df, - err_style="band", err_kws={"alpha": .5}) + p = _LinePlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a"), + err_style="band", err_kws={"alpha": .5}, + ) ax.clear() p.plot(ax, {}) for band in ax.collections: assert band.get_alpha() == .5 - p = rel._LinePlotter(x="x", y="y", hue="a", data=long_df, - err_style="bars", err_kws={"elinewidth": 2}) + p = _LinePlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a"), + err_style="bars", err_kws={"elinewidth": 2}, + ) ax.clear() p.plot(ax, {}) @@ -1160,11 +1022,57 @@ def test_plot(self, long_df, repeated_df): with pytest.raises(ValueError): p.plot(ax, {}) + x_str = long_df["x"].astype(str) + p = _LinePlotter( + data=long_df, + variables=dict(x="x", y="y", hue=x_str), + ) + ax.clear() + p.plot(ax, {}) + + p = _LinePlotter( + data=long_df, + variables=dict(x="x", y="y", size=x_str), + ) + ax.clear() + p.plot(ax, {}) + + def test_log_scale(self): + + f, ax = plt.subplots() + ax.set_xscale("log") + + x = [1, 10, 100] + y = [1, 2, 3] + + lineplot(x=x, y=y) + line = ax.lines[0] + assert_array_equal(line.get_xdata(), x) + assert_array_equal(line.get_ydata(), y) + + f, ax = plt.subplots() + ax.set_xscale("log") + ax.set_yscale("log") + + x = [1, 1, 2, 2] + y = [1, 10, 1, 100] + + lineplot(x=x, y=y, err_style="bars", errorbar=("pi", 100)) + line = ax.lines[0] + assert line.get_ydata()[1] == 10 + + ebars = ax.collections[0].get_segments() + assert_array_equal(ebars[0][:, 1], y[:2]) + assert_array_equal(ebars[1][:, 1], y[2:]) + def test_axis_labels(self, long_df): f, (ax1, ax2) = plt.subplots(1, 2, sharey=True) - p = rel._LinePlotter(x="x", y="y", data=long_df) + p = _LinePlotter( + data=long_df, + variables=dict(x="x", y="y"), + ) p.plot(ax1, {}) assert ax1.get_xlabel() == "x" @@ -1175,85 +1083,186 @@ def test_axis_labels(self, long_df): assert ax2.get_ylabel() == "y" assert not ax2.yaxis.label.get_visible() + def test_matplotlib_kwargs(self, long_df): + + kws = { + "linestyle": "--", + "linewidth": 3, + "color": (1, .5, .2), + "markeredgecolor": (.2, .5, .2), + "markeredgewidth": 1, + } + ax = lineplot(data=long_df, x="x", y="y", **kws) + + line, *_ = ax.lines + for key, val in kws.items(): + plot_val = getattr(line, f"get_{key}")() + assert plot_val == val + + def test_nonmapped_dashes(self): + + ax = lineplot(x=[1, 2], y=[1, 2], dashes=(2, 1)) + line = ax.lines[0] + # Not a great test, but lines don't expose the dash style publically + assert line.get_linestyle() == "--" + def test_lineplot_axes(self, wide_df): f1, ax1 = plt.subplots() f2, ax2 = plt.subplots() - ax = rel.lineplot(data=wide_df) + ax = lineplot(data=wide_df) assert ax is ax2 - ax = rel.lineplot(data=wide_df, ax=ax1) + ax = lineplot(data=wide_df, ax=ax1) assert ax is ax1 - def test_lineplot_smoke(self, flat_array, flat_series, - wide_array, wide_list, wide_list_of_series, - wide_df, long_df, missing_df): + def test_lineplot_vs_relplot(self, long_df, long_semantics): + + ax = lineplot(data=long_df, **long_semantics) + g = relplot(data=long_df, kind="line", **long_semantics) + + lin_lines = ax.lines + rel_lines = g.ax.lines + + for l1, l2 in zip(lin_lines, rel_lines): + assert_array_equal(l1.get_xydata(), l2.get_xydata()) + assert same_color(l1.get_color(), l2.get_color()) + assert l1.get_linewidth() == l2.get_linewidth() + assert l1.get_linestyle() == l2.get_linestyle() + + def test_lineplot_smoke( + self, + wide_df, wide_array, + wide_list_of_series, wide_list_of_arrays, wide_list_of_lists, + flat_array, flat_series, flat_list, + long_df, missing_df, object_df + ): f, ax = plt.subplots() - rel.lineplot([], []) + lineplot(x=[], y=[]) + ax.clear() + + lineplot(data=wide_df) + ax.clear() + + lineplot(data=wide_array) + ax.clear() + + lineplot(data=wide_list_of_series) + ax.clear() + + lineplot(data=wide_list_of_arrays) + ax.clear() + + lineplot(data=wide_list_of_lists) ax.clear() - rel.lineplot(data=flat_array) + lineplot(data=flat_series) ax.clear() - rel.lineplot(data=flat_series) + lineplot(data=flat_array) ax.clear() - rel.lineplot(data=wide_array) + lineplot(data=flat_list) ax.clear() - rel.lineplot(data=wide_list) + lineplot(x="x", y="y", data=long_df) ax.clear() - rel.lineplot(data=wide_list_of_series) + lineplot(x=long_df.x, y=long_df.y) ax.clear() - rel.lineplot(data=wide_df) + lineplot(x=long_df.x, y="y", data=long_df) ax.clear() - rel.lineplot(x="x", y="y", data=long_df) + lineplot(x="x", y=long_df.y.to_numpy(), data=long_df) ax.clear() - rel.lineplot(x=long_df.x, y=long_df.y) + lineplot(x="x", y="t", data=long_df) ax.clear() - rel.lineplot(x=long_df.x, y="y", data=long_df) + lineplot(x="x", y="y", hue="a", data=long_df) ax.clear() - rel.lineplot(x="x", y=long_df.y.values, data=long_df) + lineplot(x="x", y="y", hue="a", style="a", data=long_df) ax.clear() - rel.lineplot(x="x", y="y", hue="a", data=long_df) + lineplot(x="x", y="y", hue="a", style="b", data=long_df) ax.clear() - rel.lineplot(x="x", y="y", hue="a", style="a", data=long_df) + lineplot(x="x", y="y", hue="a", style="a", data=missing_df) ax.clear() - rel.lineplot(x="x", y="y", hue="a", style="b", data=long_df) + lineplot(x="x", y="y", hue="a", style="b", data=missing_df) ax.clear() - rel.lineplot(x="x", y="y", hue="a", style="a", data=missing_df) + lineplot(x="x", y="y", hue="a", size="a", data=long_df) ax.clear() - rel.lineplot(x="x", y="y", hue="a", style="b", data=missing_df) + lineplot(x="x", y="y", hue="a", size="s", data=long_df) ax.clear() - rel.lineplot(x="x", y="y", hue="a", size="a", data=long_df) + lineplot(x="x", y="y", hue="a", size="a", data=missing_df) ax.clear() - rel.lineplot(x="x", y="y", hue="a", size="s", data=long_df) + lineplot(x="x", y="y", hue="a", size="s", data=missing_df) ax.clear() - rel.lineplot(x="x", y="y", hue="a", size="a", data=missing_df) + lineplot(x="x", y="y", hue="f", data=object_df) ax.clear() - rel.lineplot(x="x", y="y", hue="a", size="s", data=missing_df) + lineplot(x="x", y="y", hue="c", size="f", data=object_df) ax.clear() + lineplot(x="x", y="y", hue="f", size="s", data=object_df) + ax.clear() + + def test_ci_deprecation(self, long_df): + + axs = plt.figure().subplots(2) + lineplot(data=long_df, x="x", y="y", errorbar=("ci", 95), seed=0, ax=axs[0]) + with pytest.warns(UserWarning, match="The `ci` parameter is deprecated"): + lineplot(data=long_df, x="x", y="y", ci=95, seed=0, ax=axs[1]) + assert_plots_equal(*axs) + + axs = plt.figure().subplots(2) + lineplot(data=long_df, x="x", y="y", errorbar="sd", ax=axs[0]) + with pytest.warns(UserWarning, match="The `ci` parameter is deprecated"): + lineplot(data=long_df, x="x", y="y", ci="sd", ax=axs[1]) + assert_plots_equal(*axs) + + +class TestScatterPlotter(SharedAxesLevelTests, Helpers): + + func = staticmethod(scatterplot) + + def get_last_color(self, ax): -class TestScatterPlotter(TestRelationalPlotter): + colors = ax.collections[-1].get_facecolors() + unique_colors = np.unique(colors, axis=0) + assert len(unique_colors) == 1 + return to_rgba(unique_colors.squeeze()) + + def test_color(self, long_df): + + super().test_color(long_df) + + ax = plt.figure().subplots() + self.func(data=long_df, x="x", y="y", facecolor="C5", ax=ax) + assert self.get_last_color(ax) == to_rgba("C5") + + ax = plt.figure().subplots() + self.func(data=long_df, x="x", y="y", facecolors="C6", ax=ax) + assert self.get_last_color(ax) == to_rgba("C6") + + if LooseVersion(mpl.__version__) >= "3.1.0": + # https://github.com/matplotlib/matplotlib/pull/12851 + + ax = plt.figure().subplots() + self.func(data=long_df, x="x", y="y", fc="C4", ax=ax) + assert self.get_last_color(ax) == to_rgba("C4") def test_legend_data(self, long_df): @@ -1261,72 +1270,132 @@ def test_legend_data(self, long_df): default_mark = m.get_path().transformed(m.get_transform()) m = mpl.markers.MarkerStyle("") - null_mark = m.get_path().transformed(m.get_transform()) + null = m.get_path().transformed(m.get_transform()) f, ax = plt.subplots() - p = rel._ScatterPlotter(x="x", y="y", data=long_df, legend="full") + p = _ScatterPlotter( + data=long_df, + variables=dict(x="x", y="y"), + legend="full", + ) p.add_legend_data(ax) - handles, _ = ax.get_legend_handles_labels() + handles, labels = ax.get_legend_handles_labels() assert handles == [] # -- ax.clear() - p = rel._ScatterPlotter(x="x", y="y", hue="a", data=long_df, - legend="full") + p = _ScatterPlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a"), + legend="full", + ) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() colors = [h.get_facecolors()[0] for h in handles] - expected_colors = ["w"] + [p.palette[l] for l in p.hue_levels] - assert labels == ["a"] + p.hue_levels - assert self.colors_equal(colors, expected_colors) + expected_colors = p._hue_map(p._hue_map.levels) + assert labels == p._hue_map.levels + assert same_color(colors, expected_colors) # -- ax.clear() - p = rel._ScatterPlotter(x="x", y="y", hue="a", style="a", - markers=True, legend="full", data=long_df) + p = _ScatterPlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a", style="a"), + legend="full", + ) + p.map_style(markers=True) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() colors = [h.get_facecolors()[0] for h in handles] - expected_colors = ["w"] + [p.palette[l] for l in p.hue_levels] + expected_colors = p._hue_map(p._hue_map.levels) paths = [h.get_paths()[0] for h in handles] - expected_paths = [null_mark] + [p.paths[l] for l in p.style_levels] - assert labels == ["a"] + p.hue_levels == ["a"] + p.style_levels - assert self.colors_equal(colors, expected_colors) + expected_paths = p._style_map(p._style_map.levels, "path") + assert labels == p._hue_map.levels + assert labels == p._style_map.levels + assert same_color(colors, expected_colors) assert self.paths_equal(paths, expected_paths) # -- ax.clear() - p = rel._ScatterPlotter(x="x", y="y", hue="a", style="b", - markers=True, legend="full", data=long_df) + p = _ScatterPlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a", style="b"), + legend="full", + ) + p.map_style(markers=True) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() colors = [h.get_facecolors()[0] for h in handles] paths = [h.get_paths()[0] for h in handles] - expected_colors = (["w"] + [p.palette[l] for l in p.hue_levels] - + ["w"] + [".2" for _ in p.style_levels]) - expected_paths = ([null_mark] + [default_mark for _ in p.hue_levels] - + [null_mark] + [p.paths[l] for l in p.style_levels]) - assert labels == ["a"] + p.hue_levels + ["b"] + p.style_levels - assert self.colors_equal(colors, expected_colors) + expected_colors = ( + ["w"] + p._hue_map(p._hue_map.levels) + + ["w"] + [".2" for _ in p._style_map.levels] + ) + expected_paths = ( + [null] + [default_mark for _ in p._hue_map.levels] + + [null] + p._style_map(p._style_map.levels, "path") + ) + assert labels == ( + ["a"] + p._hue_map.levels + ["b"] + p._style_map.levels + ) + assert same_color(colors, expected_colors) assert self.paths_equal(paths, expected_paths) # -- ax.clear() - p = rel._ScatterPlotter(x="x", y="y", hue="a", size="a", - data=long_df, legend="full") + p = _ScatterPlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a", size="a"), + legend="full" + ) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() colors = [h.get_facecolors()[0] for h in handles] - expected_colors = ["w"] + [p.palette[l] for l in p.hue_levels] + expected_colors = p._hue_map(p._hue_map.levels) + sizes = [h.get_sizes()[0] for h in handles] + expected_sizes = p._size_map(p._size_map.levels) + assert labels == p._hue_map.levels + assert labels == p._size_map.levels + assert same_color(colors, expected_colors) + assert sizes == expected_sizes + + # -- + + ax.clear() + sizes_list = [10, 100, 200] + p = _ScatterPlotter( + data=long_df, + variables=dict(x="x", y="y", size="s"), + legend="full", + ) + p.map_size(sizes=sizes_list) + p.add_legend_data(ax) + handles, labels = ax.get_legend_handles_labels() + sizes = [h.get_sizes()[0] for h in handles] + expected_sizes = p._size_map(p._size_map.levels) + assert labels == [str(l) for l in p._size_map.levels] + assert sizes == expected_sizes + + # -- + + ax.clear() + sizes_dict = {2: 10, 4: 100, 8: 200} + p = _ScatterPlotter( + data=long_df, + variables=dict(x="x", y="y", size="s"), + legend="full" + ) + p.map_size(sizes=sizes_dict) + p.add_legend_data(ax) + handles, labels = ax.get_legend_handles_labels() sizes = [h.get_sizes()[0] for h in handles] - expected_sizes = [0] + [p.sizes[l] for l in p.size_levels] - assert labels == ["a"] + p.hue_levels == ["a"] + p.size_levels - assert self.colors_equal(colors, expected_colors) + expected_sizes = p._size_map(p._size_map.levels) + assert labels == [str(l) for l in p._size_map.levels] assert sizes == expected_sizes # -- @@ -1334,33 +1403,37 @@ def test_legend_data(self, long_df): x, y = np.random.randn(2, 40) z = np.tile(np.arange(20), 2) - p = rel._ScatterPlotter(x=x, y=y, hue=z) + p = _ScatterPlotter( + variables=dict(x=x, y=y, hue=z), + ) ax.clear() p.legend = "full" p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() - assert labels == [str(l) for l in p.hue_levels] + assert labels == [str(l) for l in p._hue_map.levels] ax.clear() p.legend = "brief" p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() - assert len(labels) == 4 + assert len(labels) < len(p._hue_map.levels) - p = rel._ScatterPlotter(x=x, y=y, size=z) + p = _ScatterPlotter( + variables=dict(x=x, y=y, size=z), + ) ax.clear() p.legend = "full" p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() - assert labels == [str(l) for l in p.size_levels] + assert labels == [str(l) for l in p._size_map.levels] ax.clear() p.legend = "brief" p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() - assert len(labels) == 4 + assert len(labels) < len(p._size_map.levels) ax.clear() p.legend = "bad_value" @@ -1371,68 +1444,96 @@ def test_plot(self, long_df, repeated_df): f, ax = plt.subplots() - p = rel._ScatterPlotter(x="x", y="y", data=long_df) + p = _ScatterPlotter(data=long_df, variables=dict(x="x", y="y")) p.plot(ax, {}) points = ax.collections[0] - assert np.array_equal(points.get_offsets(), long_df[["x", "y"]].values) + assert_array_equal(points.get_offsets(), long_df[["x", "y"]].to_numpy()) ax.clear() p.plot(ax, {"color": "k", "label": "test"}) points = ax.collections[0] - assert self.colors_equal(points.get_facecolor(), "k") + assert same_color(points.get_facecolor(), "k") assert points.get_label() == "test" - p = rel._ScatterPlotter(x="x", y="y", hue="a", data=long_df) + p = _ScatterPlotter( + data=long_df, variables=dict(x="x", y="y", hue="a") + ) ax.clear() p.plot(ax, {}) points = ax.collections[0] - expected_colors = [p.palette[k] for k in p.plot_data["hue"]] - assert self.colors_equal(points.get_facecolors(), expected_colors) + expected_colors = p._hue_map(p.plot_data["hue"]) + assert same_color(points.get_facecolors(), expected_colors) - p = rel._ScatterPlotter(x="x", y="y", style="c", - markers=["+", "x"], data=long_df) + p = _ScatterPlotter( + data=long_df, + variables=dict(x="x", y="y", style="c"), + ) + p.map_style(markers=["+", "x"]) ax.clear() color = (1, .3, .8) p.plot(ax, {"color": color}) points = ax.collections[0] - assert self.colors_equal(points.get_edgecolors(), [color]) + assert same_color(points.get_edgecolors(), [color]) - p = rel._ScatterPlotter(x="x", y="y", size="a", data=long_df) + p = _ScatterPlotter( + data=long_df, variables=dict(x="x", y="y", size="a"), + ) ax.clear() p.plot(ax, {}) points = ax.collections[0] - expected_sizes = [p.size_lookup(k) for k in p.plot_data["size"]] - assert np.array_equal(points.get_sizes(), expected_sizes) + expected_sizes = p._size_map(p.plot_data["size"]) + assert_array_equal(points.get_sizes(), expected_sizes) - p = rel._ScatterPlotter(x="x", y="y", hue="a", style="a", - markers=True, data=long_df) + p = _ScatterPlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a", style="a"), + ) + p.map_style(markers=True) ax.clear() p.plot(ax, {}) - expected_colors = [p.palette[k] for k in p.plot_data["hue"]] - expected_paths = [p.paths[k] for k in p.plot_data["style"]] - assert self.colors_equal(points.get_facecolors(), expected_colors) + points = ax.collections[0] + expected_colors = p._hue_map(p.plot_data["hue"]) + expected_paths = p._style_map(p.plot_data["style"], "path") + assert same_color(points.get_facecolors(), expected_colors) assert self.paths_equal(points.get_paths(), expected_paths) - p = rel._ScatterPlotter(x="x", y="y", hue="a", style="b", - markers=True, data=long_df) + p = _ScatterPlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a", style="b"), + ) + p.map_style(markers=True) ax.clear() p.plot(ax, {}) - expected_colors = [p.palette[k] for k in p.plot_data["hue"]] - expected_paths = [p.paths[k] for k in p.plot_data["style"]] - assert self.colors_equal(points.get_facecolors(), expected_colors) + points = ax.collections[0] + expected_colors = p._hue_map(p.plot_data["hue"]) + expected_paths = p._style_map(p.plot_data["style"], "path") + assert same_color(points.get_facecolors(), expected_colors) assert self.paths_equal(points.get_paths(), expected_paths) + x_str = long_df["x"].astype(str) + p = _ScatterPlotter( + data=long_df, variables=dict(x="x", y="y", hue=x_str), + ) + ax.clear() + p.plot(ax, {}) + + p = _ScatterPlotter( + data=long_df, variables=dict(x="x", y="y", size=x_str), + ) + ax.clear() + p.plot(ax, {}) + def test_axis_labels(self, long_df): f, (ax1, ax2) = plt.subplots(1, 2, sharey=True) - p = rel._ScatterPlotter(x="x", y="y", data=long_df) + p = _ScatterPlotter(data=long_df, variables=dict(x="x", y="y")) p.plot(ax1, {}) assert ax1.get_xlabel() == "x" @@ -1448,190 +1549,183 @@ def test_scatterplot_axes(self, wide_df): f1, ax1 = plt.subplots() f2, ax2 = plt.subplots() - ax = rel.scatterplot(data=wide_df) + ax = scatterplot(data=wide_df) assert ax is ax2 - ax = rel.scatterplot(data=wide_df, ax=ax1) + ax = scatterplot(data=wide_df, ax=ax1) assert ax is ax1 - def test_scatterplot_smoke(self, flat_array, flat_series, - wide_array, wide_list, wide_list_of_series, - wide_df, long_df, missing_df): + def test_literal_attribute_vectors(self): f, ax = plt.subplots() - rel.scatterplot([], []) - ax.clear() + x = y = [1, 2, 3] + s = [5, 10, 15] + c = [(1, 1, 0, 1), (1, 0, 1, .5), (.5, 1, 0, 1)] - rel.scatterplot(data=flat_array) - ax.clear() + scatterplot(x=x, y=y, c=c, s=s, ax=ax) - rel.scatterplot(data=flat_series) - ax.clear() + points, = ax.collections - rel.scatterplot(data=wide_array) - ax.clear() + assert_array_equal(points.get_sizes().squeeze(), s) + assert_array_equal(points.get_facecolors(), c) - rel.scatterplot(data=wide_list) - ax.clear() + def test_supplied_color_array(self, long_df): - rel.scatterplot(data=wide_list_of_series) - ax.clear() + cmap = mpl.cm.get_cmap("Blues") + norm = mpl.colors.Normalize() + colors = cmap(norm(long_df["y"].to_numpy())) - rel.scatterplot(data=wide_df) - ax.clear() + keys = ["c", "facecolor", "facecolors"] - rel.scatterplot(x="x", y="y", data=long_df) - ax.clear() + if LooseVersion(mpl.__version__) >= "3.1.0": + # https://github.com/matplotlib/matplotlib/pull/12851 + keys.append("fc") - rel.scatterplot(x=long_df.x, y=long_df.y) - ax.clear() + for key in keys: - rel.scatterplot(x=long_df.x, y="y", data=long_df) - ax.clear() + ax = plt.figure().subplots() + scatterplot(data=long_df, x="x", y="y", **{key: colors}) + _draw_figure(ax.figure) + assert_array_equal(ax.collections[0].get_facecolors(), colors) - rel.scatterplot(x="x", y=long_df.y.values, data=long_df) - ax.clear() + ax = plt.figure().subplots() + scatterplot(data=long_df, x="x", y="y", c=long_df["y"], cmap=cmap) + _draw_figure(ax.figure) + assert_array_equal(ax.collections[0].get_facecolors(), colors) - rel.scatterplot(x="x", y="y", hue="a", data=long_df) - ax.clear() + def test_linewidths(self, long_df): - rel.scatterplot(x="x", y="y", hue="a", style="a", data=long_df) - ax.clear() + f, ax = plt.subplots() - rel.scatterplot(x="x", y="y", hue="a", style="b", data=long_df) - ax.clear() + scatterplot(data=long_df, x="x", y="y", s=10) + scatterplot(data=long_df, x="x", y="y", s=20) + points1, points2 = ax.collections + assert ( + points1.get_linewidths().item() < points2.get_linewidths().item() + ) - rel.scatterplot(x="x", y="y", hue="a", style="a", data=missing_df) ax.clear() + scatterplot(data=long_df, x="x", y="y", s=long_df["x"]) + scatterplot(data=long_df, x="x", y="y", s=long_df["x"] * 2) + points1, points2 = ax.collections + assert ( + points1.get_linewidths().item() < points2.get_linewidths().item() + ) - rel.scatterplot(x="x", y="y", hue="a", style="b", data=missing_df) ax.clear() + scatterplot(data=long_df, x="x", y="y", size=long_df["x"]) + scatterplot(data=long_df, x="x", y="y", size=long_df["x"] * 2) + points1, points2, *_ = ax.collections + assert ( + points1.get_linewidths().item() < points2.get_linewidths().item() + ) - rel.scatterplot(x="x", y="y", hue="a", size="a", data=long_df) ax.clear() + lw = 2 + scatterplot(data=long_df, x="x", y="y", linewidth=lw) + assert ax.collections[0].get_linewidths().item() == lw - rel.scatterplot(x="x", y="y", hue="a", size="s", data=long_df) - ax.clear() + def test_datetime_scale(self, long_df): - rel.scatterplot(x="x", y="y", hue="a", size="a", data=missing_df) - ax.clear() + ax = scatterplot(data=long_df, x="t", y="y") + # Check that we avoid weird matplotlib default auto scaling + # https://github.com/matplotlib/matplotlib/issues/17586 + ax.get_xlim()[0] > ax.xaxis.convert_units(np.datetime64("2002-01-01")) - rel.scatterplot(x="x", y="y", hue="a", size="s", data=missing_df) - ax.clear() + def test_scatterplot_vs_relplot(self, long_df, long_semantics): + ax = scatterplot(data=long_df, **long_semantics) + g = relplot(data=long_df, kind="scatter", **long_semantics) -class TestRelPlotter(TestRelationalPlotter): + for s_pts, r_pts in zip(ax.collections, g.ax.collections): - def test_relplot_simple(self, long_df): + assert_array_equal(s_pts.get_offsets(), r_pts.get_offsets()) + assert_array_equal(s_pts.get_sizes(), r_pts.get_sizes()) + assert_array_equal(s_pts.get_facecolors(), r_pts.get_facecolors()) + assert self.paths_equal(s_pts.get_paths(), r_pts.get_paths()) - g = rel.relplot(x="x", y="y", kind="scatter", data=long_df) - x, y = g.ax.collections[0].get_offsets().T - assert np.array_equal(x, long_df["x"]) - assert np.array_equal(y, long_df["y"]) + def test_scatterplot_smoke( + self, + wide_df, wide_array, + flat_series, flat_array, flat_list, + wide_list_of_series, wide_list_of_arrays, wide_list_of_lists, + long_df, missing_df, object_df + ): - g = rel.relplot(x="x", y="y", kind="line", data=long_df) - x, y = g.ax.lines[0].get_xydata().T - expected = long_df.groupby("x").y.mean() - assert np.array_equal(x, expected.index) - assert y == pytest.approx(expected.values) + f, ax = plt.subplots() - with pytest.raises(ValueError): - g = rel.relplot(x="x", y="y", kind="not_a_kind", data=long_df) + scatterplot(x=[], y=[]) + ax.clear() - def test_relplot_complex(self, long_df): + scatterplot(data=wide_df) + ax.clear() - for sem in ["hue", "size", "style"]: - g = rel.relplot(x="x", y="y", data=long_df, **{sem: "a"}) - x, y = g.ax.collections[0].get_offsets().T - assert np.array_equal(x, long_df["x"]) - assert np.array_equal(y, long_df["y"]) + scatterplot(data=wide_array) + ax.clear() - for sem in ["hue", "size", "style"]: - g = rel.relplot(x="x", y="y", col="c", data=long_df, - **{sem: "a"}) - grouped = long_df.groupby("c") - for (_, grp_df), ax in zip(grouped, g.axes.flat): - x, y = ax.collections[0].get_offsets().T - assert np.array_equal(x, grp_df["x"]) - assert np.array_equal(y, grp_df["y"]) + scatterplot(data=wide_list_of_series) + ax.clear() - for sem in ["size", "style"]: - g = rel.relplot(x="x", y="y", hue="b", col="c", data=long_df, - **{sem: "a"}) - grouped = long_df.groupby("c") - for (_, grp_df), ax in zip(grouped, g.axes.flat): - x, y = ax.collections[0].get_offsets().T - assert np.array_equal(x, grp_df["x"]) - assert np.array_equal(y, grp_df["y"]) + scatterplot(data=wide_list_of_arrays) + ax.clear() - for sem in ["hue", "size", "style"]: - g = rel.relplot(x="x", y="y", col="b", row="c", - data=sort_df(long_df, ["c", "b"]), - **{sem: "a"}) - grouped = long_df.groupby(["c", "b"]) - for (_, grp_df), ax in zip(grouped, g.axes.flat): - x, y = ax.collections[0].get_offsets().T - assert np.array_equal(x, grp_df["x"]) - assert np.array_equal(y, grp_df["y"]) + scatterplot(data=wide_list_of_lists) + ax.clear() - def test_relplot_hues(self, long_df): + scatterplot(data=flat_series) + ax.clear() - palette = ["r", "b", "g"] - g = rel.relplot(x="x", y="y", hue="a", style="b", col="c", - palette=palette, data=long_df) + scatterplot(data=flat_array) + ax.clear() - palette = dict(zip(long_df["a"].unique(), palette)) - grouped = long_df.groupby("c") - for (_, grp_df), ax in zip(grouped, g.axes.flat): - points = ax.collections[0] - expected_hues = [palette[val] for val in grp_df["a"]] - assert self.colors_equal(points.get_facecolors(), expected_hues) + scatterplot(data=flat_list) + ax.clear() - def test_relplot_sizes(self, long_df): + scatterplot(x="x", y="y", data=long_df) + ax.clear() - sizes = [5, 12, 7] - g = rel.relplot(x="x", y="y", size="a", hue="b", col="c", - sizes=sizes, data=long_df) + scatterplot(x=long_df.x, y=long_df.y) + ax.clear() - sizes = dict(zip(long_df["a"].unique(), sizes)) - grouped = long_df.groupby("c") - for (_, grp_df), ax in zip(grouped, g.axes.flat): - points = ax.collections[0] - expected_sizes = [sizes[val] for val in grp_df["a"]] - assert np.array_equal(points.get_sizes(), expected_sizes) + scatterplot(x=long_df.x, y="y", data=long_df) + ax.clear() - def test_relplot_styles(self, long_df): + scatterplot(x="x", y=long_df.y.to_numpy(), data=long_df) + ax.clear() - markers = ["o", "d", "s"] - g = rel.relplot(x="x", y="y", style="a", hue="b", col="c", - markers=markers, data=long_df) + scatterplot(x="x", y="y", hue="a", data=long_df) + ax.clear() - paths = [] - for m in markers: - m = mpl.markers.MarkerStyle(m) - paths.append(m.get_path().transformed(m.get_transform())) - paths = dict(zip(long_df["a"].unique(), paths)) + scatterplot(x="x", y="y", hue="a", style="a", data=long_df) + ax.clear() - grouped = long_df.groupby("c") - for (_, grp_df), ax in zip(grouped, g.axes.flat): - points = ax.collections[0] - expected_paths = [paths[val] for val in grp_df["a"]] - assert self.paths_equal(points.get_paths(), expected_paths) + scatterplot(x="x", y="y", hue="a", style="b", data=long_df) + ax.clear() - def test_relplot_legend(self, long_df): + scatterplot(x="x", y="y", hue="a", style="a", data=missing_df) + ax.clear() - g = rel.relplot(x="x", y="y", data=long_df) - assert g._legend is None + scatterplot(x="x", y="y", hue="a", style="b", data=missing_df) + ax.clear() - g = rel.relplot(x="x", y="y", hue="a", data=long_df) - texts = [t.get_text() for t in g._legend.texts] - expected_texts = np.append(["a"], long_df["a"].unique()) - assert np.array_equal(texts, expected_texts) + scatterplot(x="x", y="y", hue="a", size="a", data=long_df) + ax.clear() - g = rel.relplot(x="x", y="y", hue="s", size="s", data=long_df) - texts = [t.get_text() for t in g._legend.texts] - assert np.array_equal(texts[1:], np.sort(texts[1:])) + scatterplot(x="x", y="y", hue="a", size="s", data=long_df) + ax.clear() - g = rel.relplot(x="x", y="y", hue="a", legend=False, data=long_df) - assert g._legend is None + scatterplot(x="x", y="y", hue="a", size="a", data=missing_df) + ax.clear() + + scatterplot(x="x", y="y", hue="a", size="s", data=missing_df) + ax.clear() + + scatterplot(x="x", y="y", hue="f", data=object_df) + ax.clear() + + scatterplot(x="x", y="y", hue="c", size="f", data=object_df) + ax.clear() + + scatterplot(x="x", y="y", hue="f", size="s", data=object_df) + ax.clear() diff --git a/seaborn/tests/test_statistics.py b/seaborn/tests/test_statistics.py new file mode 100644 index 0000000000..6274be4c42 --- /dev/null +++ b/seaborn/tests/test_statistics.py @@ -0,0 +1,593 @@ +import numpy as np +import pandas as pd + +try: + import statsmodels.distributions as smdist +except ImportError: + smdist = None + +import pytest +from numpy.testing import assert_array_equal, assert_array_almost_equal + +from .._statistics import ( + KDE, + Histogram, + ECDF, + EstimateAggregator, + _validate_errorbar_arg, + _no_scipy, +) + + +class DistributionFixtures: + + @pytest.fixture + def x(self, rng): + return rng.normal(0, 1, 100) + + @pytest.fixture + def y(self, rng): + return rng.normal(0, 5, 100) + + @pytest.fixture + def weights(self, rng): + return rng.uniform(0, 5, 100) + + +class TestKDE: + + def integrate(self, y, x): + y = np.asarray(y) + x = np.asarray(x) + dx = np.diff(x) + return (dx * y[:-1] + dx * y[1:]).sum() / 2 + + def test_gridsize(self, rng): + + x = rng.normal(0, 3, 1000) + + n = 200 + kde = KDE(gridsize=n) + density, support = kde(x) + assert density.size == n + assert support.size == n + + def test_cut(self, rng): + + x = rng.normal(0, 3, 1000) + + kde = KDE(cut=0) + _, support = kde(x) + assert support.min() == x.min() + assert support.max() == x.max() + + cut = 2 + bw_scale = .5 + bw = x.std() * bw_scale + kde = KDE(cut=cut, bw_method=bw_scale, gridsize=1000) + _, support = kde(x) + assert support.min() == pytest.approx(x.min() - bw * cut, abs=1e-2) + assert support.max() == pytest.approx(x.max() + bw * cut, abs=1e-2) + + def test_clip(self, rng): + + x = rng.normal(0, 3, 100) + clip = -1, 1 + kde = KDE(clip=clip) + _, support = kde(x) + + assert support.min() >= clip[0] + assert support.max() <= clip[1] + + def test_density_normalization(self, rng): + + x = rng.normal(0, 3, 1000) + kde = KDE() + density, support = kde(x) + assert self.integrate(density, support) == pytest.approx(1, abs=1e-5) + + @pytest.mark.skipif(_no_scipy, reason="Test requires scipy") + def test_cumulative(self, rng): + + x = rng.normal(0, 3, 1000) + kde = KDE(cumulative=True) + density, _ = kde(x) + assert density[0] == pytest.approx(0, abs=1e-5) + assert density[-1] == pytest.approx(1, abs=1e-5) + + def test_cached_support(self, rng): + + x = rng.normal(0, 3, 100) + kde = KDE() + kde.define_support(x) + _, support = kde(x[(x > -1) & (x < 1)]) + assert_array_equal(support, kde.support) + + def test_bw_method(self, rng): + + x = rng.normal(0, 3, 100) + kde1 = KDE(bw_method=.2) + kde2 = KDE(bw_method=2) + + d1, _ = kde1(x) + d2, _ = kde2(x) + + assert np.abs(np.diff(d1)).mean() > np.abs(np.diff(d2)).mean() + + def test_bw_adjust(self, rng): + + x = rng.normal(0, 3, 100) + kde1 = KDE(bw_adjust=.2) + kde2 = KDE(bw_adjust=2) + + d1, _ = kde1(x) + d2, _ = kde2(x) + + assert np.abs(np.diff(d1)).mean() > np.abs(np.diff(d2)).mean() + + def test_bivariate_grid(self, rng): + + n = 100 + x, y = rng.normal(0, 3, (2, 50)) + kde = KDE(gridsize=n) + density, (xx, yy) = kde(x, y) + + assert density.shape == (n, n) + assert xx.size == n + assert yy.size == n + + def test_bivariate_normalization(self, rng): + + x, y = rng.normal(0, 3, (2, 50)) + kde = KDE(gridsize=100) + density, (xx, yy) = kde(x, y) + + dx = xx[1] - xx[0] + dy = yy[1] - yy[0] + + total = density.sum() * (dx * dy) + assert total == pytest.approx(1, abs=1e-2) + + @pytest.mark.skipif(_no_scipy, reason="Test requires scipy") + def test_bivariate_cumulative(self, rng): + + x, y = rng.normal(0, 3, (2, 50)) + kde = KDE(gridsize=100, cumulative=True) + density, _ = kde(x, y) + + assert density[0, 0] == pytest.approx(0, abs=1e-2) + assert density[-1, -1] == pytest.approx(1, abs=1e-2) + + +class TestHistogram(DistributionFixtures): + + def test_string_bins(self, x): + + h = Histogram(bins="sqrt") + edges = h.define_bin_edges(x) + assert_array_equal(edges, np.histogram_bin_edges(x, "sqrt")) + + def test_int_bins(self, x): + + n = 24 + h = Histogram(bins=n) + edges = h.define_bin_edges(x) + assert len(edges) == n + 1 + + def test_array_bins(self, x): + + bins = [-3, -2, 1, 2, 3] + h = Histogram(bins=bins) + edges = h.define_bin_edges(x) + assert_array_equal(edges, bins) + + def test_bivariate_string_bins(self, x, y): + + s1, s2 = "sqrt", "fd" + + h = Histogram(bins=s1) + e1, e2 = h.define_bin_edges(x, y) + assert_array_equal(e1, np.histogram_bin_edges(x, s1)) + assert_array_equal(e2, np.histogram_bin_edges(y, s1)) + + h = Histogram(bins=(s1, s2)) + e1, e2 = h.define_bin_edges(x, y) + assert_array_equal(e1, np.histogram_bin_edges(x, s1)) + assert_array_equal(e2, np.histogram_bin_edges(y, s2)) + + def test_bivariate_int_bins(self, x, y): + + b1, b2 = 5, 10 + + h = Histogram(bins=b1) + e1, e2 = h.define_bin_edges(x, y) + assert len(e1) == b1 + 1 + assert len(e2) == b1 + 1 + + h = Histogram(bins=(b1, b2)) + e1, e2 = h.define_bin_edges(x, y) + assert len(e1) == b1 + 1 + assert len(e2) == b2 + 1 + + def test_bivariate_array_bins(self, x, y): + + b1 = [-3, -2, 1, 2, 3] + b2 = [-5, -2, 3, 6] + + h = Histogram(bins=b1) + e1, e2 = h.define_bin_edges(x, y) + assert_array_equal(e1, b1) + assert_array_equal(e2, b1) + + h = Histogram(bins=(b1, b2)) + e1, e2 = h.define_bin_edges(x, y) + assert_array_equal(e1, b1) + assert_array_equal(e2, b2) + + def test_binwidth(self, x): + + binwidth = .5 + h = Histogram(binwidth=binwidth) + edges = h.define_bin_edges(x) + assert np.all(np.diff(edges) == binwidth) + + def test_bivariate_binwidth(self, x, y): + + w1, w2 = .5, 1 + + h = Histogram(binwidth=w1) + e1, e2 = h.define_bin_edges(x, y) + assert np.all(np.diff(e1) == w1) + assert np.all(np.diff(e2) == w1) + + h = Histogram(binwidth=(w1, w2)) + e1, e2 = h.define_bin_edges(x, y) + assert np.all(np.diff(e1) == w1) + assert np.all(np.diff(e2) == w2) + + def test_binrange(self, x): + + binrange = (-4, 4) + h = Histogram(binrange=binrange) + edges = h.define_bin_edges(x) + assert edges.min() == binrange[0] + assert edges.max() == binrange[1] + + def test_bivariate_binrange(self, x, y): + + r1, r2 = (-4, 4), (-10, 10) + + h = Histogram(binrange=r1) + e1, e2 = h.define_bin_edges(x, y) + assert e1.min() == r1[0] + assert e1.max() == r1[1] + assert e2.min() == r1[0] + assert e2.max() == r1[1] + + h = Histogram(binrange=(r1, r2)) + e1, e2 = h.define_bin_edges(x, y) + assert e1.min() == r1[0] + assert e1.max() == r1[1] + assert e2.min() == r2[0] + assert e2.max() == r2[1] + + def test_discrete_bins(self, rng): + + x = rng.binomial(20, .5, 100) + h = Histogram(discrete=True) + edges = h.define_bin_edges(x) + expected_edges = np.arange(x.min(), x.max() + 2) - .5 + assert_array_equal(edges, expected_edges) + + def test_histogram(self, x): + + h = Histogram() + heights, edges = h(x) + heights_mpl, edges_mpl = np.histogram(x, bins="auto") + + assert_array_equal(heights, heights_mpl) + assert_array_equal(edges, edges_mpl) + + def test_count_stat(self, x): + + h = Histogram(stat="count") + heights, _ = h(x) + assert heights.sum() == len(x) + + def test_density_stat(self, x): + + h = Histogram(stat="density") + heights, edges = h(x) + assert (heights * np.diff(edges)).sum() == 1 + + def test_probability_stat(self, x): + + h = Histogram(stat="probability") + heights, _ = h(x) + assert heights.sum() == 1 + + def test_frequency_stat(self, x): + + h = Histogram(stat="frequency") + heights, edges = h(x) + assert (heights * np.diff(edges)).sum() == len(x) + + def test_cumulative_count(self, x): + + h = Histogram(stat="count", cumulative=True) + heights, _ = h(x) + assert heights[-1] == len(x) + + def test_cumulative_density(self, x): + + h = Histogram(stat="density", cumulative=True) + heights, _ = h(x) + assert heights[-1] == 1 + + def test_cumulative_probability(self, x): + + h = Histogram(stat="probability", cumulative=True) + heights, _ = h(x) + assert heights[-1] == 1 + + def test_cumulative_frequency(self, x): + + h = Histogram(stat="frequency", cumulative=True) + heights, _ = h(x) + assert heights[-1] == len(x) + + def test_bivariate_histogram(self, x, y): + + h = Histogram() + heights, edges = h(x, y) + bins_mpl = ( + np.histogram_bin_edges(x, "auto"), + np.histogram_bin_edges(y, "auto"), + ) + heights_mpl, *edges_mpl = np.histogram2d(x, y, bins_mpl) + assert_array_equal(heights, heights_mpl) + assert_array_equal(edges[0], edges_mpl[0]) + assert_array_equal(edges[1], edges_mpl[1]) + + def test_bivariate_count_stat(self, x, y): + + h = Histogram(stat="count") + heights, _ = h(x, y) + assert heights.sum() == len(x) + + def test_bivariate_density_stat(self, x, y): + + h = Histogram(stat="density") + heights, (edges_x, edges_y) = h(x, y) + areas = np.outer(np.diff(edges_x), np.diff(edges_y)) + assert (heights * areas).sum() == pytest.approx(1) + + def test_bivariate_probability_stat(self, x, y): + + h = Histogram(stat="probability") + heights, _ = h(x, y) + assert heights.sum() == 1 + + def test_bivariate_frequency_stat(self, x, y): + + h = Histogram(stat="frequency") + heights, (x_edges, y_edges) = h(x, y) + area = np.outer(np.diff(x_edges), np.diff(y_edges)) + assert (heights * area).sum() == len(x) + + def test_bivariate_cumulative_count(self, x, y): + + h = Histogram(stat="count", cumulative=True) + heights, _ = h(x, y) + assert heights[-1, -1] == len(x) + + def test_bivariate_cumulative_density(self, x, y): + + h = Histogram(stat="density", cumulative=True) + heights, _ = h(x, y) + assert heights[-1, -1] == pytest.approx(1) + + def test_bivariate_cumulative_frequency(self, x, y): + + h = Histogram(stat="frequency", cumulative=True) + heights, _ = h(x, y) + assert heights[-1, -1] == len(x) + + def test_bivariate_cumulative_probability(self, x, y): + + h = Histogram(stat="probability", cumulative=True) + heights, _ = h(x, y) + assert heights[-1, -1] == pytest.approx(1) + + def test_bad_stat(self): + + with pytest.raises(ValueError): + Histogram(stat="invalid") + + +class TestECDF(DistributionFixtures): + + def test_univariate_proportion(self, x): + + ecdf = ECDF() + stat, vals = ecdf(x) + assert_array_equal(vals[1:], np.sort(x)) + assert_array_almost_equal(stat[1:], np.linspace(0, 1, len(x) + 1)[1:]) + assert stat[0] == 0 + + def test_univariate_count(self, x): + + ecdf = ECDF(stat="count") + stat, vals = ecdf(x) + + assert_array_equal(vals[1:], np.sort(x)) + assert_array_almost_equal(stat[1:], np.arange(len(x)) + 1) + assert stat[0] == 0 + + def test_univariate_proportion_weights(self, x, weights): + + ecdf = ECDF() + stat, vals = ecdf(x, weights=weights) + assert_array_equal(vals[1:], np.sort(x)) + expected_stats = weights[x.argsort()].cumsum() / weights.sum() + assert_array_almost_equal(stat[1:], expected_stats) + assert stat[0] == 0 + + def test_univariate_count_weights(self, x, weights): + + ecdf = ECDF(stat="count") + stat, vals = ecdf(x, weights=weights) + assert_array_equal(vals[1:], np.sort(x)) + assert_array_almost_equal(stat[1:], weights[x.argsort()].cumsum()) + assert stat[0] == 0 + + @pytest.mark.skipif(smdist is None, reason="Requires statsmodels") + def test_against_statsmodels(self, x): + + sm_ecdf = smdist.empirical_distribution.ECDF(x) + + ecdf = ECDF() + stat, vals = ecdf(x) + assert_array_equal(vals, sm_ecdf.x) + assert_array_almost_equal(stat, sm_ecdf.y) + + ecdf = ECDF(complementary=True) + stat, vals = ecdf(x) + assert_array_equal(vals, sm_ecdf.x) + assert_array_almost_equal(stat, sm_ecdf.y[::-1]) + + def test_invalid_stat(self, x): + + with pytest.raises(ValueError, match="`stat` must be one of"): + ECDF(stat="density") + + def test_bivariate_error(self, x, y): + + with pytest.raises(NotImplementedError, match="Bivariate ECDF"): + ecdf = ECDF() + ecdf(x, y) + + +class TestEstimateAggregator: + + def test_func_estimator(self, long_df): + + func = np.mean + agg = EstimateAggregator(func) + out = agg(long_df, "x") + assert out["x"] == func(long_df["x"]) + + def test_name_estimator(self, long_df): + + agg = EstimateAggregator("mean") + out = agg(long_df, "x") + assert out["x"] == long_df["x"].mean() + + def test_se_errorbars(self, long_df): + + agg = EstimateAggregator("mean", "se") + out = agg(long_df, "x") + assert out["x"] == long_df["x"].mean() + assert out["xmin"] == (long_df["x"].mean() - long_df["x"].sem()) + assert out["xmax"] == (long_df["x"].mean() + long_df["x"].sem()) + + agg = EstimateAggregator("mean", ("se", 2)) + out = agg(long_df, "x") + assert out["x"] == long_df["x"].mean() + assert out["xmin"] == (long_df["x"].mean() - 2 * long_df["x"].sem()) + assert out["xmax"] == (long_df["x"].mean() + 2 * long_df["x"].sem()) + + def test_sd_errorbars(self, long_df): + + agg = EstimateAggregator("mean", "sd") + out = agg(long_df, "x") + assert out["x"] == long_df["x"].mean() + assert out["xmin"] == (long_df["x"].mean() - long_df["x"].std()) + assert out["xmax"] == (long_df["x"].mean() + long_df["x"].std()) + + agg = EstimateAggregator("mean", ("sd", 2)) + out = agg(long_df, "x") + assert out["x"] == long_df["x"].mean() + assert out["xmin"] == (long_df["x"].mean() - 2 * long_df["x"].std()) + assert out["xmax"] == (long_df["x"].mean() + 2 * long_df["x"].std()) + + def test_pi_errorbars(self, long_df): + + agg = EstimateAggregator("mean", "pi") + out = agg(long_df, "y") + assert out["ymin"] == np.percentile(long_df["y"], 2.5) + assert out["ymax"] == np.percentile(long_df["y"], 97.5) + + agg = EstimateAggregator("mean", ("pi", 50)) + out = agg(long_df, "y") + assert out["ymin"] == np.percentile(long_df["y"], 25) + assert out["ymax"] == np.percentile(long_df["y"], 75) + + def test_ci_errorbars(self, long_df): + + agg = EstimateAggregator("mean", "ci", n_boot=100000, seed=0) + out = agg(long_df, "y") + + agg_ref = EstimateAggregator("mean", ("se", 1.96)) + out_ref = agg_ref(long_df, "y") + + assert out["ymin"] == pytest.approx(out_ref["ymin"], abs=1e-2) + assert out["ymax"] == pytest.approx(out_ref["ymax"], abs=1e-2) + + agg = EstimateAggregator("mean", ("ci", 68), n_boot=100000, seed=0) + out = agg(long_df, "y") + + agg_ref = EstimateAggregator("mean", ("se", 1)) + out_ref = agg_ref(long_df, "y") + + assert out["ymin"] == pytest.approx(out_ref["ymin"], abs=1e-2) + assert out["ymax"] == pytest.approx(out_ref["ymax"], abs=1e-2) + + agg = EstimateAggregator("mean", "ci", seed=0) + out_orig = agg_ref(long_df, "y") + out_test = agg_ref(long_df, "y") + assert_array_equal(out_orig, out_test) + + def test_custom_errorbars(self, long_df): + + f = lambda x: (x.min(), x.max()) # noqa: E731 + agg = EstimateAggregator("mean", f) + out = agg(long_df, "y") + assert out["ymin"] == long_df["y"].min() + assert out["ymax"] == long_df["y"].max() + + def test_singleton_errorbars(self): + + agg = EstimateAggregator("mean", "ci") + val = 7 + out = agg(pd.DataFrame(dict(y=[val])), "y") + assert out["y"] == val + assert pd.isna(out["ymin"]) + assert pd.isna(out["ymax"]) + + def test_errorbar_validation(self): + + method, level = _validate_errorbar_arg(("ci", 99)) + assert method == "ci" + assert level == 99 + + method, level = _validate_errorbar_arg("sd") + assert method == "sd" + assert level == 1 + + f = lambda x: (x.min(), x.max()) # noqa: E731 + method, level = _validate_errorbar_arg(f) + assert method is f + assert level is None + + bad_args = [ + ("sem", ValueError), + (("std", 2), ValueError), + (("pi", 5, 95), ValueError), + (95, TypeError), + (("ci", "large"), TypeError), + ] + + for arg, exception in bad_args: + with pytest.raises(exception, match="`errorbar` must be"): + _validate_errorbar_arg(arg) diff --git a/seaborn/tests/test_utils.py b/seaborn/tests/test_utils.py index f3ec847ca2..cab8a357f2 100644 --- a/seaborn/tests/test_utils.py +++ b/seaborn/tests/test_utils.py @@ -1,66 +1,63 @@ -"""Tests for plotting utilities.""" +"""Tests for seaborn utility functions.""" import tempfile -import shutil +from urllib.request import urlopen +from http.client import HTTPException import numpy as np import pandas as pd import matplotlib as mpl import matplotlib.pyplot as plt -import nose -import nose.tools as nt -from nose.tools import assert_equal, raises -import numpy.testing as npt -try: - import pandas.testing as pdt -except ImportError: - import pandas.util.testing as pdt - -from distutils.version import LooseVersion -pandas_has_categoricals = LooseVersion(pd.__version__) >= "0.15" +from cycler import cycler +import pytest +from numpy.testing import ( + assert_array_equal, +) +from pandas.testing import ( + assert_series_equal, + assert_frame_equal, +) -try: - from bs4 import BeautifulSoup -except ImportError: - BeautifulSoup = None +from distutils.version import LooseVersion from .. import utils, rcmod -from ..utils import get_dataset_names, load_dataset, _network +from ..utils import ( + get_dataset_names, + get_color_cycle, + remove_na, + load_dataset, + _assign_default_kwargs, + _draw_figure, + _deprecate_ci, +) a_norm = np.random.randn(100) -def test_pmf_hist_basics(): - """Test the function to return barplot args for pmf hist.""" - out = utils.pmf_hist(a_norm) - assert_equal(len(out), 3) - x, h, w = out - assert_equal(len(x), len(h)) - - # Test simple case - a = np.arange(10) - x, h, w = utils.pmf_hist(a, 10) - nose.tools.assert_true(np.all(h == h[0])) - - -def test_pmf_hist_widths(): - """Test histogram width is correct.""" - x, h, w = utils.pmf_hist(a_norm) - assert_equal(x[1] - x[0], w) +def _network(t=None, url="https://github.com"): + """ + Decorator that will skip a test if `url` is unreachable. + Parameters + ---------- + t : function, optional + url : str, optional -def test_pmf_hist_normalization(): - """Test that output data behaves like a PMF.""" - x, h, w = utils.pmf_hist(a_norm) - nose.tools.assert_almost_equal(sum(h), 1) - nose.tools.assert_less_equal(h.max(), 1) + """ + if t is None: + return lambda x: _network(x, url=url) - -def test_pmf_hist_bins(): - """Test bin specification.""" - x, h, w = utils.pmf_hist(a_norm, 20) - assert_equal(len(x), 20) + def wrapper(*args, **kwargs): + # attempt to connect + try: + f = urlopen(url) + except (IOError, HTTPException): + pytest.skip("No internet connection") + else: + f.close() + return t(*args, **kwargs) + return wrapper def test_ci_to_errsize(): @@ -74,49 +71,55 @@ def test_ci_to_errsize(): [.25, 0]]) test_errsize = utils.ci_to_errsize(cis, heights) - npt.assert_array_equal(actual_errsize, test_errsize) + assert_array_equal(actual_errsize, test_errsize) def test_desaturate(): """Test color desaturation.""" out1 = utils.desaturate("red", .5) - assert_equal(out1, (.75, .25, .25)) + assert out1 == (.75, .25, .25) out2 = utils.desaturate("#00FF00", .5) - assert_equal(out2, (.25, .75, .25)) + assert out2 == (.25, .75, .25) out3 = utils.desaturate((0, 0, 1), .5) - assert_equal(out3, (.25, .25, .75)) + assert out3 == (.25, .25, .75) out4 = utils.desaturate("red", .5) - assert_equal(out4, (.75, .25, .25)) + assert out4 == (.75, .25, .25) -@raises(ValueError) def test_desaturation_prop(): """Test that pct outside of [0, 1] raises exception.""" - utils.desaturate("blue", 50) + with pytest.raises(ValueError): + utils.desaturate("blue", 50) def test_saturate(): """Test performance of saturation function.""" out = utils.saturate((.75, .25, .25)) - assert_equal(out, (1, 0, 0)) - - -def test_iqr(): - """Test the IQR function.""" - a = np.arange(5) - iqr = utils.iqr(a) - assert_equal(iqr, 2) - - -def test_str_to_utf8(): - """Test the to_utf8 function: string to Unicode""" - s = "\u01ff\u02ff" + assert out == (1, 0, 0) + + +@pytest.mark.parametrize( + "s,exp", + [ + ("a", "a"), + ("abc", "abc"), + (b"a", "a"), + (b"abc", "abc"), + (bytearray("abc", "utf-8"), "abc"), + (bytearray(), ""), + (1, "1"), + (0, "0"), + ([], str([])), + ], +) +def test_to_utf8(s, exp): + """Test the to_utf8 function: object to string""" u = utils.to_utf8(s) - assert_equal(type(s), type(str())) - assert_equal(type(u), type(u"\u01ff\u02ff")) + assert type(u) == str + assert u == exp class TestSpineUtils(object): @@ -132,17 +135,17 @@ class TestSpineUtils(object): def test_despine(self): f, ax = plt.subplots() for side in self.sides: - nt.assert_true(ax.spines[side].get_visible()) + assert ax.spines[side].get_visible() utils.despine() for side in self.outer_sides: - nt.assert_true(~ax.spines[side].get_visible()) + assert ~ax.spines[side].get_visible() for side in self.inner_sides: - nt.assert_true(ax.spines[side].get_visible()) + assert ax.spines[side].get_visible() utils.despine(**dict(zip(self.sides, [True] * 4))) for side in self.sides: - nt.assert_true(~ax.spines[side].get_visible()) + assert ~ax.spines[side].get_visible() def test_despine_specific_axes(self): f, (ax1, ax2) = plt.subplots(2, 1) @@ -150,19 +153,19 @@ def test_despine_specific_axes(self): utils.despine(ax=ax2) for side in self.sides: - nt.assert_true(ax1.spines[side].get_visible()) + assert ax1.spines[side].get_visible() for side in self.outer_sides: - nt.assert_true(~ax2.spines[side].get_visible()) + assert ~ax2.spines[side].get_visible() for side in self.inner_sides: - nt.assert_true(ax2.spines[side].get_visible()) + assert ax2.spines[side].get_visible() def test_despine_with_offset(self): f, ax = plt.subplots() for side in self.sides: - nt.assert_equal(ax.spines[side].get_position(), - self.original_position) + pos = ax.spines[side].get_position() + assert pos == self.original_position utils.despine(ax=ax, offset=self.offset) @@ -170,9 +173,9 @@ def test_despine_with_offset(self): is_visible = ax.spines[side].get_visible() new_position = ax.spines[side].get_position() if is_visible: - nt.assert_equal(new_position, self.offset_position) + assert new_position == self.offset_position else: - nt.assert_equal(new_position, self.original_position) + assert new_position == self.original_position def test_despine_side_specific_offset(self): @@ -183,9 +186,9 @@ def test_despine_side_specific_offset(self): is_visible = ax.spines[side].get_visible() new_position = ax.spines[side].get_position() if is_visible and side == "left": - nt.assert_equal(new_position, self.offset_position) + assert new_position == self.offset_position else: - nt.assert_equal(new_position, self.original_position) + assert new_position == self.original_position def test_despine_with_offset_specific_axes(self): f, (ax1, ax2) = plt.subplots(2, 1) @@ -193,16 +196,16 @@ def test_despine_with_offset_specific_axes(self): utils.despine(offset=self.offset, ax=ax2) for side in self.sides: - nt.assert_equal(ax1.spines[side].get_position(), - self.original_position) + pos1 = ax1.spines[side].get_position() + pos2 = ax2.spines[side].get_position() + assert pos1 == self.original_position if ax2.spines[side].get_visible(): - nt.assert_equal(ax2.spines[side].get_position(), - self.offset_position) + assert pos2 == self.offset_position else: - nt.assert_equal(ax2.spines[side].get_position(), - self.original_position) + assert pos2 == self.original_position def test_despine_trim_spines(self): + f, ax = plt.subplots() ax.plot([1, 2, 3], [1, 2, 3]) ax.set_xlim(.75, 3.25) @@ -210,7 +213,7 @@ def test_despine_trim_spines(self): utils.despine(trim=True) for side in self.inner_sides: bounds = ax.spines[side].get_bounds() - nt.assert_equal(bounds, (1, 3)) + assert bounds == (1, 3) def test_despine_trim_inverted(self): @@ -222,7 +225,7 @@ def test_despine_trim_inverted(self): utils.despine(trim=True) for side in self.inner_sides: bounds = ax.spines[side].get_bounds() - nt.assert_equal(bounds, (1, 3)) + assert bounds == (1, 3) def test_despine_trim_noticks(self): @@ -230,7 +233,20 @@ def test_despine_trim_noticks(self): ax.plot([1, 2, 3], [1, 2, 3]) ax.set_yticks([]) utils.despine(trim=True) - nt.assert_equal(ax.get_yticks().size, 0) + assert ax.get_yticks().size == 0 + + def test_despine_trim_categorical(self): + + f, ax = plt.subplots() + ax.plot(["a", "b", "c"], [1, 2, 3]) + + utils.despine(trim=True) + + bounds = ax.spines["left"].get_bounds() + assert bounds == (1, 3) + + bounds = ax.spines["bottom"].get_bounds() + assert bounds == (0, 2) def test_despine_moved_ticks(self): @@ -238,7 +254,7 @@ def test_despine_moved_ticks(self): for t in ax.yaxis.majorTicks: t.tick1line.set_visible(True) utils.despine(ax=ax, left=True, right=False) - for y in ax.yaxis.majorTicks: + for t in ax.yaxis.majorTicks: assert t.tick2line.get_visible() plt.close(f) @@ -246,7 +262,7 @@ def test_despine_moved_ticks(self): for t in ax.yaxis.majorTicks: t.tick1line.set_visible(False) utils.despine(ax=ax, left=True, right=False) - for y in ax.yaxis.majorTicks: + for t in ax.yaxis.majorTicks: assert not t.tick2line.get_visible() plt.close(f) @@ -254,7 +270,7 @@ def test_despine_moved_ticks(self): for t in ax.xaxis.majorTicks: t.tick1line.set_visible(True) utils.despine(ax=ax, bottom=True, top=False) - for y in ax.xaxis.majorTicks: + for t in ax.xaxis.majorTicks: assert t.tick2line.get_visible() plt.close(f) @@ -262,7 +278,7 @@ def test_despine_moved_ticks(self): for t in ax.xaxis.majorTicks: t.tick1line.set_visible(False) utils.despine(ax=ax, bottom=True, top=False) - for y in ax.xaxis.majorTicks: + for t in ax.xaxis.majorTicks: assert not t.tick2line.get_visible() plt.close(f) @@ -287,133 +303,176 @@ def test_ticklabels_overlap(): assert not y -def test_categorical_order(): +def test_locator_to_legend_entries(): - x = ["a", "c", "c", "b", "a", "d"] - y = [3, 2, 5, 1, 4] - order = ["a", "b", "c", "d"] + locator = mpl.ticker.MaxNLocator(nbins=3) + limits = (0.09, 0.4) + levels, str_levels = utils.locator_to_legend_entries( + locator, limits, float + ) + assert str_levels == ["0.15", "0.30"] - out = utils.categorical_order(x) - nt.assert_equal(out, ["a", "c", "b", "d"]) + limits = (0.8, 0.9) + levels, str_levels = utils.locator_to_legend_entries( + locator, limits, float + ) + assert str_levels == ["0.80", "0.84", "0.88"] - out = utils.categorical_order(x, order) - nt.assert_equal(out, order) + limits = (1, 6) + levels, str_levels = utils.locator_to_legend_entries(locator, limits, int) + assert str_levels == ["2", "4", "6"] - out = utils.categorical_order(x, ["b", "a"]) - nt.assert_equal(out, ["b", "a"]) + locator = mpl.ticker.LogLocator(numticks=5) + limits = (5, 1425) + levels, str_levels = utils.locator_to_legend_entries(locator, limits, int) + if LooseVersion(mpl.__version__) >= "3.1": + assert str_levels == ['10', '100', '1000'] - out = utils.categorical_order(np.array(x)) - nt.assert_equal(out, ["a", "c", "b", "d"]) + limits = (0.00003, 0.02) + levels, str_levels = utils.locator_to_legend_entries( + locator, limits, float + ) + if LooseVersion(mpl.__version__) >= "3.1": + assert str_levels == ['1e-04', '1e-03', '1e-02'] - out = utils.categorical_order(pd.Series(x)) - nt.assert_equal(out, ["a", "c", "b", "d"]) - out = utils.categorical_order(y) - nt.assert_equal(out, [1, 2, 3, 4, 5]) +def check_load_dataset(name): + ds = load_dataset(name, cache=False) + assert(isinstance(ds, pd.DataFrame)) - out = utils.categorical_order(np.array(y)) - nt.assert_equal(out, [1, 2, 3, 4, 5]) - out = utils.categorical_order(pd.Series(y)) - nt.assert_equal(out, [1, 2, 3, 4, 5]) +def check_load_cached_dataset(name): + # Test the cacheing using a temporary file. + with tempfile.TemporaryDirectory() as tmpdir: + # download and cache + ds = load_dataset(name, cache=True, data_home=tmpdir) - if pandas_has_categoricals: - x = pd.Categorical(x, order) - out = utils.categorical_order(x) - nt.assert_equal(out, list(x.categories)) + # use cached version + ds2 = load_dataset(name, cache=True, data_home=tmpdir) + assert_frame_equal(ds, ds2) - x = pd.Series(x) - out = utils.categorical_order(x) - nt.assert_equal(out, list(x.cat.categories)) - out = utils.categorical_order(x, ["b", "a"]) - nt.assert_equal(out, ["b", "a"]) +@_network(url="https://github.com/mwaskom/seaborn-data") +def test_get_dataset_names(): + names = get_dataset_names() + assert names + assert "tips" in names - x = ["a", np.nan, "c", "c", "b", "a", "d"] - out = utils.categorical_order(x) - nt.assert_equal(out, ["a", "c", "b", "d"]) +@_network(url="https://github.com/mwaskom/seaborn-data") +def test_load_datasets(): -if LooseVersion(pd.__version__) >= "0.15": + # Heavy test to verify that we can load all available datasets + for name in get_dataset_names(): + # unfortunately @network somehow obscures this generator so it + # does not get in effect, so we need to call explicitly + # yield check_load_dataset, name + check_load_dataset(name) - def check_load_dataset(name): - ds = load_dataset(name, cache=False) - assert(isinstance(ds, pd.DataFrame)) - def check_load_cached_dataset(name): - # Test the cacheing using a temporary file. - # With Python 3.2+, we could use the tempfile.TemporaryDirectory() - # context manager instead of this try...finally statement - tmpdir = tempfile.mkdtemp() - try: - # download and cache - ds = load_dataset(name, cache=True, data_home=tmpdir) - - # use cached version - ds2 = load_dataset(name, cache=True, data_home=tmpdir) - pdt.assert_frame_equal(ds, ds2) - - finally: - shutil.rmtree(tmpdir) - - @_network(url="https://github.com/mwaskom/seaborn-data") - def test_get_dataset_names(): - if not BeautifulSoup: - raise nose.SkipTest("No BeautifulSoup available for parsing html") - names = get_dataset_names() - assert(len(names) > 0) - assert(u"titanic" in names) - - @_network(url="https://github.com/mwaskom/seaborn-data") - def test_load_datasets(): - if not BeautifulSoup: - raise nose.SkipTest("No BeautifulSoup available for parsing html") - - # Heavy test to verify that we can load all available datasets - for name in get_dataset_names(): - # unfortunately @network somehow obscures this generator so it - # does not get in effect, so we need to call explicitly - # yield check_load_dataset, name - check_load_dataset(name) - - @_network(url="https://github.com/mwaskom/seaborn-data") - def test_load_cached_datasets(): - if not BeautifulSoup: - raise nose.SkipTest("No BeautifulSoup available for parsing html") - - # Heavy test to verify that we can load all available datasets - for name in get_dataset_names(): - # unfortunately @network somehow obscures this generator so it - # does not get in effect, so we need to call explicitly - # yield check_load_dataset, name - check_load_cached_dataset(name) +@_network(url="https://github.com/mwaskom/seaborn-data") +def test_load_dataset_error(): + + name = "bad_name" + err = f"'{name}' is not one of the example datasets." + with pytest.raises(ValueError, match=err): + load_dataset(name) + + +@_network(url="https://github.com/mwaskom/seaborn-data") +def test_load_cached_datasets(): + + # Heavy test to verify that we can load all available datasets + for name in get_dataset_names(): + # unfortunately @network somehow obscures this generator so it + # does not get in effect, so we need to call explicitly + # yield check_load_dataset, name + check_load_cached_dataset(name) def test_relative_luminance(): """Test relative luminance.""" out1 = utils.relative_luminance("white") - assert_equal(out1, 1) + assert out1 == 1 out2 = utils.relative_luminance("#000000") - assert_equal(out2, 0) + assert out2 == 0 out3 = utils.relative_luminance((.25, .5, .75)) - nose.tools.assert_almost_equal(out3, 0.201624536) + assert out3 == pytest.approx(0.201624536) rgbs = mpl.cm.RdBu(np.linspace(0, 1, 10)) lums1 = [utils.relative_luminance(rgb) for rgb in rgbs] lums2 = utils.relative_luminance(rgbs) for lum1, lum2 in zip(lums1, lums2): - nose.tools.assert_almost_equal(lum1, lum2) + assert lum1 == pytest.approx(lum2) + + +@pytest.mark.parametrize( + "cycler,result", + [ + (cycler(color=["y"]), ["y"]), + (cycler(color=["k"]), ["k"]), + (cycler(color=["k", "y"]), ["k", "y"]), + (cycler(color=["y", "k"]), ["y", "k"]), + (cycler(color=["b", "r"]), ["b", "r"]), + (cycler(color=["r", "b"]), ["r", "b"]), + (cycler(lw=[1, 2]), [".15"]), # no color in cycle + ], +) +def test_get_color_cycle(cycler, result): + with mpl.rc_context(rc={"axes.prop_cycle": cycler}): + assert get_color_cycle() == result def test_remove_na(): a_array = np.array([1, 2, np.nan, 3]) - a_array_rm = utils.remove_na(a_array) - npt.assert_array_equal(a_array_rm, np.array([1, 2, 3])) + a_array_rm = remove_na(a_array) + assert_array_equal(a_array_rm, np.array([1, 2, 3])) a_series = pd.Series([1, 2, np.nan, 3]) - a_series_rm = utils.remove_na(a_series) - pdt.assert_series_equal(a_series_rm, pd.Series([1., 2, 3], [0, 1, 3])) + a_series_rm = remove_na(a_series) + assert_series_equal(a_series_rm, pd.Series([1., 2, 3], [0, 1, 3])) + + +def test_assign_default_kwargs(): + + def f(a, b, c, d): + pass + + def g(c=1, d=2): + pass + + kws = {"c": 3} + + kws = _assign_default_kwargs(kws, f, g) + assert kws == {"c": 3, "d": 2} + + +def test_draw_figure(): + + f, ax = plt.subplots() + ax.plot(["a", "b", "c"], [1, 2, 3]) + _draw_figure(f) + assert not f.stale + # ticklabels are not populated until a draw, but this may change + assert ax.get_xticklabels()[0].get_text() == "a" + + +def test_deprecate_ci(): + + msg = "The `ci` parameter is deprecated; use `errorbar=" + + with pytest.warns(UserWarning, match=msg + "None"): + out = _deprecate_ci(None, None) + assert out is None + + with pytest.warns(UserWarning, match=msg + "'sd'"): + out = _deprecate_ci(None, "sd") + assert out == "sd" + + with pytest.warns(UserWarning, match=msg + r"\('ci', 68\)"): + out = _deprecate_ci(None, 68) + assert out == ("ci", 68) diff --git a/seaborn/timeseries.py b/seaborn/timeseries.py deleted file mode 100644 index a3b25bf457..0000000000 --- a/seaborn/timeseries.py +++ /dev/null @@ -1,454 +0,0 @@ -"""Timeseries plotting functions.""" -from __future__ import division -import numpy as np -import pandas as pd -from scipy import stats, interpolate -import matplotlib as mpl -import matplotlib.pyplot as plt - -import warnings - -from .external.six import string_types - -from . import utils -from . import algorithms as algo -from .palettes import color_palette - - -__all__ = ["tsplot"] - - -def tsplot(data, time=None, unit=None, condition=None, value=None, - err_style="ci_band", ci=68, interpolate=True, color=None, - estimator=np.mean, n_boot=5000, err_palette=None, err_kws=None, - legend=True, ax=None, **kwargs): - """Plot one or more timeseries with flexible representation of uncertainty. - - This function is intended to be used with data where observations are - nested within sampling units that were measured at multiple timepoints. - - It can take data specified either as a long-form (tidy) DataFrame or as an - ndarray with dimensions (unit, time) The interpretation of some of the - other parameters changes depending on the type of object passed as data. - - Parameters - ---------- - data : DataFrame or ndarray - Data for the plot. Should either be a "long form" dataframe or an - array with dimensions (unit, time, condition). In both cases, the - condition field/dimension is optional. The type of this argument - determines the interpretation of the next few parameters. When - using a DataFrame, the index has to be sequential. - time : string or series-like - Either the name of the field corresponding to time in the data - DataFrame or x values for a plot when data is an array. If a Series, - the name will be used to label the x axis. - unit : string - Field in the data DataFrame identifying the sampling unit (e.g. - subject, neuron, etc.). The error representation will collapse over - units at each time/condition observation. This has no role when data - is an array. - value : string - Either the name of the field corresponding to the data values in - the data DataFrame (i.e. the y coordinate) or a string that forms - the y axis label when data is an array. - condition : string or Series-like - Either the name of the field identifying the condition an observation - falls under in the data DataFrame, or a sequence of names with a length - equal to the size of the third dimension of data. There will be a - separate trace plotted for each condition. If condition is a Series - with a name attribute, the name will form the title for the plot - legend (unless legend is set to False). - err_style : string or list of strings or None - Names of ways to plot uncertainty across units from set of - {ci_band, ci_bars, boot_traces, boot_kde, unit_traces, unit_points}. - Can use one or more than one method. - ci : float or list of floats in [0, 100] or "sd" or None - Confidence interval size(s). If a list, it will stack the error plots - for each confidence interval. If ``"sd"``, show standard deviation of - the observations instead of boostrapped confidence intervals. Only - relevant for error styles with "ci" in the name. - interpolate : boolean - Whether to do a linear interpolation between each timepoint when - plotting. The value of this parameter also determines the marker - used for the main plot traces, unless marker is specified as a keyword - argument. - color : seaborn palette or matplotlib color name or dictionary - Palette or color for the main plots and error representation (unless - plotting by unit, which can be separately controlled with err_palette). - If a dictionary, should map condition name to color spec. - estimator : callable - Function to determine central tendency and to pass to bootstrap - must take an ``axis`` argument. - n_boot : int - Number of bootstrap iterations. - err_palette : seaborn palette - Palette name or list of colors used when plotting data for each unit. - err_kws : dict, optional - Keyword argument dictionary passed through to matplotlib function - generating the error plot, - legend : bool, optional - If ``True`` and there is a ``condition`` variable, add a legend to - the plot. - ax : axis object, optional - Plot in given axis; if None creates a new figure - kwargs : - Other keyword arguments are passed to main plot() call - - Returns - ------- - ax : matplotlib axis - axis with plot data - - Examples - -------- - - Plot a trace with translucent confidence bands: - - .. plot:: - :context: close-figs - - >>> import numpy as np; np.random.seed(22) - >>> import seaborn as sns; sns.set(color_codes=True) - >>> x = np.linspace(0, 15, 31) - >>> data = np.sin(x) + np.random.rand(10, 31) + np.random.randn(10, 1) - >>> ax = sns.tsplot(data=data) - - Plot a long-form dataframe with several conditions: - - .. plot:: - :context: close-figs - - >>> gammas = sns.load_dataset("gammas") - >>> ax = sns.tsplot(time="timepoint", value="BOLD signal", - ... unit="subject", condition="ROI", - ... data=gammas) - - Use error bars at the positions of the observations: - - .. plot:: - :context: close-figs - - >>> ax = sns.tsplot(data=data, err_style="ci_bars", color="g") - - Don't interpolate between the observations: - - .. plot:: - :context: close-figs - - >>> import matplotlib.pyplot as plt - >>> ax = sns.tsplot(data=data, err_style="ci_bars", interpolate=False) - - Show multiple confidence bands: - - .. plot:: - :context: close-figs - - >>> ax = sns.tsplot(data=data, ci=[68, 95], color="m") - - Show the standard deviation of the observations: - - .. plot:: - :context: close-figs - - >>> ax = sns.tsplot(data=data, ci="sd") - - Use a different estimator: - - .. plot:: - :context: close-figs - - >>> ax = sns.tsplot(data=data, estimator=np.median) - - Show each bootstrap resample: - - .. plot:: - :context: close-figs - - >>> ax = sns.tsplot(data=data, err_style="boot_traces", n_boot=500) - - Show the trace from each sampling unit: - - - .. plot:: - :context: close-figs - - >>> ax = sns.tsplot(data=data, err_style="unit_traces") - - """ - msg = ( - "The `tsplot` function is deprecated and will be removed in a future " - "release. Please update your code to use the new `lineplot` function." - ) - warnings.warn(msg, UserWarning) - - # Sort out default values for the parameters - if ax is None: - ax = plt.gca() - - if err_kws is None: - err_kws = {} - - # Handle different types of input data - if isinstance(data, pd.DataFrame): - - xlabel = time - ylabel = value - - # Condition is optional - if condition is None: - condition = pd.Series(1, index=data.index) - legend = False - legend_name = None - n_cond = 1 - else: - legend = True and legend - legend_name = condition - n_cond = len(data[condition].unique()) - - else: - data = np.asarray(data) - - # Data can be a timecourse from a single unit or - # several observations in one condition - if data.ndim == 1: - data = data[np.newaxis, :, np.newaxis] - elif data.ndim == 2: - data = data[:, :, np.newaxis] - n_unit, n_time, n_cond = data.shape - - # Units are experimental observations. Maybe subjects, or neurons - if unit is None: - units = np.arange(n_unit) - unit = "unit" - units = np.repeat(units, n_time * n_cond) - ylabel = None - - # Time forms the xaxis of the plot - if time is None: - times = np.arange(n_time) - else: - times = np.asarray(time) - xlabel = None - if hasattr(time, "name"): - xlabel = time.name - time = "time" - times = np.tile(np.repeat(times, n_cond), n_unit) - - # Conditions split the timeseries plots - if condition is None: - conds = range(n_cond) - legend = False - if isinstance(color, dict): - err = "Must have condition names if using color dict." - raise ValueError(err) - else: - conds = np.asarray(condition) - legend = True and legend - if hasattr(condition, "name"): - legend_name = condition.name - else: - legend_name = None - condition = "cond" - conds = np.tile(conds, n_unit * n_time) - - # Value forms the y value in the plot - if value is None: - ylabel = None - else: - ylabel = value - value = "value" - - # Convert to long-form DataFrame - data = pd.DataFrame(dict(value=data.ravel(), - time=times, - unit=units, - cond=conds)) - - # Set up the err_style and ci arguments for the loop below - if isinstance(err_style, string_types): - err_style = [err_style] - elif err_style is None: - err_style = [] - if not hasattr(ci, "__iter__"): - ci = [ci] - - # Set up the color palette - if color is None: - current_palette = utils.get_color_cycle() - if len(current_palette) < n_cond: - colors = color_palette("husl", n_cond) - else: - colors = color_palette(n_colors=n_cond) - elif isinstance(color, dict): - colors = [color[c] for c in data[condition].unique()] - else: - try: - colors = color_palette(color, n_cond) - except ValueError: - color = mpl.colors.colorConverter.to_rgb(color) - colors = [color] * n_cond - - # Do a groupby with condition and plot each trace - c = None - for c, (cond, df_c) in enumerate(data.groupby(condition, sort=False)): - - df_c = df_c.pivot(unit, time, value) - x = df_c.columns.values.astype(np.float) - - # Bootstrap the data for confidence intervals - if "sd" in ci: - est = estimator(df_c.values, axis=0) - sd = np.std(df_c.values, axis=0) - cis = [(est - sd, est + sd)] - boot_data = df_c.values - else: - boot_data = algo.bootstrap(df_c.values, n_boot=n_boot, - axis=0, func=estimator) - cis = [utils.ci(boot_data, v, axis=0) for v in ci] - central_data = estimator(df_c.values, axis=0) - - # Get the color for this condition - color = colors[c] - - # Use subroutines to plot the uncertainty - for style in err_style: - - # Allow for null style (only plot central tendency) - if style is None: - continue - - # Grab the function from the global environment - try: - plot_func = globals()["_plot_%s" % style] - except KeyError: - raise ValueError("%s is not a valid err_style" % style) - - # Possibly set up to plot each observation in a different color - if err_palette is not None and "unit" in style: - orig_color = color - color = color_palette(err_palette, len(df_c.values)) - - # Pass all parameters to the error plotter as keyword args - plot_kwargs = dict(ax=ax, x=x, data=df_c.values, - boot_data=boot_data, - central_data=central_data, - color=color, err_kws=err_kws) - - # Plot the error representation, possibly for multiple cis - for ci_i in cis: - plot_kwargs["ci"] = ci_i - plot_func(**plot_kwargs) - - if err_palette is not None and "unit" in style: - color = orig_color - - # Plot the central trace - kwargs.setdefault("marker", "" if interpolate else "o") - ls = kwargs.pop("ls", "-" if interpolate else "") - kwargs.setdefault("linestyle", ls) - label = cond if legend else "_nolegend_" - ax.plot(x, central_data, color=color, label=label, **kwargs) - - if c is None: - raise RuntimeError("Invalid input data for tsplot.") - - # Pad the sides of the plot only when not interpolating - ax.set_xlim(x.min(), x.max()) - x_diff = x[1] - x[0] - if not interpolate: - ax.set_xlim(x.min() - x_diff, x.max() + x_diff) - - # Add the plot labels - if xlabel is not None: - ax.set_xlabel(xlabel) - if ylabel is not None: - ax.set_ylabel(ylabel) - if legend: - ax.legend(loc=0, title=legend_name) - - return ax - -# Subroutines for tsplot errorbar plotting -# ---------------------------------------- - - -def _plot_ci_band(ax, x, ci, color, err_kws, **kwargs): - """Plot translucent error bands around the central tendancy.""" - low, high = ci - if "alpha" not in err_kws: - err_kws["alpha"] = 0.2 - ax.fill_between(x, low, high, facecolor=color, **err_kws) - - -def _plot_ci_bars(ax, x, central_data, ci, color, err_kws, **kwargs): - """Plot error bars at each data point.""" - for x_i, y_i, (low, high) in zip(x, central_data, ci.T): - ax.plot([x_i, x_i], [low, high], color=color, - solid_capstyle="round", **err_kws) - - -def _plot_boot_traces(ax, x, boot_data, color, err_kws, **kwargs): - """Plot 250 traces from bootstrap.""" - err_kws.setdefault("alpha", 0.25) - err_kws.setdefault("linewidth", 0.25) - if "lw" in err_kws: - err_kws["linewidth"] = err_kws.pop("lw") - ax.plot(x, boot_data.T, color=color, label="_nolegend_", **err_kws) - - -def _plot_unit_traces(ax, x, data, ci, color, err_kws, **kwargs): - """Plot a trace for each observation in the original data.""" - if isinstance(color, list): - if "alpha" not in err_kws: - err_kws["alpha"] = .5 - for i, obs in enumerate(data): - ax.plot(x, obs, color=color[i], label="_nolegend_", **err_kws) - else: - if "alpha" not in err_kws: - err_kws["alpha"] = .2 - ax.plot(x, data.T, color=color, label="_nolegend_", **err_kws) - - -def _plot_unit_points(ax, x, data, color, err_kws, **kwargs): - """Plot each original data point discretely.""" - if isinstance(color, list): - for i, obs in enumerate(data): - ax.plot(x, obs, "o", color=color[i], alpha=0.8, markersize=4, - label="_nolegend_", **err_kws) - else: - ax.plot(x, data.T, "o", color=color, alpha=0.5, markersize=4, - label="_nolegend_", **err_kws) - - -def _plot_boot_kde(ax, x, boot_data, color, **kwargs): - """Plot the kernal density estimate of the bootstrap distribution.""" - kwargs.pop("data") - _ts_kde(ax, x, boot_data, color, **kwargs) - - -def _plot_unit_kde(ax, x, data, color, **kwargs): - """Plot the kernal density estimate over the sample.""" - _ts_kde(ax, x, data, color, **kwargs) - - -def _ts_kde(ax, x, data, color, **kwargs): - """Upsample over time and plot a KDE of the bootstrap distribution.""" - kde_data = [] - y_min, y_max = data.min(), data.max() - y_vals = np.linspace(y_min, y_max, 100) - upsampler = interpolate.interp1d(x, data) - data_upsample = upsampler(np.linspace(x.min(), x.max(), 100)) - for pt_data in data_upsample.T: - pt_kde = stats.kde.gaussian_kde(pt_data) - kde_data.append(pt_kde(y_vals)) - kde_data = np.transpose(kde_data) - rgb = mpl.colors.ColorConverter().to_rgb(color) - img = np.zeros((kde_data.shape[0], kde_data.shape[1], 4)) - img[:, :, :3] = rgb - kde_data /= kde_data.max(axis=0) - kde_data[kde_data > 1] = 1 - img[:, :, 3] = kde_data - ax.imshow(img, interpolation="spline16", zorder=2, - extent=(x.min(), x.max(), y_min, y_max), - aspect="auto", origin="lower") diff --git a/seaborn/utils.py b/seaborn/utils.py index 8ddf82fc25..4547e61d4b 100644 --- a/seaborn/utils.py +++ b/seaborn/utils.py @@ -1,46 +1,22 @@ -"""Small plotting-related utility functions.""" -from __future__ import print_function, division -import colorsys +"""Utility functions, mostly for internal use.""" import os +import re +import inspect +import warnings +import colorsys +from urllib.request import urlopen, urlretrieve +from distutils.version import LooseVersion import numpy as np -from scipy import stats import pandas as pd import matplotlib as mpl -import matplotlib.colors as mplcol +from matplotlib.colors import to_rgb import matplotlib.pyplot as plt - -from .external.six.moves.urllib.request import urlopen, urlretrieve -from .external.six.moves.http_client import HTTPException +from matplotlib.cbook import normalize_kwargs __all__ = ["desaturate", "saturate", "set_hls_values", - "despine", "get_dataset_names", "load_dataset"] - - -def remove_na(arr): - """Helper method for removing NA values from array-like. - - Parameters - ---------- - arr : array-like - The array-like from which to remove NA values. - - Returns - ------- - clean_arr : array-like - The original array with NA values removed. - - """ - return arr[pd.notnull(arr)] - - -def sort_df(df, *args, **kwargs): - """Wrapper to handle different pandas sorting API pre/post 0.17.""" - try: - return df.sort_values(*args, **kwargs) - except AttributeError: - return df.sort(*args, **kwargs) + "despine", "get_dataset_names", "get_data_home", "load_dataset"] def ci_to_errsize(cis, heights): @@ -48,7 +24,7 @@ def ci_to_errsize(cis, heights): Parameters ---------- - cis: 2 x n sequence + cis : 2 x n sequence sequence of confidence interval limits heights : n sequence sequence of plot heights @@ -73,30 +49,117 @@ def ci_to_errsize(cis, heights): return errsize -def pmf_hist(a, bins=10): - """Return arguments to plt.bar for pmf-like histogram of an array. - - Parameters - ---------- - a: array-like - array to make histogram of - bins: int - number of bins +def _normal_quantile_func(q): + """ + Compute the quantile function of the standard normal distribution. - Returns - ------- - x: array - left x position of bars - h: array - height of bars - w: float - width of bars + This wrapper exists because we are dropping scipy as a mandatory dependency + but statistics.NormalDist was added to the standard library in 3.8. """ - n, x = np.histogram(a, bins) - h = n / n.sum() - w = x[1] - x[0] - return x[:-1], h, w + try: + from statistics import NormalDist + qf = np.vectorize(NormalDist().inv_cdf) + except ImportError: + try: + from scipy.stats import norm + qf = norm.ppf + except ImportError: + msg = ( + "Standard normal quantile functions require either Python>=3.8 or scipy" + ) + raise RuntimeError(msg) + return qf(q) + + +def _draw_figure(fig): + """Force draw of a matplotlib figure, accounting for back-compat.""" + # See https://github.com/matplotlib/matplotlib/issues/19197 for context + fig.canvas.draw() + if fig.stale: + try: + fig.draw(fig.canvas.get_renderer()) + except AttributeError: + pass + + +def _default_color(method, hue, color, kws): + """If needed, get a default color by using the matplotlib property cycle.""" + if hue is not None: + # This warning is probably user-friendly, but it's currently triggered + # in a FacetGrid context and I don't want to mess with that logic right now + # if color is not None: + # msg = "`color` is ignored when `hue` is assigned." + # warnings.warn(msg) + return None + + if color is not None: + return color + + elif method.__name__ == "plot": + + scout, = method([], [], **kws) + color = scout.get_color() + scout.remove() + + elif method.__name__ == "scatter": + + # Matplotlib will raise if the size of x/y don't match s/c, + # and the latter might be in the kws dict + scout_size = max( + np.atleast_1d(kws.get(key, [])).shape[0] + for key in ["s", "c", "fc", "facecolor", "facecolors"] + ) + scout_x = scout_y = np.full(scout_size, np.nan) + + scout = method(scout_x, scout_y, **kws) + facecolors = scout.get_facecolors() + + if not len(facecolors): + # Handle bug in matplotlib <= 3.2 (I think) + # This will limit the ability to use non color= kwargs to specify + # a color in versions of matplotlib with the bug, but trying to + # work out what the user wanted by re-implementing the broken logic + # of inspecting the kwargs is probably too brittle. + single_color = False + else: + single_color = np.unique(facecolors, axis=0).shape[0] == 1 + + # Allow the user to specify an array of colors through various kwargs + if "c" not in kws and single_color: + color = to_rgb(facecolors[0]) + + scout.remove() + + elif method.__name__ == "bar": + + # bar() needs masked, not empty data, to generate a patch + scout, = method([np.nan], [np.nan], **kws) + color = to_rgb(scout.get_facecolor()) + scout.remove() + + elif method.__name__ == "fill_between": + + # There is a bug on matplotlib < 3.3 where fill_between with + # datetime units and empty data will set incorrect autoscale limits + # To workaround it, we'll always return the first color in the cycle. + # https://github.com/matplotlib/matplotlib/issues/17586 + ax = method.__self__ + datetime_axis = any([ + isinstance(ax.xaxis.converter, mpl.dates.DateConverter), + isinstance(ax.yaxis.converter, mpl.dates.DateConverter), + ]) + if LooseVersion(mpl.__version__) < "3.3" and datetime_axis: + return "C0" + + kws = _normalize_kwargs(kws, mpl.collections.PolyCollection) + + scout = method([], [], **kws) + facecolor = scout.get_facecolor() + color = to_rgb(facecolor[0]) + scout.remove() + + return color def desaturate(color, prop): @@ -120,7 +183,7 @@ def desaturate(color, prop): raise ValueError("prop must be between 0 and 1") # Get rgb tuple rep - rgb = mplcol.colorConverter.to_rgb(color) + rgb = to_rgb(color) # Convert to hls h, l, s = colorsys.rgb_to_hls(*rgb) @@ -139,7 +202,7 @@ def saturate(color): Parameters ---------- - color : matplotlib color + color : matplotlib color hex, rgb-tuple, or html color name Returns @@ -167,8 +230,8 @@ def set_hls_values(color, h=None, l=None, s=None): # noqa new color code in RGB tuple representation """ - # Get rgb tuple representation - rgb = mplcol.colorConverter.to_rgb(color) + # Get an RGB tuple representation + rgb = to_rgb(color) vals = list(colorsys.rgb_to_hls(*rgb)) for i, val in enumerate([h, l, s]): if val is not None: @@ -179,20 +242,60 @@ def set_hls_values(color, h=None, l=None, s=None): # noqa def axlabel(xlabel, ylabel, **kwargs): - """Grab current axis and label it.""" + """Grab current axis and label it. + + DEPRECATED: will be removed in a future version. + + """ + msg = "This function is deprecated and will be removed in a future version" + warnings.warn(msg, FutureWarning) ax = plt.gca() ax.set_xlabel(xlabel, **kwargs) ax.set_ylabel(ylabel, **kwargs) +def remove_na(vector): + """Helper method for removing null values from data vectors. + + Parameters + ---------- + vector : vector object + Must implement boolean masking with [] subscript syntax. + + Returns + ------- + clean_clean : same type as ``vector`` + Vector of data with null values removed. May be a copy or a view. + + """ + return vector[pd.notnull(vector)] + + +def get_color_cycle(): + """Return the list of colors in the current matplotlib color cycle + + Parameters + ---------- + None + + Returns + ------- + colors : list + List of matplotlib colors in the current cycle, or dark gray if + the current color cycle is empty. + """ + cycler = mpl.rcParams['axes.prop_cycle'] + return cycler.by_key()['color'] if 'color' in cycler.keys else [".15"] + + def despine(fig=None, ax=None, top=True, right=True, left=False, bottom=False, offset=None, trim=False): """Remove the top and right spines from plot(s). fig : matplotlib figure, optional - Figure to despine all axes of, default uses current figure. + Figure to despine all axes of, defaults to the current figure. ax : matplotlib axes, optional - Specific axes object to despine. + Specific axes object to despine. Ignored if fig is provided. top, right, left, bottom : boolean, optional If True, remove that spine. offset : int or dict, optional @@ -227,7 +330,7 @@ def despine(fig=None, ax=None, top=True, right=True, left=False, val = offset.get(side, 0) except AttributeError: val = offset - _set_spine_position(ax_i.spines[side], ('outward', val)) + ax_i.spines[side].set_position(('outward', val)) # Potentially move the ticks if left and not right: @@ -262,7 +365,7 @@ def despine(fig=None, ax=None, top=True, right=True, left=False, if trim: # clip off the parts of the spines that extend past major ticks - xticks = ax_i.get_xticks() + xticks = np.asarray(ax_i.get_xticks()) if xticks.size: firsttick = np.compress(xticks >= min(ax_i.get_xlim()), xticks)[0] @@ -274,7 +377,7 @@ def despine(fig=None, ax=None, top=True, right=True, left=False, newticks = newticks.compress(newticks >= firsttick) ax_i.set_xticks(newticks) - yticks = ax_i.get_yticks() + yticks = np.asarray(ax_i.get_yticks()) if yticks.size: firsttick = np.compress(yticks >= min(ax_i.get_ylim()), yticks)[0] @@ -287,117 +390,44 @@ def despine(fig=None, ax=None, top=True, right=True, left=False, ax_i.set_yticks(newticks) -def _set_spine_position(spine, position): - """ - Set the spine's position without resetting an associated axis. - - As of matplotlib v. 1.0.0, if a spine has an associated axis, then - spine.set_position() calls axis.cla(), which resets locators, formatters, - etc. We temporarily replace that call with axis.reset_ticks(), which is - sufficient for our purposes. - """ - axis = spine.axis - if axis is not None: - cla = axis.cla - axis.cla = axis.reset_ticks - spine.set_position(position) - if axis is not None: - axis.cla = cla - - def _kde_support(data, bw, gridsize, cut, clip): """Establish support for a kernel density estimate.""" support_min = max(data.min() - bw * cut, clip[0]) support_max = min(data.max() + bw * cut, clip[1]) - return np.linspace(support_min, support_max, gridsize) - - -def percentiles(a, pcts, axis=None): - """Like scoreatpercentile but can take and return array of percentiles. - - Parameters - ---------- - a : array - data - pcts : sequence of percentile values - percentile or percentiles to find score at - axis : int or None - if not None, computes scores over this axis - - Returns - ------- - scores: array - array of scores at requested percentiles - first dimension is length of object passed to ``pcts`` + support = np.linspace(support_min, support_max, gridsize) - """ - scores = [] - try: - n = len(pcts) - except TypeError: - pcts = [pcts] - n = 0 - for i, p in enumerate(pcts): - if axis is None: - score = stats.scoreatpercentile(a.ravel(), p) - else: - score = np.apply_along_axis(stats.scoreatpercentile, axis, a, p) - scores.append(score) - scores = np.asarray(scores) - if not n: - scores = scores.squeeze() - return scores + return support def ci(a, which=95, axis=None): """Return a percentile range from an array of values.""" p = 50 - which / 2, 50 + which / 2 - return percentiles(a, p, axis) + return np.nanpercentile(a, p, axis) -def sig_stars(p): - """Return a R-style significance string corresponding to p values.""" - if p < 0.001: - return "***" - elif p < 0.01: - return "**" - elif p < 0.05: - return "*" - elif p < 0.1: - return "." - return "" - - -def iqr(a): - """Calculate the IQR for an array of numbers.""" - a = np.asarray(a) - q1 = stats.scoreatpercentile(a, 25) - q3 = stats.scoreatpercentile(a, 75) - return q3 - q1 +def get_dataset_names(): + """Report available example datasets, useful for reporting issues. + Requires an internet connection. -def get_dataset_names(): - """Report available example datasets, useful for reporting issues.""" - # delayed import to not demand bs4 unless this function is actually used - from bs4 import BeautifulSoup - http = urlopen('https://github.com/mwaskom/seaborn-data/') - gh_list = BeautifulSoup(http) + """ + url = "https://github.com/mwaskom/seaborn-data" + with urlopen(url) as resp: + html = resp.read() - return [l.text.replace('.csv', '') - for l in gh_list.find_all("a", {"class": "js-navigation-open"}) - if l.text.endswith('.csv')] + pat = r"/mwaskom/seaborn-data/blob/master/(\w*).csv" + datasets = re.findall(pat, html.decode()) + return datasets def get_data_home(data_home=None): - """Return the path of the seaborn data directory. + """Return a path to the cache directory for example datasets. - This is used by the ``load_dataset`` function. + This directory is then used by :func:`load_dataset`. - If the ``data_home`` argument is not specified, the default location - is ``~/seaborn-data``. + If the ``data_home`` argument is not specified, it tries to read from the + ``SEABORN_DATA`` environment variable and defaults to ``~/seaborn-data``. - Alternatively, a different default location can be specified using the - environment variable ``SEABORN_DATA``. """ if data_home is None: data_home = os.environ.get('SEABORN_DATA', @@ -409,20 +439,35 @@ def get_data_home(data_home=None): def load_dataset(name, cache=True, data_home=None, **kws): - """Load a dataset from the online repository (requires internet). + """Load an example dataset from the online repository (requires internet). + + This function provides quick access to a small number of example datasets + that are useful for documenting seaborn or generating reproducible examples + for bug reports. It is not necessary for normal usage. + + Note that some of the datasets have a small amount of preprocessing applied + to define a proper ordering for categorical variables. + + Use :func:`get_dataset_names` to see a list of available datasets. Parameters ---------- name : str - Name of the dataset (`name`.csv on - https://github.com/mwaskom/seaborn-data). You can obtain list of - available datasets using :func:`get_dataset_names` + Name of the dataset (``{name}.csv`` on + https://github.com/mwaskom/seaborn-data). cache : boolean, optional - If True, then cache data locally and use the cache on subsequent calls + If True, try to load from the local cache first, and save to the cache + if a download is required. data_home : string, optional - The directory in which to cache data. By default, uses ~/seaborn-data/ - kws : dict, optional - Passed to pandas.read_csv + The directory in which to cache data; see :func:`get_data_home`. + kws : keys and values, optional + Additional keyword arguments are passed to passed through to + :func:`pandas.read_csv`. + + Returns + ------- + df : :class:`pandas.DataFrame` + Tabular data, possibly with some preprocessing applied. """ path = ("https://raw.githubusercontent.com/" @@ -433,10 +478,13 @@ def load_dataset(name, cache=True, data_home=None, **kws): cache_path = os.path.join(get_data_home(data_home), os.path.basename(full_path)) if not os.path.exists(cache_path): + if name not in get_dataset_names(): + raise ValueError(f"'{name}' is not one of the example datasets.") urlretrieve(full_path, cache_path) full_path = cache_path df = pd.read_csv(full_path, **kws) + if df.iloc[-1].isnull().all(): df = df.iloc[:-1] @@ -449,7 +497,8 @@ def load_dataset(name, cache=True, data_home=None, **kws): df["smoker"] = pd.Categorical(df["smoker"], ["Yes", "No"]) if name == "flights": - df["month"] = pd.Categorical(df["month"], df.month.unique()) + months = df["month"].str[:3] + df["month"] = pd.Categorical(months, months.unique()) if name == "exercise": df["time"] = pd.Categorical(df["time"], ["1 min", "15 min", "30 min"]) @@ -460,6 +509,20 @@ def load_dataset(name, cache=True, data_home=None, **kws): df["class"] = pd.Categorical(df["class"], ["First", "Second", "Third"]) df["deck"] = pd.Categorical(df["deck"], list("ABCDEFG")) + if name == "penguins": + df["sex"] = df["sex"].str.title() + + if name == "diamonds": + df["color"] = pd.Categorical( + df["color"], ["D", "E", "F", "G", "H", "I", "J"], + ) + df["clarity"] = pd.Categorical( + df["clarity"], ["IF", "VVS1", "VVS2", "VS1", "VS2", "SI1", "SI2", "I1"], + ) + df["cut"] = pd.Categorical( + df["cut"], ["Ideal", "Premium", "Very Good", "Good", "Fair"], + ) + return df @@ -468,7 +531,7 @@ def axis_ticklabels_overlap(labels): Parameters ---------- - labels : list of ticklabels + labels : list of matplotlib ticklabels Returns ------- @@ -483,7 +546,7 @@ def axis_ticklabels_overlap(labels): overlaps = [b.count_overlaps(bboxes) for b in bboxes] return max(overlaps) > 1 except RuntimeError: - # Issue on macosx backend rasies an error in the above code + # Issue on macos backend raises an error in the above code return False @@ -504,59 +567,30 @@ def axes_ticklabels_overlap(ax): axis_ticklabels_overlap(ax.get_yticklabels())) -def categorical_order(values, order=None): - """Return a list of unique data values. +def locator_to_legend_entries(locator, limits, dtype): + """Return levels and formatted levels for brief numeric legends.""" + raw_levels = locator.tick_values(*limits).astype(dtype) - Determine an ordered list of levels in ``values``. - - Parameters - ---------- - values : list, array, Categorical, or Series - Vector of "categorical" values - order : list-like, optional - Desired order of category levels to override the order determined - from the ``values`` object. + # The locator can return ticks outside the limits, clip them here + raw_levels = [l for l in raw_levels if l >= limits[0] and l <= limits[1]] - Returns - ------- - order : list - Ordered list of category levels not including null values. + class dummy_axis: + def get_view_interval(self): + return limits - """ - if order is None: - if hasattr(values, "categories"): - order = values.categories - else: - try: - order = values.cat.categories - except (TypeError, AttributeError): - try: - order = values.unique() - except AttributeError: - order = pd.unique(values) - try: - np.asarray(values).astype(np.float) - order = np.sort(order) - except (ValueError, TypeError): - order = order - order = filter(pd.notnull, order) - return list(order) + if isinstance(locator, mpl.ticker.LogLocator): + formatter = mpl.ticker.LogFormatter() + else: + formatter = mpl.ticker.ScalarFormatter() + formatter.axis = dummy_axis() + # TODO: The following two lines should be replaced + # once pinned matplotlib>=3.1.0 with: + # formatted_levels = formatter.format_ticks(raw_levels) + formatter.set_locs(raw_levels) + formatted_levels = [formatter(x) for x in raw_levels] -def get_color_cycle(): - """Return the list of colors in the current matplotlib color cycle.""" - try: - cyl = mpl.rcParams['axes.prop_cycle'] - try: - # matplotlib 1.5 verifies that axes.prop_cycle *is* a cycler - # but no garuantee that there's a `color` key. - # so users could have a custom rcParmas w/ no color... - return [x['color'] for x in cyl] - except KeyError: - pass - except KeyError: - pass - return mpl.rcParams['axes.color_cycle'] + return raw_levels, formatted_levels def relative_luminance(color): @@ -582,16 +616,14 @@ def relative_luminance(color): def to_utf8(obj): - """Return a Unicode string representing a Python object. + """Return a string representing a Python object. - Unicode strings (i.e. type ``unicode`` in Python 2.7 and type ``str`` in - Python 3.x) are returned unchanged. + Strings (i.e. type ``str``) are returned unchanged. - Byte strings (i.e. type ``str`` in Python 2.7 and type ``bytes`` in - Python 3.x) are returned as UTF-8-encoded strings. + Byte strings (i.e. type ``bytes``) are returned as UTF-8-decoded strings. For other objects, the method ``__str__()`` is called, and the result is - returned as a UTF-8-encoded string. + returned as a string. Parameters ---------- @@ -600,58 +632,100 @@ def to_utf8(obj): Returns ------- - s : unicode (Python 2.7) / str (Python 3.x) - UTF-8-encoded string representation of ``obj`` + s : str + UTF-8-decoded string representation of ``obj`` + """ if isinstance(obj, str): - try: - # If obj is a string, try to return it as a Unicode-encoded - # string: - return obj.decode("utf-8") - except AttributeError: - # Python 3.x strings are already Unicode, and do not have a - # decode() method, so the unchanged string is returned - return obj - + return obj try: - if isinstance(obj, unicode): - # do not attemt a conversion if string is already a Unicode - # string: - return obj - else: - # call __str__() for non-string object, and return the - # result to Unicode: - return obj.__str__().decode("utf-8") - except NameError: - # NameError is raised in Python 3.x as type 'unicode' is not - # defined. - if isinstance(obj, bytes): - return obj.decode("utf-8") - else: - return obj.__str__() - - -def _network(t=None, url='https://google.com'): - """ - Decorator that will skip a test if `url` is unreachable. - - Parameters - ---------- - t : function, optional - url : str, optional + return obj.decode(encoding="utf-8") + except AttributeError: # obj is not bytes-like + return str(obj) + + +def _normalize_kwargs(kws, artist): + """Wrapper for mpl.cbook.normalize_kwargs that supports <= 3.2.1.""" + _alias_map = { + 'color': ['c'], + 'linewidth': ['lw'], + 'linestyle': ['ls'], + 'facecolor': ['fc'], + 'edgecolor': ['ec'], + 'markerfacecolor': ['mfc'], + 'markeredgecolor': ['mec'], + 'markeredgewidth': ['mew'], + 'markersize': ['ms'] + } + try: + kws = normalize_kwargs(kws, artist) + except AttributeError: + kws = normalize_kwargs(kws, _alias_map) + return kws + + +def _check_argument(param, options, value): + """Raise if value for param is not in options.""" + if value not in options: + raise ValueError( + f"`{param}` must be one of {options}, but {repr(value)} was passed." + ) + + +def _assign_default_kwargs(kws, call_func, source_func): + """Assign default kwargs for call_func using values from source_func.""" + # This exists so that axes-level functions and figure-level functions can + # both call a Plotter method while having the default kwargs be defined in + # the signature of the axes-level function. + # An alternative would be to have a decorator on the method that sets its + # defaults based on those defined in the axes-level function. + # Then the figure-level function would not need to worry about defaults. + # I am not sure which is better. + needed = inspect.signature(call_func).parameters + defaults = inspect.signature(source_func).parameters + + for param in needed: + if param in defaults and param not in kws: + kws[param] = defaults[param].default + + return kws + + +def adjust_legend_subtitles(legend): + """Make invisible-handle "subtitles" entries look more like titles.""" + # Legend title not in rcParams until 3.0 + font_size = plt.rcParams.get("legend.title_fontsize", None) + hpackers = legend.findobj(mpl.offsetbox.VPacker)[0].get_children() + for hpack in hpackers: + draw_area, text_area = hpack.get_children() + handles = draw_area.get_children() + if not all(artist.get_visible() for artist in handles): + draw_area.set_width(0) + for text in text_area.get_children(): + if font_size is not None: + text.set_size(font_size) + + +def _deprecate_ci(errorbar, ci): """ - import nose + Warn on usage of ci= and convert to appropriate errorbar= arg. - if t is None: - return lambda x: _network(x, url=url) + ci was deprecated when errorbar was added in 0.12. It should not be removed + completely for some time, but it can be moved out of function definitions + (and extracted from kwargs) after one cycle. - def wrapper(*args, **kwargs): - # attempt to connect - try: - f = urlopen(url) - except (IOError, HTTPException): - raise nose.SkipTest() + """ + if ci != "deprecated": + if ci is None: + errorbar = None + elif ci == "sd": + errorbar = "sd" else: - f.close() - return t(*args, **kwargs) - return wrapper + errorbar = ("ci", ci) + msg = ( + "The `ci` parameter is deprecated; " + f"use `errorbar={repr(errorbar)}` for same effect." + ) + warnings.warn(msg, UserWarning) + + return errorbar diff --git a/seaborn/widgets.py b/seaborn/widgets.py index 6976f61bf4..c75cc66c48 100644 --- a/seaborn/widgets.py +++ b/seaborn/widgets.py @@ -1,4 +1,3 @@ -from __future__ import division import numpy as np import matplotlib.pyplot as plt from matplotlib.colors import LinearSegmentedColormap @@ -208,7 +207,7 @@ def choose_dark_palette_rgb(r=(0., 1.), elif input == "hls": @interact def choose_dark_palette_hls(h=(0., 1.), - l=(0., 1.), + l=(0., 1.), # noqa: E741 s=(0., 1.), n=(3, 17)): color = h, l, s @@ -224,7 +223,7 @@ def choose_dark_palette_hls(h=(0., 1.), @interact def choose_dark_palette_husl(h=(0, 359), s=(0, 99), - l=(0, 99), + l=(0, 99), # noqa: E741 n=(3, 17)): color = h, s, l if as_cmap: @@ -293,7 +292,7 @@ def choose_light_palette_rgb(r=(0., 1.), elif input == "hls": @interact def choose_light_palette_hls(h=(0., 1.), - l=(0., 1.), + l=(0., 1.), # noqa: E741 s=(0., 1.), n=(3, 17)): color = h, l, s @@ -309,7 +308,7 @@ def choose_light_palette_hls(h=(0., 1.), @interact def choose_light_palette_husl(h=(0, 359), s=(0, 99), - l=(0, 99), + l=(0, 99), # noqa: E741 n=(3, 17)): color = h, s, l if as_cmap: @@ -358,17 +357,19 @@ def choose_diverging_palette(as_cmap=False): cmap = _init_mutable_colormap() @interact - def choose_diverging_palette(h_neg=IntSlider(min=0, - max=359, - value=220), - h_pos=IntSlider(min=0, - max=359, - value=10), - s=IntSlider(min=0, max=99, value=74), - l=IntSlider(min=0, max=99, value=50), - sep=IntSlider(min=1, max=50, value=10), - n=(2, 16), - center=["light", "dark"]): + def choose_diverging_palette( + h_neg=IntSlider(min=0, + max=359, + value=220), + h_pos=IntSlider(min=0, + max=359, + value=10), + s=IntSlider(min=0, max=99, value=74), + l=IntSlider(min=0, max=99, value=50), # noqa: E741 + sep=IntSlider(min=1, max=50, value=10), + n=(2, 16), + center=["light", "dark"] + ): if as_cmap: colors = diverging_palette(h_neg, h_pos, s, l, sep, 256, center) _update_lut(cmap, colors) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000000..5fe3a51f96 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,7 @@ +[metadata] +license_file = LICENSE + +[flake8] +max-line-length = 88 +exclude = seaborn/cm.py,seaborn/external +ignore = E741,F522,W503 diff --git a/setup.py b/setup.py index e2b7f601b1..abba8cdfdc 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ #! /usr/bin/env python # -# Copyright (C) 2012-2019 Michael Waskom +# Copyright (C) 2012-2020 Michael Waskom DESCRIPTION = "seaborn: statistical data visualization" LONG_DESCRIPTION = """\ @@ -26,15 +26,23 @@ URL = 'https://seaborn.pydata.org' LICENSE = 'BSD (3-clause)' DOWNLOAD_URL = 'https://github.com/mwaskom/seaborn/' -VERSION = '0.9.1.dev' +VERSION = '0.12.0.dev0' +PYTHON_REQUIRES = ">=3.7" INSTALL_REQUIRES = [ - 'numpy>=1.9.3', - 'scipy>=0.14.0', - 'pandas>=0.15.2', - 'matplotlib>=1.4.3', + 'numpy>=1.16', + 'pandas>=0.24', + 'matplotlib>=3.0', ] +EXTRAS_REQUIRE = { + 'all': [ + 'scipy>=1.2', + 'statsmodels>=0.9', + ] +} + + PACKAGES = [ 'seaborn', 'seaborn.colors', @@ -44,26 +52,25 @@ CLASSIFIERS = [ 'Intended Audience :: Science/Research', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', 'License :: OSI Approved :: BSD License', 'Topic :: Scientific/Engineering :: Visualization', 'Topic :: Multimedia :: Graphics', - 'Operating System :: POSIX', - 'Operating System :: Unix', - 'Operating System :: MacOS' + 'Operating System :: OS Independent', + 'Framework :: Matplotlib', ] -try: - from setuptools import setup - _has_setuptools = True -except ImportError: - from distutils.core import setup if __name__ == "__main__": + from setuptools import setup + + import sys + if sys.version_info[:2] < (3, 6): + raise RuntimeError("seaborn requires python >= 3.6.") + setup( name=DISTNAME, author=MAINTAINER, @@ -76,7 +83,9 @@ url=URL, version=VERSION, download_url=DOWNLOAD_URL, + python_requires=PYTHON_REQUIRES, install_requires=INSTALL_REQUIRES, + extras_require=EXTRAS_REQUIRE, packages=PACKAGES, classifiers=CLASSIFIERS ) diff --git a/testing/deps_latest.txt b/testing/deps_latest.txt deleted file mode 100644 index a8d32f9934..0000000000 --- a/testing/deps_latest.txt +++ /dev/null @@ -1,5 +0,0 @@ -numpy -scipy -matplotlib -pandas -statsmodels diff --git a/testing/deps_minimal.txt b/testing/deps_minimal.txt deleted file mode 100644 index 8fed041814..0000000000 --- a/testing/deps_minimal.txt +++ /dev/null @@ -1,4 +0,0 @@ -numpy -scipy -matplotlib -pandas diff --git a/testing/deps_pinned.txt b/testing/deps_pinned.txt deleted file mode 100644 index cd5035940e..0000000000 --- a/testing/deps_pinned.txt +++ /dev/null @@ -1,8 +0,0 @@ -numpy=1.9.3 -scipy=0.14.0 -matplotlib=1.4.3 -pandas=0.15.2 -statsmodels=0.5.0 - -# Needed due to incomplete scipy recipe -libgfortran=1.0 diff --git a/testing/matplotlibrc_agg b/testing/matplotlibrc_agg deleted file mode 100644 index 88a8365733..0000000000 --- a/testing/matplotlibrc_agg +++ /dev/null @@ -1 +0,0 @@ -backend: Agg diff --git a/testing/matplotlibrc_qtagg b/testing/matplotlibrc_qtagg deleted file mode 100644 index 957624030e..0000000000 --- a/testing/matplotlibrc_qtagg +++ /dev/null @@ -1 +0,0 @@ -backend: Qt5Agg diff --git a/testing/utils.txt b/testing/utils.txt deleted file mode 100644 index e07e015a77..0000000000 --- a/testing/utils.txt +++ /dev/null @@ -1,4 +0,0 @@ -pytest -pytest-cov -flake8 -nose