Skip to content

Fix for authentication expiry - Automatically attempt to refresh auth tokens before they expire #317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion dask_kubernetes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
from . import config
from .auth import ClusterAuth, KubeAuth, KubeConfig, InCluster
from .auth import (
ClusterAuth,
KubeAuth,
KubeConfig,
InCluster,
AutoRefreshKubeConfigLoader,
AutoRefreshConfiguration,
)
from .core import KubeCluster
from .helm import HelmCluster
from .objects import make_pod_spec, make_pod_from_dict, clean_pod_template
Expand Down
318 changes: 311 additions & 7 deletions dask_kubernetes/auth.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,299 @@
"""
Defines different methods to configure a connection to a Kubernetes cluster.
"""
import asyncio
import base64
import contextlib
import copy
import datetime
import json
import logging
import os

import kubernetes
import kubernetes_asyncio

from kubernetes_asyncio.client import Configuration
from kubernetes_asyncio.config.kube_config import KubeConfigLoader, KubeConfigMerger
from kubernetes_asyncio.config.google_auth import google_auth_credentials
from kubernetes_asyncio.config.dateutil import parse_rfc3339

logger = logging.getLogger(__name__)

tzUTC = datetime.timezone.utc


class AutoRefreshKubeConfigLoader(KubeConfigLoader):
"""
Extends KubeConfigLoader, automatically attempts to refresh authentication
credentials before they expire.
"""

def __init__(self, *args, **kwargs):
super(AutoRefreshKubeConfigLoader, self).__init__(*args, **kwargs)

self._retry_count = 0
self._max_retries = float("Inf")
self.auto_refresh = True
self.refresh_task = None
self.last_refreshed = None
self.token_expire_ts = None

def __del__(self):
self.auto_refresh = False

def extract_oid_expiration_from_provider(self, provider):
"""
Extracts the expiration datestamp for the provider token
Parameters
----------
provider : authentication provider dictionary.

Returns
-------
expires : expiration timestamp
"""
parts = provider["config"]["id-token"].split(".")

if len(parts) != 3:
raise ValueError("oidc: JWT tokens should contain 3 period-delimited parts")

id_token = parts[1]
# Re-pad the unpadded JWT token
id_token += (4 - len(id_token) % 4) * "="
jwt_attributes = json.loads(base64.b64decode(id_token).decode("utf8"))
expires = jwt_attributes.get("exp")

return expires

async def create_refresh_task_from_expiration_timestamp(self, expiration_timestamp):
"""
Takes an expiration timestamp, and creates a refresh task to ensure that the token
does not expire.

Parameters
----------
expiration_timestamp : time at which the current authentication token will expire

Returns
-------
N/A
"""
# Set our token expiry to be actual expiry - 20%
expiry = parse_rfc3339(expiration_timestamp)
expiry_delta = datetime.timedelta(
seconds=(expiry - datetime.datetime.now(tz=tzUTC)).total_seconds()
)
scaled_expiry_delta = datetime.timedelta(
seconds=0.8 * expiry_delta.total_seconds()
)

self.refresh_task = asyncio.create_task(
self.refresh_after(
when=scaled_expiry_delta.total_seconds(), reschedule_on_failure=True
),
name="dask_auth_auto_refresh",
)

self.last_refreshed = datetime.datetime.now(tz=tzUTC)
self.token_expire_ts = self.last_refreshed + scaled_expiry_delta

async def refresh_after(self, when, reschedule_on_failure=False):
"""
Refresh kuberenetes authentication
Parameters
----------
when : Seconds before we should refresh. This should be set to some delta before
the actual token expiration time, or you will likely see authentication race
/ failure conditions.

reschedule_on_failure : If the refresh task fails, re-try in 30 seconds, until
_max_retries is exceeded, then raise an exception.
"""

if not self.auto_refresh:
return

logger.debug(
msg=f"Refresh_at coroutine sleeping for "
f"{int(when // 60)} minutes {(when % 60):0.2f} seconds."
)
try:
await asyncio.sleep(when)
if self.provider == "gcp":
await self.refresh_gcp_token()
elif self.provider == "oidc":
await self.refresh_oid_token()
return
elif "exec" in self._user:
logger.warning(msg="Auto-refresh doesn't support generic ExecProvider")
return

except Exception as e:
logger.warning(
msg=f"Authentication refresh failed for provider '{self.provider}.'",
exc_info=e,
)
if not reschedule_on_failure or self._retry_count > self._max_retries:
raise

logger.warning(msg=f"Retrying '{self.provider}' in 30 seconds.")
self._retry_count += 1
self.refresh_task = asyncio.create_task(self.refresh_after(30))

async def refresh_oid_token(self):
"""
Adapted from kubernetes_asyncio/config/kube_config:_load_oid_token

Refreshes the existing oid token, if necessary, and creates a refresh task
that will keep the token from expiring.

Returns
-------
"""
provider = self._user["auth-provider"]

logger.debug("Refreshing OID token.")

if "config" not in provider:
raise ValueError("oidc: missing configuration")

if (not self.token_expire_ts) or (
self.token_expire_ts <= datetime.datetime.now(tz=tzUTC)
):
await self._refresh_oidc(provider)
expires = self.extract_oid_expiration_from_provider(provider=provider)

await self.create_refresh_task_from_expiration_timestamp(
expiration_timestamp=expires
)

self.token = "Bearer {}".format(provider["config"]["id-token"])

async def refresh_gcp_token(self):
"""
Adapted from kubernetes_asyncio/config/kube_config:load_gcp_token

Refreshes the existing gcp token, if necessary, and creates a refresh task
that will keep the token from expiring.

Returns
-------
"""
if "config" not in self._user["auth-provider"]:
self._user["auth-provider"].value["config"] = {}

config = self._user["auth-provider"]["config"]

if (not self.token_expire_ts) or (
self.token_expire_ts <= datetime.datetime.now(tz=tzUTC)
):

logger.debug("Refreshing GCP token.")
if self._get_google_credentials is not None:
if asyncio.iscoroutinefunction(self._get_google_credentials):
credentials = await self._get_google_credentials()
else:
credentials = self._get_google_credentials()
else:
# config is read-only.
extra_args = " --force-auth-refresh"
_config = {
"cmd-args": config["cmd-args"] + extra_args,
"cmd-path": config["cmd-path"],
}
credentials = await google_auth_credentials(_config)

config.value["access-token"] = credentials.token
config.value["expiry"] = credentials.expiry

# Set our token expiry to be actual expiry - 20%
await self.create_refresh_task_from_expiration_timestamp(
expiration_timestamp=config.value["expiry"]
)

if self._config_persister:
self._config_persister(self._config.value)

self.token = "Bearer %s" % config["access-token"]

async def _load_oid_token(self):
"""
Overrides KubeConfigLoader implementation.
Returns
-------
Auth token
"""
await self.refresh_oid_token()

return self.token

async def load_gcp_token(self):
"""
Override KubeConfigLoader implementation so that we can keep track of the expiration timestamp
and automatically refresh auth tokens.

Returns
-------
GCP access token
"""
await self.refresh_gcp_token()

return self.token


class AutoRefreshConfiguration(Configuration):
"""
Extends kubernetes_async Configuration to support automatic token refresh.
Lets us keep track of the original loader object, which can be used
to regenerate the authentication token.
"""

def __init__(self, loader, refresh_frequency=None, *args, **kwargs):
super(AutoRefreshConfiguration, self).__init__(*args, **kwargs)

# Set refresh api callback
self.refresh_api_key_hook = self.refresh_api_key
self.last_refreshed = datetime.datetime.now(tz=tzUTC)
self.loader = loader

# Adapted from kubernetes_asyncio/client/configuration.py:__deepcopy__
def __deepcopy__(self, memo):
"""
Modified so that we don't try to deep copy the loader off the config
"""
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k not in ("logger", "logger_file_handler", "loader"):
setattr(result, k, copy.deepcopy(v, memo))

# shallow copy loader object
result.loader = self.loader
# shallow copy of loggers
result.logger = copy.copy(self.logger)
# use setters to configure loggers
result.logger_file = self.logger_file
result.debug = self.debug

return result

def refresh_api_key(self, client_configuration):
"""
Checks to see if the loader has updated the authentication token. If it
has, the token is copied from the loader into the current configuration.

This function is assigned to Configuration.refresh_api_key_hook, and will
fire when entering get_api_key_with_prefix, before the api_key is retrieved.
"""
if self.last_refreshed < self.loader.last_refreshed:
logger.debug("Entering refresh_api_key_hook")
client_configuration.api_key[
"authorization"
] = client_configuration.loader.token
self.last_refreshed = datetime.datetime.now(tz=tzUTC)


class ClusterAuth(object):
"""
Expand Down Expand Up @@ -45,7 +329,6 @@ async def load_first(auth=None):

Parameters
----------

auth: List[ClusterAuth] (optional)
Configuration methods to attempt in order. Defaults to
``[InCluster(), KubeConfig()]``.
Expand Down Expand Up @@ -127,15 +410,35 @@ async def load(self):
with contextlib.suppress(KeyError):
if self.config_file is None:
self.config_file = os.path.abspath(
os.path.expanduser(os.environ["KUBECONFIG"])
os.path.expanduser(os.environ.get("KUBECONFIG", "~/.kube/config"))
)
kubernetes.config.load_kube_config(
self.config_file, self.context, None, self.persist_config
)
await kubernetes_asyncio.config.load_kube_config(
self.config_file, self.context, None, self.persist_config

await self.load_kube_config()

# Adapted from from kubernetes_asyncio/config/kube_config.py:get_kube_config_loader_for_yaml_file
def get_kube_config_loader_for_yaml_file(self):
kcfg = KubeConfigMerger(self.config_file)
config_persister = None
if self.persist_config:
config_persister = kcfg.save_changes()

return AutoRefreshKubeConfigLoader(
config_dict=kcfg.config,
config_base_path=None,
config_persister=config_persister,
)

# Adapted from kubernetes_asyncio/config/kube_config.py:load_kube_config
async def load_kube_config(self):
# Create a config loader, this will automatically refresh our credentials before they expire
loader = self.get_kube_config_loader_for_yaml_file()

# Grab our async + callback aware configuration
config = AutoRefreshConfiguration(loader)

await loader.load_and_set(config)
Configuration.set_default(config)


class KubeAuth(ClusterAuth):
"""Configure the Kubernetes connection explicitly.
Expand Down Expand Up @@ -172,6 +475,7 @@ def __init__(self, host, **kwargs):
# values.
config = type.__call__(kubernetes.client.Configuration)
config.host = host

for key, value in kwargs.items():
setattr(config, key, value)
self.config = config
Expand Down
25 changes: 25 additions & 0 deletions dask_kubernetes/tests/fake_gcp_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import datetime
import json

expiry = datetime.datetime.now() + datetime.timedelta(seconds=5)
expiry.replace(tzinfo=datetime.timezone.utc)
expiry_str = expiry.isoformat("T") + "Z"

fake_token = "0" * 137
fake_id = "abcdefghijklmnopqrstuvwxyz.1234567890" * 37 + "." * 32

data = """
{
"credential": {
"access_token": "%s",
"id_token": "%s",
"token_expiry": "%s"
}
}
""" % (
fake_token,
fake_id,
expiry_str,
)

print(json.dumps(json.loads(data), indent=4))
Loading