From 2c9d69cc1ff74591e53ef7a8155543dace34670f Mon Sep 17 00:00:00 2001 From: Yasser Elsayed Date: Fri, 3 Mar 2017 19:08:46 -0800 Subject: [PATCH] Fix query constructor to take dictionaries or lists of dependencies (#270) This PR fixes the Query constructor to take dependencies (subqueries, udfs, or external datasources) in the form of either a list or dictionary. If a list is given, it should be a list of strings, where items are names of objects defined in the environment object (user can build their own object or pass the notebook_environment() as parameter). If a dictionary is provided, it should be a mapping between dependency names and objects. All of the below examples now work: q1 = bq.Query('sql') q2 = bq.Query('sql', subqueries=['q1'], env={'q1': q1}) q2 = bq.Query('sql', subqueries=['q1'], env=utils.commands.notebook_environment()) q2 = bq.Query('sql', subqueries={'q1': q1}) --- google/datalab/bigquery/_api.py | 7 ++- google/datalab/bigquery/_query.py | 83 +++++++++++++++++++------------ tests/bigquery/query_tests.py | 12 +++-- tests/kernel/bigquery_tests.py | 8 +-- 4 files changed, 67 insertions(+), 43 deletions(-) diff --git a/google/datalab/bigquery/_api.py b/google/datalab/bigquery/_api.py index 2ff8405f9..478a630f1 100644 --- a/google/datalab/bigquery/_api.py +++ b/google/datalab/bigquery/_api.py @@ -150,7 +150,7 @@ def jobs_insert_query(self, sql, table_name=None, append=False, priority, more expensive). allow_large_results: whether to allow large results (slower with some restrictions but can handle big jobs). - table_definitions: a dictionary of JSON external table definitions for any external tables + table_definitions: a dictionary of ExternalDataSource names and objects for any external tables referenced in the query. query_params: a dictionary containing query parameter types and values, passed to BigQuery. Returns: @@ -177,7 +177,10 @@ def jobs_insert_query(self, sql, table_name=None, append=False, query_config = data['configuration']['query'] if table_definitions: - query_config['tableDefinitions'] = table_definitions + expanded_definitions = {} + for td in table_definitions: + expanded_definitions[td] = table_definitions[td]._to_query_json() + query_config['tableDefinitions'] = expanded_definitions if table_name: query_config['destinationTable'] = { diff --git a/google/datalab/bigquery/_query.py b/google/datalab/bigquery/_query.py index 9c110a890..434e72344 100644 --- a/google/datalab/bigquery/_query.py +++ b/google/datalab/bigquery/_query.py @@ -41,39 +41,48 @@ def __init__(self, sql, env=None, udfs=None, data_sources=None, subqueries=None) sql: the BigQuery SQL query string to execute env: a dictionary containing objects from the query execution context, used to get references to UDFs, subqueries, and external data sources referenced by the query - udfs: list of UDFs referenced in the SQL. - data_sources: list of external data sources referenced in the SQL. - subqueries: list of subqueries referenced in the SQL + udfs: list of UDFs names referenced in the SQL, or dictionary of names and UDF objects + data_sources: list of external data sources names referenced in the SQL, or dictionary of + names and data source objects + subqueries: list of subqueries names referenced in the SQL, or dictionary of names and + Query objects Raises: Exception if expansion of any variables failed. """ self._sql = sql - self._udfs = udfs - self._subqueries = subqueries - self._env = env - if self._env is None: - self._env = google.datalab.utils.commands.notebook_environment() + self._udfs = {} + self._subqueries = {} self._data_sources = {} + self._env = env or {} + + # Validate given list or dictionary of objects that they are of correct type + # and add them to the target dictionary + def _expand_objects(obj_container, obj_type, target_dict): + for item in obj_container: + # for a list of objects, we should find these objects in the given environment + if isinstance(obj_container, list): + value = self._env.get(item) + if value is None: + raise Exception('Cannot find object %s' % item) + + # for a dictionary of objects, each pair must be a string an object of the expected type + elif isinstance(obj_container, dict): + value = obj_container[item] + if not isinstance(value, obj_type): + raise Exception('Expected type: %s, found: %s.' % (obj_type, type(value))) - def _validate_object(obj, obj_type): - item = self._env.get(obj) - if item is None: - raise Exception('Cannot find object %s.' % obj) - if not isinstance(item, obj_type): - raise Exception('Expected type: %s, found: %s.' % (obj_type, type(item))) - - # Validate subqueries, UDFs, and datasources when adding them to query - if self._subqueries: - for subquery in self._subqueries: - _validate_object(subquery, Query) - if self._udfs: - for udf in self._udfs: - _validate_object(udf, _udf.UDF) + else: + raise Exception('Unexpected container for type %s. Expected a list or dictionary' % obj_type) + + target_dict[item] = value + + if subqueries: + _expand_objects(subqueries, Query, self._subqueries) + if udfs: + _expand_objects(udfs, _udf.UDF, self._udfs) if data_sources: - for ds in data_sources: - _validate_object(ds, _external_data_source.ExternalDataSource) - self._data_sources[ds] = self._env[ds]._to_query_json() + _expand_objects(data_sources, _external_data_source.ExternalDataSource, self._data_sources) if len(self._data_sources) > 1: raise Exception('Only one temporary external datasource is supported in queries.') @@ -115,8 +124,8 @@ def _expanded_sql(self, sampling=None): The expanded SQL string of this object """ - udfs = set() - subqueries = set() + udfs = {} + subqueries = {} expanded_sql = '' def _recurse_subqueries(query): @@ -125,21 +134,21 @@ def _recurse_subqueries(query): if query._subqueries: subqueries.update(query._subqueries) if query._udfs: - udfs.update(set(query._udfs)) + udfs.update(query._udfs) if query._subqueries: for subquery in query._subqueries: - _recurse_subqueries(self._env[subquery]) + _recurse_subqueries(query._subqueries[subquery]) subqueries_sql = udfs_sql = '' _recurse_subqueries(self) if udfs: - expanded_sql += '\n'.join([self._env[udf]._expanded_sql() for udf in udfs]) + expanded_sql += '\n'.join([udfs[udf]._expanded_sql() for udf in udfs]) expanded_sql += '\n' if subqueries: expanded_sql += 'WITH ' + \ - ',\n'.join(['%s AS (%s)' % (sq, self._env[sq]._sql) for sq in subqueries]) + ',\n'.join(['%s AS (%s)' % (sq, subqueries[sq]._sql) for sq in subqueries]) expanded_sql += '\n' expanded_sql += sampling(self._sql) if sampling else self._sql @@ -169,9 +178,19 @@ def sql(self): @property def udfs(self): - """ Get the UDFs referenced by the query.""" + """ Get a dictionary of UDFs referenced by the query.""" return self._udfs + @property + def subqueries(self): + """ Get a dictionary of subqueries referenced by the query.""" + return self._subqueries + + @property + def data_sources(self): + """ Get a dictionary of external data sources referenced by the query.""" + return self._data_sources + def dry_run(self, context=None, query_params=None): """Dry run a query, to check the validity of the query and return some useful statistics. diff --git a/tests/bigquery/query_tests.py b/tests/bigquery/query_tests.py index 3e16a4e11..19e2922da 100644 --- a/tests/bigquery/query_tests.py +++ b/tests/bigquery/query_tests.py @@ -27,18 +27,20 @@ def test_parameter_validation(self): sql = 'SELECT * FROM table' with self.assertRaises(Exception) as error: q = TestCases._create_query(sql, subqueries=['subquery']) - env = {'subquery': TestCases._create_query()} + sq = TestCases._create_query() + env = {'subquery': sq} q = TestCases._create_query(sql, env=env, subqueries=['subquery']) self.assertIsNotNone(q) - self.assertEqual(q._subqueries, ['subquery']) + self.assertEqual(q._subqueries, {'subquery': sq}) self.assertEqual(q._sql, sql) with self.assertRaises(Exception) as error: q = TestCases._create_query(sql, udfs=['udf']) - env = {'udf': TestCases._create_udf('test_udf', 'code', 'TYPE')} - q = TestCases._create_query(sql, env=env, udfs=['udf']) + udf = TestCases._create_udf('test_udf', 'code', 'TYPE') + env = {'testudf': udf} + q = TestCases._create_query(sql, env=env, udfs=['testudf']) self.assertIsNotNone(q) - self.assertEqual(q._udfs, ['udf']) + self.assertEqual(q._udfs, {'testudf': udf}) self.assertEqual(q._sql, sql) @mock.patch('google.datalab.bigquery._api.Api.tabledata_list') diff --git a/tests/kernel/bigquery_tests.py b/tests/kernel/bigquery_tests.py index 8c58104c6..7025ee4dc 100644 --- a/tests/kernel/bigquery_tests.py +++ b/tests/kernel/bigquery_tests.py @@ -75,8 +75,8 @@ def test_query_cell(self, mock_default_context, mock_notebook_environment): 'datasources': None, 'subqueries': None}, q1_body) q1 = env['q1'] self.assertIsNotNone(q1) - self.assertIsNone(q1._udfs) - self.assertIsNone(q1._subqueries) + self.assertEqual(q1.udfs, {}) + self.assertEqual(q1.subqueries, {}) self.assertEqual(q1_body, q1._sql) self.assertEqual(q1_body, q1.sql) @@ -86,8 +86,8 @@ def test_query_cell(self, mock_default_context, mock_notebook_environment): 'datasources': None, 'subqueries': ['q1']}, q2_body) q2 = env['q2'] self.assertIsNotNone(q2) - self.assertIsNone(q2._udfs) - self.assertEqual(['q1'], q2._subqueries) + self.assertEqual(q2.udfs, {}) + self.assertEqual({'q1': q1}, q2.subqueries) expected_sql = 'WITH q1 AS (%s)\n%s' % (q1_body, q2_body) self.assertEqual(q2_body, q2._sql) self.assertEqual(expected_sql, q2.sql)