Skip to content

Commit

Permalink
Add tests and new validate_column logic.
Browse files Browse the repository at this point in the history
  • Loading branch information
ellisonbg committed Oct 17, 2017
1 parent e4c94ea commit 952654d
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 54 deletions.
121 changes: 69 additions & 52 deletions altair/v1/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class FieldError(Exception):
# This is added to TopLevelMixin as a method if MIME rendering is enabled
def _repr_mimebundle_(self, include, exclude, **kwargs):
"""Return a MIME-bundle for rich display in the Jupyter Notebook."""
spec = self.to_dict()
spec = self.to_dict(validate_columns=True)
bundle = create_vegalite_mime_bundle(spec)
return bundle

Expand Down Expand Up @@ -107,6 +107,7 @@ def disable_mime_rendering():
#*************************************************************************
# Channel Aliases
#*************************************************************************

from .schema import X, Y, Row, Column, Color, Size, Shape, Text, Label, Detail, Opacity, Order, Path
from .schema import Encoding, Facet

Expand All @@ -130,6 +131,7 @@ def decorate(f):
# - makes field a required first argument of initialization
# - allows expr trait to be an Expression and processes it properly
#*************************************************************************

class Formula(schema.Formula):
expr = jst.JSONUnion([jst.JSONString(),
jst.JSONInstance(expr.Expression)],
Expand All @@ -149,6 +151,7 @@ def _finalize(self, **kwargs):
# Transform wrapper
# - allows filter trait to be an Expression and processes it properly
#*************************************************************************

class Transform(schema.Transform):
filter = jst.JSONUnion([jst.JSONString(),
jst.JSONInstance(expr.Expression),
Expand All @@ -175,6 +178,7 @@ def _finalize(self, **kwargs):
#*************************************************************************
# Top-level Objects
#*************************************************************************

class TopLevelMixin(object):

@staticmethod
Expand Down Expand Up @@ -263,22 +267,26 @@ def to_html(self, template=None, title=None, **kwargs):
including HTML
"""
from ..utils.html import to_html
return to_html(self.to_dict(), template=template, title=title, **kwargs)
return to_html(self.to_dict(validate_columns=True), template=template, title=title, **kwargs)

def to_dict(self, data=True):
def to_dict(self, data=True, validate_columns=False):
"""Emit the JSON representation for this object as as dict.
Parameters
----------
data : bool
If True (default) then include data in the representation.
validate_columns : bool
If True (default is False) raise FieldError if there are missing or misspelled
column names. This only actually raises if self.validate_columns is also set
(it defaults to True).
Returns
-------
spec : dict
The JSON specification of the chart object.
"""
dct = super(TopLevelMixin, self).to_dict(data=data)
dct = super(TopLevelMixin, self).to_dict(data=data, validate_columns=validate_columns)
dct['$schema'] = schema.vegalite_schema_url
return dct

Expand Down Expand Up @@ -434,7 +442,7 @@ def _ipython_display_(self):
"""Use the vega package to display in the classic Jupyter Notebook."""
from IPython.display import display
from vega import VegaLite
display(VegaLite(self.to_dict()))
display(VegaLite(self.to_dict(validate_columns=True)))

def display(self):
"""Display the Chart using the Jupyter Notebook's rich output.
Expand Down Expand Up @@ -481,6 +489,21 @@ def serve(self, ip='127.0.0.1', port=8888, n_retries=50, files=None,
files=files, jupyter_warning=jupyter_warning,
open_browser=open_browser, http_server=http_server)

def _finalize(self, **kwargs):
self._finalize_data()
# data comes from wrappers, but self.data overrides this if defined
if self.data is not None:
kwargs['data'] = self.data
super(TopLevelMixin, self)._finalize(**kwargs)

# Validate columns after the rest of _finalize() has run. This is last as
# field names are not yet filled in from shortcuts until now.
validate_columns = kwargs.get('validate_columns')
# Only do validation if the requested as a keyword arg to `_finalize`
# and the Chart allows it.
if validate_columns and self.validate_columns:
self._validate_columns()

def _finalize_data(self):
"""
This function is called by _finalize() below.
Expand All @@ -491,19 +514,10 @@ def _finalize_data(self):
* Whether the data attribute contains expressions, and if so it extracts
the appropriate data object and generates the appropriate transforms.
"""
# Check to see if data has too many rows.
if isinstance(self.data, pd.DataFrame):
if len(self.data) > self.max_rows:
raise MaxRowsExceeded(
"Your dataset has too many rows and could take a long "
"time to send to the frontend or to render. To override the "
"default maximum rows (%s), set the max_rows property of "
"your Chart to an integer larger than the number of rows "
"in your dataset. Alternatively you could perform aggregations "
"or other data reductions before using it with Altair" % DEFAULT_MAX_ROWS
)

# Handle expressions.
# Handle expressions. This transforms expr.DataFrame object into a set
# of transforms and an actual pd.DataFrame. After this block runs,
# self.data is either a URL or a pd.DataFrame or None.
if isinstance(self.data, expr.DataFrame):
columns = self.data._cols
calculated_cols = self.data._calculated_cols
Expand All @@ -522,10 +536,23 @@ def _finalize_data(self):
else:
self.transform_data(filter=filters)

def _validate_spec(self):
"""Validate the spec.
# If self.data is a pd.DataFrame, check to see if data has too many rows.
if isinstance(self.data, pd.DataFrame):
if len(self.data) > self.max_rows:
raise MaxRowsExceeded(
"Your dataset has too many rows and could take a long "
"time to send to the frontend or to render. To override the "
"default maximum rows (%s), set the max_rows property of "
"your Chart to an integer larger than the number of rows "
"in your dataset. Alternatively you could perform aggregations "
"or other data reductions before using it with Altair" % DEFAULT_MAX_ROWS
)


def _validate_columns(self):
"""Validate the columns in the encoding, but only if if the data is a ``DataFrame``.
This has to be called after the rest of the _finalize() logic, which fills in the
This has to be called after the rest of the ``_finalize()`` logic, which fills in the
shortcut field names and also processes the expressions for computed fields.
This validates:
Expand All @@ -537,8 +564,7 @@ def _validate_spec(self):
This logic only runs when the dataset is a ``DataFrame``.
"""

# If we have a concrete dataset (not a URL) make sure field names in the encodings
# are present in the data or formulas
# Only validate columns if the data is a pd.DataFrame.
if isinstance(self.data, pd.DataFrame):
# Find columns with visual encodings
encoded_columns = set()
Expand All @@ -549,7 +575,9 @@ def _validate_spec(self):
if channel is not jst.undefined:
field = channel.field
if field is jst.undefined:
raise FieldError("Missing field/column name for channel: {}".format(channel_name))
raise FieldError(
"Missing field/column name for channel: {}".format(channel_name)
)
else:
encoded_columns.add(field)
# Find columns in the data
Expand All @@ -565,7 +593,9 @@ def _validate_spec(self):
# Find columns in the visual encoding that are not in the data
missing_columns = encoded_columns - data_columns
if missing_columns:
raise FieldError("Fields/columns not found in the data: {}".format(missing_columns))
raise FieldError(
"Fields/columns not found in the data: {}".format(missing_columns)
)


class Chart(TopLevelMixin, schema.ExtendedUnitSpec):
Expand All @@ -577,11 +607,14 @@ class Chart(TopLevelMixin, schema.ExtendedUnitSpec):
transform = jst.JSONInstance(Transform,
help=schema.ExtendedUnitSpec.transform.help)
mark = schema.Mark(default_value='point', help="""The mark type.""")

max_rows = T.Int(
default_value=DEFAULT_MAX_ROWS,
help="Maximum number of rows in the dataset to accept."
)
validate_columns = T.Bool(
default_value=True,
help="Raise FieldError if the data is a DataFrame and there are missing columns."
)

@property
def data(self):
Expand All @@ -597,7 +630,7 @@ def data(self, new):
else:
raise TypeError('Expected DataFrame or altair.Data, got: {0}'.format(new))

_skip_on_export = ['data', '_data', 'max_rows']
_skip_on_export = ['data', '_data', 'max_rows', 'validate_columns']

def __init__(self, data=None, **kwargs):
super(Chart, self).__init__(**kwargs)
Expand Down Expand Up @@ -671,16 +704,6 @@ def encode(self, *args, **kwargs):
"""Define the encoding for the Chart."""
return update_subtraits(self, 'encoding', *args, **kwargs)

def _finalize(self, **kwargs):
self._finalize_data()
# data comes from wrappers, but self.data overrides this if defined
if self.data is not None:
kwargs['data'] = self.data
super(Chart, self)._finalize(**kwargs)
# After the rest of _finalize() has run, validate the spec.
# This is last as field names are not yet filled in from shortcuts until now.
self._validate_spec()

def __add__(self, other):
if isinstance(other, Chart):
lc = LayeredChart()
Expand Down Expand Up @@ -732,6 +755,10 @@ class LayeredChart(TopLevelMixin, schema.LayerSpec):
default_value=DEFAULT_MAX_ROWS,
help="Maximum number of rows in the dataset to accept."
)
validate_columns = T.Bool(
default_value=True,
help="Raise FieldError if the data is a DataFrame and there are missing columns."
)

@property
def data(self):
Expand All @@ -747,7 +774,7 @@ def data(self, new):
else:
raise TypeError('Expected DataFrame or altair.Data, got: {0}'.format(new))

_skip_on_export = ['data', '_data', 'max_rows']
_skip_on_export = ['data', '_data', 'max_rows', 'validate_columns']

def __init__(self, data=None, **kwargs):
super(LayeredChart, self).__init__(**kwargs)
Expand All @@ -760,13 +787,6 @@ def set_layers(self, *layers):
self.layers = list(layers)
return self

def _finalize(self, **kwargs):
self._finalize_data()
# data comes from wrappers, but self.data overrides this if defined
if self.data is not None:
kwargs['data'] = self.data
super(LayeredChart, self)._finalize(**kwargs)

def __iadd__(self, layer):
if self.layers is jst.undefined:
self.layers = [layer]
Expand All @@ -789,6 +809,10 @@ class FacetedChart(TopLevelMixin, schema.FacetSpec):
default_value=DEFAULT_MAX_ROWS,
help="Maximum number of rows in the dataset to accept."
)
validate_columns = T.Bool(
default_value=True,
help="Raise FieldError if the data is a DataFrame and there are missing columns."
)

@property
def data(self):
Expand All @@ -804,7 +828,7 @@ def data(self, new):
else:
raise TypeError('Expected DataFrame or altair.Data, got: {0}'.format(new))

_skip_on_export = ['data', '_data', 'max_rows']
_skip_on_export = ['data', '_data', 'max_rows', 'validate_columns']

def __init__(self, data=None, **kwargs):
super(FacetedChart, self).__init__(**kwargs)
Expand All @@ -817,10 +841,3 @@ def __dir__(self):
def set_facet(self, *args, **kwargs):
"""Define the facet encoding for the Chart."""
return update_subtraits(self, 'facet', *args, **kwargs)

def _finalize(self, **kwargs):
self._finalize_data()
# data comes from wrappers, but self.data overrides this if defined
if self.data is not None:
kwargs['data'] = self.data
super(FacetedChart, self)._finalize(**kwargs)
12 changes: 10 additions & 2 deletions altair/v1/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,15 +594,23 @@ def test_validate_spec():
# Make sure we catch channels with no field specified
c = make_chart()
c.encode(Color())
assert isinstance(c.to_dict(), dict)
assert isinstance(c.to_dict(validate_columns=False), dict)
with pytest.raises(FieldError):
c.to_dict()
c.to_dict(validate_columns=True)
c.validate_columns = False
assert isinstance(c.to_dict(validate_columns=True), dict)

# Make sure we catch encoded fields not in the data
c = make_chart()
c.encode(x='x', y='y', color='z')
c.encode(color='z')
assert isinstance(c.to_dict(), dict)
assert isinstance(c.to_dict(validate_columns=False), dict)
with pytest.raises(FieldError):
c.to_dict()
c.to_dict(validate_columns=True)
c.validate_columns = False
assert isinstance(c.to_dict(validate_columns=True), dict)

# Make sure we can resolve computed fields
c = make_chart()
Expand Down

0 comments on commit 952654d

Please sign in to comment.