Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: dataset extra import/export #17740

Merged
merged 2 commits into from
Dec 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion superset/datasets/commands/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

logger = logging.getLogger(__name__)

JSON_KEYS = {"params", "template_params"}
JSON_KEYS = {"params", "template_params", "extra"}


class ExportDatasetsCommand(ExportModelsCommand):
Expand Down
5 changes: 2 additions & 3 deletions superset/datasets/commands/importers/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
CHUNKSIZE = 512
VARCHAR = re.compile(r"VARCHAR\((\d+)\)", re.IGNORECASE)

JSON_KEYS = {"params", "template_params"}
JSON_KEYS = {"params", "template_params", "extra"}


type_map = {
Expand Down Expand Up @@ -97,8 +97,7 @@ def import_dataset(
logger.info("Unable to encode `%s` field: %s", key, config[key])
for key in ("metrics", "columns"):
for attributes in config.get(key, []):
# should be a dictionary, but in initial exports this was a string
if isinstance(attributes.get("extra"), dict):
if attributes.get("extra") is not None:
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved
try:
attributes["extra"] = json.dumps(attributes["extra"])
except TypeError:
Expand Down
31 changes: 27 additions & 4 deletions superset/datasets/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import re
from typing import Any, Dict

from flask_babel import lazy_gettext as _
from marshmallow import fields, Schema, ValidationError
from marshmallow import fields, pre_load, Schema, ValidationError
from marshmallow.validate import Length

get_delete_ids_schema = {"type": "array", "items": {"type": "integer"}}
Expand Down Expand Up @@ -130,9 +132,19 @@ class DatasetRelatedObjectsResponse(Schema):


class ImportV1ColumnSchema(Schema):
# pylint: disable=no-self-use, unused-argument
@pre_load
def fix_extra(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
"""
Fix for extra initially beeing exported as a string.
"""
if isinstance(data.get("extra"), str):
data["extra"] = json.loads(data["extra"])

return data

column_name = fields.String(required=True)
# extra was initially exported incorrectly as a string
extra = fields.Raw(allow_none=True)
extra = fields.Dict(allow_none=True)
verbose_name = fields.String(allow_none=True)
is_dttm = fields.Boolean(default=False, allow_none=True)
is_active = fields.Boolean(default=True, allow_none=True)
Expand All @@ -156,6 +168,17 @@ class ImportV1MetricSchema(Schema):


class ImportV1DatasetSchema(Schema):
# pylint: disable=no-self-use, unused-argument
@pre_load
def fix_extra(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
"""
Fix for extra initially beeing exported as a string.
"""
if isinstance(data.get("extra"), str):
data["extra"] = json.loads(data["extra"])

return data

table_name = fields.String(required=True)
main_dttm_col = fields.String(allow_none=True)
description = fields.String(allow_none=True)
Expand All @@ -168,7 +191,7 @@ class ImportV1DatasetSchema(Schema):
template_params = fields.Dict(allow_none=True)
filter_select_enabled = fields.Boolean()
fetch_values_predicate = fields.String(allow_none=True)
extra = fields.String(allow_none=True)
extra = fields.Dict(allow_none=True)
uuid = fields.UUID(required=True)
columns = fields.List(fields.Nested(ImportV1ColumnSchema))
metrics = fields.List(fields.Nested(ImportV1MetricSchema))
Expand Down
5 changes: 4 additions & 1 deletion tests/integration_tests/datasets/commands_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,10 @@ def test_import_v1_dataset(self, mock_g):
assert dataset.template_params == "{}"
assert dataset.filter_select_enabled
assert dataset.fetch_values_predicate is None
assert dataset.extra == "dttm > sysdate() -10 "
assert (
dataset.extra
== '{"certification": {"certified_by": "Data Platform Team", "details": "This table is the source of truth."}, "warning_markdown": "This is a warning."}'
)

# user should be included as one of the owners
assert dataset.owners == [mock_g.user]
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/fixtures/importexport.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@
"template_params": {},
"filter_select_enabled": True,
"fetch_values_predicate": None,
"extra": "dttm > sysdate() -10 ",
"extra": '{ "certification": { "certified_by": "Data Platform Team", "details": "This table is the source of truth." }, "warning_markdown": "This is a warning." }',
"metrics": [
{
"metric_name": "count",
Expand Down
3 changes: 2 additions & 1 deletion tests/unit_tests/datasets/commands/export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def test_export(app_context: None, session: Session) -> None:
answer: '42'
filter_select_enabled: 1
fetch_values_predicate: foo IN (1, 2)
extra: '{{\"warning_markdown\": \"*WARNING*\"}}'
extra:
warning_markdown: '*WARNING*'
uuid: null
metrics:
- metric_name: cnt
Expand Down
23 changes: 15 additions & 8 deletions tests/unit_tests/datasets/commands/importers/v1/import_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

from sqlalchemy.orm.session import Session

from superset.datasets.schemas import ImportV1DatasetSchema


def test_import_(app_context: None, session: Session) -> None:
"""
Expand Down Expand Up @@ -56,7 +58,7 @@ def test_import_(app_context: None, session: Session) -> None:
"template_params": {"answer": "42",},
"filter_select_enabled": True,
"fetch_values_predicate": "foo IN (1, 2)",
"extra": '{"warning_markdown": "*WARNING*"}',
"extra": {"warning_markdown": "*WARNING*"},
"uuid": dataset_uuid,
"metrics": [
{
Expand Down Expand Up @@ -147,7 +149,8 @@ def test_import_column_extra_is_string(app_context: None, session: Session) -> N
session.flush()

dataset_uuid = uuid.uuid4()
config: Dict[str, Any] = {
yaml_config: Dict[str, Any] = {
"version": "1.0.0",
"table_name": "my_table",
"main_dttm_col": "ds",
"description": "This is the description",
Expand All @@ -171,20 +174,24 @@ def test_import_column_extra_is_string(app_context: None, session: Session) -> N
{
"column_name": "profit",
"verbose_name": None,
"is_dttm": None,
"is_active": None,
"is_dttm": False,
"is_active": True,
"type": "INTEGER",
"groupby": None,
"filterable": None,
"groupby": False,
"filterable": False,
"expression": "revenue-expenses",
"description": None,
"python_date_format": None,
"extra": '{"certified_by": "User"}',
}
],
"database_uuid": database.uuid,
"database_id": database.id,
}

sqla_table = import_dataset(session, config)
schema = ImportV1DatasetSchema()
dataset_config = schema.load(yaml_config)
dataset_config["database_id"] = database.id
sqla_table = import_dataset(session, dataset_config)

assert sqla_table.columns[0].extra == '{"certified_by": "User"}'
assert sqla_table.extra == '{"warning_markdown": "*WARNING*"}'