Skip to content

Commit 9360057

Browse files
plamuttswast
authored andcommitted
fix(bigquery): add close() method to client for releasing open sockets (#9894)
* Add close() method to Client * Add psutil as an extra test dependency * Fix open sockets leak in IPython magics * Move psutil test dependency to noxfile * Wrap entire cell magic into try-finally block A single common cleanup point at the end makes it much less likely to accidentally re-introduce an open socket leak.
1 parent b7ba918 commit 9360057

File tree

6 files changed

+211
-69
lines changed

6 files changed

+211
-69
lines changed

bigquery/google/cloud/bigquery/client.py

+12
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,18 @@ def location(self):
194194
"""Default location for jobs / datasets / tables."""
195195
return self._location
196196

197+
def close(self):
198+
"""Close the underlying transport objects, releasing system resources.
199+
200+
.. note::
201+
202+
The client instance can be used for making additional requests even
203+
after closing, in which case the underlying connections are
204+
automatically re-created.
205+
"""
206+
self._http._auth_request.session.close()
207+
self._http.close()
208+
197209
def get_service_account_email(self, project=None):
198210
"""Get the email address of the project's BigQuery service account
199211

bigquery/google/cloud/bigquery/magics.py

+92-68
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@
137137

138138
import re
139139
import ast
140+
import functools
140141
import sys
141142
import time
142143
from concurrent import futures
@@ -494,86 +495,91 @@ def _cell_magic(line, query):
494495
args.use_bqstorage_api or context.use_bqstorage_api, context.credentials
495496
)
496497

497-
if args.max_results:
498-
max_results = int(args.max_results)
499-
else:
500-
max_results = None
498+
close_transports = functools.partial(_close_transports, client, bqstorage_client)
501499

502-
query = query.strip()
500+
try:
501+
if args.max_results:
502+
max_results = int(args.max_results)
503+
else:
504+
max_results = None
505+
506+
query = query.strip()
507+
508+
# Any query that does not contain whitespace (aside from leading and trailing whitespace)
509+
# is assumed to be a table id
510+
if not re.search(r"\s", query):
511+
try:
512+
rows = client.list_rows(query, max_results=max_results)
513+
except Exception as ex:
514+
_handle_error(ex, args.destination_var)
515+
return
516+
517+
result = rows.to_dataframe(bqstorage_client=bqstorage_client)
518+
if args.destination_var:
519+
IPython.get_ipython().push({args.destination_var: result})
520+
return
521+
else:
522+
return result
523+
524+
job_config = bigquery.job.QueryJobConfig()
525+
job_config.query_parameters = params
526+
job_config.use_legacy_sql = args.use_legacy_sql
527+
job_config.dry_run = args.dry_run
528+
529+
if args.destination_table:
530+
split = args.destination_table.split(".")
531+
if len(split) != 2:
532+
raise ValueError(
533+
"--destination_table should be in a <dataset_id>.<table_id> format."
534+
)
535+
dataset_id, table_id = split
536+
job_config.allow_large_results = True
537+
dataset_ref = client.dataset(dataset_id)
538+
destination_table_ref = dataset_ref.table(table_id)
539+
job_config.destination = destination_table_ref
540+
job_config.create_disposition = "CREATE_IF_NEEDED"
541+
job_config.write_disposition = "WRITE_TRUNCATE"
542+
_create_dataset_if_necessary(client, dataset_id)
543+
544+
if args.maximum_bytes_billed == "None":
545+
job_config.maximum_bytes_billed = 0
546+
elif args.maximum_bytes_billed is not None:
547+
value = int(args.maximum_bytes_billed)
548+
job_config.maximum_bytes_billed = value
503549

504-
# Any query that does not contain whitespace (aside from leading and trailing whitespace)
505-
# is assumed to be a table id
506-
if not re.search(r"\s", query):
507550
try:
508-
rows = client.list_rows(query, max_results=max_results)
551+
query_job = _run_query(client, query, job_config=job_config)
509552
except Exception as ex:
510553
_handle_error(ex, args.destination_var)
511554
return
512555

513-
result = rows.to_dataframe(bqstorage_client=bqstorage_client)
514-
if args.destination_var:
515-
IPython.get_ipython().push({args.destination_var: result})
516-
return
517-
else:
518-
return result
519-
520-
job_config = bigquery.job.QueryJobConfig()
521-
job_config.query_parameters = params
522-
job_config.use_legacy_sql = args.use_legacy_sql
523-
job_config.dry_run = args.dry_run
556+
if not args.verbose:
557+
display.clear_output()
524558

525-
if args.destination_table:
526-
split = args.destination_table.split(".")
527-
if len(split) != 2:
528-
raise ValueError(
529-
"--destination_table should be in a <dataset_id>.<table_id> format."
559+
if args.dry_run and args.destination_var:
560+
IPython.get_ipython().push({args.destination_var: query_job})
561+
return
562+
elif args.dry_run:
563+
print(
564+
"Query validated. This query will process {} bytes.".format(
565+
query_job.total_bytes_processed
566+
)
530567
)
531-
dataset_id, table_id = split
532-
job_config.allow_large_results = True
533-
dataset_ref = client.dataset(dataset_id)
534-
destination_table_ref = dataset_ref.table(table_id)
535-
job_config.destination = destination_table_ref
536-
job_config.create_disposition = "CREATE_IF_NEEDED"
537-
job_config.write_disposition = "WRITE_TRUNCATE"
538-
_create_dataset_if_necessary(client, dataset_id)
539-
540-
if args.maximum_bytes_billed == "None":
541-
job_config.maximum_bytes_billed = 0
542-
elif args.maximum_bytes_billed is not None:
543-
value = int(args.maximum_bytes_billed)
544-
job_config.maximum_bytes_billed = value
545-
546-
try:
547-
query_job = _run_query(client, query, job_config=job_config)
548-
except Exception as ex:
549-
_handle_error(ex, args.destination_var)
550-
return
551-
552-
if not args.verbose:
553-
display.clear_output()
568+
return query_job
554569

555-
if args.dry_run and args.destination_var:
556-
IPython.get_ipython().push({args.destination_var: query_job})
557-
return
558-
elif args.dry_run:
559-
print(
560-
"Query validated. This query will process {} bytes.".format(
561-
query_job.total_bytes_processed
570+
if max_results:
571+
result = query_job.result(max_results=max_results).to_dataframe(
572+
bqstorage_client=bqstorage_client
562573
)
563-
)
564-
return query_job
565-
566-
if max_results:
567-
result = query_job.result(max_results=max_results).to_dataframe(
568-
bqstorage_client=bqstorage_client
569-
)
570-
else:
571-
result = query_job.to_dataframe(bqstorage_client=bqstorage_client)
574+
else:
575+
result = query_job.to_dataframe(bqstorage_client=bqstorage_client)
572576

573-
if args.destination_var:
574-
IPython.get_ipython().push({args.destination_var: result})
575-
else:
576-
return result
577+
if args.destination_var:
578+
IPython.get_ipython().push({args.destination_var: result})
579+
else:
580+
return result
581+
finally:
582+
close_transports()
577583

578584

579585
def _make_bqstorage_client(use_bqstorage_api, credentials):
@@ -601,3 +607,21 @@ def _make_bqstorage_client(use_bqstorage_api, credentials):
601607
credentials=credentials,
602608
client_info=gapic_client_info.ClientInfo(user_agent=IPYTHON_USER_AGENT),
603609
)
610+
611+
612+
def _close_transports(client, bqstorage_client):
613+
"""Close the given clients' underlying transport channels.
614+
615+
Closing the transport is needed to release system resources, namely open
616+
sockets.
617+
618+
Args:
619+
client (:class:`~google.cloud.bigquery.client.Client`):
620+
bqstorage_client
621+
(Optional[:class:`~google.cloud.bigquery_storage_v1beta1.BigQueryStorageClient`]):
622+
A client for the BigQuery Storage API.
623+
624+
"""
625+
client.close()
626+
if bqstorage_client is not None:
627+
bqstorage_client.transport.channel.close()

bigquery/noxfile.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def system(session):
8181
session.install("--pre", "grpcio")
8282

8383
# Install all test dependencies, then install local packages in place.
84-
session.install("mock", "pytest")
84+
session.install("mock", "pytest", "psutil")
8585
for local_dep in LOCAL_DEPS:
8686
session.install("-e", local_dep)
8787
session.install("-e", os.path.join("..", "storage"))

bigquery/tests/system.py

+28
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import re
2828

2929
import six
30+
import psutil
3031
import pytest
3132
import pytz
3233

@@ -203,6 +204,27 @@ def _create_bucket(self, bucket_name, location=None):
203204

204205
return bucket
205206

207+
def test_close_releases_open_sockets(self):
208+
current_process = psutil.Process()
209+
conn_count_start = len(current_process.connections())
210+
211+
client = Config.CLIENT
212+
client.query(
213+
"""
214+
SELECT
215+
source_year AS year, COUNT(is_male) AS birth_count
216+
FROM `bigquery-public-data.samples.natality`
217+
GROUP BY year
218+
ORDER BY year DESC
219+
LIMIT 15
220+
"""
221+
)
222+
223+
client.close()
224+
225+
conn_count_end = len(current_process.connections())
226+
self.assertEqual(conn_count_end, conn_count_start)
227+
206228
def test_create_dataset(self):
207229
DATASET_ID = _make_dataset_id("create_dataset")
208230
dataset = self.temp_dataset(DATASET_ID)
@@ -2417,6 +2439,9 @@ def temp_dataset(self, dataset_id, location=None):
24172439
@pytest.mark.usefixtures("ipython_interactive")
24182440
def test_bigquery_magic():
24192441
ip = IPython.get_ipython()
2442+
current_process = psutil.Process()
2443+
conn_count_start = len(current_process.connections())
2444+
24202445
ip.extension_manager.load_extension("google.cloud.bigquery")
24212446
sql = """
24222447
SELECT
@@ -2432,6 +2457,8 @@ def test_bigquery_magic():
24322457
with io.capture_output() as captured:
24332458
result = ip.run_cell_magic("bigquery", "", sql)
24342459

2460+
conn_count_end = len(current_process.connections())
2461+
24352462
lines = re.split("\n|\r", captured.stdout)
24362463
# Removes blanks & terminal code (result of display clearing)
24372464
updates = list(filter(lambda x: bool(x) and x != "\x1b[2K", lines))
@@ -2441,6 +2468,7 @@ def test_bigquery_magic():
24412468
assert isinstance(result, pandas.DataFrame)
24422469
assert len(result) == 10 # verify row count
24432470
assert list(result) == ["url", "view_count"] # verify column names
2471+
assert conn_count_end == conn_count_start # system resources are released
24442472

24452473

24462474
def _job_done(instance):

bigquery/tests/unit/test_client.py

+11
Original file line numberDiff line numberDiff line change
@@ -1398,6 +1398,17 @@ def test_create_table_alreadyexists_w_exists_ok_true(self):
13981398
]
13991399
)
14001400

1401+
def test_close(self):
1402+
creds = _make_credentials()
1403+
http = mock.Mock()
1404+
http._auth_request.session = mock.Mock()
1405+
client = self._make_one(project=self.PROJECT, credentials=creds, _http=http)
1406+
1407+
client.close()
1408+
1409+
http.close.assert_called_once()
1410+
http._auth_request.session.close.assert_called_once()
1411+
14011412
def test_get_model(self):
14021413
path = "projects/%s/datasets/%s/models/%s" % (
14031414
self.PROJECT,

bigquery/tests/unit/test_magics.py

+67
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,7 @@ def test_bigquery_magic_with_bqstorage_from_argument(monkeypatch):
545545
bqstorage_instance_mock = mock.create_autospec(
546546
bigquery_storage_v1beta1.BigQueryStorageClient, instance=True
547547
)
548+
bqstorage_instance_mock.transport = mock.Mock()
548549
bqstorage_mock.return_value = bqstorage_instance_mock
549550
bqstorage_client_patch = mock.patch(
550551
"google.cloud.bigquery_storage_v1beta1.BigQueryStorageClient", bqstorage_mock
@@ -601,6 +602,7 @@ def test_bigquery_magic_with_bqstorage_from_context(monkeypatch):
601602
bqstorage_instance_mock = mock.create_autospec(
602603
bigquery_storage_v1beta1.BigQueryStorageClient, instance=True
603604
)
605+
bqstorage_instance_mock.transport = mock.Mock()
604606
bqstorage_mock.return_value = bqstorage_instance_mock
605607
bqstorage_client_patch = mock.patch(
606608
"google.cloud.bigquery_storage_v1beta1.BigQueryStorageClient", bqstorage_mock
@@ -728,6 +730,41 @@ def test_bigquery_magic_w_max_results_valid_calls_queryjob_result():
728730
query_job_mock.result.assert_called_with(max_results=5)
729731

730732

733+
@pytest.mark.usefixtures("ipython_interactive")
734+
def test_bigquery_magic_w_max_results_query_job_results_fails():
735+
ip = IPython.get_ipython()
736+
ip.extension_manager.load_extension("google.cloud.bigquery")
737+
magics.context._project = None
738+
739+
credentials_mock = mock.create_autospec(
740+
google.auth.credentials.Credentials, instance=True
741+
)
742+
default_patch = mock.patch(
743+
"google.auth.default", return_value=(credentials_mock, "general-project")
744+
)
745+
client_query_patch = mock.patch(
746+
"google.cloud.bigquery.client.Client.query", autospec=True
747+
)
748+
close_transports_patch = mock.patch(
749+
"google.cloud.bigquery.magics._close_transports", autospec=True,
750+
)
751+
752+
sql = "SELECT 17 AS num"
753+
754+
query_job_mock = mock.create_autospec(
755+
google.cloud.bigquery.job.QueryJob, instance=True
756+
)
757+
query_job_mock.result.side_effect = [[], OSError]
758+
759+
with pytest.raises(
760+
OSError
761+
), client_query_patch as client_query_mock, default_patch, close_transports_patch as close_transports:
762+
client_query_mock.return_value = query_job_mock
763+
ip.run_cell_magic("bigquery", "--max_results=5", sql)
764+
765+
assert close_transports.called
766+
767+
731768
def test_bigquery_magic_w_table_id_invalid():
732769
ip = IPython.get_ipython()
733770
ip.extension_manager.load_extension("google.cloud.bigquery")
@@ -820,6 +857,7 @@ def test_bigquery_magic_w_table_id_and_bqstorage_client():
820857
bqstorage_instance_mock = mock.create_autospec(
821858
bigquery_storage_v1beta1.BigQueryStorageClient, instance=True
822859
)
860+
bqstorage_instance_mock.transport = mock.Mock()
823861
bqstorage_mock.return_value = bqstorage_instance_mock
824862
bqstorage_client_patch = mock.patch(
825863
"google.cloud.bigquery_storage_v1beta1.BigQueryStorageClient", bqstorage_mock
@@ -1290,3 +1328,32 @@ def test_bigquery_magic_w_destination_table():
12901328
assert job_config_used.write_disposition == "WRITE_TRUNCATE"
12911329
assert job_config_used.destination.dataset_id == "dataset_id"
12921330
assert job_config_used.destination.table_id == "table_id"
1331+
1332+
1333+
@pytest.mark.usefixtures("ipython_interactive")
1334+
def test_bigquery_magic_create_dataset_fails():
1335+
ip = IPython.get_ipython()
1336+
ip.extension_manager.load_extension("google.cloud.bigquery")
1337+
magics.context.credentials = mock.create_autospec(
1338+
google.auth.credentials.Credentials, instance=True
1339+
)
1340+
1341+
create_dataset_if_necessary_patch = mock.patch(
1342+
"google.cloud.bigquery.magics._create_dataset_if_necessary",
1343+
autospec=True,
1344+
side_effect=OSError,
1345+
)
1346+
close_transports_patch = mock.patch(
1347+
"google.cloud.bigquery.magics._close_transports", autospec=True,
1348+
)
1349+
1350+
with pytest.raises(
1351+
OSError
1352+
), create_dataset_if_necessary_patch, close_transports_patch as close_transports:
1353+
ip.run_cell_magic(
1354+
"bigquery",
1355+
"--destination_table dataset_id.table_id",
1356+
"SELECT foo FROM WHERE LIMIT bar",
1357+
)
1358+
1359+
assert close_transports.called

0 commit comments

Comments
 (0)