Skip to content

Add equal vars for Rules to cnf building #3714

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions src/cfnlint/conditions/_equals.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Any, Mapping, Tuple

from sympy import Symbol
from sympy.logic.boolalg import BooleanFalse, BooleanFunction, BooleanTrue
from sympy.logic.boolalg import BooleanFalse, BooleanTrue

from cfnlint.conditions._utils import get_hash
from cfnlint.helpers import is_function
Expand Down Expand Up @@ -145,7 +145,9 @@ def left(self):
def right(self):
return self._right

def build_cnf(self, params: dict[str, Symbol]) -> BooleanFunction:
def build_cnf(
self, params: dict[str, Symbol]
) -> BooleanTrue | BooleanFalse | Symbol:
"""Build a SymPy CNF solver based on the provided params
Args:
params dict[str, Symbol]: params is a dict that represents
Expand All @@ -158,10 +160,7 @@ def build_cnf(self, params: dict[str, Symbol]) -> BooleanFunction:
return BooleanTrue()
return BooleanFalse()

if self.hash in params:
return params.get(self.hash)

return Symbol(self.hash)
return params.get(self.hash, Symbol(self.hash))

def test(self, scenarios: Mapping[str, str]) -> bool:
"""Do an equals based on the provided scenario"""
Expand Down
13 changes: 9 additions & 4 deletions src/cfnlint/conditions/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,12 @@ def _build_cnf(
cnf = EncodedCNF()

# build parameters and equals into solver
equal_vars: dict[str, Symbol] = {}
equal_vars: dict[str, Symbol | BooleanFalse | BooleanTrue] = {}

equals: dict[str, Equal] = {}
for condition_name in condition_names:
c_equals = self._conditions[condition_name].equals

def _build_equal_vars(c_equals: list[Equal]):
for c_equal in c_equals:
# check to see if equals already matches another one
if c_equal.hash in equal_vars:
continue

Expand All @@ -139,6 +138,12 @@ def _build_cnf(
)
equals[c_equal.hash] = c_equal

for rule in self._rules:
_build_equal_vars(rule.equals)

for condition_name in condition_names:
_build_equal_vars(self._conditions[condition_name].equals)

# Determine if a set of conditions can never be all false
allowed_values = self._parameters.copy()
if allowed_values:
Expand Down
111 changes: 100 additions & 11 deletions test/unit/module/conditions/test_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ def test_conditions_with_rules(self):
Assertions:
- Assert:
Fn::And:
- !Condition IsProd
- !Condition IsUsEast1
- !Equals [!Ref Environment, "prod"]
- !Equals [!Ref "AWS::Region", "us-east-1"]
Rule2:
Assertions:
- Assert:
Fn::Or:
- !Condition IsProd
- !Condition IsUsEast1
- !Equals [!Ref Environment, "prod"]
- !Equals [!Ref "AWS::Region", "us-east-1"]
"""
)[0]

Expand Down Expand Up @@ -79,9 +79,9 @@ def test_conditions_with_rules_implies(self):
IsUsEast1: !Equals [!Ref "AWS::Region", "us-east-1"]
Rules:
Rule:
RuleCondition: !Condition IsProd
RuleCondition: !Equals [!Ref Environment, "prod"]
Assertions:
- Assert: !Condition IsUsEast1
- Assert: !Equals [!Ref "AWS::Region", "us-east-1"]
"""
)[0]
Expand Down Expand Up @@ -143,11 +143,11 @@ def test_conditions_with_multiple_rules(self):
Rule1:
RuleCondition: !Equals [!Ref Environment, "prod"]
Assertions:
- Assert: !Condition IsUsEast1
- Assert: !Equals [!Ref "AWS::Region", "us-east-1"]
Rule2:
RuleCondition: !Equals [!Ref Environment, "dev"]
Assertions:
- Assert: !Not [!Condition IsUsEast1]
- Assert: !Not [!Equals [!Ref "AWS::Region", "us-east-1"]]
"""
)[0]

Expand Down Expand Up @@ -366,6 +366,95 @@ def test_fn_equals_assertions_ref_never_satisfiable(self):
)
)

def test_conditions_with_rules_and_parameters(self):
template = decode_str(
"""
Conditions:
DeployGateway: !Equals
- !Ref 'DeployGateway'
- 'true'
DeployVpc: !Equals
- !Ref 'DeployVpc'
- 'true'
Parameters:
DeployAnything:
AllowedValues:
- 'false'
- 'true'
Type: 'String'
DeployGateway:
AllowedValues:
- 'false'
- 'true'
Type: 'String'
DeployVpc:
AllowedValues:
- 'false'
- 'true'
Type: 'String'
Rules:
DeployGateway:
Assertions:
- Assert: !Or
- !Equals
- !Ref 'DeployAnything'
- 'true'
- !Equals
- !Ref 'DeployGateway'
- 'false'
DeployVpc:
Assertions:
- Assert: !Or
- !Equals
- !Ref 'DeployGateway'
- 'true'
- !Equals
- !Ref 'DeployVpc'
- 'false'
Resources:
InternetGateway:
Condition: 'DeployGateway'
Type: 'AWS::EC2::InternetGateway'
InternetGatewayAttachment:
Condition: 'DeployVpc'
Type: 'AWS::EC2::VPCGatewayAttachment'
Properties:
InternetGatewayId: !Ref 'InternetGateway'
VpcId: !Ref 'Vpc'
"""
)[0]

cfn = Template("", template)
self.assertEqual(len(cfn.conditions._conditions), 2)
self.assertEqual(len(cfn.conditions._rules), 2)

self.assertListEqual(
[equal.hash for equal in cfn.conditions._rules[0].equals],
[
"d0d70a1e66dc83d7a0fce24c2eca396af1f34e53",
"bbf5c94c1a4b5a79c7a7863fe9463884cb422450",
],
)

self.assertTrue(
cfn.conditions.satisfiable(
{},
{},
)
)

self.assertTrue(
cfn.conditions.check_implies({"DeployVpc": True}, "DeployGateway")
)

self.assertFalse(
cfn.conditions.check_implies({"DeployVpc": False}, "DeployGateway")
)

self.assertFalse(
cfn.conditions.check_implies({"DeployGateway": False}, "DeployVpc")
)


class TestAssertion(TestCase):
def test_assertion_errors(self):
Expand Down Expand Up @@ -405,7 +494,7 @@ def test_init_rules_with_wrong_assertions_type(self):
Assertions: {"Foo": "Bar"}
Rule2:
Assertions:
- Assert: !Condition IsUsEast1
- Assert: !Equals [!Ref "AWS::Region", "us-east-1"]
"""
)[0]

Expand All @@ -425,8 +514,8 @@ def test_init_rules_with_no_keys(self):
Assertions:
- Assert:
Fn::Or:
- !Condition IsNotUsEast1
- !Condition IsUsEast1
- !Not [!Equals [!Ref "AWS::Region", "us-east-1"]]
- !Equals [!Ref "AWS::Region", "us-east-1"]
Rule3: []
"""
)[0]
Expand Down
Loading