Skip to content

SNOW-1983343 add timeout for ocsp root certs #2338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
- v3.16(TBD)
- Bumped numpy dependency from <2.1.0 to <=2.2.4
- Added Windows support for Python 3.13.
- Add `ocsp_root_certs_dict_lock_timeout` connection parameter to set the timeout (in seconds) for acquiring the lock on the OCSP root certs dictionary. Default value for this parameter is -1 which indicates no timeout.

- v3.15.1(May 20, 2025)
- Added basic arrow support for Interval types.
Expand Down
5 changes: 5 additions & 0 deletions src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,10 @@ def _get_private_bytes_from_file(
True,
bool,
), # SNOW-XXXXX: remove the check_arrow_conversion_error_on_every_column flag
"ocsp_root_certs_dict_lock_timeout": (
-1,
int,
),
}

APPLICATION_RE = re.compile(r"[\w\d_]+")
Expand Down Expand Up @@ -445,6 +449,7 @@ class SnowflakeConnection:
token_file_path: The file path of the token file. If both token and token_file_path are provided, the token in token_file_path will be used.
unsafe_file_write: When true, files downloaded by GET will be saved with 644 permissions. Otherwise, files will be saved with safe - owner-only permissions: 600.
check_arrow_conversion_error_on_every_column: When true, the error check after the conversion from arrow to python types will happen for every column in the row. This is a new behaviour which fixes the bug that caused the type errors to trigger silently when occurring at any place other than last column in a row. To revert the previous (faulty) behaviour, please set this flag to false.
ocsp_root_certs_dict_lock_timeout: Timeout for the OCSP root certs dict lock in seconds. Default value is -1, which means no timeout.
"""

OCSP_ENV_LOCK = Lock()
Expand Down
6 changes: 6 additions & 0 deletions src/snowflake/connector/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,12 @@ def __init__(
ssl_wrap_socket.FEATURE_OCSP_RESPONSE_CACHE_FILE_NAME = (
self._connection._ocsp_response_cache_filename if self._connection else None
)
# OCSP root timeout
ssl_wrap_socket.FEATURE_ROOT_CERTS_DICT_LOCK_TIMEOUT = (
self._connection._ocsp_root_certs_dict_lock_timeout
if self._connection
else -1
)

# This is to address the issue where requests hangs
_ = "dummy".encode("idna").decode("utf-8")
Expand Down
9 changes: 8 additions & 1 deletion src/snowflake/connector/ocsp_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,6 +1032,7 @@ def __init__(
use_ocsp_cache_server=None,
use_post_method: bool = True,
use_fail_open: bool = True,
root_certs_dict_lock_timeout: int = -1,
**kwargs,
) -> None:
self.test_mode = os.getenv("SF_OCSP_TEST_MODE", None)
Expand All @@ -1040,6 +1041,7 @@ def __init__(
logger.debug("WARNING - DRIVER CONFIGURED IN TEST MODE")

self._use_post_method = use_post_method
self._root_certs_dict_lock_timeout = root_certs_dict_lock_timeout
self.OCSP_CACHE_SERVER = OCSPServer(
top_level_domain=extract_top_level_domain_from_hostname(
kwargs.pop("hostname", None)
Expand Down Expand Up @@ -1410,7 +1412,10 @@ def _check_ocsp_response_cache_server(

def _lazy_read_ca_bundle(self) -> None:
"""Reads the local cabundle file and cache it in memory."""
with SnowflakeOCSP.ROOT_CERTIFICATES_DICT_LOCK:
SnowflakeOCSP.ROOT_CERTIFICATES_DICT_LOCK.acquire(
timeout=self._root_certs_dict_lock_timeout
)
try:
if SnowflakeOCSP.ROOT_CERTIFICATES_DICT:
# return if already loaded
return
Expand Down Expand Up @@ -1471,6 +1476,8 @@ def _lazy_read_ca_bundle(self) -> None:
"No CA bundle file is found in the system. "
"Set REQUESTS_CA_BUNDLE to the file."
)
finally:
SnowflakeOCSP.ROOT_CERTIFICATES_DICT_LOCK.release()

@staticmethod
def _calculate_tolerable_validity(this_update: float, next_update: float) -> int:
Expand Down
2 changes: 2 additions & 0 deletions src/snowflake/connector/ssl_wrap_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

DEFAULT_OCSP_MODE: OCSPMode = OCSPMode.FAIL_OPEN
FEATURE_OCSP_MODE: OCSPMode = DEFAULT_OCSP_MODE
FEATURE_ROOT_CERTS_DICT_LOCK_TIMEOUT: int = -1

"""
OCSP Response cache file name
Expand Down Expand Up @@ -84,6 +85,7 @@ def ssl_wrap_socket_with_ocsp(*args: Any, **kwargs: Any) -> WrappedSocket:
ocsp_response_cache_uri=FEATURE_OCSP_RESPONSE_CACHE_FILE_NAME,
use_fail_open=FEATURE_OCSP_MODE == OCSPMode.FAIL_OPEN,
hostname=server_hostname,
root_certs_dict_lock_timeout=FEATURE_ROOT_CERTS_DICT_LOCK_TIMEOUT,
).validate(server_hostname, ret.connection)
if not v:
raise OperationalError(
Expand Down
90 changes: 47 additions & 43 deletions test/unit/test_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,6 @@ class OCSPMode(Enum):
url_3 = f"https://{hostname_2}/rgm1-s-sfctst0/stages/another-url"


mock_conn = mock.Mock()
mock_conn.disable_request_pooling = False
mock_conn._ocsp_mode = lambda: DEFAULT_OCSP_MODE


def close_sessions(rest: SnowflakeRestful, num_session_pools: int) -> None:
"""Helper function to call SnowflakeRestful.close(). Asserts close was called on all SessionPools."""
with mock.patch("snowflake.connector.network.SessionPool.close") as close_mock:
Expand All @@ -50,59 +45,68 @@ def create_session(

@mock.patch("snowflake.connector.network.SnowflakeRestful.make_requests_session")
def test_no_url_multiple_sessions(make_session_mock):
rest = SnowflakeRestful(connection=mock_conn)
with mock.patch("snowflake.connector.SnowflakeConnection") as mock_conn:
mock_conn.disable_request_pooling = False
mock_conn._ocsp_mode = lambda: DEFAULT_OCSP_MODE
rest = SnowflakeRestful(connection=mock_conn)

create_session(rest, 2)
create_session(rest, 2)

assert make_session_mock.call_count == 2
assert make_session_mock.call_count == 2

assert list(rest._sessions_map.keys()) == [None]
assert list(rest._sessions_map.keys()) == [None]

session_pool = rest._sessions_map[None]
assert len(session_pool._idle_sessions) == 2
assert len(session_pool._active_sessions) == 0
session_pool = rest._sessions_map[None]
assert len(session_pool._idle_sessions) == 2
assert len(session_pool._active_sessions) == 0

close_sessions(rest, 1)
close_sessions(rest, 1)


@mock.patch("snowflake.connector.network.SnowflakeRestful.make_requests_session")
def test_multiple_urls_multiple_sessions(make_session_mock):
rest = SnowflakeRestful(connection=mock_conn)
with mock.patch("snowflake.connector.SnowflakeConnection") as mock_conn:
mock_conn.disable_request_pooling = False
mock_conn._ocsp_mode = lambda: DEFAULT_OCSP_MODE
rest = SnowflakeRestful(connection=mock_conn)

for url in [url_1, url_2, None]:
create_session(rest, num_sessions=2, url=url)
for url in [url_1, url_2, None]:
create_session(rest, num_sessions=2, url=url)

assert make_session_mock.call_count == 6
assert make_session_mock.call_count == 6

hostnames = list(rest._sessions_map.keys())
for hostname in [hostname_1, hostname_2, None]:
assert hostname in hostnames
hostnames = list(rest._sessions_map.keys())
for hostname in [hostname_1, hostname_2, None]:
assert hostname in hostnames

for pool in rest._sessions_map.values():
assert len(pool._idle_sessions) == 2
assert len(pool._active_sessions) == 0
for pool in rest._sessions_map.values():
assert len(pool._idle_sessions) == 2
assert len(pool._active_sessions) == 0

close_sessions(rest, 3)
close_sessions(rest, 3)


@mock.patch("snowflake.connector.network.SnowflakeRestful.make_requests_session")
def test_multiple_urls_reuse_sessions(make_session_mock):
rest = SnowflakeRestful(connection=mock_conn)
for url in [url_1, url_2, url_3, None]:
# create 10 sessions, one after another
for _ in range(10):
create_session(rest, url=url)

# only one session is created and reused thereafter
assert make_session_mock.call_count == 3

hostnames = list(rest._sessions_map.keys())
assert len(hostnames) == 3
for hostname in [hostname_1, hostname_2, None]:
assert hostname in hostnames

for pool in rest._sessions_map.values():
assert len(pool._idle_sessions) == 1
assert len(pool._active_sessions) == 0

close_sessions(rest, 3)
with mock.patch("snowflake.connector.SnowflakeConnection") as mock_conn:
mock_conn.disable_request_pooling = False
mock_conn._ocsp_mode = lambda: DEFAULT_OCSP_MODE
rest = SnowflakeRestful(connection=mock_conn)
for url in [url_1, url_2, url_3, None]:
# create 10 sessions, one after another
for _ in range(10):
create_session(rest, url=url)

# only one session is created and reused thereafter
assert make_session_mock.call_count == 3

hostnames = list(rest._sessions_map.keys())
assert len(hostnames) == 3
for hostname in [hostname_1, hostname_2, None]:
assert hostname in hostnames

for pool in rest._sessions_map.values():
assert len(pool._idle_sessions) == 1
assert len(pool._active_sessions) == 0

close_sessions(rest, 3)
1 change: 0 additions & 1 deletion test/unit/test_wiremock_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
except ImportError:
import requests


from ..wiremock.wiremock_utils import WiremockClient


Expand Down
Loading