Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check field names #399

Merged
merged 6 commits into from
Oct 17, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions altair/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
RangeFilter,
OneOfFilter,
MaxRowsExceeded,
FieldError,
enable_mime_rendering,
disable_mime_rendering
)
Expand Down
160 changes: 118 additions & 42 deletions altair/v1/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ class MaxRowsExceeded(Exception):
"""Raised if the number of rows in the dataset is too large."""
pass

class FieldError(Exception):
"""Raised if a channel has a field related error.

This is raised if a channel has no field name or if the field name is
not found as the column name of the ``DataFrame``.
"""




DEFAULT_MAX_ROWS = 5000

#*************************************************************************
Expand All @@ -69,7 +79,7 @@ class MaxRowsExceeded(Exception):
# This is added to TopLevelMixin as a method if MIME rendering is enabled
def _repr_mimebundle_(self, include, exclude, **kwargs):
"""Return a MIME-bundle for rich display in the Jupyter Notebook."""
spec = self.to_dict()
spec = self.to_dict(validate_columns=True)
bundle = create_vegalite_mime_bundle(spec)
return bundle

Expand Down Expand Up @@ -97,6 +107,7 @@ def disable_mime_rendering():
#*************************************************************************
# Channel Aliases
#*************************************************************************

from .schema import X, Y, Row, Column, Color, Size, Shape, Text, Label, Detail, Opacity, Order, Path
from .schema import Encoding, Facet

Expand All @@ -120,6 +131,7 @@ def decorate(f):
# - makes field a required first argument of initialization
# - allows expr trait to be an Expression and processes it properly
#*************************************************************************

class Formula(schema.Formula):
expr = jst.JSONUnion([jst.JSONString(),
jst.JSONInstance(expr.Expression)],
Expand All @@ -139,6 +151,7 @@ def _finalize(self, **kwargs):
# Transform wrapper
# - allows filter trait to be an Expression and processes it properly
#*************************************************************************

class Transform(schema.Transform):
filter = jst.JSONUnion([jst.JSONString(),
jst.JSONInstance(expr.Expression),
Expand All @@ -165,6 +178,7 @@ def _finalize(self, **kwargs):
#*************************************************************************
# Top-level Objects
#*************************************************************************

class TopLevelMixin(object):

@staticmethod
Expand Down Expand Up @@ -253,22 +267,26 @@ def to_html(self, template=None, title=None, **kwargs):
including HTML
"""
from ..utils.html import to_html
return to_html(self.to_dict(), template=template, title=title, **kwargs)
return to_html(self.to_dict(validate_columns=True), template=template, title=title, **kwargs)

def to_dict(self, data=True):
def to_dict(self, data=True, validate_columns=False):
"""Emit the JSON representation for this object as as dict.

Parameters
----------
data : bool
If True (default) then include data in the representation.
validate_columns : bool
If True (default is False) raise FieldError if there are missing or misspelled
column names. This only actually raises if self.validate_columns is also set
(it defaults to True).

Returns
-------
spec : dict
The JSON specification of the chart object.
"""
dct = super(TopLevelMixin, self.clone()).to_dict(data=data)
dct = super(TopLevelMixin, self.clone()).to_dict(data=data, validate_columns=validate_columns)
dct['$schema'] = schema.vegalite_schema_url
return dct

Expand Down Expand Up @@ -424,7 +442,7 @@ def _ipython_display_(self):
"""Use the vega package to display in the classic Jupyter Notebook."""
from IPython.display import display
from vega import VegaLite
display(VegaLite(self.to_dict()))
display(VegaLite(self.to_dict(validate_columns=True)))

def display(self):
"""Display the Chart using the Jupyter Notebook's rich output.
Expand Down Expand Up @@ -471,6 +489,21 @@ def serve(self, ip='127.0.0.1', port=8888, n_retries=50, files=None,
files=files, jupyter_warning=jupyter_warning,
open_browser=open_browser, http_server=http_server)

def _finalize(self, **kwargs):
self._finalize_data()
# data comes from wrappers, but self.data overrides this if defined
if self.data is not None:
kwargs['data'] = self.data
super(TopLevelMixin, self)._finalize(**kwargs)

# Validate columns after the rest of _finalize() has run. This is last as
# field names are not yet filled in from shortcuts until now.
validate_columns = kwargs.get('validate_columns')
# Only do validation if the requested as a keyword arg to `_finalize`
# and the Chart allows it.
if validate_columns and self.validate_columns:
self._validate_columns()

def _finalize_data(self):
"""
This function is called by _finalize() below.
Expand All @@ -481,19 +514,10 @@ def _finalize_data(self):
* Whether the data attribute contains expressions, and if so it extracts
the appropriate data object and generates the appropriate transforms.
"""
# Check to see if data has too many rows.
if isinstance(self.data, pd.DataFrame):
if len(self.data) > self.max_rows:
raise MaxRowsExceeded(
"Your dataset has too many rows and could take a long "
"time to send to the frontend or to render. To override the "
"default maximum rows (%s), set the max_rows property of "
"your Chart to an integer larger than the number of rows "
"in your dataset. Alternatively you could perform aggregations "
"or other data reductions before using it with Altair" % DEFAULT_MAX_ROWS
)

# Handle expressions.
# Handle expressions. This transforms expr.DataFrame object into a set
# of transforms and an actual pd.DataFrame. After this block runs,
# self.data is either a URL or a pd.DataFrame or None.
if isinstance(self.data, expr.DataFrame):
columns = self.data._cols
calculated_cols = self.data._calculated_cols
Expand All @@ -512,6 +536,68 @@ def _finalize_data(self):
else:
self.transform_data(filter=filters)

# If self.data is a pd.DataFrame, check to see if data has too many rows.
if isinstance(self.data, pd.DataFrame):
if len(self.data) > self.max_rows:
raise MaxRowsExceeded(
"Your dataset has too many rows and could take a long "
"time to send to the frontend or to render. To override the "
"default maximum rows (%s), set the max_rows property of "
"your Chart to an integer larger than the number of rows "
"in your dataset. Alternatively you could perform aggregations "
"or other data reductions before using it with Altair" % DEFAULT_MAX_ROWS
)


def _validate_columns(self):
"""Validate the columns in the encoding, but only if if the data is a ``DataFrame``.

This has to be called after the rest of the ``_finalize()`` logic, which fills in the
shortcut field names and also processes the expressions for computed fields.

This validates:

1. That each encoding channel has a field (column name).
2. That the specified field name is present the column names of the ``DataFrame`` or
computed field from transform expressions.

This logic only runs when the dataset is a ``DataFrame``.
"""

# Only validate columns if the data is a pd.DataFrame.
if isinstance(self.data, pd.DataFrame):
# Find columns with visual encodings
encoded_columns = set()
encoding = self.encoding
if encoding is not jst.undefined:
for channel_name in encoding.channel_names:
channel = getattr(encoding, channel_name)
if channel is not jst.undefined:
field = channel.field
if field is jst.undefined:
raise FieldError(
"Missing field/column name for channel: {}".format(channel_name)
)
else:
if field != '*':
encoded_columns.add(field)
# Find columns in the data
data_columns = set(self.data.columns.values)
transform = self.transform
if transform is not jst.undefined:
calculate = transform.calculate
if calculate is not jst.undefined:
for formula in calculate:
field = formula.field
if field is not jst.undefined:
data_columns.add(field)
# Find columns in the visual encoding that are not in the data
missing_columns = encoded_columns - data_columns
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to manually remove "*" from missing columns

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... if it's present

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed now!

if missing_columns:
raise FieldError(
"Fields/columns not found in the data: {}".format(missing_columns)
)


class Chart(TopLevelMixin, schema.ExtendedUnitSpec):
_data = None
Expand All @@ -522,11 +608,14 @@ class Chart(TopLevelMixin, schema.ExtendedUnitSpec):
transform = jst.JSONInstance(Transform,
help=schema.ExtendedUnitSpec.transform.help)
mark = schema.Mark(default_value='point', help="""The mark type.""")

max_rows = T.Int(
default_value=DEFAULT_MAX_ROWS,
help="Maximum number of rows in the dataset to accept."
)
validate_columns = T.Bool(
default_value=True,
help="Raise FieldError if the data is a DataFrame and there are missing columns."
)

def clone(self):
"""
Expand All @@ -550,7 +639,7 @@ def data(self, new):
else:
raise TypeError('Expected DataFrame or altair.Data, got: {0}'.format(new))

_skip_on_export = ['data', '_data', 'max_rows']
_skip_on_export = ['data', '_data', 'max_rows', 'validate_columns']

def __init__(self, data=None, **kwargs):
super(Chart, self).__init__(**kwargs)
Expand Down Expand Up @@ -624,13 +713,6 @@ def encode(self, *args, **kwargs):
"""Define the encoding for the Chart."""
return update_subtraits(self, 'encoding', *args, **kwargs)

def _finalize(self, **kwargs):
self._finalize_data()
# data comes from wrappers, but self.data overrides this if defined
if self.data is not None:
kwargs['data'] = self.data
super(Chart, self)._finalize(**kwargs)

def __add__(self, other):
if isinstance(other, Chart):
lc = LayeredChart()
Expand Down Expand Up @@ -682,6 +764,10 @@ class LayeredChart(TopLevelMixin, schema.LayerSpec):
default_value=DEFAULT_MAX_ROWS,
help="Maximum number of rows in the dataset to accept."
)
validate_columns = T.Bool(
default_value=True,
help="Raise FieldError if the data is a DataFrame and there are missing columns."
)

def clone(self):
"""
Expand All @@ -705,7 +791,7 @@ def data(self, new):
else:
raise TypeError('Expected DataFrame or altair.Data, got: {0}'.format(new))

_skip_on_export = ['data', '_data', 'max_rows']
_skip_on_export = ['data', '_data', 'max_rows', 'validate_columns']

def __init__(self, data=None, **kwargs):
super(LayeredChart, self).__init__(**kwargs)
Expand All @@ -718,13 +804,6 @@ def set_layers(self, *layers):
self.layers = list(layers)
return self

def _finalize(self, **kwargs):
self._finalize_data()
# data comes from wrappers, but self.data overrides this if defined
if self.data is not None:
kwargs['data'] = self.data
super(LayeredChart, self)._finalize(**kwargs)

def __iadd__(self, layer):
if self.layers is jst.undefined:
self.layers = [layer]
Expand All @@ -747,6 +826,10 @@ class FacetedChart(TopLevelMixin, schema.FacetSpec):
default_value=DEFAULT_MAX_ROWS,
help="Maximum number of rows in the dataset to accept."
)
validate_columns = T.Bool(
default_value=True,
help="Raise FieldError if the data is a DataFrame and there are missing columns."
)

def clone(self):
"""
Expand All @@ -770,7 +853,7 @@ def data(self, new):
else:
raise TypeError('Expected DataFrame or altair.Data, got: {0}'.format(new))

_skip_on_export = ['data', '_data', 'max_rows']
_skip_on_export = ['data', '_data', 'max_rows', 'validate_columns']

def __init__(self, data=None, **kwargs):
super(FacetedChart, self).__init__(**kwargs)
Expand All @@ -783,10 +866,3 @@ def __dir__(self):
def set_facet(self, *args, **kwargs):
"""Define the facet encoding for the Chart."""
return update_subtraits(self, 'facet', *args, **kwargs)

def _finalize(self, **kwargs):
self._finalize_data()
# data comes from wrappers, but self.data overrides this if defined
if self.data is not None:
kwargs['data'] = self.data
super(FacetedChart, self)._finalize(**kwargs)
4 changes: 2 additions & 2 deletions altair/v1/examples/tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ def test_json_examples_round_trip(example):
filename, json_dict = example

v = Chart.from_dict(json_dict)
v_dict = v.to_dict()
v_dict = v.to_dict(validate_columns=True)
if '$schema' not in json_dict:
v_dict.pop('$schema')
assert v_dict == json_dict

# code generation discards empty function calls, and so we
# filter these out before comparison
v2 = eval(v.to_python())
v2_dict = v2.to_dict()
v2_dict = v2.to_dict(validate_columns=True)
if '$schema' not in json_dict:
v2_dict.pop('$schema')
assert v2_dict == remove_empty_fields(json_dict)
Expand Down
37 changes: 37 additions & 0 deletions altair/v1/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,3 +587,40 @@ def test_enable_mime_rendering():
enable_mime_rendering()
disable_mime_rendering()
disable_mime_rendering()


def test_validate_spec():

# Make sure we catch channels with no field specified
c = make_chart()
c.encode(Color())
assert isinstance(c.to_dict(), dict)
assert isinstance(c.to_dict(validate_columns=False), dict)
with pytest.raises(FieldError):
c.to_dict(validate_columns=True)
c.validate_columns = False
assert isinstance(c.to_dict(validate_columns=True), dict)

# Make sure we catch encoded fields not in the data
c = make_chart()
c.encode(x='x', y='y', color='z')
c.encode(color='z')
assert isinstance(c.to_dict(), dict)
assert isinstance(c.to_dict(validate_columns=False), dict)
with pytest.raises(FieldError):
c.to_dict(validate_columns=True)
c.validate_columns = False
assert isinstance(c.to_dict(validate_columns=True), dict)

c = make_chart()
c.encode(x='x', y='count(*)')
assert isinstance(c.to_dict(validate_columns=True), dict)

# Make sure we can resolve computed fields
c = make_chart()
c.encode(x='x', y='y', color='z')
c.encode(color='z')
c.transform_data(
calculate=[Formula('z', 'sin(((2*PI)*datum.x))')]
)
assert isinstance(c.to_dict(), dict)