Skip to content

Commit

Permalink
fix(bigquery): add close() method to client for releasing open sockets (
Browse files Browse the repository at this point in the history
#9894)

* Add close() method to Client

* Add psutil as an extra test dependency

* Fix open sockets leak in IPython magics

* Move psutil test dependency to noxfile

* Wrap entire cell magic into try-finally block

A single common cleanup point at the end makes it much less likely
to accidentally re-introduce an open socket leak.
  • Loading branch information
plamut authored and tswast committed Nov 27, 2019
1 parent b7ba918 commit 9360057
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 69 deletions.
12 changes: 12 additions & 0 deletions bigquery/google/cloud/bigquery/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,18 @@ def location(self):
"""Default location for jobs / datasets / tables."""
return self._location

def close(self):
"""Close the underlying transport objects, releasing system resources.
.. note::
The client instance can be used for making additional requests even
after closing, in which case the underlying connections are
automatically re-created.
"""
self._http._auth_request.session.close()
self._http.close()

def get_service_account_email(self, project=None):
"""Get the email address of the project's BigQuery service account
Expand Down
160 changes: 92 additions & 68 deletions bigquery/google/cloud/bigquery/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@

import re
import ast
import functools
import sys
import time
from concurrent import futures
Expand Down Expand Up @@ -494,86 +495,91 @@ def _cell_magic(line, query):
args.use_bqstorage_api or context.use_bqstorage_api, context.credentials
)

if args.max_results:
max_results = int(args.max_results)
else:
max_results = None
close_transports = functools.partial(_close_transports, client, bqstorage_client)

query = query.strip()
try:
if args.max_results:
max_results = int(args.max_results)
else:
max_results = None

query = query.strip()

# Any query that does not contain whitespace (aside from leading and trailing whitespace)
# is assumed to be a table id
if not re.search(r"\s", query):
try:
rows = client.list_rows(query, max_results=max_results)
except Exception as ex:
_handle_error(ex, args.destination_var)
return

result = rows.to_dataframe(bqstorage_client=bqstorage_client)
if args.destination_var:
IPython.get_ipython().push({args.destination_var: result})
return
else:
return result

job_config = bigquery.job.QueryJobConfig()
job_config.query_parameters = params
job_config.use_legacy_sql = args.use_legacy_sql
job_config.dry_run = args.dry_run

if args.destination_table:
split = args.destination_table.split(".")
if len(split) != 2:
raise ValueError(
"--destination_table should be in a <dataset_id>.<table_id> format."
)
dataset_id, table_id = split
job_config.allow_large_results = True
dataset_ref = client.dataset(dataset_id)
destination_table_ref = dataset_ref.table(table_id)
job_config.destination = destination_table_ref
job_config.create_disposition = "CREATE_IF_NEEDED"
job_config.write_disposition = "WRITE_TRUNCATE"
_create_dataset_if_necessary(client, dataset_id)

if args.maximum_bytes_billed == "None":
job_config.maximum_bytes_billed = 0
elif args.maximum_bytes_billed is not None:
value = int(args.maximum_bytes_billed)
job_config.maximum_bytes_billed = value

# Any query that does not contain whitespace (aside from leading and trailing whitespace)
# is assumed to be a table id
if not re.search(r"\s", query):
try:
rows = client.list_rows(query, max_results=max_results)
query_job = _run_query(client, query, job_config=job_config)
except Exception as ex:
_handle_error(ex, args.destination_var)
return

result = rows.to_dataframe(bqstorage_client=bqstorage_client)
if args.destination_var:
IPython.get_ipython().push({args.destination_var: result})
return
else:
return result

job_config = bigquery.job.QueryJobConfig()
job_config.query_parameters = params
job_config.use_legacy_sql = args.use_legacy_sql
job_config.dry_run = args.dry_run
if not args.verbose:
display.clear_output()

if args.destination_table:
split = args.destination_table.split(".")
if len(split) != 2:
raise ValueError(
"--destination_table should be in a <dataset_id>.<table_id> format."
if args.dry_run and args.destination_var:
IPython.get_ipython().push({args.destination_var: query_job})
return
elif args.dry_run:
print(
"Query validated. This query will process {} bytes.".format(
query_job.total_bytes_processed
)
)
dataset_id, table_id = split
job_config.allow_large_results = True
dataset_ref = client.dataset(dataset_id)
destination_table_ref = dataset_ref.table(table_id)
job_config.destination = destination_table_ref
job_config.create_disposition = "CREATE_IF_NEEDED"
job_config.write_disposition = "WRITE_TRUNCATE"
_create_dataset_if_necessary(client, dataset_id)

if args.maximum_bytes_billed == "None":
job_config.maximum_bytes_billed = 0
elif args.maximum_bytes_billed is not None:
value = int(args.maximum_bytes_billed)
job_config.maximum_bytes_billed = value

try:
query_job = _run_query(client, query, job_config=job_config)
except Exception as ex:
_handle_error(ex, args.destination_var)
return

if not args.verbose:
display.clear_output()
return query_job

if args.dry_run and args.destination_var:
IPython.get_ipython().push({args.destination_var: query_job})
return
elif args.dry_run:
print(
"Query validated. This query will process {} bytes.".format(
query_job.total_bytes_processed
if max_results:
result = query_job.result(max_results=max_results).to_dataframe(
bqstorage_client=bqstorage_client
)
)
return query_job

if max_results:
result = query_job.result(max_results=max_results).to_dataframe(
bqstorage_client=bqstorage_client
)
else:
result = query_job.to_dataframe(bqstorage_client=bqstorage_client)
else:
result = query_job.to_dataframe(bqstorage_client=bqstorage_client)

if args.destination_var:
IPython.get_ipython().push({args.destination_var: result})
else:
return result
if args.destination_var:
IPython.get_ipython().push({args.destination_var: result})
else:
return result
finally:
close_transports()


def _make_bqstorage_client(use_bqstorage_api, credentials):
Expand Down Expand Up @@ -601,3 +607,21 @@ def _make_bqstorage_client(use_bqstorage_api, credentials):
credentials=credentials,
client_info=gapic_client_info.ClientInfo(user_agent=IPYTHON_USER_AGENT),
)


def _close_transports(client, bqstorage_client):
"""Close the given clients' underlying transport channels.
Closing the transport is needed to release system resources, namely open
sockets.
Args:
client (:class:`~google.cloud.bigquery.client.Client`):
bqstorage_client
(Optional[:class:`~google.cloud.bigquery_storage_v1beta1.BigQueryStorageClient`]):
A client for the BigQuery Storage API.
"""
client.close()
if bqstorage_client is not None:
bqstorage_client.transport.channel.close()
2 changes: 1 addition & 1 deletion bigquery/noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def system(session):
session.install("--pre", "grpcio")

# Install all test dependencies, then install local packages in place.
session.install("mock", "pytest")
session.install("mock", "pytest", "psutil")
for local_dep in LOCAL_DEPS:
session.install("-e", local_dep)
session.install("-e", os.path.join("..", "storage"))
Expand Down
28 changes: 28 additions & 0 deletions bigquery/tests/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import re

import six
import psutil
import pytest
import pytz

Expand Down Expand Up @@ -203,6 +204,27 @@ def _create_bucket(self, bucket_name, location=None):

return bucket

def test_close_releases_open_sockets(self):
current_process = psutil.Process()
conn_count_start = len(current_process.connections())

client = Config.CLIENT
client.query(
"""
SELECT
source_year AS year, COUNT(is_male) AS birth_count
FROM `bigquery-public-data.samples.natality`
GROUP BY year
ORDER BY year DESC
LIMIT 15
"""
)

client.close()

conn_count_end = len(current_process.connections())
self.assertEqual(conn_count_end, conn_count_start)

def test_create_dataset(self):
DATASET_ID = _make_dataset_id("create_dataset")
dataset = self.temp_dataset(DATASET_ID)
Expand Down Expand Up @@ -2417,6 +2439,9 @@ def temp_dataset(self, dataset_id, location=None):
@pytest.mark.usefixtures("ipython_interactive")
def test_bigquery_magic():
ip = IPython.get_ipython()
current_process = psutil.Process()
conn_count_start = len(current_process.connections())

ip.extension_manager.load_extension("google.cloud.bigquery")
sql = """
SELECT
Expand All @@ -2432,6 +2457,8 @@ def test_bigquery_magic():
with io.capture_output() as captured:
result = ip.run_cell_magic("bigquery", "", sql)

conn_count_end = len(current_process.connections())

lines = re.split("\n|\r", captured.stdout)
# Removes blanks & terminal code (result of display clearing)
updates = list(filter(lambda x: bool(x) and x != "\x1b[2K", lines))
Expand All @@ -2441,6 +2468,7 @@ def test_bigquery_magic():
assert isinstance(result, pandas.DataFrame)
assert len(result) == 10 # verify row count
assert list(result) == ["url", "view_count"] # verify column names
assert conn_count_end == conn_count_start # system resources are released


def _job_done(instance):
Expand Down
11 changes: 11 additions & 0 deletions bigquery/tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1398,6 +1398,17 @@ def test_create_table_alreadyexists_w_exists_ok_true(self):
]
)

def test_close(self):
creds = _make_credentials()
http = mock.Mock()
http._auth_request.session = mock.Mock()
client = self._make_one(project=self.PROJECT, credentials=creds, _http=http)

client.close()

http.close.assert_called_once()
http._auth_request.session.close.assert_called_once()

def test_get_model(self):
path = "projects/%s/datasets/%s/models/%s" % (
self.PROJECT,
Expand Down
67 changes: 67 additions & 0 deletions bigquery/tests/unit/test_magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,7 @@ def test_bigquery_magic_with_bqstorage_from_argument(monkeypatch):
bqstorage_instance_mock = mock.create_autospec(
bigquery_storage_v1beta1.BigQueryStorageClient, instance=True
)
bqstorage_instance_mock.transport = mock.Mock()
bqstorage_mock.return_value = bqstorage_instance_mock
bqstorage_client_patch = mock.patch(
"google.cloud.bigquery_storage_v1beta1.BigQueryStorageClient", bqstorage_mock
Expand Down Expand Up @@ -601,6 +602,7 @@ def test_bigquery_magic_with_bqstorage_from_context(monkeypatch):
bqstorage_instance_mock = mock.create_autospec(
bigquery_storage_v1beta1.BigQueryStorageClient, instance=True
)
bqstorage_instance_mock.transport = mock.Mock()
bqstorage_mock.return_value = bqstorage_instance_mock
bqstorage_client_patch = mock.patch(
"google.cloud.bigquery_storage_v1beta1.BigQueryStorageClient", bqstorage_mock
Expand Down Expand Up @@ -728,6 +730,41 @@ def test_bigquery_magic_w_max_results_valid_calls_queryjob_result():
query_job_mock.result.assert_called_with(max_results=5)


@pytest.mark.usefixtures("ipython_interactive")
def test_bigquery_magic_w_max_results_query_job_results_fails():
ip = IPython.get_ipython()
ip.extension_manager.load_extension("google.cloud.bigquery")
magics.context._project = None

credentials_mock = mock.create_autospec(
google.auth.credentials.Credentials, instance=True
)
default_patch = mock.patch(
"google.auth.default", return_value=(credentials_mock, "general-project")
)
client_query_patch = mock.patch(
"google.cloud.bigquery.client.Client.query", autospec=True
)
close_transports_patch = mock.patch(
"google.cloud.bigquery.magics._close_transports", autospec=True,
)

sql = "SELECT 17 AS num"

query_job_mock = mock.create_autospec(
google.cloud.bigquery.job.QueryJob, instance=True
)
query_job_mock.result.side_effect = [[], OSError]

with pytest.raises(
OSError
), client_query_patch as client_query_mock, default_patch, close_transports_patch as close_transports:
client_query_mock.return_value = query_job_mock
ip.run_cell_magic("bigquery", "--max_results=5", sql)

assert close_transports.called


def test_bigquery_magic_w_table_id_invalid():
ip = IPython.get_ipython()
ip.extension_manager.load_extension("google.cloud.bigquery")
Expand Down Expand Up @@ -820,6 +857,7 @@ def test_bigquery_magic_w_table_id_and_bqstorage_client():
bqstorage_instance_mock = mock.create_autospec(
bigquery_storage_v1beta1.BigQueryStorageClient, instance=True
)
bqstorage_instance_mock.transport = mock.Mock()
bqstorage_mock.return_value = bqstorage_instance_mock
bqstorage_client_patch = mock.patch(
"google.cloud.bigquery_storage_v1beta1.BigQueryStorageClient", bqstorage_mock
Expand Down Expand Up @@ -1290,3 +1328,32 @@ def test_bigquery_magic_w_destination_table():
assert job_config_used.write_disposition == "WRITE_TRUNCATE"
assert job_config_used.destination.dataset_id == "dataset_id"
assert job_config_used.destination.table_id == "table_id"


@pytest.mark.usefixtures("ipython_interactive")
def test_bigquery_magic_create_dataset_fails():
ip = IPython.get_ipython()
ip.extension_manager.load_extension("google.cloud.bigquery")
magics.context.credentials = mock.create_autospec(
google.auth.credentials.Credentials, instance=True
)

create_dataset_if_necessary_patch = mock.patch(
"google.cloud.bigquery.magics._create_dataset_if_necessary",
autospec=True,
side_effect=OSError,
)
close_transports_patch = mock.patch(
"google.cloud.bigquery.magics._close_transports", autospec=True,
)

with pytest.raises(
OSError
), create_dataset_if_necessary_patch, close_transports_patch as close_transports:
ip.run_cell_magic(
"bigquery",
"--destination_table dataset_id.table_id",
"SELECT foo FROM WHERE LIMIT bar",
)

assert close_transports.called

0 comments on commit 9360057

Please sign in to comment.