Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(viz): improve dtype inference logic #12933

Merged
merged 5 commits into from
Feb 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 7 additions & 21 deletions superset-frontend/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions superset-frontend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@
"@superset-ui/legacy-preset-chart-big-number": "^0.17.5",
"@superset-ui/legacy-preset-chart-deckgl": "^0.4.1",
"@superset-ui/legacy-preset-chart-nvd3": "^0.17.5",
"@superset-ui/plugin-chart-echarts": "^0.17.5",
"@superset-ui/plugin-chart-table": "^0.17.5",
"@superset-ui/plugin-chart-echarts": "^0.17.6",
"@superset-ui/plugin-chart-table": "^0.17.6",
"@superset-ui/plugin-chart-word-cloud": "^0.17.5",
"@superset-ui/preset-chart-xy": "^0.17.5",
"@vx/responsive": "^0.0.195",
Expand Down
2 changes: 1 addition & 1 deletion superset/common/query_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def get_single_payload(
status = payload["status"]
if status != utils.QueryStatus.FAILED:
payload["colnames"] = list(df.columns)
payload["coltypes"] = utils.serialize_pandas_dtypes(df.dtypes)
payload["coltypes"] = utils.extract_dataframe_dtypes(df)
payload["data"] = self.get_data(df)
del payload["df"]

Expand Down
35 changes: 23 additions & 12 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
from flask_appbuilder.security.sqla.models import Role, User
from flask_babel import gettext as __
from flask_babel.speaklater import LazyString
from pandas.api.types import infer_dtype
from sqlalchemy import event, exc, select, Text
from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlalchemy.engine import Connection, Engine
Expand Down Expand Up @@ -1401,19 +1402,29 @@ def get_column_names_from_metrics(metrics: List[Metric]) -> List[str]:
return columns


def serialize_pandas_dtypes(dtypes: List[np.dtype]) -> List[GenericDataType]:
"""Serialize pandas/numpy dtypes to JavaScript types"""
mapping = {
"object": GenericDataType.STRING,
"category": GenericDataType.STRING,
"datetime64[ns]": GenericDataType.TEMPORAL,
"int64": GenericDataType.NUMERIC,
"in32": GenericDataType.NUMERIC,
"float64": GenericDataType.NUMERIC,
"float32": GenericDataType.NUMERIC,
"bool": GenericDataType.BOOLEAN,
def extract_dataframe_dtypes(df: pd.DataFrame) -> List[GenericDataType]:
"""Serialize pandas/numpy dtypes to generic types"""

# omitting string types as those will be the default type
inferred_type_map: Dict[str, GenericDataType] = {
"floating": GenericDataType.NUMERIC,
"integer": GenericDataType.NUMERIC,
"mixed-integer-float": GenericDataType.NUMERIC,
"decimal": GenericDataType.NUMERIC,
"boolean": GenericDataType.BOOLEAN,
"datetime64": GenericDataType.TEMPORAL,
"datetime": GenericDataType.TEMPORAL,
"date": GenericDataType.TEMPORAL,
}
return [mapping.get(str(x), GenericDataType.STRING) for x in dtypes]

generic_types: List[GenericDataType] = []
for column in df.columns:
series = df[column]
inferred_type = infer_dtype(series)
generic_type = inferred_type_map.get(inferred_type, GenericDataType.STRING)
generic_types.append(generic_type)

return generic_types


def indexed(
Expand Down
43 changes: 39 additions & 4 deletions tests/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
import json
import os
import re
from typing import Any, Tuple, List
from unittest.mock import Mock, patch
from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices

import numpy
import numpy as np
import pandas as pd
import pytest
from flask import Flask, g
import marshmallow
Expand All @@ -44,6 +46,7 @@
convert_legacy_filters_into_adhoc,
create_ssl_cert_file,
format_timedelta,
GenericDataType,
get_form_data_token,
get_iterable,
get_email_address_list,
Expand All @@ -57,6 +60,7 @@
merge_request_params,
parse_ssl_cert,
parse_js_uri_path_item,
extract_dataframe_dtypes,
split,
TimeRangeEndpoint,
validate_json,
Expand Down Expand Up @@ -113,9 +117,9 @@ def test_json_iso_dttm_ser(self):
json_iso_dttm_ser("this is not a date")

def test_base_json_conv(self):
assert isinstance(base_json_conv(numpy.bool_(1)), bool) is True
assert isinstance(base_json_conv(numpy.int64(1)), int) is True
assert isinstance(base_json_conv(numpy.array([1, 2, 3])), list) is True
assert isinstance(base_json_conv(np.bool_(1)), bool) is True
assert isinstance(base_json_conv(np.int64(1)), int) is True
assert isinstance(base_json_conv(np.array([1, 2, 3])), list) is True
assert isinstance(base_json_conv(set([1])), list) is True
assert isinstance(base_json_conv(Decimal("1.0")), float) is True
assert isinstance(base_json_conv(uuid.uuid4()), str) is True
Expand Down Expand Up @@ -1066,3 +1070,34 @@ def test_get_form_data_token(self):
assert get_form_data_token({"token": "token_abcdefg1"}) == "token_abcdefg1"
generated_token = get_form_data_token({})
assert re.match(r"^token_[a-z0-9]{8}$", generated_token) is not None

def test_extract_dataframe_dtypes(self):
cols: Tuple[Tuple[str, GenericDataType, List[Any]], ...] = (
("dt", GenericDataType.TEMPORAL, [date(2021, 2, 4), date(2021, 2, 4)]),
(
"dttm",
GenericDataType.TEMPORAL,
[datetime(2021, 2, 4, 1, 1, 1), datetime(2021, 2, 4, 1, 1, 1)],
),
("str", GenericDataType.STRING, ["foo", "foo"]),
("int", GenericDataType.NUMERIC, [1, 1]),
("float", GenericDataType.NUMERIC, [0.5, 0.5]),
("mixed-int-float", GenericDataType.NUMERIC, [0.5, 1.0]),
("bool", GenericDataType.BOOLEAN, [True, False]),
("mixed-str-int", GenericDataType.STRING, ["abc", 1.0]),
("obj", GenericDataType.STRING, [{"a": 1}, {"a": 1}]),
("dt_null", GenericDataType.TEMPORAL, [None, date(2021, 2, 4)]),
(
"dttm_null",
GenericDataType.TEMPORAL,
[None, datetime(2021, 2, 4, 1, 1, 1)],
),
("str_null", GenericDataType.STRING, [None, "foo"]),
("int_null", GenericDataType.NUMERIC, [None, 1]),
("float_null", GenericDataType.NUMERIC, [None, 0.5]),
("bool_null", GenericDataType.BOOLEAN, [None, False]),
("obj_null", GenericDataType.STRING, [None, {"a": 1}]),
)

df = pd.DataFrame(data={col[0]: col[2] for col in cols})
assert extract_dataframe_dtypes(df) == [col[1] for col in cols]