diff --git a/requirements/base.txt b/requirements/base.txt index 38e289b9a..4a1af3f75 100755 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,3 +1,6 @@ boto3>=1.19.5,==1.* jsonschema<5,>=3.2 # TODO: evaluate risk of removing jsonschema 3.x support typing_extensions>=4.4,<5 # 3.7 doesn't have Literal + +# resource validation & schema generation +pydantic~=1.8 diff --git a/requirements/dev.txt b/requirements/dev.txt index bcdd33903..e652d392e 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -29,6 +29,3 @@ mypy~=1.0.0 boto3-stubs[appconfig,serverlessrepo]>=1.19.5,==1.* types-PyYAML~=5.4 types-jsonschema~=3.2 - -# schema generation, requiring features in >=1.10 -pydantic~=1.10 diff --git a/samtranslator/internal/schema_source/common.py b/samtranslator/internal/schema_source/common.py index b24a3c175..c86c6775a 100644 --- a/samtranslator/internal/schema_source/common.py +++ b/samtranslator/internal/schema_source/common.py @@ -56,6 +56,14 @@ class BaseModel(LenientBaseModel): class Config: extra = Extra.forbid + def __getattribute__(self, __name: str) -> Any: + """Overloading get attribute operation to allow access PassThroughProp without using __root__""" + attr_value = super().__getattribute__(__name) + if isinstance(attr_value, PassThroughProp): + # See docstring of PassThroughProp + return attr_value.__root__ + return attr_value + # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/intrinsic-function-reference-ref.html class Ref(BaseModel): diff --git a/samtranslator/model/__init__.py b/samtranslator/model/__init__.py index cdf8ebc56..6d4ad47ab 100644 --- a/samtranslator/model/__init__.py +++ b/samtranslator/model/__init__.py @@ -2,7 +2,11 @@ import inspect import re from abc import ABC, ABCMeta, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from contextlib import suppress +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union + +from pydantic import BaseModel +from pydantic.error_wrappers import ValidationError from samtranslator.intrinsics.resolver import IntrinsicsResolver from samtranslator.model.exceptions import ExpectedType, InvalidResourceException, InvalidResourcePropertyTypeException @@ -10,6 +14,8 @@ from samtranslator.model.types import IS_DICT, IS_STR, Validator, any_type, is_type from samtranslator.plugins import LifeCycleEvents +RT = TypeVar("RT", bound=BaseModel) # return type + class PropertyType: """Stores validation information for a CloudFormation resource property. @@ -312,6 +318,25 @@ def __setattr__(self, name, value): # type: ignore[no-untyped-def] ), ) + # Note: For compabitliy issue, we should ONLY use this with new abstraction/resources. + def validate_properties_and_return_model(self, cls: Type[RT]) -> RT: + """ + Given a resource properties, return a typed object from the definitions of SAM schema model + + param: + resource_properties: properties from input template + cls: schema models + """ + try: + return cls.parse_obj(self._generate_resource_dict()["Properties"]) + except ValidationError as e: + error_properties: str = "" + with suppress(KeyError): + error_properties = ", ".join([str(error["loc"][0]) for error in e.errors()]) + raise InvalidResourceException( + self.logical_id, f"Given resource property '{error_properties}' is invalid" + ) from e + def validate_properties(self) -> None: """Validates that the required properties for this Resource have been populated, and that all properties have valid values. diff --git a/tests/model/test_resource_validator.py b/tests/model/test_resource_validator.py new file mode 100644 index 000000000..580746833 --- /dev/null +++ b/tests/model/test_resource_validator.py @@ -0,0 +1,100 @@ +from unittest import TestCase + +from samtranslator.internal.schema_source.aws_serverless_connector import Properties as ConnectorProperties +from samtranslator.internal.schema_source.aws_serverless_function import Properties as FunctionProperties +from samtranslator.model.exceptions import InvalidResourceException +from samtranslator.model.sam_resources import ( + SamConnector, + SamFunction, +) + + +class TestResourceValidator(TestCase): + def setUp(self) -> None: + self.connector = SamConnector("foo") + self.connector.Source = { + "Arn": "random-arn", + "Type": "random-type", + } + self.connector.Destination = {"Id": "MyTable"} + self.connector.Permissions = ["Read"] + + self.function = SamFunction("function") + self.function.CodeUri = "s3://foobar/foo.zip" + self.function.Runtime = "foo" + self.function.Handler = "bar" + self.function.FunctionUrlConfig = {"Cors": {"AllowOrigins": ["example1.com"]}, "AuthType": "123"} + self.function.Events = { + "MyMqEvent": { + "Type": "MQ", + "Properties": { + "Broker": {"Fn::GetAtt": "MyMqBroker.Arn"}, + "Queues": ["TestQueue"], + "SourceAccessConfigurations": [{"Type": "BASIC_AUTH"}], + }, + } + } + + def test_connector_model(self): + connector_model = self.connector.validate_properties_and_return_model( + ConnectorProperties, + ) + self.assertEqual(connector_model.Source.Arn, "random-arn") + self.assertEqual(connector_model.Source.Type, "random-type") + self.assertEqual(connector_model.Source.Id, None) + self.assertEqual(connector_model.Destination.Id, "MyTable") + self.assertEqual(connector_model.Permissions, ["Read"]) + + def test_lambda_model(self): + model = self.function.validate_properties_and_return_model(FunctionProperties) + self.assertEqual(model.CodeUri, "s3://foobar/foo.zip") + self.assertEqual(model.Runtime, "foo") + self.assertEqual(model.Handler, "bar") + self.assertEqual(model.FunctionUrlConfig.Cors, {"AllowOrigins": ["example1.com"]}) + self.assertEqual(model.FunctionUrlConfig.AuthType, "123") + self.assertEqual(model.Events["MyMqEvent"].Type, "MQ") + self.assertEqual(model.Events["MyMqEvent"].Properties.Broker, {"Fn::GetAtt": "MyMqBroker.Arn"}) + self.assertEqual(model.Events["MyMqEvent"].Properties.Queues, ["TestQueue"]) + self.assertEqual(model.Events["MyMqEvent"].Properties.SourceAccessConfigurations, [{"Type": "BASIC_AUTH"}]) + + +class TestResourceValidatorFailure(TestCase): + def test_connector_with_empty_properties(self): + invalid_connector = SamConnector("foo") + with self.assertRaises( + InvalidResourceException, + ): + invalid_connector.validate_properties_and_return_model(ConnectorProperties) + self.assertRegex(".+Given resource property '(Source|Destination|Permissions)'.+ is invalid.") + + def test_connector_without_source(self): + invalid_connector = SamConnector("foo") + invalid_connector.Destination = {"Id": "MyTable"} + invalid_connector.Permissions = ["Read"] + with self.assertRaises( + InvalidResourceException, + ): + invalid_connector.validate_properties_and_return_model(ConnectorProperties) + self.assertRegex(".+Given resource property 'Source'.+ is invalid.") + + def test_connector_with_invalid_permission(self): + invalid_connector = SamConnector("foo") + invalid_connector.Source = {"Id": "MyTable"} + invalid_connector.Destination = {"Id": "MyTable"} + invalid_connector.Permissions = ["Invoke"] + with self.assertRaises( + InvalidResourceException, + ): + invalid_connector.validate_properties_and_return_model(ConnectorProperties) + self.assertRegex(".+Given resource property 'Permissions'.+ is invalid.") + + def test_connector_with_invalid_permission_type(self): + invalid_connector = SamConnector("foo") + invalid_connector.Source = {"Id": "MyTable"} + invalid_connector.Destination = {"Id": "MyTable"} + invalid_connector.Permissions = {"Hello": "World"} + with self.assertRaises( + InvalidResourceException, + ): + invalid_connector.validate_properties_and_return_model(ConnectorProperties) + self.assertRegex(".+Given resource property 'Permissions'.+ is invalid.")