6
6
from urllib .parse import urlparse , parse_qs
7
7
8
8
from msal import ConfidentialClientApplication , PublicClientApplication
9
- from oauthlib .oauth2 import TokenExpiredError , WebApplicationClient , BackendApplicationClient , LegacyApplicationClient
10
9
from requests import Session
11
10
from requests .adapters import HTTPAdapter
12
11
from requests .exceptions import HTTPError , RequestException , ProxyError
62
61
}
63
62
64
63
64
+ class TokenExpiredError (HTTPError ):
65
+ pass
66
+
67
+
65
68
class Protocol :
66
69
""" Base class for all protocols """
67
70
@@ -359,6 +362,7 @@ def __init__(self, credentials, *, scopes=None,
359
362
timeout = None , json_encoder = None ,
360
363
verify_ssl = True ,
361
364
default_headers : dict = None ,
365
+ store_token_after_refresh : bool = True ,
362
366
** kwargs ):
363
367
""" Creates an API connection object
364
368
@@ -398,6 +402,7 @@ def __init__(self, credentials, *, scopes=None,
398
402
data before giving up, as a float, or a tuple (connect timeout, read timeout)
399
403
:param JSONEncoder json_encoder: The JSONEncoder to use during the JSON serialization on the request.
400
404
:param bool verify_ssl: set the verify flag on the requests library
405
+ :param bool store_token_after_refresh: if after a token refresh the token backend should call save_token
401
406
:param dict kwargs: any extra params passed to Connection
402
407
:raises ValueError: if credentials is not tuple of
403
408
(client_id, client_secret)
@@ -422,12 +427,13 @@ def __init__(self, credentials, *, scopes=None,
422
427
self .password = password
423
428
self .scopes = scopes
424
429
self .default_headers = default_headers or dict ()
425
- self .store_token = True
430
+ self .store_token_after_refresh : bool = store_token_after_refresh
431
+
426
432
token_backend = token_backend or FileSystemTokenBackend (** kwargs )
427
433
if not isinstance (token_backend , BaseTokenBackend ):
428
434
raise ValueError ('"token_backend" must be an instance of a subclass of BaseTokenBackend' )
429
435
self .token_backend = token_backend
430
- self .session = None # requests Oauth2Session object
436
+ self .session = None # requests Session object
431
437
432
438
self .proxy = {}
433
439
self .set_proxy (proxy_server , proxy_port , proxy_username , proxy_password , proxy_http_only )
@@ -436,8 +442,8 @@ def __init__(self, credentials, *, scopes=None,
436
442
self .raise_http_errors = raise_http_errors
437
443
self .request_retries = request_retries
438
444
self .timeout = timeout
439
- self .json_encoder = json_encoder
440
445
self .verify_ssl = verify_ssl
446
+ self .json_encoder = json_encoder
441
447
442
448
self .naive_session = None # lazy loaded: holds a requests Session object
443
449
@@ -449,6 +455,14 @@ def __init__(self, credentials, *, scopes=None,
449
455
'{}/oauth2/v2.0/token' .format (tenant_id )
450
456
self .oauth_redirect_url = 'https://login.microsoftonline.com/common/oauth2/nativeclient'
451
457
458
+ # In the event of a response that returned 401 unauthorised this will flag between requests
459
+ # that this 401 can be a token expired error. MsGraph is returning 401 when the access token
460
+ # has expired. We can not distinguish between a real 401 or token expired 401. So in the event
461
+ # of a 401 http error we will first try to refresh the token, set this flag to True and then
462
+ # re-run the request. If the 401 goes away we will then set this flag to false. If it keeps the
463
+ # 401 then we will raise the error.
464
+ self ._token_expired_flag = False
465
+
452
466
@property
453
467
def auth_flow_type (self ):
454
468
return self ._auth_flow_type
@@ -658,7 +672,7 @@ def get_naive_session(self):
658
672
659
673
return naive_session
660
674
661
- def refresh_token (self ):
675
+ def refresh_token (self ) -> bool :
662
676
"""
663
677
Refresh the OAuth authorization token.
664
678
This will be called automatically when the access token
@@ -669,40 +683,42 @@ def refresh_token(self):
669
683
if self .session is None :
670
684
self .session = self .get_session (load_token = True )
671
685
672
- token = self .token_backend .token
673
- if not token :
674
- raise RuntimeError ('Token not found.' )
675
-
676
- if token .is_long_lived or self .auth_flow_type == 'credentials' :
677
- log .debug ('Refreshing token' )
678
- if self .auth_flow_type == 'authorization' :
679
- client_id , client_secret = self .auth
680
- self .token_backend .token = Token (
681
- self .session .refresh_token (
682
- self ._oauth2_token_url ,
683
- client_id = client_id ,
684
- client_secret = client_secret ,
685
- verify = self .verify_ssl )
686
- )
687
- elif self .auth_flow_type in ('public' , 'password' ):
688
- client_id = self .auth [0 ]
689
- self .token_backend .token = Token (
690
- self .session .refresh_token (
691
- self ._oauth2_token_url ,
692
- client_id = client_id ,
693
- verify = self .verify_ssl )
686
+ if self .token_backend .access_token is None :
687
+ raise RuntimeError ('Access Token not found.' )
688
+
689
+ token_refreshed = False
690
+
691
+ if self .token_backend .token_is_long_lived or self .auth_flow_type == 'credentials' :
692
+ should_rt = self .token_backend .should_refresh_token (self )
693
+ if should_rt is True :
694
+ # The backend has checked that we can refresh the token
695
+ log .debug ('Refreshing access token' )
696
+ result = self .msal_client .acquire_token_silent_with_error (
697
+ scopes = self .scopes ,
698
+ account = self .msal_client .get_accounts ()[0 ]
694
699
)
695
- elif self .auth_flow_type == 'credentials' :
696
- if self .request_token (None , store_token = False ) is False :
697
- log .error ('Refresh for Client Credentials Grant Flow failed.' )
698
- return False
699
- log .debug ('New oauth token fetched by refresh method' )
700
+ token_refreshed = True
701
+ if result is None :
702
+ raise RuntimeError ('There is no access token to refresh' )
703
+ elif 'error' in result :
704
+ raise RuntimeError (f'Refresh token operation failed: { result ["error" ]} ' )
705
+ elif 'access_token' in result :
706
+ # refresh done, update authorization header
707
+ self .session .headers .update ({'Authorization' : 'Bearer {}' .format (result ['access_token' ])})
708
+ log .debug ('New oauth token fetched by refresh method' )
709
+ elif should_rt is False :
710
+ # the token was refreshed by another instance and updated into this instance,
711
+ # so: update the session token and retry the request again
712
+ self .session .headers .update ({'Authorization' : f'Bearer { self .token_backend .access_token ["secret" ]} ' })
713
+ else :
714
+ # the refresh was performed by the token backend.
715
+ pass
700
716
else :
701
717
log .error ('You can not refresh an access token that has no "refresh_token" available.'
702
718
'Include "offline_access" scope when authenticating to get a "refresh_token"' )
703
719
return False
704
720
705
- if self .store_token :
721
+ if token_refreshed and self .store_token_after_refresh :
706
722
self .token_backend .save_token ()
707
723
return True
708
724
@@ -753,8 +769,6 @@ def _internal_request(self, request_obj, url, method, **kwargs):
753
769
if self .timeout is not None :
754
770
kwargs ['timeout' ] = self .timeout
755
771
756
- kwargs .setdefault ("verify" , self .verify_ssl )
757
-
758
772
request_done = False
759
773
token_refreshed = False
760
774
@@ -770,6 +784,12 @@ def _internal_request(self, request_obj, url, method, **kwargs):
770
784
response .status_code , response .url ))
771
785
request_done = True
772
786
return response
787
+ except (ConnectionError , ProxyError , SSLError , Timeout ) as e :
788
+ # We couldn't connect to the target url, raise error
789
+ log .debug ('Connection Error calling: {}.{}'
790
+ '' .format (url , ('Using proxy: {}' .format (self .proxy )
791
+ if self .proxy else '' )))
792
+ raise e # re-raise exception
773
793
except TokenExpiredError as e :
774
794
# Token has expired, try to refresh the token and try again on the next loop
775
795
log .debug ('Oauth Token is expired' )
@@ -792,16 +812,17 @@ def _internal_request(self, request_obj, url, method, **kwargs):
792
812
else :
793
813
# the refresh was performed by the tokend backend.
794
814
token_refreshed = True
795
-
796
- except (ConnectionError , ProxyError , SSLError , Timeout ) as e :
797
- # We couldn't connect to the target url, raise error
798
- log .debug ('Connection Error calling: {}.{}'
799
- '' .format (url , ('Using proxy: {}' .format (self .proxy )
800
- if self .proxy else '' )))
801
- raise e # re-raise exception
802
815
except HTTPError as e :
803
816
# Server response with 4XX or 5XX error status codes
804
817
818
+ if e .response .status_code == 401 and self ._token_expired_flag is False :
819
+ # This could be a token expired error.
820
+ if self .token_backend .token_is_expired ():
821
+ log .debug ('Oauth Token is expired' )
822
+ # Token has expired, try to refresh the token and try again on the next loop
823
+ if self .token_backend .token_is_long_lived is False and self .auth_flow_type == 'authorization' :
824
+ raise e
825
+
805
826
# try to extract the error message:
806
827
try :
807
828
error = response .json ()
@@ -866,7 +887,12 @@ def oauth_request(self, url, method, **kwargs):
866
887
if self .session is None :
867
888
self .session = self .get_session (load_token = True )
868
889
869
- return self ._internal_request (self .session , url , method , ** kwargs )
890
+ try :
891
+ return self ._internal_request (self .session , url , method , ** kwargs )
892
+ except TokenExpiredError :
893
+ # refresh and try again the request!
894
+ self .refresh_token ()
895
+ return self ._internal_request (self .session , url , method , ** kwargs )
870
896
871
897
def get (self , url , params = None , ** kwargs ):
872
898
""" Shorthand for self.oauth_request(url, 'get')
0 commit comments