diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 764f3fde70580..96eecd358708b 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1025,12 +1025,15 @@ def select_star( # pylint: disable=too-many-arguments,too-many-locals return sql @classmethod - def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]: + def estimate_statement_cost( + cls, statement: str, cursor: Any, engine: Engine + ) -> Dict[str, Any]: """ Generate a SQL query that estimates the cost of a given statement. :param statement: A single SQL statement :param cursor: Cursor instance + :param engine: Engine instance :return: Dictionary with different costs """ raise Exception("Database does not support cost estimation") @@ -1095,7 +1098,9 @@ def estimate_query_cost( processed_statement = cls.process_statement( statement, database, user_name ) - costs.append(cls.estimate_statement_cost(processed_statement, cursor)) + costs.append( + cls.estimate_statement_cost(processed_statement, cursor, engine) + ) return costs @classmethod @@ -1425,6 +1430,33 @@ def cancel_query( # pylint: disable=unused-argument def parse_sql(cls, sql: str) -> List[str]: return [str(s).strip(" ;") for s in sqlparse.parse(sql)] + @classmethod + def _humanize(cls, value: Any, suffix: str, category: Optional[str] = None) -> str: + try: + value = int(value) + except ValueError: + return str(value) + if category not in ["bytes", None]: + raise Exception(f"Unsupported value category: {category}") + + to_next_prefix = 1000 + prefixes = ["", "K", "M", "G", "T", "P", "E", "Z", "Y"] + suffixes = [p + suffix for p in prefixes] + + if category == "bytes": + to_next_prefix = 1024 + suffixes = ["B" if p == "" else p + "iB" for p in prefixes] + + suffix = suffixes.pop(0) + while value >= to_next_prefix and suffixes: + suffix = suffixes.pop(0) + value //= to_next_prefix + + if not suffix.startswith(" "): + suffix = " " + suffix + + return "{}{}".format(value, suffix).strip() + # schema for adding a database by providing parameters instead of the # full SQLAlchemy URI diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index 30e04c4f2fe9b..fb9e04474d212 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -185,6 +185,47 @@ class BigQueryEngineSpec(BaseEngineSpec): ), } + @classmethod + def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool: + return True + + @classmethod + def estimate_statement_cost( + cls, statement: str, cursor: Any, engine: Engine + ) -> Dict[str, Any]: + # pylint: disable=import-outside-toplevel + from google.cloud import bigquery + from google.oauth2 import service_account + + creds = engine.dialect.credentials_info + credentials = service_account.Credentials.from_service_account_info(creds) + client = bigquery.Client(credentials=credentials) + dry_run_result = client.query( + statement, bigquery.job.QueryJobConfig(dry_run=True) + ) + + return { + "Total bytes processed": dry_run_result.total_bytes_processed, + } + + @classmethod + def query_cost_formatter( + cls, raw_cost: List[Dict[str, Any]] + ) -> List[Dict[str, str]]: + cost = [] + columns = [ + ("Total bytes processed", "", "bytes"), + ] + + for row in raw_cost: + statement_cost = {} + for key, suffix, category in columns: + if key in row: + statement_cost[key] = cls._humanize(row[key], suffix, category) + cost.append(statement_cost) + + return cost + @classmethod def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None @@ -316,16 +357,9 @@ def df_to_sql( :param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method """ - try: - # pylint: disable=import-outside-toplevel - import pandas_gbq - from google.oauth2 import service_account - except ImportError as ex: - raise Exception( - "Could not import libraries `pandas_gbq` or `google.oauth2`, which are " - "required to be installed in your environment in order " - "to upload data to BigQuery" - ) from ex + # pylint: disable=import-outside-toplevel + import pandas_gbq + from google.oauth2 import service_account if not table.schema: raise Exception("The table schema must be defined") diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index f6c6888ee97bb..bbf58a364b270 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -23,6 +23,7 @@ from flask_babel import gettext as __ from sqlalchemy.dialects.postgresql import ARRAY, DOUBLE_PRECISION, ENUM, JSON from sqlalchemy.dialects.postgresql.base import PGInspector +from sqlalchemy.engine.base import Engine from sqlalchemy.types import String from superset.db_engine_specs.base import ( @@ -197,7 +198,9 @@ def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool: return True @classmethod - def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]: + def estimate_statement_cost( + cls, statement: str, cursor: Any, engine: Engine + ) -> Dict[str, Any]: sql = f"EXPLAIN {statement}" cursor.execute(sql) diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 376151587cdee..c49b506f446b0 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -637,7 +637,9 @@ def select_star( # pylint: disable=too-many-arguments ) @classmethod - def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]: + def estimate_statement_cost( + cls, statement: str, cursor: Any, engine: Engine + ) -> Dict[str, Any]: """ Run a SQL query that estimates the cost of a given statement. @@ -675,35 +677,22 @@ def query_cost_formatter( :return: Human readable cost estimate """ - def humanize(value: Any, suffix: str) -> str: - try: - value = int(value) - except ValueError: - return str(value) - - prefixes = ["K", "M", "G", "T", "P", "E", "Z", "Y"] - prefix = "" - to_next_prefix = 1000 - while value > to_next_prefix and prefixes: - prefix = prefixes.pop(0) - value //= to_next_prefix - - return f"{value} {prefix}{suffix}" - cost = [] columns = [ - ("outputRowCount", "Output count", " rows"), - ("outputSizeInBytes", "Output size", "B"), - ("cpuCost", "CPU cost", ""), - ("maxMemory", "Max memory", "B"), - ("networkCost", "Network cost", ""), + ("outputRowCount", "Output count", " rows", None), + ("outputSizeInBytes", "Output size", "", "bytes"), + ("cpuCost", "CPU cost", "", None), + ("maxMemory", "Max memory", "", "bytes"), + ("networkCost", "Network cost", "", None), ] for row in raw_cost: estimate: Dict[str, float] = row.get("estimate", {}) statement_cost = {} - for key, label, suffix in columns: + for key, label, suffix, category in columns: if key in estimate: - statement_cost[label] = humanize(estimate[key], suffix).strip() + statement_cost[label] = cls._humanize( + estimate[key], suffix, category + ).strip() cost.append(statement_cost) return cost diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 4e5f153ad2ab2..b04e31e08472a 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -21,6 +21,7 @@ import simplejson as json from flask import current_app +from sqlalchemy.engine.base import Engine from sqlalchemy.engine.url import make_url, URL from superset.db_engine_specs.base import BaseEngineSpec @@ -118,7 +119,9 @@ def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool: return True @classmethod - def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]: + def estimate_statement_cost( + cls, statement: str, cursor: Any, engine: Engine + ) -> Dict[str, Any]: """ Run a SQL query that estimates the cost of a given statement. @@ -156,35 +159,22 @@ def query_cost_formatter( :return: Human readable cost estimate """ - def humanize(value: Any, suffix: str) -> str: - try: - value = int(value) - except ValueError: - return str(value) - - prefixes = ["K", "M", "G", "T", "P", "E", "Z", "Y"] - prefix = "" - to_next_prefix = 1000 - while value > to_next_prefix and prefixes: - prefix = prefixes.pop(0) - value //= to_next_prefix - - return f"{value} {prefix}{suffix}" - cost = [] columns = [ - ("outputRowCount", "Output count", " rows"), - ("outputSizeInBytes", "Output size", "B"), - ("cpuCost", "CPU cost", ""), - ("maxMemory", "Max memory", "B"), - ("networkCost", "Network cost", ""), + ("outputRowCount", "Output count", " rows", None), + ("outputSizeInBytes", "Output size", "", "bytes"), + ("cpuCost", "CPU cost", "", None), + ("maxMemory", "Max memory", "", "bytes"), + ("networkCost", "Network cost", "", None), ] for row in raw_cost: estimate: Dict[str, float] = row.get("estimate", {}) statement_cost = {} - for key, label, suffix in columns: + for key, label, suffix, category in columns: if key in estimate: - statement_cost[label] = humanize(estimate[key], suffix).strip() + statement_cost[label] = cls._humanize( + estimate[key], suffix, category + ).strip() cost.append(statement_cost) return cost diff --git a/tests/integration_tests/db_engine_specs/bigquery_tests.py b/tests/integration_tests/db_engine_specs/bigquery_tests.py index b7405092c5446..ae6d48a644541 100644 --- a/tests/integration_tests/db_engine_specs/bigquery_tests.py +++ b/tests/integration_tests/db_engine_specs/bigquery_tests.py @@ -366,3 +366,74 @@ def test_calculated_column_in_order_by(self): } sql = table.get_query_str(query_obj) assert "ORDER BY gender_cc ASC" in sql + + @mock.patch("google.cloud.bigquery.Client") + @mock.patch( + "google.oauth2.service_account.Credentials.from_service_account_info", + mock.Mock(), + ) + def test_estimate_statement_cost_select_star(self, mocked_client_class): + mocked_client = mocked_client_class.return_value + mocked_client.query.return_value = mock.Mock() + mocked_client.query.return_value.total_bytes_processed = 123 + cursor = mock.Mock() + engine = mock.Mock() + sql = "SELECT * FROM `some-project.database.table`" + results = BigQueryEngineSpec.estimate_statement_cost(sql, cursor, engine) + mocked_client.query.assert_called_once() + args = mocked_client.query.call_args.args + self.assertEqual(args[0], sql) + self.assertEqual(args[1].dry_run, True) + self.assertEqual( + results, {"Total bytes processed": 123}, + ) + + @mock.patch("google.cloud.bigquery.Client") + @mock.patch( + "google.oauth2.service_account.Credentials.from_service_account_info", + mock.Mock(), + ) + def test_estimate_statement_invalid_syntax(self, mocked_client_class): + from google.api_core.exceptions import BadRequest + + cursor = mock.Mock() + mocked_client = mocked_client_class.return_value + mocked_client.query.side_effect = BadRequest( + """ + POST https://bigquery.googleapis.com/bigquery/v2/projects/xxx/jobs? + prettyPrint=false: Table name "birth_names" missing dataset while no def + ault dataset is set in the request. + + (job ID: xxx) + + -----Query Job SQL Follows----- + + | . | . | + 1:DROP TABLE birth_names + | . | . | + """ + ) + engine = mock.Mock() + sql = "DROP TABLE birth_names" + with self.assertRaises(BadRequest): + BigQueryEngineSpec.estimate_statement_cost(sql, cursor, engine) + + def test_query_cost_formatter_example_costs(self): + raw_cost = [ + {"Total bytes processed": 123}, + {"Total bytes processed": 1024}, + {"Total bytes processed": 1024 ** 2 + 1024 * 512,}, + {"Total bytes processed": 1024 ** 3 * 100,}, + {"Total bytes processed": 1024 ** 4 * 1000,}, + ] + result = BigQueryEngineSpec.query_cost_formatter(raw_cost) + self.assertEqual( + result, + [ + {"Total bytes processed": "123 B"}, + {"Total bytes processed": "1 KiB"}, + {"Total bytes processed": "1 MiB",}, + {"Total bytes processed": "100 GiB",}, + {"Total bytes processed": "1000 TiB",}, + ], + ) diff --git a/tests/integration_tests/db_engine_specs/postgres_tests.py b/tests/integration_tests/db_engine_specs/postgres_tests.py index dcf5310fecac5..c600d2b435eab 100644 --- a/tests/integration_tests/db_engine_specs/postgres_tests.py +++ b/tests/integration_tests/db_engine_specs/postgres_tests.py @@ -176,8 +176,9 @@ def test_estimate_statement_cost_select_star(self): cursor.fetchone.return_value = ( "Seq Scan on birth_names (cost=0.00..1537.91 rows=75691 width=46)", ) + engine = mock.Mock() sql = "SELECT * FROM birth_names" - results = PostgresEngineSpec.estimate_statement_cost(sql, cursor) + results = PostgresEngineSpec.estimate_statement_cost(sql, cursor, engine) self.assertEqual( results, {"Start-up cost": 0.00, "Total cost": 1537.91,}, ) @@ -196,9 +197,10 @@ def test_estimate_statement_invalid_syntax(self): ^ """ ) + engine = mock.Mock() sql = "DROP TABLE birth_names" with self.assertRaises(errors.SyntaxError): - PostgresEngineSpec.estimate_statement_cost(sql, cursor) + PostgresEngineSpec.estimate_statement_cost(sql, cursor, engine) def test_query_cost_formatter_example_costs(self): """ diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py index 5833c6bdcbfcb..3579244215b7b 100644 --- a/tests/integration_tests/db_engine_specs/presto_tests.py +++ b/tests/integration_tests/db_engine_specs/presto_tests.py @@ -524,7 +524,7 @@ def test_query_cost_formatter(self): expected = [ { "Output count": "904 M rows", - "Output size": "354 GB", + "Output size": "329 GiB", "CPU cost": "354 G", "Max memory": "0 B", "Network cost": "354 G", @@ -795,17 +795,19 @@ def test_estimate_statement_cost(self): mock_cursor.fetchone.return_value = [ '{"a": "b"}', ] + mock_engine = mock.Mock() result = PrestoEngineSpec.estimate_statement_cost( - "SELECT * FROM brth_names", mock_cursor + "SELECT * FROM brth_names", mock_cursor, mock_engine ) assert result == estimate_json def test_estimate_statement_cost_invalid_syntax(self): mock_cursor = mock.MagicMock() mock_cursor.execute.side_effect = Exception() + mock_engine = mock.Mock() with self.assertRaises(Exception): PrestoEngineSpec.estimate_statement_cost( - "DROP TABLE brth_names", mock_cursor + "DROP TABLE brth_names", mock_cursor, mock_engine ) def test_get_all_datasource_names(self): diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index 4dc27c0928f99..93fbaa2eb1512 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -17,6 +17,7 @@ # pylint: disable=unused-argument, import-outside-toplevel, protected-access from textwrap import dedent +from typing import Any, Optional import pytest from flask.ctx import AppContext @@ -99,3 +100,41 @@ def test_cte_query_parsing( actual = BaseEngineSpec.get_cte_query(original) assert actual == expected + + +@pytest.mark.parametrize( + "value,suffix,category,expected", + [ + ("str", "", None, "str"), + (0, "", None, "0"), + (100, "", None, "100"), + (1000, "", None, "1 K"), + (10000, "", None, "10 K"), + (123, " rows", None, "123 rows"), + (1234, " rows", None, "1 K rows"), + (1999, " rows", None, "1 K rows"), + (2000, " rows", None, "2 K rows"), + (123, "", "bytes", "123 B"), + (1024, "", "bytes", "1 KiB"), + (1024 ** 2, "", "bytes", "1 MiB"), + (1000 ** 2, "J", None, "1 MJ"), + (1024 ** 3, "", "bytes", "1 GiB"), + (1000 ** 3, "W", None, "1 GW"), + (1024 ** 8, "", "bytes", "1 YiB"), + (1000 ** 8, "m", None, "1 Ym"), + # Yottabyte is the largest unit, but larger values can be handled + (1024 ** 9, "", "bytes", "1024 YiB"), + (1000 ** 9, "m", None, "1000 Ym"), + ], +) +def test_humanize( + app_context: AppContext, + value: Any, + suffix: str, + category: Optional[str], + expected: str, +) -> None: + from superset.db_engine_specs.base import BaseEngineSpec + + actual = BaseEngineSpec._humanize(value, suffix, category) + assert actual == expected