Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(bigquery): add support for query cost estimate #18694

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,12 +1025,15 @@ def select_star( # pylint: disable=too-many-arguments,too-many-locals
return sql

@classmethod
def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
def estimate_statement_cost(
cls, statement: str, cursor: Any, engine: Engine
) -> Dict[str, Any]:
"""
Generate a SQL query that estimates the cost of a given statement.

:param statement: A single SQL statement
:param cursor: Cursor instance
:param engine: Engine instance
:return: Dictionary with different costs
"""
raise Exception("Database does not support cost estimation")
Expand Down Expand Up @@ -1095,7 +1098,9 @@ def estimate_query_cost(
processed_statement = cls.process_statement(
statement, database, user_name
)
costs.append(cls.estimate_statement_cost(processed_statement, cursor))
costs.append(
cls.estimate_statement_cost(processed_statement, cursor, engine)
)
return costs

@classmethod
Expand Down
54 changes: 54 additions & 0 deletions superset/db_engine_specs/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,60 @@ class BigQueryEngineSpec(BaseEngineSpec):
),
}

@classmethod
def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
return True

@classmethod
def estimate_statement_cost(
cls, statement: str, cursor: Any, engine: Engine
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only way to estimate the cost in advance in BigQuery is to run the query with dry_run, and since this is not possible with only cursor, I add engine as an argument.

Another way to handle bigquery.Client directly is to configure sqlalchemy to pass the dryrun parameter when creating the connection, but this seems to be more complicated...

https://github.com/googleapis/python-bigquery-sqlalchemy#connection-string-parameters

) -> Dict[str, Any]:
try:
# pylint: disable=import-outside-toplevel
from google.cloud import bigquery
from google.oauth2 import service_account
except ImportError as ex:
raise Exception(
"Could not import libraries `google.cloud` or `google.oauth2`, "
"which are required to be installed in your environment in order "
"to estimate cost"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, wouldn't these be necessarily installed if the user has a BigQuery database connected?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right, we can simply use import here. I'll fix this!

) from ex

creds = engine.dialect.credentials_info
credentials = service_account.Credentials.from_service_account_info(creds)
client = bigquery.Client(credentials=credentials)
dry_run_result = client.query(
statement, bigquery.job.QueryJobConfig(dry_run=True)
)

return {
"Total bytes processed": dry_run_result.total_bytes_processed,
}

@classmethod
def query_cost_formatter(
cls, raw_cost: List[Dict[str, Any]]
) -> List[Dict[str, str]]:
def format_bytes_str(raw_bytes: int) -> str:
if not isinstance(raw_bytes, int):
return str(raw_bytes)
units = ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]
index = 0
bytes = float(raw_bytes)
while bytes >= 1024 and index < len(units) - 1:
bytes /= 1024
index += 1

return "{:.1f}".format(bytes) + f" {units[index]}"

return [
{
k: format_bytes_str(v) if k == "Total bytes processed" else str(v)
for k, v in row.items()
}
for row in raw_cost
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this this logic overlaps with the humanize functions in the query_cost_formatter methods in TrinoEngineSpec and PrestoEngineSpec . I wonder if we should move humanize to BaseEngineSpec` so we could remove the duplication?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review!

so we could remove the duplication?

It is possible that it is better to go DRY, but there are some things to consider somewhat.

The intent of this implementation was to be consistent with the official UI provided by BigQuery, both in KiB notation and to the first decimal place.
スクリーンショット 2022-02-14 20 32 57

In particular, the current presto and trino implementations divide by 1000 instead of 1024, which is a problem.
There will be a small difference between the number of predicted bytes in BigQuery and the number of predicted bytes in superset.
I would like to avoid using the current humanize implementation as is, because this would cause confusion for users.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that there are several possible patterns:

a. Based on the humanize implementation, prepare methods to pass prefixes and to_next_prefixes as parameters

It allows for common implementation and same result, but is somewhat complex to implement.

b. Provide two methods, humanize_number and humanize_bytes

The behavior of the byte count display in trino and presto changes slightly.

c. Keep a separate implementation ( or share only between trino and presto )

Which do you think is the best?
For me, any of them is OK and I'll try it.
However, it might be better to work on a separate PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, being consistent with the BQ console definitely makes sense. I remember a discussion about this in the original PR where 1024 vs 1000 was debated: #8172 (comment) While being consistent with the BQ console, it would feel funny to have different units in for BQ vs Presto/Trino. @betodealmeida thoughts?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does anyone have any suggestions?
I think the ability of estimating bytes is important, and the format is relatively unimportant, so I'm not strongly concerned about how to handle it.
I think it would be better to refactor trino and presto to KiB notation, but if there are some reasons not to do so, I'll make the BigQuery implementation as KB notation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can change the humanize function in Presto (it's also duplicated in Trino) to return bytes in 1024 increments. The only reason I did 1000 is because it's also applied to row count and other parameters. This way it's consistent. Ideally we'd have a single function used by Presto, Trino, BigQuery and other engine specs.

(Note that it's also possible to overwrite the formatter function using QUERY_COST_FORMATTERS_BY_ENGINE in the config. We used that at Lyft to show the query cost in dollars, estimated run time, and carbon footprint.)


@classmethod
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
Expand Down
5 changes: 4 additions & 1 deletion superset/db_engine_specs/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from flask_babel import gettext as __
from sqlalchemy.dialects.postgresql import ARRAY, DOUBLE_PRECISION, ENUM, JSON
from sqlalchemy.dialects.postgresql.base import PGInspector
from sqlalchemy.engine.base import Engine
from sqlalchemy.types import String

from superset.db_engine_specs.base import (
Expand Down Expand Up @@ -197,7 +198,9 @@ def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
return True

@classmethod
def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
def estimate_statement_cost(
cls, statement: str, cursor: Any, engine: Engine
) -> Dict[str, Any]:
sql = f"EXPLAIN {statement}"
cursor.execute(sql)

Expand Down
4 changes: 3 additions & 1 deletion superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,9 @@ def select_star( # pylint: disable=too-many-arguments
)

@classmethod
def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
def estimate_statement_cost(
cls, statement: str, cursor: Any, engine: Engine
) -> Dict[str, Any]:
"""
Run a SQL query that estimates the cost of a given statement.

Expand Down
5 changes: 4 additions & 1 deletion superset/db_engine_specs/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import simplejson as json
from flask import current_app
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.url import make_url, URL

from superset.db_engine_specs.base import BaseEngineSpec
Expand Down Expand Up @@ -118,7 +119,9 @@ def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
return True

@classmethod
def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
def estimate_statement_cost(
cls, statement: str, cursor: Any, engine: Engine
) -> Dict[str, Any]:
"""
Run a SQL query that estimates the cost of a given statement.

Expand Down
76 changes: 76 additions & 0 deletions tests/integration_tests/db_engine_specs/bigquery_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,79 @@ def test_calculated_column_in_order_by(self):
}
sql = table.get_query_str(query_obj)
assert "ORDER BY gender_cc ASC" in sql

@mock.patch("google.cloud.bigquery.Client")
@mock.patch(
"google.oauth2.service_account.Credentials.from_service_account_info",
mock.Mock(),
)
def test_estimate_statement_cost_select_star(self, mocked_client_class):
mocked_client = mocked_client_class.return_value
mocked_client.query.return_value = mock.Mock()
mocked_client.query.return_value.total_bytes_processed = 123
cursor = mock.Mock()
engine = mock.Mock()
sql = "SELECT * FROM `some-project.database.table`"
results = BigQueryEngineSpec.estimate_statement_cost(sql, cursor, engine)
mocked_client.query.assert_called_once()
args = mocked_client.query.call_args.args
self.assertEqual(args[0], sql)
self.assertEqual(args[1].dry_run, True)
self.assertEqual(
results, {"Total bytes processed": 123},
)

@mock.patch("google.cloud.bigquery.Client")
@mock.patch(
"google.oauth2.service_account.Credentials.from_service_account_info",
mock.Mock(),
)
def test_estimate_statement_invalid_syntax(self, mocked_client_class):
from google.api_core.exceptions import BadRequest

cursor = mock.Mock()
mocked_client = mocked_client_class.return_value
mocked_client.query.side_effect = BadRequest(
"""
POST https://bigquery.googleapis.com/bigquery/v2/projects/xxx/jobs?
prettyPrint=false: Table name "birth_names" missing dataset while no def
ault dataset is set in the request.

(job ID: xxx)

-----Query Job SQL Follows-----

| . | . |
1:DROP TABLE birth_names
| . | . |
"""
)
engine = mock.Mock()
sql = "DROP TABLE birth_names"
with self.assertRaises(BadRequest):
BigQueryEngineSpec.estimate_statement_cost(sql, cursor, engine)

def test_query_cost_formatter_example_costs(self):
raw_cost = [
{"Total bytes processed": 123, "Some other column": 123,},
{"Total bytes processed": 1024, "Some other column": "abcde",},
{"Total bytes processed": 1024 * 1024 + 1024 * 512,},
{"Total bytes processed": 1024 ** 3,},
{"Total bytes processed": 1024 ** 4,},
{"Total bytes processed": 1024 ** 5,},
{"Total bytes processed": 1024 ** 6,},
]
result = BigQueryEngineSpec.query_cost_formatter(raw_cost)
self.assertEqual(
result,
[
{"Total bytes processed": "123.0 B", "Some other column": "123",},
{"Total bytes processed": "1.0 KiB", "Some other column": "abcde",},
{"Total bytes processed": "1.5 MiB",},
{"Total bytes processed": "1.0 GiB",},
{"Total bytes processed": "1.0 TiB",},
{"Total bytes processed": "1.0 PiB",},
# Petabyte is the largest unit, but larger values can be handled
{"Total bytes processed": "1024.0 PiB",},
],
)
6 changes: 4 additions & 2 deletions tests/integration_tests/db_engine_specs/postgres_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,9 @@ def test_estimate_statement_cost_select_star(self):
cursor.fetchone.return_value = (
"Seq Scan on birth_names (cost=0.00..1537.91 rows=75691 width=46)",
)
engine = mock.Mock()
sql = "SELECT * FROM birth_names"
results = PostgresEngineSpec.estimate_statement_cost(sql, cursor)
results = PostgresEngineSpec.estimate_statement_cost(sql, cursor, engine)
self.assertEqual(
results, {"Start-up cost": 0.00, "Total cost": 1537.91,},
)
Expand All @@ -196,9 +197,10 @@ def test_estimate_statement_invalid_syntax(self):
^
"""
)
engine = mock.Mock()
sql = "DROP TABLE birth_names"
with self.assertRaises(errors.SyntaxError):
PostgresEngineSpec.estimate_statement_cost(sql, cursor)
PostgresEngineSpec.estimate_statement_cost(sql, cursor, engine)

def test_query_cost_formatter_example_costs(self):
"""
Expand Down
6 changes: 4 additions & 2 deletions tests/integration_tests/db_engine_specs/presto_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,17 +795,19 @@ def test_estimate_statement_cost(self):
mock_cursor.fetchone.return_value = [
'{"a": "b"}',
]
mock_engine = mock.Mock()
result = PrestoEngineSpec.estimate_statement_cost(
"SELECT * FROM brth_names", mock_cursor
"SELECT * FROM brth_names", mock_cursor, mock_engine
)
assert result == estimate_json

def test_estimate_statement_cost_invalid_syntax(self):
mock_cursor = mock.MagicMock()
mock_cursor.execute.side_effect = Exception()
mock_engine = mock.Mock()
with self.assertRaises(Exception):
PrestoEngineSpec.estimate_statement_cost(
"DROP TABLE brth_names", mock_cursor
"DROP TABLE brth_names", mock_cursor, mock_engine
)

def test_get_all_datasource_names(self):
Expand Down