Skip to content

Commit 5b11ced

Browse files
authored
Merge pull request #47 from djpugh/fix/user-klass-reload
2 parents 3e6d059 + 7d8256f commit 5b11ced

File tree

2 files changed

+86
-1
lines changed

2 files changed

+86
-1
lines changed

src/fastapi_aad_auth/_base/state.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Authentication State Handler."""
22
from enum import Enum
3+
import importlib
34
import json
45
from typing import List, Optional
56
import uuid
@@ -43,6 +44,11 @@ def permissions(self):
4344
permissions.append(scope)
4445
return permissions[:]
4546

47+
@property
48+
def klass(self):
49+
"""Return the user klass information for loading from a session."""
50+
return f'{self.__class__.__module__}:{self.__class__.__name__}'
51+
4652
@validator('scopes', always=True, pre=True)
4753
def _validate_scopes(cls, value):
4854
if isinstance(value, str):
@@ -52,14 +58,27 @@ def _validate_scopes(cls, value):
5258

5359
class AuthenticationState(LoggingMixin, InheritableBaseModel):
5460
"""Authentication State."""
61+
_logger = None
5562
session_state: str = str(uuid.uuid4())
5663
state: AuthenticationOptions = AuthenticationOptions.unauthenticated
5764
user: Optional[User] = None
58-
_logger = None
5965

6066
class Config: # noqa: D106
6167
underscore_attrs_are_private = True
6268

69+
@validator('user', always=True, pre=True)
70+
def _validate_user_klass(cls, value):
71+
if isinstance(value, dict):
72+
klass = value.get('klass', None)
73+
if klass:
74+
module, name = klass.split(':')
75+
mod = importlib.import_module(module)
76+
klass = getattr(mod, name)
77+
else:
78+
klass = User
79+
value = klass(**value)
80+
return value
81+
6382
@root_validator(pre=True)
6483
def _validate_user(cls, values):
6584
if values.get('user', None) is None:

tests/unit/test_auth_state.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from typing import List
2+
import unittest
3+
import uuid
4+
5+
from fastapi_aad_auth._base.state import AuthenticationState, User, AuthenticationOptions
6+
from fastapi_aad_auth._base.validators import SessionValidator
7+
8+
9+
class User2(User):
10+
b: int = 2
11+
12+
13+
class User3(User2):
14+
15+
@property
16+
def permissions(self):
17+
return [self.name, 'a']
18+
19+
20+
class AuthenticationStateTestCase(unittest.TestCase):
21+
22+
def setUp(self):
23+
self.serializer = SessionValidator.get_session_serializer(str(uuid.uuid4()), str(uuid.uuid4()))
24+
25+
def test_create(self):
26+
user = User(name='Joe Bloggs', email='[email protected]', username='[email protected]')
27+
state = AuthenticationState(user=user, state=AuthenticationOptions.authenticated)
28+
self.assertIsInstance(state.user, User)
29+
self.assertEqual(state.user.name, user.name)
30+
31+
def test_create_custom_user(self):
32+
user = User2(name='Joe Bloggs', email='[email protected]', username='[email protected]')
33+
state = AuthenticationState(user=user, state=AuthenticationOptions.authenticated)
34+
self.assertIsInstance(state.user, User2)
35+
self.assertEqual(state.user.name, user.name)
36+
self.assertEqual(state.user.b, user.b)
37+
self.assertEqual(state.user.b, 2)
38+
39+
def test_load(self):
40+
user = User(name='Joe Bloggs', email='[email protected]', username='[email protected]')
41+
state = AuthenticationState(user=user, state=AuthenticationOptions.authenticated)
42+
loaded_state = AuthenticationState.load(self.serializer, state.store(self.serializer))
43+
self.assertIsInstance(state.user, User)
44+
self.assertEqual(state.user.name, user.name)
45+
46+
def test_load_custom_user(self):
47+
user = User2(name='Joe Bloggs', email='[email protected]', username='[email protected]')
48+
state = AuthenticationState(user=user, state=AuthenticationOptions.authenticated)
49+
loaded_state = AuthenticationState.load(self.serializer, state.store(self.serializer))
50+
self.assertIsInstance(state.user, User2)
51+
self.assertEqual(state.user.name, user.name)
52+
self.assertEqual(state.user.b, user.b)
53+
self.assertEqual(state.user.b, 2)
54+
55+
def test_load_custom_user_permissions(self):
56+
user = User3(name='Joe Bloggs', email='[email protected]', username='[email protected]', b=4)
57+
state = AuthenticationState(user=user, state=AuthenticationOptions.authenticated)
58+
loaded_state = AuthenticationState.load(self.serializer, state.store(self.serializer))
59+
self.assertIsInstance(state.user, User3)
60+
self.assertEqual(state.user.name, user.name)
61+
self.assertEqual(state.user.b, user.b)
62+
self.assertEqual(state.user.b, 4)
63+
self.assertEqual(state.user.permissions, ['Joe Bloggs', 'a'])
64+
65+
66+

0 commit comments

Comments
 (0)