Skip to content

Commit ca70697

Browse files
authored
Automatically attempt to refresh auth tokens before they expire (#317)
* Proof of concept + clean impl, for discussion * Minor tweaks * Update to force auth token refresh on the server * Clean up update process, add error checking and exception handling * Add oidc auto-refresh * Add refresh unit test, fake data generator and init imports * Formatting updates Co-authored-by: Devin Robison <[email protected]>
1 parent 7c6b8aa commit ca70697

File tree

4 files changed

+416
-9
lines changed

4 files changed

+416
-9
lines changed

dask_kubernetes/__init__.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
from . import config
2-
from .auth import ClusterAuth, KubeAuth, KubeConfig, InCluster
2+
from .auth import (
3+
ClusterAuth,
4+
KubeAuth,
5+
KubeConfig,
6+
InCluster,
7+
AutoRefreshKubeConfigLoader,
8+
AutoRefreshConfiguration,
9+
)
310
from .core import KubeCluster
411
from .helm import HelmCluster
512
from .objects import make_pod_spec, make_pod_from_dict, clean_pod_template

dask_kubernetes/auth.py

+311-7
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,299 @@
11
"""
22
Defines different methods to configure a connection to a Kubernetes cluster.
33
"""
4+
import asyncio
5+
import base64
46
import contextlib
7+
import copy
8+
import datetime
9+
import json
510
import logging
611
import os
712

813
import kubernetes
914
import kubernetes_asyncio
1015

16+
from kubernetes_asyncio.client import Configuration
17+
from kubernetes_asyncio.config.kube_config import KubeConfigLoader, KubeConfigMerger
18+
from kubernetes_asyncio.config.google_auth import google_auth_credentials
19+
from kubernetes_asyncio.config.dateutil import parse_rfc3339
20+
1121
logger = logging.getLogger(__name__)
1222

23+
tzUTC = datetime.timezone.utc
24+
25+
26+
class AutoRefreshKubeConfigLoader(KubeConfigLoader):
27+
"""
28+
Extends KubeConfigLoader, automatically attempts to refresh authentication
29+
credentials before they expire.
30+
"""
31+
32+
def __init__(self, *args, **kwargs):
33+
super(AutoRefreshKubeConfigLoader, self).__init__(*args, **kwargs)
34+
35+
self._retry_count = 0
36+
self._max_retries = float("Inf")
37+
self.auto_refresh = True
38+
self.refresh_task = None
39+
self.last_refreshed = None
40+
self.token_expire_ts = None
41+
42+
def __del__(self):
43+
self.auto_refresh = False
44+
45+
def extract_oid_expiration_from_provider(self, provider):
46+
"""
47+
Extracts the expiration datestamp for the provider token
48+
Parameters
49+
----------
50+
provider : authentication provider dictionary.
51+
52+
Returns
53+
-------
54+
expires : expiration timestamp
55+
"""
56+
parts = provider["config"]["id-token"].split(".")
57+
58+
if len(parts) != 3:
59+
raise ValueError("oidc: JWT tokens should contain 3 period-delimited parts")
60+
61+
id_token = parts[1]
62+
# Re-pad the unpadded JWT token
63+
id_token += (4 - len(id_token) % 4) * "="
64+
jwt_attributes = json.loads(base64.b64decode(id_token).decode("utf8"))
65+
expires = jwt_attributes.get("exp")
66+
67+
return expires
68+
69+
async def create_refresh_task_from_expiration_timestamp(self, expiration_timestamp):
70+
"""
71+
Takes an expiration timestamp, and creates a refresh task to ensure that the token
72+
does not expire.
73+
74+
Parameters
75+
----------
76+
expiration_timestamp : time at which the current authentication token will expire
77+
78+
Returns
79+
-------
80+
N/A
81+
"""
82+
# Set our token expiry to be actual expiry - 20%
83+
expiry = parse_rfc3339(expiration_timestamp)
84+
expiry_delta = datetime.timedelta(
85+
seconds=(expiry - datetime.datetime.now(tz=tzUTC)).total_seconds()
86+
)
87+
scaled_expiry_delta = datetime.timedelta(
88+
seconds=0.8 * expiry_delta.total_seconds()
89+
)
90+
91+
self.refresh_task = asyncio.create_task(
92+
self.refresh_after(
93+
when=scaled_expiry_delta.total_seconds(), reschedule_on_failure=True
94+
),
95+
name="dask_auth_auto_refresh",
96+
)
97+
98+
self.last_refreshed = datetime.datetime.now(tz=tzUTC)
99+
self.token_expire_ts = self.last_refreshed + scaled_expiry_delta
100+
101+
async def refresh_after(self, when, reschedule_on_failure=False):
102+
"""
103+
Refresh kuberenetes authentication
104+
Parameters
105+
----------
106+
when : Seconds before we should refresh. This should be set to some delta before
107+
the actual token expiration time, or you will likely see authentication race
108+
/ failure conditions.
109+
110+
reschedule_on_failure : If the refresh task fails, re-try in 30 seconds, until
111+
_max_retries is exceeded, then raise an exception.
112+
"""
113+
114+
if not self.auto_refresh:
115+
return
116+
117+
logger.debug(
118+
msg=f"Refresh_at coroutine sleeping for "
119+
f"{int(when // 60)} minutes {(when % 60):0.2f} seconds."
120+
)
121+
try:
122+
await asyncio.sleep(when)
123+
if self.provider == "gcp":
124+
await self.refresh_gcp_token()
125+
elif self.provider == "oidc":
126+
await self.refresh_oid_token()
127+
return
128+
elif "exec" in self._user:
129+
logger.warning(msg="Auto-refresh doesn't support generic ExecProvider")
130+
return
131+
132+
except Exception as e:
133+
logger.warning(
134+
msg=f"Authentication refresh failed for provider '{self.provider}.'",
135+
exc_info=e,
136+
)
137+
if not reschedule_on_failure or self._retry_count > self._max_retries:
138+
raise
139+
140+
logger.warning(msg=f"Retrying '{self.provider}' in 30 seconds.")
141+
self._retry_count += 1
142+
self.refresh_task = asyncio.create_task(self.refresh_after(30))
143+
144+
async def refresh_oid_token(self):
145+
"""
146+
Adapted from kubernetes_asyncio/config/kube_config:_load_oid_token
147+
148+
Refreshes the existing oid token, if necessary, and creates a refresh task
149+
that will keep the token from expiring.
150+
151+
Returns
152+
-------
153+
"""
154+
provider = self._user["auth-provider"]
155+
156+
logger.debug("Refreshing OID token.")
157+
158+
if "config" not in provider:
159+
raise ValueError("oidc: missing configuration")
160+
161+
if (not self.token_expire_ts) or (
162+
self.token_expire_ts <= datetime.datetime.now(tz=tzUTC)
163+
):
164+
await self._refresh_oidc(provider)
165+
expires = self.extract_oid_expiration_from_provider(provider=provider)
166+
167+
await self.create_refresh_task_from_expiration_timestamp(
168+
expiration_timestamp=expires
169+
)
170+
171+
self.token = "Bearer {}".format(provider["config"]["id-token"])
172+
173+
async def refresh_gcp_token(self):
174+
"""
175+
Adapted from kubernetes_asyncio/config/kube_config:load_gcp_token
176+
177+
Refreshes the existing gcp token, if necessary, and creates a refresh task
178+
that will keep the token from expiring.
179+
180+
Returns
181+
-------
182+
"""
183+
if "config" not in self._user["auth-provider"]:
184+
self._user["auth-provider"].value["config"] = {}
185+
186+
config = self._user["auth-provider"]["config"]
187+
188+
if (not self.token_expire_ts) or (
189+
self.token_expire_ts <= datetime.datetime.now(tz=tzUTC)
190+
):
191+
192+
logger.debug("Refreshing GCP token.")
193+
if self._get_google_credentials is not None:
194+
if asyncio.iscoroutinefunction(self._get_google_credentials):
195+
credentials = await self._get_google_credentials()
196+
else:
197+
credentials = self._get_google_credentials()
198+
else:
199+
# config is read-only.
200+
extra_args = " --force-auth-refresh"
201+
_config = {
202+
"cmd-args": config["cmd-args"] + extra_args,
203+
"cmd-path": config["cmd-path"],
204+
}
205+
credentials = await google_auth_credentials(_config)
206+
207+
config.value["access-token"] = credentials.token
208+
config.value["expiry"] = credentials.expiry
209+
210+
# Set our token expiry to be actual expiry - 20%
211+
await self.create_refresh_task_from_expiration_timestamp(
212+
expiration_timestamp=config.value["expiry"]
213+
)
214+
215+
if self._config_persister:
216+
self._config_persister(self._config.value)
217+
218+
self.token = "Bearer %s" % config["access-token"]
219+
220+
async def _load_oid_token(self):
221+
"""
222+
Overrides KubeConfigLoader implementation.
223+
Returns
224+
-------
225+
Auth token
226+
"""
227+
await self.refresh_oid_token()
228+
229+
return self.token
230+
231+
async def load_gcp_token(self):
232+
"""
233+
Override KubeConfigLoader implementation so that we can keep track of the expiration timestamp
234+
and automatically refresh auth tokens.
235+
236+
Returns
237+
-------
238+
GCP access token
239+
"""
240+
await self.refresh_gcp_token()
241+
242+
return self.token
243+
244+
245+
class AutoRefreshConfiguration(Configuration):
246+
"""
247+
Extends kubernetes_async Configuration to support automatic token refresh.
248+
Lets us keep track of the original loader object, which can be used
249+
to regenerate the authentication token.
250+
"""
251+
252+
def __init__(self, loader, refresh_frequency=None, *args, **kwargs):
253+
super(AutoRefreshConfiguration, self).__init__(*args, **kwargs)
254+
255+
# Set refresh api callback
256+
self.refresh_api_key_hook = self.refresh_api_key
257+
self.last_refreshed = datetime.datetime.now(tz=tzUTC)
258+
self.loader = loader
259+
260+
# Adapted from kubernetes_asyncio/client/configuration.py:__deepcopy__
261+
def __deepcopy__(self, memo):
262+
"""
263+
Modified so that we don't try to deep copy the loader off the config
264+
"""
265+
cls = self.__class__
266+
result = cls.__new__(cls)
267+
memo[id(self)] = result
268+
for k, v in self.__dict__.items():
269+
if k not in ("logger", "logger_file_handler", "loader"):
270+
setattr(result, k, copy.deepcopy(v, memo))
271+
272+
# shallow copy loader object
273+
result.loader = self.loader
274+
# shallow copy of loggers
275+
result.logger = copy.copy(self.logger)
276+
# use setters to configure loggers
277+
result.logger_file = self.logger_file
278+
result.debug = self.debug
279+
280+
return result
281+
282+
def refresh_api_key(self, client_configuration):
283+
"""
284+
Checks to see if the loader has updated the authentication token. If it
285+
has, the token is copied from the loader into the current configuration.
286+
287+
This function is assigned to Configuration.refresh_api_key_hook, and will
288+
fire when entering get_api_key_with_prefix, before the api_key is retrieved.
289+
"""
290+
if self.last_refreshed < self.loader.last_refreshed:
291+
logger.debug("Entering refresh_api_key_hook")
292+
client_configuration.api_key[
293+
"authorization"
294+
] = client_configuration.loader.token
295+
self.last_refreshed = datetime.datetime.now(tz=tzUTC)
296+
13297

14298
class ClusterAuth(object):
15299
"""
@@ -45,7 +329,6 @@ async def load_first(auth=None):
45329
46330
Parameters
47331
----------
48-
49332
auth: List[ClusterAuth] (optional)
50333
Configuration methods to attempt in order. Defaults to
51334
``[InCluster(), KubeConfig()]``.
@@ -127,15 +410,35 @@ async def load(self):
127410
with contextlib.suppress(KeyError):
128411
if self.config_file is None:
129412
self.config_file = os.path.abspath(
130-
os.path.expanduser(os.environ["KUBECONFIG"])
413+
os.path.expanduser(os.environ.get("KUBECONFIG", "~/.kube/config"))
131414
)
132-
kubernetes.config.load_kube_config(
133-
self.config_file, self.context, None, self.persist_config
134-
)
135-
await kubernetes_asyncio.config.load_kube_config(
136-
self.config_file, self.context, None, self.persist_config
415+
416+
await self.load_kube_config()
417+
418+
# Adapted from from kubernetes_asyncio/config/kube_config.py:get_kube_config_loader_for_yaml_file
419+
def get_kube_config_loader_for_yaml_file(self):
420+
kcfg = KubeConfigMerger(self.config_file)
421+
config_persister = None
422+
if self.persist_config:
423+
config_persister = kcfg.save_changes()
424+
425+
return AutoRefreshKubeConfigLoader(
426+
config_dict=kcfg.config,
427+
config_base_path=None,
428+
config_persister=config_persister,
137429
)
138430

431+
# Adapted from kubernetes_asyncio/config/kube_config.py:load_kube_config
432+
async def load_kube_config(self):
433+
# Create a config loader, this will automatically refresh our credentials before they expire
434+
loader = self.get_kube_config_loader_for_yaml_file()
435+
436+
# Grab our async + callback aware configuration
437+
config = AutoRefreshConfiguration(loader)
438+
439+
await loader.load_and_set(config)
440+
Configuration.set_default(config)
441+
139442

140443
class KubeAuth(ClusterAuth):
141444
"""Configure the Kubernetes connection explicitly.
@@ -172,6 +475,7 @@ def __init__(self, host, **kwargs):
172475
# values.
173476
config = type.__call__(kubernetes.client.Configuration)
174477
config.host = host
478+
175479
for key, value in kwargs.items():
176480
setattr(config, key, value)
177481
self.config = config
+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import datetime
2+
import json
3+
4+
expiry = datetime.datetime.now() + datetime.timedelta(seconds=5)
5+
expiry.replace(tzinfo=datetime.timezone.utc)
6+
expiry_str = expiry.isoformat("T") + "Z"
7+
8+
fake_token = "0" * 137
9+
fake_id = "abcdefghijklmnopqrstuvwxyz.1234567890" * 37 + "." * 32
10+
11+
data = """
12+
{
13+
"credential": {
14+
"access_token": "%s",
15+
"id_token": "%s",
16+
"token_expiry": "%s"
17+
}
18+
}
19+
""" % (
20+
fake_token,
21+
fake_id,
22+
expiry_str,
23+
)
24+
25+
print(json.dumps(json.loads(data), indent=4))

0 commit comments

Comments
 (0)