From 952654d962c67433e87206db80cb5a67f99ac030 Mon Sep 17 00:00:00 2001 From: "Brian E. Granger" Date: Tue, 17 Oct 2017 10:51:07 -0700 Subject: [PATCH] Add tests and new validate_column logic. --- altair/v1/api.py | 121 ++++++++++++++++++++---------------- altair/v1/tests/test_api.py | 12 +++- 2 files changed, 79 insertions(+), 54 deletions(-) diff --git a/altair/v1/api.py b/altair/v1/api.py index 20aa4f8daf..0befd294c3 100644 --- a/altair/v1/api.py +++ b/altair/v1/api.py @@ -79,7 +79,7 @@ class FieldError(Exception): # This is added to TopLevelMixin as a method if MIME rendering is enabled def _repr_mimebundle_(self, include, exclude, **kwargs): """Return a MIME-bundle for rich display in the Jupyter Notebook.""" - spec = self.to_dict() + spec = self.to_dict(validate_columns=True) bundle = create_vegalite_mime_bundle(spec) return bundle @@ -107,6 +107,7 @@ def disable_mime_rendering(): #************************************************************************* # Channel Aliases #************************************************************************* + from .schema import X, Y, Row, Column, Color, Size, Shape, Text, Label, Detail, Opacity, Order, Path from .schema import Encoding, Facet @@ -130,6 +131,7 @@ def decorate(f): # - makes field a required first argument of initialization # - allows expr trait to be an Expression and processes it properly #************************************************************************* + class Formula(schema.Formula): expr = jst.JSONUnion([jst.JSONString(), jst.JSONInstance(expr.Expression)], @@ -149,6 +151,7 @@ def _finalize(self, **kwargs): # Transform wrapper # - allows filter trait to be an Expression and processes it properly #************************************************************************* + class Transform(schema.Transform): filter = jst.JSONUnion([jst.JSONString(), jst.JSONInstance(expr.Expression), @@ -175,6 +178,7 @@ def _finalize(self, **kwargs): #************************************************************************* # Top-level Objects #************************************************************************* + class TopLevelMixin(object): @staticmethod @@ -263,22 +267,26 @@ def to_html(self, template=None, title=None, **kwargs): including HTML """ from ..utils.html import to_html - return to_html(self.to_dict(), template=template, title=title, **kwargs) + return to_html(self.to_dict(validate_columns=True), template=template, title=title, **kwargs) - def to_dict(self, data=True): + def to_dict(self, data=True, validate_columns=False): """Emit the JSON representation for this object as as dict. Parameters ---------- data : bool If True (default) then include data in the representation. + validate_columns : bool + If True (default is False) raise FieldError if there are missing or misspelled + column names. This only actually raises if self.validate_columns is also set + (it defaults to True). Returns ------- spec : dict The JSON specification of the chart object. """ - dct = super(TopLevelMixin, self).to_dict(data=data) + dct = super(TopLevelMixin, self).to_dict(data=data, validate_columns=validate_columns) dct['$schema'] = schema.vegalite_schema_url return dct @@ -434,7 +442,7 @@ def _ipython_display_(self): """Use the vega package to display in the classic Jupyter Notebook.""" from IPython.display import display from vega import VegaLite - display(VegaLite(self.to_dict())) + display(VegaLite(self.to_dict(validate_columns=True))) def display(self): """Display the Chart using the Jupyter Notebook's rich output. @@ -481,6 +489,21 @@ def serve(self, ip='127.0.0.1', port=8888, n_retries=50, files=None, files=files, jupyter_warning=jupyter_warning, open_browser=open_browser, http_server=http_server) + def _finalize(self, **kwargs): + self._finalize_data() + # data comes from wrappers, but self.data overrides this if defined + if self.data is not None: + kwargs['data'] = self.data + super(TopLevelMixin, self)._finalize(**kwargs) + + # Validate columns after the rest of _finalize() has run. This is last as + # field names are not yet filled in from shortcuts until now. + validate_columns = kwargs.get('validate_columns') + # Only do validation if the requested as a keyword arg to `_finalize` + # and the Chart allows it. + if validate_columns and self.validate_columns: + self._validate_columns() + def _finalize_data(self): """ This function is called by _finalize() below. @@ -491,19 +514,10 @@ def _finalize_data(self): * Whether the data attribute contains expressions, and if so it extracts the appropriate data object and generates the appropriate transforms. """ - # Check to see if data has too many rows. - if isinstance(self.data, pd.DataFrame): - if len(self.data) > self.max_rows: - raise MaxRowsExceeded( - "Your dataset has too many rows and could take a long " - "time to send to the frontend or to render. To override the " - "default maximum rows (%s), set the max_rows property of " - "your Chart to an integer larger than the number of rows " - "in your dataset. Alternatively you could perform aggregations " - "or other data reductions before using it with Altair" % DEFAULT_MAX_ROWS - ) - # Handle expressions. + # Handle expressions. This transforms expr.DataFrame object into a set + # of transforms and an actual pd.DataFrame. After this block runs, + # self.data is either a URL or a pd.DataFrame or None. if isinstance(self.data, expr.DataFrame): columns = self.data._cols calculated_cols = self.data._calculated_cols @@ -522,10 +536,23 @@ def _finalize_data(self): else: self.transform_data(filter=filters) - def _validate_spec(self): - """Validate the spec. + # If self.data is a pd.DataFrame, check to see if data has too many rows. + if isinstance(self.data, pd.DataFrame): + if len(self.data) > self.max_rows: + raise MaxRowsExceeded( + "Your dataset has too many rows and could take a long " + "time to send to the frontend or to render. To override the " + "default maximum rows (%s), set the max_rows property of " + "your Chart to an integer larger than the number of rows " + "in your dataset. Alternatively you could perform aggregations " + "or other data reductions before using it with Altair" % DEFAULT_MAX_ROWS + ) + + + def _validate_columns(self): + """Validate the columns in the encoding, but only if if the data is a ``DataFrame``. - This has to be called after the rest of the _finalize() logic, which fills in the + This has to be called after the rest of the ``_finalize()`` logic, which fills in the shortcut field names and also processes the expressions for computed fields. This validates: @@ -537,8 +564,7 @@ def _validate_spec(self): This logic only runs when the dataset is a ``DataFrame``. """ - # If we have a concrete dataset (not a URL) make sure field names in the encodings - # are present in the data or formulas + # Only validate columns if the data is a pd.DataFrame. if isinstance(self.data, pd.DataFrame): # Find columns with visual encodings encoded_columns = set() @@ -549,7 +575,9 @@ def _validate_spec(self): if channel is not jst.undefined: field = channel.field if field is jst.undefined: - raise FieldError("Missing field/column name for channel: {}".format(channel_name)) + raise FieldError( + "Missing field/column name for channel: {}".format(channel_name) + ) else: encoded_columns.add(field) # Find columns in the data @@ -565,7 +593,9 @@ def _validate_spec(self): # Find columns in the visual encoding that are not in the data missing_columns = encoded_columns - data_columns if missing_columns: - raise FieldError("Fields/columns not found in the data: {}".format(missing_columns)) + raise FieldError( + "Fields/columns not found in the data: {}".format(missing_columns) + ) class Chart(TopLevelMixin, schema.ExtendedUnitSpec): @@ -577,11 +607,14 @@ class Chart(TopLevelMixin, schema.ExtendedUnitSpec): transform = jst.JSONInstance(Transform, help=schema.ExtendedUnitSpec.transform.help) mark = schema.Mark(default_value='point', help="""The mark type.""") - max_rows = T.Int( default_value=DEFAULT_MAX_ROWS, help="Maximum number of rows in the dataset to accept." ) + validate_columns = T.Bool( + default_value=True, + help="Raise FieldError if the data is a DataFrame and there are missing columns." + ) @property def data(self): @@ -597,7 +630,7 @@ def data(self, new): else: raise TypeError('Expected DataFrame or altair.Data, got: {0}'.format(new)) - _skip_on_export = ['data', '_data', 'max_rows'] + _skip_on_export = ['data', '_data', 'max_rows', 'validate_columns'] def __init__(self, data=None, **kwargs): super(Chart, self).__init__(**kwargs) @@ -671,16 +704,6 @@ def encode(self, *args, **kwargs): """Define the encoding for the Chart.""" return update_subtraits(self, 'encoding', *args, **kwargs) - def _finalize(self, **kwargs): - self._finalize_data() - # data comes from wrappers, but self.data overrides this if defined - if self.data is not None: - kwargs['data'] = self.data - super(Chart, self)._finalize(**kwargs) - # After the rest of _finalize() has run, validate the spec. - # This is last as field names are not yet filled in from shortcuts until now. - self._validate_spec() - def __add__(self, other): if isinstance(other, Chart): lc = LayeredChart() @@ -732,6 +755,10 @@ class LayeredChart(TopLevelMixin, schema.LayerSpec): default_value=DEFAULT_MAX_ROWS, help="Maximum number of rows in the dataset to accept." ) + validate_columns = T.Bool( + default_value=True, + help="Raise FieldError if the data is a DataFrame and there are missing columns." + ) @property def data(self): @@ -747,7 +774,7 @@ def data(self, new): else: raise TypeError('Expected DataFrame or altair.Data, got: {0}'.format(new)) - _skip_on_export = ['data', '_data', 'max_rows'] + _skip_on_export = ['data', '_data', 'max_rows', 'validate_columns'] def __init__(self, data=None, **kwargs): super(LayeredChart, self).__init__(**kwargs) @@ -760,13 +787,6 @@ def set_layers(self, *layers): self.layers = list(layers) return self - def _finalize(self, **kwargs): - self._finalize_data() - # data comes from wrappers, but self.data overrides this if defined - if self.data is not None: - kwargs['data'] = self.data - super(LayeredChart, self)._finalize(**kwargs) - def __iadd__(self, layer): if self.layers is jst.undefined: self.layers = [layer] @@ -789,6 +809,10 @@ class FacetedChart(TopLevelMixin, schema.FacetSpec): default_value=DEFAULT_MAX_ROWS, help="Maximum number of rows in the dataset to accept." ) + validate_columns = T.Bool( + default_value=True, + help="Raise FieldError if the data is a DataFrame and there are missing columns." + ) @property def data(self): @@ -804,7 +828,7 @@ def data(self, new): else: raise TypeError('Expected DataFrame or altair.Data, got: {0}'.format(new)) - _skip_on_export = ['data', '_data', 'max_rows'] + _skip_on_export = ['data', '_data', 'max_rows', 'validate_columns'] def __init__(self, data=None, **kwargs): super(FacetedChart, self).__init__(**kwargs) @@ -817,10 +841,3 @@ def __dir__(self): def set_facet(self, *args, **kwargs): """Define the facet encoding for the Chart.""" return update_subtraits(self, 'facet', *args, **kwargs) - - def _finalize(self, **kwargs): - self._finalize_data() - # data comes from wrappers, but self.data overrides this if defined - if self.data is not None: - kwargs['data'] = self.data - super(FacetedChart, self)._finalize(**kwargs) diff --git a/altair/v1/tests/test_api.py b/altair/v1/tests/test_api.py index 626db278f2..df023511c0 100644 --- a/altair/v1/tests/test_api.py +++ b/altair/v1/tests/test_api.py @@ -594,15 +594,23 @@ def test_validate_spec(): # Make sure we catch channels with no field specified c = make_chart() c.encode(Color()) + assert isinstance(c.to_dict(), dict) + assert isinstance(c.to_dict(validate_columns=False), dict) with pytest.raises(FieldError): - c.to_dict() + c.to_dict(validate_columns=True) + c.validate_columns = False + assert isinstance(c.to_dict(validate_columns=True), dict) # Make sure we catch encoded fields not in the data c = make_chart() c.encode(x='x', y='y', color='z') c.encode(color='z') + assert isinstance(c.to_dict(), dict) + assert isinstance(c.to_dict(validate_columns=False), dict) with pytest.raises(FieldError): - c.to_dict() + c.to_dict(validate_columns=True) + c.validate_columns = False + assert isinstance(c.to_dict(validate_columns=True), dict) # Make sure we can resolve computed fields c = make_chart()