Skip to content

Commit 9d0b2e1

Browse files
committed
Added token_is_long_lived to BaseTokenBackend
Changed token_expiration_datetime and token_is_expired to methods First part redoing the refresh token operation. The refresh token has been moved to the oauth_request method by catching a custom exception raised from internal_request
1 parent af98c88 commit 9d0b2e1

File tree

3 files changed

+81
-50
lines changed

3 files changed

+81
-50
lines changed

O365/account.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def is_authenticated(self) -> bool:
7777
if token is None:
7878
self.con.token_backend.load_token()
7979

80-
return not self.con.token_backend.token_is_expired
80+
return not self.con.token_backend.token_is_expired(refresh_token=True)
8181

8282
def authenticate(self, *, scopes: Optional[list] = None,
8383
handle_consent: Callable = consent_input_token, **kwargs) -> bool:

O365/connection.py

+69-43
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from urllib.parse import urlparse, parse_qs
77

88
from msal import ConfidentialClientApplication, PublicClientApplication
9-
from oauthlib.oauth2 import TokenExpiredError, WebApplicationClient, BackendApplicationClient, LegacyApplicationClient
109
from requests import Session
1110
from requests.adapters import HTTPAdapter
1211
from requests.exceptions import HTTPError, RequestException, ProxyError
@@ -62,6 +61,10 @@
6261
}
6362

6463

64+
class TokenExpiredError(HTTPError):
65+
pass
66+
67+
6568
class Protocol:
6669
""" Base class for all protocols """
6770

@@ -359,6 +362,7 @@ def __init__(self, credentials, *, scopes=None,
359362
timeout=None, json_encoder=None,
360363
verify_ssl=True,
361364
default_headers: dict = None,
365+
store_token_after_refresh: bool = True,
362366
**kwargs):
363367
""" Creates an API connection object
364368
@@ -398,6 +402,7 @@ def __init__(self, credentials, *, scopes=None,
398402
data before giving up, as a float, or a tuple (connect timeout, read timeout)
399403
:param JSONEncoder json_encoder: The JSONEncoder to use during the JSON serialization on the request.
400404
:param bool verify_ssl: set the verify flag on the requests library
405+
:param bool store_token_after_refresh: if after a token refresh the token backend should call save_token
401406
:param dict kwargs: any extra params passed to Connection
402407
:raises ValueError: if credentials is not tuple of
403408
(client_id, client_secret)
@@ -422,12 +427,13 @@ def __init__(self, credentials, *, scopes=None,
422427
self.password = password
423428
self.scopes = scopes
424429
self.default_headers = default_headers or dict()
425-
self.store_token = True
430+
self.store_token_after_refresh: bool = store_token_after_refresh
431+
426432
token_backend = token_backend or FileSystemTokenBackend(**kwargs)
427433
if not isinstance(token_backend, BaseTokenBackend):
428434
raise ValueError('"token_backend" must be an instance of a subclass of BaseTokenBackend')
429435
self.token_backend = token_backend
430-
self.session = None # requests Oauth2Session object
436+
self.session = None # requests Session object
431437

432438
self.proxy = {}
433439
self.set_proxy(proxy_server, proxy_port, proxy_username, proxy_password, proxy_http_only)
@@ -436,8 +442,8 @@ def __init__(self, credentials, *, scopes=None,
436442
self.raise_http_errors = raise_http_errors
437443
self.request_retries = request_retries
438444
self.timeout = timeout
439-
self.json_encoder = json_encoder
440445
self.verify_ssl = verify_ssl
446+
self.json_encoder = json_encoder
441447

442448
self.naive_session = None # lazy loaded: holds a requests Session object
443449

@@ -449,6 +455,14 @@ def __init__(self, credentials, *, scopes=None,
449455
'{}/oauth2/v2.0/token'.format(tenant_id)
450456
self.oauth_redirect_url = 'https://login.microsoftonline.com/common/oauth2/nativeclient'
451457

458+
# In the event of a response that returned 401 unauthorised this will flag between requests
459+
# that this 401 can be a token expired error. MsGraph is returning 401 when the access token
460+
# has expired. We can not distinguish between a real 401 or token expired 401. So in the event
461+
# of a 401 http error we will first try to refresh the token, set this flag to True and then
462+
# re-run the request. If the 401 goes away we will then set this flag to false. If it keeps the
463+
# 401 then we will raise the error.
464+
self._token_expired_flag = False
465+
452466
@property
453467
def auth_flow_type(self):
454468
return self._auth_flow_type
@@ -658,7 +672,7 @@ def get_naive_session(self):
658672

659673
return naive_session
660674

661-
def refresh_token(self):
675+
def refresh_token(self) -> bool:
662676
"""
663677
Refresh the OAuth authorization token.
664678
This will be called automatically when the access token
@@ -669,40 +683,42 @@ def refresh_token(self):
669683
if self.session is None:
670684
self.session = self.get_session(load_token=True)
671685

672-
token = self.token_backend.token
673-
if not token:
674-
raise RuntimeError('Token not found.')
675-
676-
if token.is_long_lived or self.auth_flow_type == 'credentials':
677-
log.debug('Refreshing token')
678-
if self.auth_flow_type == 'authorization':
679-
client_id, client_secret = self.auth
680-
self.token_backend.token = Token(
681-
self.session.refresh_token(
682-
self._oauth2_token_url,
683-
client_id=client_id,
684-
client_secret=client_secret,
685-
verify=self.verify_ssl)
686-
)
687-
elif self.auth_flow_type in ('public', 'password'):
688-
client_id = self.auth[0]
689-
self.token_backend.token = Token(
690-
self.session.refresh_token(
691-
self._oauth2_token_url,
692-
client_id=client_id,
693-
verify=self.verify_ssl)
686+
if self.token_backend.access_token is None:
687+
raise RuntimeError('Access Token not found.')
688+
689+
token_refreshed = False
690+
691+
if self.token_backend.token_is_long_lived or self.auth_flow_type == 'credentials':
692+
should_rt = self.token_backend.should_refresh_token(self)
693+
if should_rt is True:
694+
# The backend has checked that we can refresh the token
695+
log.debug('Refreshing access token')
696+
result = self.msal_client.acquire_token_silent_with_error(
697+
scopes=self.scopes,
698+
account=self.msal_client.get_accounts()[0]
694699
)
695-
elif self.auth_flow_type == 'credentials':
696-
if self.request_token(None, store_token=False) is False:
697-
log.error('Refresh for Client Credentials Grant Flow failed.')
698-
return False
699-
log.debug('New oauth token fetched by refresh method')
700+
token_refreshed = True
701+
if result is None:
702+
raise RuntimeError('There is no access token to refresh')
703+
elif 'error' in result:
704+
raise RuntimeError(f'Refresh token operation failed: {result["error"]}')
705+
elif 'access_token' in result:
706+
# refresh done, update authorization header
707+
self.session.headers.update({'Authorization': 'Bearer {}'.format(result['access_token'])})
708+
log.debug('New oauth token fetched by refresh method')
709+
elif should_rt is False:
710+
# the token was refreshed by another instance and updated into this instance,
711+
# so: update the session token and retry the request again
712+
self.session.headers.update({'Authorization': f'Bearer {self.token_backend.access_token["secret"]}'})
713+
else:
714+
# the refresh was performed by the token backend.
715+
pass
700716
else:
701717
log.error('You can not refresh an access token that has no "refresh_token" available.'
702718
'Include "offline_access" scope when authenticating to get a "refresh_token"')
703719
return False
704720

705-
if self.store_token:
721+
if token_refreshed and self.store_token_after_refresh:
706722
self.token_backend.save_token()
707723
return True
708724

@@ -753,8 +769,6 @@ def _internal_request(self, request_obj, url, method, **kwargs):
753769
if self.timeout is not None:
754770
kwargs['timeout'] = self.timeout
755771

756-
kwargs.setdefault("verify", self.verify_ssl)
757-
758772
request_done = False
759773
token_refreshed = False
760774

@@ -770,6 +784,12 @@ def _internal_request(self, request_obj, url, method, **kwargs):
770784
response.status_code, response.url))
771785
request_done = True
772786
return response
787+
except (ConnectionError, ProxyError, SSLError, Timeout) as e:
788+
# We couldn't connect to the target url, raise error
789+
log.debug('Connection Error calling: {}.{}'
790+
''.format(url, ('Using proxy: {}'.format(self.proxy)
791+
if self.proxy else '')))
792+
raise e # re-raise exception
773793
except TokenExpiredError as e:
774794
# Token has expired, try to refresh the token and try again on the next loop
775795
log.debug('Oauth Token is expired')
@@ -792,16 +812,17 @@ def _internal_request(self, request_obj, url, method, **kwargs):
792812
else:
793813
# the refresh was performed by the tokend backend.
794814
token_refreshed = True
795-
796-
except (ConnectionError, ProxyError, SSLError, Timeout) as e:
797-
# We couldn't connect to the target url, raise error
798-
log.debug('Connection Error calling: {}.{}'
799-
''.format(url, ('Using proxy: {}'.format(self.proxy)
800-
if self.proxy else '')))
801-
raise e # re-raise exception
802815
except HTTPError as e:
803816
# Server response with 4XX or 5XX error status codes
804817

818+
if e.response.status_code == 401 and self._token_expired_flag is False:
819+
# This could be a token expired error.
820+
if self.token_backend.token_is_expired():
821+
log.debug('Oauth Token is expired')
822+
# Token has expired, try to refresh the token and try again on the next loop
823+
if self.token_backend.token_is_long_lived is False and self.auth_flow_type == 'authorization':
824+
raise e
825+
805826
# try to extract the error message:
806827
try:
807828
error = response.json()
@@ -866,7 +887,12 @@ def oauth_request(self, url, method, **kwargs):
866887
if self.session is None:
867888
self.session = self.get_session(load_token=True)
868889

869-
return self._internal_request(self.session, url, method, **kwargs)
890+
try:
891+
return self._internal_request(self.session, url, method, **kwargs)
892+
except TokenExpiredError:
893+
# refresh and try again the request!
894+
self.refresh_token()
895+
return self._internal_request(self.session, url, method, **kwargs)
870896

871897
def get(self, url, params=None, **kwargs):
872898
""" Shorthand for self.oauth_request(url, 'get')

O365/utils/token.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ def __init__(self):
2121
self._has_state_changed = False
2222
self.cryptography_manager = None
2323

24-
@property
25-
def token_expiration_datetime(self):
24+
def token_expiration_datetime(self, refresh_token=False):
2625
"""
2726
Returns the current token expiration datetime
2827
If the refresh token is present, then the expiration datetime is extended by 3 months
28+
:param bool refresh_token: if true will check for the refresh token and return its expiration datetime
2929
:return dt.datetime or None: The expiration datetime
3030
"""
3131
access_token = self.access_token
@@ -40,19 +40,24 @@ def token_expiration_datetime(self):
4040
expires_on = int(expires_on)
4141

4242
expiration_datetime = dt.datetime.fromtimestamp(expires_on)
43-
if self.refresh_token is not None:
43+
if refresh_token is True and self.token_is_long_lived:
4444
# current token is long-lived, add 3 months to the token expiration date
4545
expiration_datetime = expiration_datetime + dt.timedelta(days=90)
4646

4747
return expiration_datetime
4848

49-
@property
50-
def token_is_expired(self):
49+
def token_is_expired(self, refresh_token=False):
5150
"""
5251
Checks whether the current token is expired
52+
:param bool refresh_token: if true will check for the refresh token and return its expiration datetime
5353
:return bool: True if the token is expired, False otherwise
5454
"""
55-
return dt.datetime.now() > self.token_expiration_datetime
55+
return dt.datetime.now() > self.token_expiration_datetime(refresh_token=refresh_token)
56+
57+
@property
58+
def token_is_long_lived(self):
59+
""" Returns if the token has a refresh token """
60+
return self.refresh_token is not None
5661

5762
def add(self, event, **kwargs):
5863
super().add(event, **kwargs)

0 commit comments

Comments
 (0)