Skip to content

Commit ff85aa5

Browse files
committed
Refactor code
1 parent c563df7 commit ff85aa5

File tree

4 files changed

+31
-59
lines changed

4 files changed

+31
-59
lines changed

requirements/base.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@ boto3>=1.19.5,==1.*
22
jsonschema<5,>=3.2 # TODO: evaluate risk of removing jsonschema 3.x support
33
typing_extensions>=4.4,<5 # 3.7 doesn't have Literal
44

5-
# resource validation & schema generation, requiring features in >=1.10
6-
pydantic~=1.10
5+
# resource validation & schema generation
6+
pydantic~=1.8.0

samtranslator/internal/schema_source/common.py

+9
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,15 @@ class BaseModel(LenientBaseModel):
5656
class Config:
5757
extra = Extra.forbid
5858

59+
def __getattribute__(self, __name: str) -> Any:
60+
"""Overloading get attribute operation"""
61+
attr_value = super().__getattribute__(__name)
62+
if isinstance(attr_value, PassThroughProp):
63+
# Access __root__ attribute to get actual value from PassThroughProp
64+
# See https://github.com/aws/serverless-application-model/blob/develop/samtranslator/internal/schema_source/common.py#L19
65+
return attr_value.__root__
66+
return attr_value
67+
5968

6069
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/intrinsic-function-reference-ref.html
6170
class Ref(BaseModel):
+4-33
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,18 @@
11
"""A resource validator to help validate resource properties and raise exception when some value is unexpected."""
2-
from typing import Any, Dict, Type
2+
from typing import Any, Dict, Type, TypeVar
33

44
from pydantic import BaseModel
55

6-
7-
class Model:
8-
"""
9-
Wrapper class around a SAM schema BaseModel to with a new functional "get" method
10-
"""
11-
12-
def __init__(self, model: BaseModel) -> None:
13-
self.model = model
14-
15-
def _process_attr_value(self, attr_value: Any) -> Any:
16-
if isinstance(attr_value, BaseModel):
17-
if "__root__" in attr_value.__dict__:
18-
return attr_value.__dict__["__root__"]
19-
return Model(attr_value)
20-
return attr_value
21-
22-
def _get_item(self, attr_value: Any) -> Any:
23-
if isinstance(attr_value, list):
24-
return [self._process_attr_value(attr) for attr in attr_value]
25-
return self._process_attr_value(attr_value)
26-
27-
def get(self, attr_key: str, default_value: Any = None) -> Any:
28-
"""Return the value for key if key is in Model properties else default."""
29-
attr_value = self.model.__dict__.get(attr_key, default_value)
30-
return self._get_item(attr_value)
31-
32-
def __getitem__(self, attr_key: str) -> Any:
33-
"""Return the value for key if key is in Model properties else raise KeyError exception."""
34-
attr_value = self.model.__dict__[attr_key]
35-
return self._get_item(attr_value)
6+
T = TypeVar("T", bound=BaseModel)
367

378

389
# Note: For compabitliy issue, we should ONLY use this with new abstraction/resources.
39-
def to_model(resource_properties: Dict[Any, Any], cls: Type[BaseModel]) -> Model:
10+
def to_model(resource_properties: Dict[Any, Any], cls: Type[T]) -> T:
4011
"""
4112
Given properties of a SAM resource return a typed object from the definitions of SAM schema model
4213
4314
param:
4415
resource_properties: properties from input template
4516
cls: SAM schema models
4617
"""
47-
return Model(cls(**resource_properties))
18+
return cls.parse_obj(resource_properties)

tests/validator/test_resource_validator.py

+16-24
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
from pydantic import BaseModel
66
from pydantic.error_wrappers import ValidationError
7+
from samtranslator.internal.schema_source.aws_serverless_connector import Properties as ConnectorProperties
78
from samtranslator.validator.resource_validator import to_model
89
from samtranslator.yaml_helper import yaml_parse
9-
from schema_source.aws_serverless_connector import Properties as ConnectorProperties
1010

1111
BASE_PATH = os.path.dirname(__file__)
1212
CONNECTOR_INPUT_FOLDER = os.path.join(BASE_PATH, "input", "connector")
@@ -65,32 +65,24 @@ def test_connector_model_get_operation(self):
6565
self.connector_template,
6666
ConnectorProperties,
6767
)
68-
self.assertEqual(connector_model.get("Source").get("Arn"), "random-arn")
69-
self.assertEqual(connector_model.get("Source").get("Type"), "random-type")
70-
self.assertEqual(connector_model.get("Source").get("Id"), None)
71-
self.assertEqual(connector_model.get("Destination").get("Id"), "MyTable")
72-
self.assertEqual(connector_model.get("Permissions"), ["Read"])
73-
self.assertEqual(connector_model.get("FakeProperty"), None)
74-
75-
self.assertEqual(connector_model["Source"]["Arn"], "random-arn")
76-
self.assertEqual(connector_model["Source"]["Type"], "random-type")
77-
self.assertEqual(connector_model["Destination"]["Id"], "MyTable")
78-
self.assertEqual(connector_model["Permissions"], ["Read"])
68+
self.assertEqual(connector_model.Source.Arn, "random-arn")
69+
self.assertEqual(connector_model.Source.Type, "random-type")
70+
self.assertEqual(connector_model.Source.Id, None)
71+
self.assertEqual(connector_model.Destination.Id, "MyTable")
72+
self.assertEqual(connector_model.Permissions, ["Read"])
7973

8074
def test_model_get_operation(self):
8175
model = to_model(self.model_template, ValidatiorBaseModel)
82-
self.assertEqual(model["Properties"]["Key"], {"A": {"value": 10}})
83-
self.assertEqual(model["Properties"]["Key"]["A"], {"value": 10})
84-
self.assertEqual(model["Properties"]["Hello"], ["1", "2", "3"])
85-
self.assertEqual(model.get("Properties").get("Random")["value"], 5)
86-
self.assertEqual(model.get("DoNotExist"), None)
87-
88-
self.assertEqual(len(model["Contents"]), 3)
89-
self.assertEqual(model["Contents"][0]["Content"]["Tags"]["A"], "hello")
90-
self.assertEqual(model["Contents"][0]["Content"]["Tags"]["B"], 5)
91-
self.assertEqual(model["Contents"][1]["Content"]["Tags"]["A"], "wow")
92-
self.assertEqual(model["Contents"][0]["Content"]["Tags"].get("C"), None)
93-
self.assertEqual(model["Contents"][2]["Content"]["Tags"]["B"], -5)
76+
self.assertEqual(model.Properties.Key, {"A": {"value": 10}})
77+
self.assertEqual(model.Properties.Key["A"], {"value": 10})
78+
self.assertEqual(model.Properties.Hello, ["1", "2", "3"])
79+
self.assertEqual(model.Properties.Random.value, 5)
80+
81+
self.assertEqual(len(model.Contents), 3)
82+
self.assertEqual(model.Contents[0].Content.Tags.A, "hello")
83+
self.assertEqual(model.Contents[0].Content.Tags.B, 5)
84+
self.assertEqual(model.Contents[1].Content.Tags.A, "wow")
85+
self.assertEqual(model.Contents[2].Content.Tags.B, -5)
9486

9587

9688
class TestModelValidatorFailure(TestCase):

0 commit comments

Comments
 (0)