Skip to content

Commit 3e6d059

Browse files
authored
Merge pull request #45 from djpugh/feature/inheritable-classes
2 parents 67d480e + 0196f0d commit 3e6d059

File tree

6 files changed

+122
-8
lines changed

6 files changed

+122
-8
lines changed

src/fastapi_aad_auth/_base/provider.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from typing import List, Optional
22

3+
from pydantic import PrivateAttr
34
from starlette.requests import Request
45
from starlette.responses import RedirectResponse
56
from starlette.routing import Route
67

78
from fastapi_aad_auth._base.authenticators import SessionAuthenticator
89
from fastapi_aad_auth._base.validators import Validator
910
from fastapi_aad_auth.mixins import LoggingMixin
10-
from fastapi_aad_auth.utilities import urls
11+
from fastapi_aad_auth.utilities import InheritableBaseSettings, urls
1112

1213

1314
class Provider(LoggingMixin):
@@ -71,3 +72,9 @@ def redirect_url(self):
7172
if self._redirect_url is None:
7273
self._redirect_url = self._build_oauth_url(self.oauth_base_route, 'redirect')
7374
return self._redirect_url
75+
76+
77+
class ProviderConfig(InheritableBaseSettings):
78+
"""Configuration for a provider."""
79+
80+
_provider_klass: type = PrivateAttr(Provider)

src/fastapi_aad_auth/_base/state.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66

77
from itsdangerous import URLSafeSerializer
88
from itsdangerous.exc import BadSignature
9-
from pydantic import BaseModel, Field, root_validator, validator
9+
from pydantic import Field, root_validator, validator
1010
from starlette.authentication import AuthCredentials, SimpleUser, UnauthenticatedUser
1111

1212
from fastapi_aad_auth.errors import AuthenticationError
1313
from fastapi_aad_auth.mixins import LoggingMixin
14+
from fastapi_aad_auth.utilities import InheritableBaseModel, InheritablePropertyBaseModel
1415

1516

1617
SESSION_STORE_KEY = 'auth'
@@ -23,7 +24,7 @@ class AuthenticationOptions(Enum):
2324
authenticated = 1
2425

2526

26-
class User(BaseModel):
27+
class User(InheritablePropertyBaseModel):
2728
"""User Model."""
2829
name: str = Field(..., description='Full name')
2930
email: str = Field(..., description='User email')
@@ -49,7 +50,7 @@ def _validate_scopes(cls, value):
4950
return value
5051

5152

52-
class AuthenticationState(LoggingMixin, BaseModel):
53+
class AuthenticationState(LoggingMixin, InheritableBaseModel):
5354
"""Authentication State."""
5455
session_state: str = str(uuid.uuid4())
5556
state: AuthenticationOptions = AuthenticationOptions.unauthenticated

src/fastapi_aad_auth/config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
"""fastapi_aad_auth configuration options."""
2-
from typing import Dict, List, Optional, Union
2+
from typing import Dict, List, Optional
33
import uuid
44

55
from pkg_resources import resource_filename
66
from pydantic import BaseSettings as _BaseSettings, DirectoryPath, Field, FilePath, SecretStr, validator
77

8+
from fastapi_aad_auth._base.provider import ProviderConfig
89
from fastapi_aad_auth.providers.aad import AADConfig
910
from fastapi_aad_auth.utilities import bool_from_env, DeprecatableFieldsMixin, DeprecatedField, expand_doc, klass_from_str
1011

@@ -138,7 +139,7 @@ class Config(BaseSettings):
138139
"""
139140

140141
enabled: bool = Field(True, description="Enable authentication", env='FASTAPI_AUTH_ENABLED')
141-
providers: List[Union[AADConfig]] = Field(None, description="The provider configurations to use")
142+
providers: List[ProviderConfig] = Field(None, description="The provider configurations to use")
142143
aad: Optional[AADConfig] = DeprecatedField(None, description='AAD Configuration information', deprecated_in='0.2.0', replaced_by='Config.providers')
143144
auth_session: AuthSessionConfig = Field(None, description="The configuration for encoding the authentication information in the session")
144145
routing: RoutingConfig = Field(None, description="Configuration for routing")
@@ -158,6 +159,7 @@ def _validate_providers(cls, value, values):
158159
value = []
159160
if enabled:
160161
value.append(AADConfig(_env_file=cls.Config.env_file))
162+
161163
return value
162164

163165
@validator('aad', always=True, pre=True)

src/fastapi_aad_auth/providers/aad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from starlette.requests import Request
1616

1717
from fastapi_aad_auth._base.authenticators import SessionAuthenticator
18-
from fastapi_aad_auth._base.provider import Provider
18+
from fastapi_aad_auth._base.provider import Provider, ProviderConfig
1919
from fastapi_aad_auth._base.state import User
2020
from fastapi_aad_auth._base.validators import SessionValidator, TokenValidator
2121
from fastapi_aad_auth.errors import ConfigurationError
@@ -340,7 +340,7 @@ def get_login_button(self, post_redirect='/'):
340340

341341

342342
@expand_doc
343-
class AADConfig(BaseSettings):
343+
class AADConfig(ProviderConfig):
344344
"""Configuration for the AAD application.
345345
346346
Includes expected claims, application registration, etc.

src/fastapi_aad_auth/utilities/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from fastapi_aad_auth.utilities import logging # noqa: F401
99
from fastapi_aad_auth.utilities import urls # noqa: F401
10+
from fastapi_aad_auth.utilities.basemodel import InheritableBaseModel, InheritableBaseSettings, InheritablePropertyBaseModel, InheritablePropertyBaseSettings, PropertyBaseModel, PropertyBaseSettings # noqa: F401
1011
from fastapi_aad_auth.utilities.deprecate import DeprecatableFieldsMixin, deprecate, deprecate_module, DeprecatedField, is_deprecated # noqa: F401
1112

1213

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""Provide inheritable property dictable basemodel.
2+
3+
Implements pydantic work arounds for:
4+
* https://github.com/samuelcolvin/pydantic/issues/265
5+
* https://github.com/samuelcolvin/pydantic/issues/935
6+
7+
"""
8+
from functools import wraps
9+
10+
from pydantic import BaseModel, BaseSettings
11+
from pydantic.validators import dict_validator
12+
13+
14+
class InheritableMixin:
15+
"""BaseModel that will Validate with inheritance rather than the original Class."""
16+
17+
@classmethod
18+
def get_validators(cls):
19+
"""Get the validator for the object."""
20+
yield cls.validate
21+
22+
@classmethod
23+
def validate(cls, value):
24+
"""Validate the class as itself."""
25+
if isinstance(value, cls):
26+
return value
27+
else:
28+
return cls(**dict_validator(value))
29+
30+
31+
class PropertyMixin:
32+
"""BaseModel with Properties in dict.
33+
34+
A Pydantic BaseModel that includes properties in it's dict() result
35+
enabling a mix of both fields and properties
36+
"""
37+
38+
@classmethod
39+
def get_properties(cls):
40+
"""Get the properties."""
41+
return [prop for prop in dir(cls) if cls._is_property(prop)]
42+
43+
@classmethod
44+
def _is_property(cls, prop):
45+
return isinstance(getattr(cls, prop), property) \
46+
and prop not in ("__values__", "fields")
47+
48+
@wraps(BaseModel.dict)
49+
def dict(self,
50+
*,
51+
include=None,
52+
exclude=None,
53+
by_alias: bool = False,
54+
skip_defaults: bool = None,
55+
exclude_unset: bool = False,
56+
exclude_defaults: bool = False,
57+
exclude_none: bool = False,):
58+
"""Return the object as a dictionary."""
59+
attribs = super().dict( # type: ignore
60+
include=include,
61+
exclude=exclude,
62+
by_alias=by_alias,
63+
skip_defaults=skip_defaults,
64+
exclude_unset=exclude_unset,
65+
exclude_defaults=exclude_defaults,
66+
exclude_none=exclude_none
67+
)
68+
props = self.get_properties()
69+
# Include and exclude properties
70+
if include:
71+
props = [prop for prop in props if prop in include]
72+
if exclude:
73+
props = [prop for prop in props if prop not in exclude]
74+
75+
# Update the attribute dict with the properties
76+
if props:
77+
attribs.update({prop: getattr(self, prop) for prop in props})
78+
79+
return attribs
80+
81+
82+
class InheritableBaseSettings(InheritableMixin, BaseSettings):
83+
"""A Pydantic BaseSettings that allows inheritance."""
84+
85+
86+
class PropertyBaseSettings(PropertyMixin, BaseSettings):
87+
"""A Pydantic BaseSettings that allows roperties in the dict."""
88+
89+
90+
class InheritablePropertyBaseSettings(InheritableMixin, PropertyBaseSettings):
91+
"""A Pydantic BaseSettings that allows inheritance and properties in the dict."""
92+
93+
94+
class InheritableBaseModel(InheritableMixin, BaseModel):
95+
"""A Pydantic BaseModel that allows inheritance."""
96+
97+
98+
class PropertyBaseModel(PropertyMixin, BaseModel):
99+
"""A Pydantic BaseModel that allows roperties in the dict."""
100+
101+
102+
class InheritablePropertyBaseModel(InheritableMixin, PropertyBaseModel):
103+
"""A Pydantic BaseModel that allows inheritance and properties in the dict."""

0 commit comments

Comments
 (0)