Skip to content

Commit

Permalink
Merge pull request #1452 from criccomini/improve-gcp-hooks
Browse files Browse the repository at this point in the history
AIRFLOW-16: Update Google cloud hooks to use new Google cloud platfor…
  • Loading branch information
criccomini committed Apr 29, 2016
2 parents 86e3957 + bfdd1ca commit f657c16
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 49 deletions.
25 changes: 11 additions & 14 deletions airflow/contrib/hooks/bigquery_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
from airflow.contrib.hooks.gc_base_hook import GoogleCloudBaseHook
from airflow.hooks.dbapi_hook import DbApiHook
from apiclient.discovery import build
from pandas.io.gbq import GbqConnector, _parse_data as gbq_parse_data
from pandas.io.gbq import GbqConnector, \
_parse_data as gbq_parse_data, \
_check_google_client_version as gbq_check_google_client_version, \
_test_google_api_imports as gbq_test_google_api_imports
from pandas.tools.merge import concat

logging.getLogger("bigquery").setLevel(logging.INFO)
Expand All @@ -48,15 +51,9 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook):
conn_name_attr = 'bigquery_conn_id'

def __init__(self,
scope='https://www.googleapis.com/auth/bigquery',
bigquery_conn_id='bigquery_default',
delegate_to=None):
"""
:param scope: The scope of the hook.
:type scope: string
"""
super(BigQueryHook, self).__init__(
scope=scope,
conn_id=bigquery_conn_id,
delegate_to=delegate_to)

Expand All @@ -65,8 +62,7 @@ def get_conn(self):
Returns a BigQuery PEP 249 connection object.
"""
service = self.get_service()
connection_extras = self._extras_dejson()
project = connection_extras['project']
project = self._get_field('project')
return BigQueryConnection(service=service, project_id=project)

def get_service(self):
Expand Down Expand Up @@ -97,10 +93,9 @@ def get_pandas_df(self, bql, parameters=None):
:type bql: string
"""
service = self.get_service()
connection_extras = self._extras_dejson()
project = connection_extras['project']
project = self._get_field('project')
connector = BigQueryPandasConnector(project, service)
schema, pages = connector.run_query(bql, verbose=False)
schema, pages = connector.run_query(bql)
dataframe_list = []

while len(pages) > 0:
Expand All @@ -121,11 +116,13 @@ class BigQueryPandasConnector(GbqConnector):
without forcing a three legged OAuth connection. Instead, we can inject
service account credentials into the binding.
"""
def __init__(self, project_id, service, reauth=False):
self.test_google_api_imports()
def __init__(self, project_id, service, reauth=False, verbose=False):
gbq_check_google_client_version()
gbq_test_google_api_imports()
self.project_id = project_id
self.reauth = reauth
self.service = service
self.verbose = verbose


class BigQueryConnection(object):
Expand Down
10 changes: 2 additions & 8 deletions airflow/contrib/hooks/datastore_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,12 @@ class DatastoreHook(GoogleCloudBaseHook):
simultaniously, you will need to create a hook per thread.
"""

conn_name_attr = 'datastore_conn_id'

def __init__(self,
scope=None,
datastore_conn_id='google_cloud_datastore_default',
delegate_to=None):
scope = scope or [
'https://www.googleapis.com/auth/datastore',
'https://www.googleapis.com/auth/userinfo.email']
super(DatastoreHook, self).__init__(scope, datastore_conn_id, delegate_to)
super(DatastoreHook, self).__init__(datastore_conn_id, delegate_to)
# datasetId is the same as the project name
self.dataset_id = self._extras_dejson().get('project')
self.dataset_id = self._get_field('project')
self.connection = self.get_conn()

def get_conn(self):
Expand Down
28 changes: 15 additions & 13 deletions airflow/contrib/hooks/gc_base_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,30 +31,26 @@ class GoogleCloudBaseHook(BaseHook):
The class also contains some miscellaneous helper functions.
"""

def __init__(self, scope, conn_id, delegate_to=None):
def __init__(self, conn_id, delegate_to=None):
"""
:param scope: The scope of the hook.
:type scope: string or an iterable of strings.
:param conn_id: The connection ID to use when fetching connection info.
:type conn_id: string
:param delegate_to: The account to impersonate, if any.
For this to work, the service account making the request must have domain-wide delegation enabled.
:type delegate_to: string
"""
self.scope = scope
self.conn_id = conn_id
self.delegate_to = delegate_to
self.extras = self.get_connection(conn_id).extra_dejson

def _authorize(self):
"""
Returns an authorized HTTP object to be used to build a Google cloud
service hook connection.
"""
connection_info = self.get_connection(self.conn_id)
connection_extras = connection_info.extra_dejson
service_account = connection_extras.get('service_account', False)
key_path = connection_extras.get('key_path', False)
service_account = self._get_field('service_account', False)
key_path = self._get_field('key_path', False)
scope = self._get_field('scope', False)

kwargs = {}
if self.delegate_to:
Expand All @@ -77,9 +73,15 @@ def _authorize(self):
http = httplib2.Http()
return credentials.authorize(http)

def _extras_dejson(self):
def _get_field(self, f, default=None):
"""
A little helper method that returns the JSON-deserialized extras in a
single call.
Fetches a field from extras, and returns it. This is some Airflow
magic. The google_cloud_platform hook type adds custom UI elements
to the hook page, which allow admins to specify service_account,
key_path, etc. They get formatted as shown below.
"""
return self.get_connection(self.conn_id).extra_dejson
long_f = 'extra__google_cloud_platform__{}'.format(f)
if long_f in self.extras:
return self.extras[long_f]
else:
return default
9 changes: 1 addition & 8 deletions airflow/contrib/hooks/gcs_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,11 @@ class GoogleCloudStorageHook(GoogleCloudBaseHook):
running Airflow, you can exclude the service_account and key_path
parameters.
"""
conn_name_attr = 'google_cloud_storage_conn_id'

def __init__(self,
scope='https://www.googleapis.com/auth/devstorage.read_only',
google_cloud_storage_conn_id='google_cloud_storage_default',
delegate_to=None):
"""
:param scope: The scope of the hook (read only, read write, etc). See:
https://cloud.google.com/storage/docs/authentication?hl=en#oauth-scopes
:type scope: string
"""
super(GoogleCloudStorageHook, self).__init__(scope, google_cloud_storage_conn_id, delegate_to)
super(GoogleCloudStorageHook, self).__init__(google_cloud_storage_conn_id, delegate_to)

def get_conn(self):
"""
Expand Down
1 change: 0 additions & 1 deletion airflow/contrib/operators/mysql_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ def _upload_to_gcs(self, files_to_upload):
Google cloud storage.
"""
hook = GoogleCloudStorageHook(google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
scope='https://www.googleapis.com/auth/devstorage.read_write',
delegate_to=self.delegate_to)
for object, tmp_file_handle in files_to_upload.items():
hook.upload(self.bucket, object, tmp_file_handle.name, 'application/json')
Expand Down
2 changes: 1 addition & 1 deletion airflow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def get_hook(self):
try:
if self.conn_type == 'mysql':
return hooks.MySqlHook(mysql_conn_id=self.conn_id)
elif self.conn_type == 'bigquery':
elif self.conn_type == 'google_cloud_platform':
return contrib_hooks.BigQueryHook(bigquery_conn_id=self.conn_id)
elif self.conn_type == 'postgres':
return hooks.PostgresHook(postgres_conn_id=self.conn_id)
Expand Down
1 change: 0 additions & 1 deletion airflow/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ def __init__(self):
try:
from airflow.contrib.hooks import GoogleCloudStorageHook
self.hook = GoogleCloudStorageHook(
scope='https://www.googleapis.com/auth/devstorage.read_write',
google_cloud_storage_conn_id=remote_conn_id)
except:
logging.error(
Expand Down
3 changes: 0 additions & 3 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2192,10 +2192,7 @@ class ConnectionModelView(wwwutils.SuperUserMixin, AirflowModelView):
}
form_choices = {
'conn_type': [
('bigquery', 'BigQuery',),
('datastore', 'Google Datastore'),
('ftp', 'FTP',),
('google_cloud_storage', 'Google Cloud Storage'),
('google_cloud_platform', 'Google Cloud Platform'),
('hdfs', 'HDFS',),
('http', 'HTTP',),
Expand Down

0 comments on commit f657c16

Please sign in to comment.