From 220ca512ce520883384020ba56abe5395605f700 Mon Sep 17 00:00:00 2001 From: "Brian E. Granger" Date: Mon, 2 Oct 2017 14:43:41 -0700 Subject: [PATCH 1/6] Adding _validate_spec to Chart --- altair/v1/api.py | 58 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/altair/v1/api.py b/altair/v1/api.py index 925d2dfa6..4894bfba7 100644 --- a/altair/v1/api.py +++ b/altair/v1/api.py @@ -58,6 +58,16 @@ class MaxRowsExceeded(Exception): """Raised if the number of rows in the dataset is too large.""" pass +class FieldError(Exception): + """Raised if a channel has a field related error. + + This is raised if a channel has no field name or if the field name is + not found as the column name of the ``DataFrame``. + """ + + + + DEFAULT_MAX_ROWS = 5000 #************************************************************************* @@ -512,6 +522,51 @@ def _finalize_data(self): else: self.transform_data(filter=filters) + def _validate_spec(self): + """Validate the spec. + + This has to be called after the rest of the _finalize() logic, which fills in the + shortcut field names and also processes the expressions for computed fields. + + This validates: + + 1. That each encoding channel has a field (column name). + 2. That the specified field name is present the column names of the ``DataFrame`` or + computed field from transform expressions. + + This logic only runs when the dataset is a ``DataFrame``. + """ + + # If we have a concrete dataset (not a URL) make sure field names in the encodings + # are present in the data or formulas + if isinstance(self.data, pd.DataFrame): + # Find columns with visual encodings + encoded_columns = set() + encoding = self.encoding + if encoding is not jst.undefined: + for channel_name in encoding.channel_names: + channel = getattr(encoding, channel_name) + if channel is not jst.undefined: + field = channel.field + if field is jst.undefined: + raise FieldError("Missing field/column name for channel: {}".format(channel_name)) + else: + encoded_columns.add(field) + # Find columns in the data + data_columns = set(self.data.columns.values) + transform = self.transform + if transform is not jst.undefined: + calculate = transform.calculate + if calculate is not jst.undefined: + for formula in calculate: + field = formula.field + if field is not jst.undefined: + data_columns.add(field) + # Find columns in the visual encoding that are not in the data + missing_columns = encoded_columns - data_columns + if missing_columns: + raise FieldError("Fields/columns not found in the data: {}".format(missing_columns)) + class Chart(TopLevelMixin, schema.ExtendedUnitSpec): _data = None @@ -630,6 +685,9 @@ def _finalize(self, **kwargs): if self.data is not None: kwargs['data'] = self.data super(Chart, self)._finalize(**kwargs) + # After the rest of _finalize() has run, validate the spec. + # This is last as field names are not yet filled in from shortcuts until now. + self._validate_spec() def __add__(self, other): if isinstance(other, Chart): From 3baba095893c3e6c10a10d98a12f732d7613d4d1 Mon Sep 17 00:00:00 2001 From: "Brian E. Granger" Date: Mon, 2 Oct 2017 14:45:42 -0700 Subject: [PATCH 2/6] Starting tests of _validate_spec --- altair/v1/tests/test_api.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/altair/v1/tests/test_api.py b/altair/v1/tests/test_api.py index 75dacfce4..0c326436b 100644 --- a/altair/v1/tests/test_api.py +++ b/altair/v1/tests/test_api.py @@ -587,3 +587,9 @@ def test_enable_mime_rendering(): enable_mime_rendering() disable_mime_rendering() disable_mime_rendering() + + +def test_validate_spec(): + c = make_chart() + + From f0fb38180041542db1a8403fa8329990f3d8f76c Mon Sep 17 00:00:00 2001 From: "Brian E. Granger" Date: Mon, 2 Oct 2017 14:58:46 -0700 Subject: [PATCH 3/6] Tested _validate_spec --- altair/v1/__init__.py | 1 + altair/v1/tests/test_api.py | 21 ++++++++++++++++++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/altair/v1/__init__.py b/altair/v1/__init__.py index 62db76eb4..f67374d61 100644 --- a/altair/v1/__init__.py +++ b/altair/v1/__init__.py @@ -42,6 +42,7 @@ RangeFilter, OneOfFilter, MaxRowsExceeded, + FieldError, enable_mime_rendering, disable_mime_rendering ) diff --git a/altair/v1/tests/test_api.py b/altair/v1/tests/test_api.py index 0c326436b..626db278f 100644 --- a/altair/v1/tests/test_api.py +++ b/altair/v1/tests/test_api.py @@ -588,8 +588,27 @@ def test_enable_mime_rendering(): disable_mime_rendering() disable_mime_rendering() - + def test_validate_spec(): + + # Make sure we catch channels with no field specified c = make_chart() + c.encode(Color()) + with pytest.raises(FieldError): + c.to_dict() + # Make sure we catch encoded fields not in the data + c = make_chart() + c.encode(x='x', y='y', color='z') + c.encode(color='z') + with pytest.raises(FieldError): + c.to_dict() + # Make sure we can resolve computed fields + c = make_chart() + c.encode(x='x', y='y', color='z') + c.encode(color='z') + c.transform_data( + calculate=[Formula('z', 'sin(((2*PI)*datum.x))')] + ) + assert isinstance(c.to_dict(), dict) From 15639862129c5377587cebc25c5357dad454ac1d Mon Sep 17 00:00:00 2001 From: "Brian E. Granger" Date: Tue, 17 Oct 2017 10:51:07 -0700 Subject: [PATCH 4/6] Add tests and new validate_column logic. --- altair/v1/api.py | 121 ++++++++++++++++++++---------------- altair/v1/tests/test_api.py | 12 +++- 2 files changed, 79 insertions(+), 54 deletions(-) diff --git a/altair/v1/api.py b/altair/v1/api.py index 4894bfba7..e616e1c57 100644 --- a/altair/v1/api.py +++ b/altair/v1/api.py @@ -79,7 +79,7 @@ class FieldError(Exception): # This is added to TopLevelMixin as a method if MIME rendering is enabled def _repr_mimebundle_(self, include, exclude, **kwargs): """Return a MIME-bundle for rich display in the Jupyter Notebook.""" - spec = self.to_dict() + spec = self.to_dict(validate_columns=True) bundle = create_vegalite_mime_bundle(spec) return bundle @@ -107,6 +107,7 @@ def disable_mime_rendering(): #************************************************************************* # Channel Aliases #************************************************************************* + from .schema import X, Y, Row, Column, Color, Size, Shape, Text, Label, Detail, Opacity, Order, Path from .schema import Encoding, Facet @@ -130,6 +131,7 @@ def decorate(f): # - makes field a required first argument of initialization # - allows expr trait to be an Expression and processes it properly #************************************************************************* + class Formula(schema.Formula): expr = jst.JSONUnion([jst.JSONString(), jst.JSONInstance(expr.Expression)], @@ -149,6 +151,7 @@ def _finalize(self, **kwargs): # Transform wrapper # - allows filter trait to be an Expression and processes it properly #************************************************************************* + class Transform(schema.Transform): filter = jst.JSONUnion([jst.JSONString(), jst.JSONInstance(expr.Expression), @@ -175,6 +178,7 @@ def _finalize(self, **kwargs): #************************************************************************* # Top-level Objects #************************************************************************* + class TopLevelMixin(object): @staticmethod @@ -263,22 +267,26 @@ def to_html(self, template=None, title=None, **kwargs): including HTML """ from ..utils.html import to_html - return to_html(self.to_dict(), template=template, title=title, **kwargs) + return to_html(self.to_dict(validate_columns=True), template=template, title=title, **kwargs) - def to_dict(self, data=True): + def to_dict(self, data=True, validate_columns=False): """Emit the JSON representation for this object as as dict. Parameters ---------- data : bool If True (default) then include data in the representation. + validate_columns : bool + If True (default is False) raise FieldError if there are missing or misspelled + column names. This only actually raises if self.validate_columns is also set + (it defaults to True). Returns ------- spec : dict The JSON specification of the chart object. """ - dct = super(TopLevelMixin, self.clone()).to_dict(data=data) + dct = super(TopLevelMixin, self.clone()).to_dict(data=data, validate_columns=validate_columns) dct['$schema'] = schema.vegalite_schema_url return dct @@ -434,7 +442,7 @@ def _ipython_display_(self): """Use the vega package to display in the classic Jupyter Notebook.""" from IPython.display import display from vega import VegaLite - display(VegaLite(self.to_dict())) + display(VegaLite(self.to_dict(validate_columns=True))) def display(self): """Display the Chart using the Jupyter Notebook's rich output. @@ -481,6 +489,21 @@ def serve(self, ip='127.0.0.1', port=8888, n_retries=50, files=None, files=files, jupyter_warning=jupyter_warning, open_browser=open_browser, http_server=http_server) + def _finalize(self, **kwargs): + self._finalize_data() + # data comes from wrappers, but self.data overrides this if defined + if self.data is not None: + kwargs['data'] = self.data + super(TopLevelMixin, self)._finalize(**kwargs) + + # Validate columns after the rest of _finalize() has run. This is last as + # field names are not yet filled in from shortcuts until now. + validate_columns = kwargs.get('validate_columns') + # Only do validation if the requested as a keyword arg to `_finalize` + # and the Chart allows it. + if validate_columns and self.validate_columns: + self._validate_columns() + def _finalize_data(self): """ This function is called by _finalize() below. @@ -491,19 +514,10 @@ def _finalize_data(self): * Whether the data attribute contains expressions, and if so it extracts the appropriate data object and generates the appropriate transforms. """ - # Check to see if data has too many rows. - if isinstance(self.data, pd.DataFrame): - if len(self.data) > self.max_rows: - raise MaxRowsExceeded( - "Your dataset has too many rows and could take a long " - "time to send to the frontend or to render. To override the " - "default maximum rows (%s), set the max_rows property of " - "your Chart to an integer larger than the number of rows " - "in your dataset. Alternatively you could perform aggregations " - "or other data reductions before using it with Altair" % DEFAULT_MAX_ROWS - ) - # Handle expressions. + # Handle expressions. This transforms expr.DataFrame object into a set + # of transforms and an actual pd.DataFrame. After this block runs, + # self.data is either a URL or a pd.DataFrame or None. if isinstance(self.data, expr.DataFrame): columns = self.data._cols calculated_cols = self.data._calculated_cols @@ -522,10 +536,23 @@ def _finalize_data(self): else: self.transform_data(filter=filters) - def _validate_spec(self): - """Validate the spec. + # If self.data is a pd.DataFrame, check to see if data has too many rows. + if isinstance(self.data, pd.DataFrame): + if len(self.data) > self.max_rows: + raise MaxRowsExceeded( + "Your dataset has too many rows and could take a long " + "time to send to the frontend or to render. To override the " + "default maximum rows (%s), set the max_rows property of " + "your Chart to an integer larger than the number of rows " + "in your dataset. Alternatively you could perform aggregations " + "or other data reductions before using it with Altair" % DEFAULT_MAX_ROWS + ) + + + def _validate_columns(self): + """Validate the columns in the encoding, but only if if the data is a ``DataFrame``. - This has to be called after the rest of the _finalize() logic, which fills in the + This has to be called after the rest of the ``_finalize()`` logic, which fills in the shortcut field names and also processes the expressions for computed fields. This validates: @@ -537,8 +564,7 @@ def _validate_spec(self): This logic only runs when the dataset is a ``DataFrame``. """ - # If we have a concrete dataset (not a URL) make sure field names in the encodings - # are present in the data or formulas + # Only validate columns if the data is a pd.DataFrame. if isinstance(self.data, pd.DataFrame): # Find columns with visual encodings encoded_columns = set() @@ -549,7 +575,9 @@ def _validate_spec(self): if channel is not jst.undefined: field = channel.field if field is jst.undefined: - raise FieldError("Missing field/column name for channel: {}".format(channel_name)) + raise FieldError( + "Missing field/column name for channel: {}".format(channel_name) + ) else: encoded_columns.add(field) # Find columns in the data @@ -565,7 +593,9 @@ def _validate_spec(self): # Find columns in the visual encoding that are not in the data missing_columns = encoded_columns - data_columns if missing_columns: - raise FieldError("Fields/columns not found in the data: {}".format(missing_columns)) + raise FieldError( + "Fields/columns not found in the data: {}".format(missing_columns) + ) class Chart(TopLevelMixin, schema.ExtendedUnitSpec): @@ -577,11 +607,14 @@ class Chart(TopLevelMixin, schema.ExtendedUnitSpec): transform = jst.JSONInstance(Transform, help=schema.ExtendedUnitSpec.transform.help) mark = schema.Mark(default_value='point', help="""The mark type.""") - max_rows = T.Int( default_value=DEFAULT_MAX_ROWS, help="Maximum number of rows in the dataset to accept." ) + validate_columns = T.Bool( + default_value=True, + help="Raise FieldError if the data is a DataFrame and there are missing columns." + ) def clone(self): """ @@ -605,7 +638,7 @@ def data(self, new): else: raise TypeError('Expected DataFrame or altair.Data, got: {0}'.format(new)) - _skip_on_export = ['data', '_data', 'max_rows'] + _skip_on_export = ['data', '_data', 'max_rows', 'validate_columns'] def __init__(self, data=None, **kwargs): super(Chart, self).__init__(**kwargs) @@ -679,16 +712,6 @@ def encode(self, *args, **kwargs): """Define the encoding for the Chart.""" return update_subtraits(self, 'encoding', *args, **kwargs) - def _finalize(self, **kwargs): - self._finalize_data() - # data comes from wrappers, but self.data overrides this if defined - if self.data is not None: - kwargs['data'] = self.data - super(Chart, self)._finalize(**kwargs) - # After the rest of _finalize() has run, validate the spec. - # This is last as field names are not yet filled in from shortcuts until now. - self._validate_spec() - def __add__(self, other): if isinstance(other, Chart): lc = LayeredChart() @@ -740,6 +763,10 @@ class LayeredChart(TopLevelMixin, schema.LayerSpec): default_value=DEFAULT_MAX_ROWS, help="Maximum number of rows in the dataset to accept." ) + validate_columns = T.Bool( + default_value=True, + help="Raise FieldError if the data is a DataFrame and there are missing columns." + ) def clone(self): """ @@ -763,7 +790,7 @@ def data(self, new): else: raise TypeError('Expected DataFrame or altair.Data, got: {0}'.format(new)) - _skip_on_export = ['data', '_data', 'max_rows'] + _skip_on_export = ['data', '_data', 'max_rows', 'validate_columns'] def __init__(self, data=None, **kwargs): super(LayeredChart, self).__init__(**kwargs) @@ -776,13 +803,6 @@ def set_layers(self, *layers): self.layers = list(layers) return self - def _finalize(self, **kwargs): - self._finalize_data() - # data comes from wrappers, but self.data overrides this if defined - if self.data is not None: - kwargs['data'] = self.data - super(LayeredChart, self)._finalize(**kwargs) - def __iadd__(self, layer): if self.layers is jst.undefined: self.layers = [layer] @@ -805,6 +825,10 @@ class FacetedChart(TopLevelMixin, schema.FacetSpec): default_value=DEFAULT_MAX_ROWS, help="Maximum number of rows in the dataset to accept." ) + validate_columns = T.Bool( + default_value=True, + help="Raise FieldError if the data is a DataFrame and there are missing columns." + ) def clone(self): """ @@ -828,7 +852,7 @@ def data(self, new): else: raise TypeError('Expected DataFrame or altair.Data, got: {0}'.format(new)) - _skip_on_export = ['data', '_data', 'max_rows'] + _skip_on_export = ['data', '_data', 'max_rows', 'validate_columns'] def __init__(self, data=None, **kwargs): super(FacetedChart, self).__init__(**kwargs) @@ -841,10 +865,3 @@ def __dir__(self): def set_facet(self, *args, **kwargs): """Define the facet encoding for the Chart.""" return update_subtraits(self, 'facet', *args, **kwargs) - - def _finalize(self, **kwargs): - self._finalize_data() - # data comes from wrappers, but self.data overrides this if defined - if self.data is not None: - kwargs['data'] = self.data - super(FacetedChart, self)._finalize(**kwargs) diff --git a/altair/v1/tests/test_api.py b/altair/v1/tests/test_api.py index 626db278f..df023511c 100644 --- a/altair/v1/tests/test_api.py +++ b/altair/v1/tests/test_api.py @@ -594,15 +594,23 @@ def test_validate_spec(): # Make sure we catch channels with no field specified c = make_chart() c.encode(Color()) + assert isinstance(c.to_dict(), dict) + assert isinstance(c.to_dict(validate_columns=False), dict) with pytest.raises(FieldError): - c.to_dict() + c.to_dict(validate_columns=True) + c.validate_columns = False + assert isinstance(c.to_dict(validate_columns=True), dict) # Make sure we catch encoded fields not in the data c = make_chart() c.encode(x='x', y='y', color='z') c.encode(color='z') + assert isinstance(c.to_dict(), dict) + assert isinstance(c.to_dict(validate_columns=False), dict) with pytest.raises(FieldError): - c.to_dict() + c.to_dict(validate_columns=True) + c.validate_columns = False + assert isinstance(c.to_dict(validate_columns=True), dict) # Make sure we can resolve computed fields c = make_chart() From 0e1a1f8de554959a5ee127b4d7fd544835471ab0 Mon Sep 17 00:00:00 2001 From: "Brian E. Granger" Date: Tue, 17 Oct 2017 11:43:56 -0700 Subject: [PATCH 5/6] Fix * in validate_columns --- altair/v1/api.py | 3 ++- altair/v1/tests/test_api.py | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/altair/v1/api.py b/altair/v1/api.py index e616e1c57..d50d47239 100644 --- a/altair/v1/api.py +++ b/altair/v1/api.py @@ -579,7 +579,8 @@ def _validate_columns(self): "Missing field/column name for channel: {}".format(channel_name) ) else: - encoded_columns.add(field) + if field != '*': + encoded_columns.add(field) # Find columns in the data data_columns = set(self.data.columns.values) transform = self.transform diff --git a/altair/v1/tests/test_api.py b/altair/v1/tests/test_api.py index df023511c..4d0b27431 100644 --- a/altair/v1/tests/test_api.py +++ b/altair/v1/tests/test_api.py @@ -612,6 +612,10 @@ def test_validate_spec(): c.validate_columns = False assert isinstance(c.to_dict(validate_columns=True), dict) + c = make_chart() + c.encode(x='x', y='count(*)') + assert isinstance(c.to_dict(validate_columns=True), dict) + # Make sure we can resolve computed fields c = make_chart() c.encode(x='x', y='y', color='z') From 1a0d3ea433cddf64777e92084374ad0d58f80fa1 Mon Sep 17 00:00:00 2001 From: "Brian E. Granger" Date: Tue, 17 Oct 2017 11:45:58 -0700 Subject: [PATCH 6/6] validate_columns in our examples --- altair/v1/examples/tests/test_examples.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/altair/v1/examples/tests/test_examples.py b/altair/v1/examples/tests/test_examples.py index 745fda949..cd383298f 100644 --- a/altair/v1/examples/tests/test_examples.py +++ b/altair/v1/examples/tests/test_examples.py @@ -19,7 +19,7 @@ def test_json_examples_round_trip(example): filename, json_dict = example v = Chart.from_dict(json_dict) - v_dict = v.to_dict() + v_dict = v.to_dict(validate_columns=True) if '$schema' not in json_dict: v_dict.pop('$schema') assert v_dict == json_dict @@ -27,7 +27,7 @@ def test_json_examples_round_trip(example): # code generation discards empty function calls, and so we # filter these out before comparison v2 = eval(v.to_python()) - v2_dict = v2.to_dict() + v2_dict = v2.to_dict(validate_columns=True) if '$schema' not in json_dict: v2_dict.pop('$schema') assert v2_dict == remove_empty_fields(json_dict)