Skip to content

Commit

Permalink
Fix query constructor to take dictionaries or lists of dependencies (g…
Browse files Browse the repository at this point in the history
…oogledatalab#270)

This PR fixes the Query constructor to take dependencies (subqueries, udfs, or external datasources) in the form of either a list or dictionary. If a list is given, it should be a list of strings, where items are names of objects defined in the environment object (user can build their own object or pass the notebook_environment() as parameter). If a dictionary is provided, it should be a mapping between dependency names and objects.

All of the below examples now work:

q1 = bq.Query('sql')
q2 = bq.Query('sql', subqueries=['q1'], env={'q1': q1})
q2 = bq.Query('sql', subqueries=['q1'], env=utils.commands.notebook_environment())
q2 = bq.Query('sql', subqueries={'q1': q1})
  • Loading branch information
yebrahim authored Mar 4, 2017
1 parent 5415309 commit 2c9d69c
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 43 deletions.
7 changes: 5 additions & 2 deletions google/datalab/bigquery/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def jobs_insert_query(self, sql, table_name=None, append=False,
priority, more expensive).
allow_large_results: whether to allow large results (slower with some restrictions but
can handle big jobs).
table_definitions: a dictionary of JSON external table definitions for any external tables
table_definitions: a dictionary of ExternalDataSource names and objects for any external tables
referenced in the query.
query_params: a dictionary containing query parameter types and values, passed to BigQuery.
Returns:
Expand All @@ -177,7 +177,10 @@ def jobs_insert_query(self, sql, table_name=None, append=False,
query_config = data['configuration']['query']

if table_definitions:
query_config['tableDefinitions'] = table_definitions
expanded_definitions = {}
for td in table_definitions:
expanded_definitions[td] = table_definitions[td]._to_query_json()
query_config['tableDefinitions'] = expanded_definitions

if table_name:
query_config['destinationTable'] = {
Expand Down
83 changes: 51 additions & 32 deletions google/datalab/bigquery/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,39 +41,48 @@ def __init__(self, sql, env=None, udfs=None, data_sources=None, subqueries=None)
sql: the BigQuery SQL query string to execute
env: a dictionary containing objects from the query execution context, used to get references
to UDFs, subqueries, and external data sources referenced by the query
udfs: list of UDFs referenced in the SQL.
data_sources: list of external data sources referenced in the SQL.
subqueries: list of subqueries referenced in the SQL
udfs: list of UDFs names referenced in the SQL, or dictionary of names and UDF objects
data_sources: list of external data sources names referenced in the SQL, or dictionary of
names and data source objects
subqueries: list of subqueries names referenced in the SQL, or dictionary of names and
Query objects
Raises:
Exception if expansion of any variables failed.
"""
self._sql = sql
self._udfs = udfs
self._subqueries = subqueries
self._env = env
if self._env is None:
self._env = google.datalab.utils.commands.notebook_environment()
self._udfs = {}
self._subqueries = {}
self._data_sources = {}
self._env = env or {}

# Validate given list or dictionary of objects that they are of correct type
# and add them to the target dictionary
def _expand_objects(obj_container, obj_type, target_dict):
for item in obj_container:
# for a list of objects, we should find these objects in the given environment
if isinstance(obj_container, list):
value = self._env.get(item)
if value is None:
raise Exception('Cannot find object %s' % item)

# for a dictionary of objects, each pair must be a string an object of the expected type
elif isinstance(obj_container, dict):
value = obj_container[item]
if not isinstance(value, obj_type):
raise Exception('Expected type: %s, found: %s.' % (obj_type, type(value)))

def _validate_object(obj, obj_type):
item = self._env.get(obj)
if item is None:
raise Exception('Cannot find object %s.' % obj)
if not isinstance(item, obj_type):
raise Exception('Expected type: %s, found: %s.' % (obj_type, type(item)))

# Validate subqueries, UDFs, and datasources when adding them to query
if self._subqueries:
for subquery in self._subqueries:
_validate_object(subquery, Query)
if self._udfs:
for udf in self._udfs:
_validate_object(udf, _udf.UDF)
else:
raise Exception('Unexpected container for type %s. Expected a list or dictionary' % obj_type)

target_dict[item] = value

if subqueries:
_expand_objects(subqueries, Query, self._subqueries)
if udfs:
_expand_objects(udfs, _udf.UDF, self._udfs)
if data_sources:
for ds in data_sources:
_validate_object(ds, _external_data_source.ExternalDataSource)
self._data_sources[ds] = self._env[ds]._to_query_json()
_expand_objects(data_sources, _external_data_source.ExternalDataSource, self._data_sources)

if len(self._data_sources) > 1:
raise Exception('Only one temporary external datasource is supported in queries.')
Expand Down Expand Up @@ -115,8 +124,8 @@ def _expanded_sql(self, sampling=None):
The expanded SQL string of this object
"""

udfs = set()
subqueries = set()
udfs = {}
subqueries = {}
expanded_sql = ''

def _recurse_subqueries(query):
Expand All @@ -125,21 +134,21 @@ def _recurse_subqueries(query):
if query._subqueries:
subqueries.update(query._subqueries)
if query._udfs:
udfs.update(set(query._udfs))
udfs.update(query._udfs)
if query._subqueries:
for subquery in query._subqueries:
_recurse_subqueries(self._env[subquery])
_recurse_subqueries(query._subqueries[subquery])

subqueries_sql = udfs_sql = ''
_recurse_subqueries(self)

if udfs:
expanded_sql += '\n'.join([self._env[udf]._expanded_sql() for udf in udfs])
expanded_sql += '\n'.join([udfs[udf]._expanded_sql() for udf in udfs])
expanded_sql += '\n'

if subqueries:
expanded_sql += 'WITH ' + \
',\n'.join(['%s AS (%s)' % (sq, self._env[sq]._sql) for sq in subqueries])
',\n'.join(['%s AS (%s)' % (sq, subqueries[sq]._sql) for sq in subqueries])
expanded_sql += '\n'

expanded_sql += sampling(self._sql) if sampling else self._sql
Expand Down Expand Up @@ -169,9 +178,19 @@ def sql(self):

@property
def udfs(self):
""" Get the UDFs referenced by the query."""
""" Get a dictionary of UDFs referenced by the query."""
return self._udfs

@property
def subqueries(self):
""" Get a dictionary of subqueries referenced by the query."""
return self._subqueries

@property
def data_sources(self):
""" Get a dictionary of external data sources referenced by the query."""
return self._data_sources

def dry_run(self, context=None, query_params=None):
"""Dry run a query, to check the validity of the query and return some useful statistics.
Expand Down
12 changes: 7 additions & 5 deletions tests/bigquery/query_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,20 @@ def test_parameter_validation(self):
sql = 'SELECT * FROM table'
with self.assertRaises(Exception) as error:
q = TestCases._create_query(sql, subqueries=['subquery'])
env = {'subquery': TestCases._create_query()}
sq = TestCases._create_query()
env = {'subquery': sq}
q = TestCases._create_query(sql, env=env, subqueries=['subquery'])
self.assertIsNotNone(q)
self.assertEqual(q._subqueries, ['subquery'])
self.assertEqual(q._subqueries, {'subquery': sq})
self.assertEqual(q._sql, sql)

with self.assertRaises(Exception) as error:
q = TestCases._create_query(sql, udfs=['udf'])
env = {'udf': TestCases._create_udf('test_udf', 'code', 'TYPE')}
q = TestCases._create_query(sql, env=env, udfs=['udf'])
udf = TestCases._create_udf('test_udf', 'code', 'TYPE')
env = {'testudf': udf}
q = TestCases._create_query(sql, env=env, udfs=['testudf'])
self.assertIsNotNone(q)
self.assertEqual(q._udfs, ['udf'])
self.assertEqual(q._udfs, {'testudf': udf})
self.assertEqual(q._sql, sql)

@mock.patch('google.datalab.bigquery._api.Api.tabledata_list')
Expand Down
8 changes: 4 additions & 4 deletions tests/kernel/bigquery_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def test_query_cell(self, mock_default_context, mock_notebook_environment):
'datasources': None, 'subqueries': None}, q1_body)
q1 = env['q1']
self.assertIsNotNone(q1)
self.assertIsNone(q1._udfs)
self.assertIsNone(q1._subqueries)
self.assertEqual(q1.udfs, {})
self.assertEqual(q1.subqueries, {})
self.assertEqual(q1_body, q1._sql)
self.assertEqual(q1_body, q1.sql)

Expand All @@ -86,8 +86,8 @@ def test_query_cell(self, mock_default_context, mock_notebook_environment):
'datasources': None, 'subqueries': ['q1']}, q2_body)
q2 = env['q2']
self.assertIsNotNone(q2)
self.assertIsNone(q2._udfs)
self.assertEqual(['q1'], q2._subqueries)
self.assertEqual(q2.udfs, {})
self.assertEqual({'q1': q1}, q2.subqueries)
expected_sql = 'WITH q1 AS (%s)\n%s' % (q1_body, q2_body)
self.assertEqual(q2_body, q2._sql)
self.assertEqual(expected_sql, q2.sql)
Expand Down

0 comments on commit 2c9d69c

Please sign in to comment.