From 5a54700a6c51deb810187783851cce7bdfbbf577 Mon Sep 17 00:00:00 2001 From: Kevin DeJong Date: Tue, 24 Sep 2024 09:33:22 -0700 Subject: [PATCH] Add equal vars for Rules to cnf building --- src/cfnlint/conditions/_equals.py | 11 +-- src/cfnlint/conditions/conditions.py | 13 ++- test/unit/module/conditions/test_rules.py | 111 +++++++++++++++++++--- 3 files changed, 114 insertions(+), 21 deletions(-) diff --git a/src/cfnlint/conditions/_equals.py b/src/cfnlint/conditions/_equals.py index 06872c4223..401e48f17e 100644 --- a/src/cfnlint/conditions/_equals.py +++ b/src/cfnlint/conditions/_equals.py @@ -10,7 +10,7 @@ from typing import Any, Mapping, Tuple from sympy import Symbol -from sympy.logic.boolalg import BooleanFalse, BooleanFunction, BooleanTrue +from sympy.logic.boolalg import BooleanFalse, BooleanTrue from cfnlint.conditions._utils import get_hash from cfnlint.helpers import is_function @@ -145,7 +145,9 @@ def left(self): def right(self): return self._right - def build_cnf(self, params: dict[str, Symbol]) -> BooleanFunction: + def build_cnf( + self, params: dict[str, Symbol] + ) -> BooleanTrue | BooleanFalse | Symbol: """Build a SymPy CNF solver based on the provided params Args: params dict[str, Symbol]: params is a dict that represents @@ -158,10 +160,7 @@ def build_cnf(self, params: dict[str, Symbol]) -> BooleanFunction: return BooleanTrue() return BooleanFalse() - if self.hash in params: - return params.get(self.hash) - - return Symbol(self.hash) + return params.get(self.hash, Symbol(self.hash)) def test(self, scenarios: Mapping[str, str]) -> bool: """Do an equals based on the provided scenario""" diff --git a/src/cfnlint/conditions/conditions.py b/src/cfnlint/conditions/conditions.py index 235eac7236..6d612ad5cc 100644 --- a/src/cfnlint/conditions/conditions.py +++ b/src/cfnlint/conditions/conditions.py @@ -110,13 +110,12 @@ def _build_cnf( cnf = EncodedCNF() # build parameters and equals into solver - equal_vars: dict[str, Symbol] = {} + equal_vars: dict[str, Symbol | BooleanFalse | BooleanTrue] = {} equals: dict[str, Equal] = {} - for condition_name in condition_names: - c_equals = self._conditions[condition_name].equals + + def _build_equal_vars(c_equals: list[Equal]): for c_equal in c_equals: - # check to see if equals already matches another one if c_equal.hash in equal_vars: continue @@ -139,6 +138,12 @@ def _build_cnf( ) equals[c_equal.hash] = c_equal + for rule in self._rules: + _build_equal_vars(rule.equals) + + for condition_name in condition_names: + _build_equal_vars(self._conditions[condition_name].equals) + # Determine if a set of conditions can never be all false allowed_values = self._parameters.copy() if allowed_values: diff --git a/test/unit/module/conditions/test_rules.py b/test/unit/module/conditions/test_rules.py index f9871bc44c..49f09be655 100644 --- a/test/unit/module/conditions/test_rules.py +++ b/test/unit/module/conditions/test_rules.py @@ -23,14 +23,14 @@ def test_conditions_with_rules(self): Assertions: - Assert: Fn::And: - - !Condition IsProd - - !Condition IsUsEast1 + - !Equals [!Ref Environment, "prod"] + - !Equals [!Ref "AWS::Region", "us-east-1"] Rule2: Assertions: - Assert: Fn::Or: - - !Condition IsProd - - !Condition IsUsEast1 + - !Equals [!Ref Environment, "prod"] + - !Equals [!Ref "AWS::Region", "us-east-1"] """ )[0] @@ -79,9 +79,9 @@ def test_conditions_with_rules_implies(self): IsUsEast1: !Equals [!Ref "AWS::Region", "us-east-1"] Rules: Rule: - RuleCondition: !Condition IsProd + RuleCondition: !Equals [!Ref Environment, "prod"] Assertions: - - Assert: !Condition IsUsEast1 + - Assert: !Equals [!Ref "AWS::Region", "us-east-1"] """ )[0] @@ -143,11 +143,11 @@ def test_conditions_with_multiple_rules(self): Rule1: RuleCondition: !Equals [!Ref Environment, "prod"] Assertions: - - Assert: !Condition IsUsEast1 + - Assert: !Equals [!Ref "AWS::Region", "us-east-1"] Rule2: RuleCondition: !Equals [!Ref Environment, "dev"] Assertions: - - Assert: !Not [!Condition IsUsEast1] + - Assert: !Not [!Equals [!Ref "AWS::Region", "us-east-1"]] """ )[0] @@ -366,6 +366,95 @@ def test_fn_equals_assertions_ref_never_satisfiable(self): ) ) + def test_conditions_with_rules_and_parameters(self): + template = decode_str( + """ + Conditions: + DeployGateway: !Equals + - !Ref 'DeployGateway' + - 'true' + DeployVpc: !Equals + - !Ref 'DeployVpc' + - 'true' + Parameters: + DeployAnything: + AllowedValues: + - 'false' + - 'true' + Type: 'String' + DeployGateway: + AllowedValues: + - 'false' + - 'true' + Type: 'String' + DeployVpc: + AllowedValues: + - 'false' + - 'true' + Type: 'String' + Rules: + DeployGateway: + Assertions: + - Assert: !Or + - !Equals + - !Ref 'DeployAnything' + - 'true' + - !Equals + - !Ref 'DeployGateway' + - 'false' + DeployVpc: + Assertions: + - Assert: !Or + - !Equals + - !Ref 'DeployGateway' + - 'true' + - !Equals + - !Ref 'DeployVpc' + - 'false' + Resources: + InternetGateway: + Condition: 'DeployGateway' + Type: 'AWS::EC2::InternetGateway' + InternetGatewayAttachment: + Condition: 'DeployVpc' + Type: 'AWS::EC2::VPCGatewayAttachment' + Properties: + InternetGatewayId: !Ref 'InternetGateway' + VpcId: !Ref 'Vpc' + """ + )[0] + + cfn = Template("", template) + self.assertEqual(len(cfn.conditions._conditions), 2) + self.assertEqual(len(cfn.conditions._rules), 2) + + self.assertListEqual( + [equal.hash for equal in cfn.conditions._rules[0].equals], + [ + "d0d70a1e66dc83d7a0fce24c2eca396af1f34e53", + "bbf5c94c1a4b5a79c7a7863fe9463884cb422450", + ], + ) + + self.assertTrue( + cfn.conditions.satisfiable( + {}, + {}, + ) + ) + + self.assertTrue( + cfn.conditions.check_implies({"DeployVpc": True}, "DeployGateway") + ) + + self.assertFalse( + cfn.conditions.check_implies({"DeployVpc": False}, "DeployGateway") + ) + + self.assertFalse( + cfn.conditions.check_implies({"DeployGateway": False}, "DeployVpc") + ) + class TestAssertion(TestCase): def test_assertion_errors(self): @@ -405,7 +494,7 @@ def test_init_rules_with_wrong_assertions_type(self): Assertions: {"Foo": "Bar"} Rule2: Assertions: - - Assert: !Condition IsUsEast1 + - Assert: !Equals [!Ref "AWS::Region", "us-east-1"] """ )[0] @@ -425,8 +514,8 @@ def test_init_rules_with_no_keys(self): Assertions: - Assert: Fn::Or: - - !Condition IsNotUsEast1 - - !Condition IsUsEast1 + - !Not [!Equals [!Ref "AWS::Region", "us-east-1"]] + - !Equals [!Ref "AWS::Region", "us-east-1"] Rule3: [] """ )[0]