Skip to content

Refactor IPTable Rules #2697

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pkg/ipamd/ipamd.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,11 @@ func (c *IPAMContext) nodeInit() error {
if err != nil {
return errors.Wrap(err, "ipamd init: failed to set up host network")
}
err = c.networkClient.CleanUpStaleAWSChains(c.enableIPv4, c.enableIPv6)
if err != nil {
// We should not error if clean up fails since these chains don't affect the rules
log.Debugf("Failed to clean up stale AWS chains: %v", err)
}

metadataResult, err := c.awsClient.DescribeAllENIs()
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions pkg/ipamd/ipamd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ func TestNodeInit(t *testing.T) {
m.awsutils.EXPECT().GetVPCIPv4CIDRs().AnyTimes().Return(cidrs, nil)
m.awsutils.EXPECT().GetPrimaryENImac().Return("")
m.network.EXPECT().SetupHostNetwork(cidrs, "", &primaryIP, false, true, false).Return(nil)
m.network.EXPECT().CleanUpStaleAWSChains(true, false).Return(nil)
m.awsutils.EXPECT().GetPrimaryENI().AnyTimes().Return(primaryENIid)
m.awsutils.EXPECT().RefreshSGIDs(gomock.Any()).AnyTimes().Return(nil)

Expand Down Expand Up @@ -234,6 +235,7 @@ func TestNodeInitwithPDenabledIPv4Mode(t *testing.T) {
m.awsutils.EXPECT().GetVPCIPv4CIDRs().AnyTimes().Return(cidrs, nil)
m.awsutils.EXPECT().GetPrimaryENImac().Return("")
m.network.EXPECT().SetupHostNetwork(cidrs, "", &primaryIP, false, true, false).Return(nil)
m.network.EXPECT().CleanUpStaleAWSChains(true, false).Return(nil)
m.awsutils.EXPECT().GetPrimaryENI().AnyTimes().Return(primaryENIid)
m.awsutils.EXPECT().RefreshSGIDs(gomock.Any()).AnyTimes().Return(nil)

Expand Down Expand Up @@ -308,6 +310,7 @@ func TestNodeInitwithPDenabledIPv6Mode(t *testing.T) {

primaryIP := net.ParseIP(ipaddr01)
m.network.EXPECT().SetupHostNetwork(cidrs, eni1.MAC, &primaryIP, false, false, true).Return(nil)
m.network.EXPECT().CleanUpStaleAWSChains(false, true).Return(nil)
m.awsutils.EXPECT().GetIPv6PrefixesFromEC2(eni1.ENIID).AnyTimes().Return(eni1.IPv6Prefixes, nil)
m.awsutils.EXPECT().GetPrimaryENI().AnyTimes().Return(primaryENIid)
m.awsutils.EXPECT().GetPrimaryENImac().Return(eni1.MAC)
Expand Down
6 changes: 6 additions & 0 deletions pkg/iptableswrapper/iptables.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type IPTablesIface interface {
ClearChain(table, chain string) error
DeleteChain(table, chain string) error
ListChains(table string) ([]string, error)
ChainExists(table, chain string) (bool, error)
HasRandomFully() bool
}

Expand Down Expand Up @@ -98,6 +99,11 @@ func (i ipTables) ListChains(table string) ([]string, error) {
return i.ipt.ListChains(table)
}

// ChainExists implements IPTablesIface interface by calling iptables package
func (i ipTables) ChainExists(table, chain string) (bool, error) {
return i.ipt.ChainExists(table, chain)
}

// HasRandomFully implements IPTablesIface interface by calling iptables package
func (i ipTables) HasRandomFully() bool {
return i.ipt.HasRandomFully()
Expand Down
8 changes: 8 additions & 0 deletions pkg/iptableswrapper/mocks/iptables_maps.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ func (ipt *MockIptables) ListChains(table string) ([]string, error) {
return chains, nil
}

func (ipt *MockIptables) ChainExists(table, chain string) (bool, error) {
_, ok := ipt.DataplaneState[table][chain]
if ok {
return true, nil
}
return false, nil
}

func (ipt *MockIptables) HasRandomFully() bool {
// TODO: Work out how to write a test case for this
return true
Expand Down
15 changes: 15 additions & 0 deletions pkg/iptableswrapper/mocks/iptables_mocks.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions pkg/networkutils/mocks/network_mocks.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

160 changes: 102 additions & 58 deletions pkg/networkutils/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ type NetworkAPIs interface {
SetupENINetwork(eniIP string, mac string, deviceNumber int, subnetCIDR string) error
// UpdateHostIptablesRules updates the nat table iptables rules on the host
UpdateHostIptablesRules(vpcCIDRs []string, primaryMAC string, primaryAddr *net.IP, v4Enabled bool, v6Enabled bool) error
CleanUpStaleAWSChains(v4Enabled, v6Enabled bool) error
UseExternalSNAT() bool
GetExcludeSNATCIDRs() []string
GetExternalServiceCIDRs() []string
Expand Down Expand Up @@ -375,6 +376,51 @@ func (n *linuxNetwork) UpdateHostIptablesRules(vpcCIDRs []string, primaryMAC str
return n.updateHostIptablesRules(vpcCIDRs, primaryMAC, primaryAddr, v4Enabled, v6Enabled)
}

func (n *linuxNetwork) CleanUpStaleAWSChains(v4Enabled, v6Enabled bool) error {
ipProtocol := iptables.ProtocolIPv4
if v6Enabled {
ipProtocol = iptables.ProtocolIPv6
}

ipt, err := n.newIptables(ipProtocol)
if err != nil {
return errors.Wrap(err, "stale chain cleanup: failed to create iptables")
}

exists, err := ipt.ChainExists("nat", "AWS-SNAT-CHAIN-1")
if err != nil {
return errors.Wrap(err, "stale chain cleanup: failed to check if AWS-SNAT-CHAIN-1 exists")
}

if exists {
existingChains, err := ipt.ListChains("nat")
if err != nil {
return errors.Wrap(err, "stale chain cleanup: failed to list iptables nat chains")
}

for _, chain := range existingChains {
if !strings.HasPrefix(chain, "AWS-CONNMARK-CHAIN") && !strings.HasPrefix(chain, "AWS-SNAT-CHAIN") {
continue
}
parsedChain := strings.Split(chain, "-")
chainNum, err := strconv.Atoi(parsedChain[len(parsedChain)-1])
if err != nil {
return errors.Wrap(err, "stale chain cleanup: failed to convert string to int")
}
// Chains 1 --> x (0 indexed) will be stale
if chainNum > 0 {
// No need to clear the chain since computeStaleIptablesRules cleans up all rules already
log.Infof("Deleting stale chain: %s", chain)
err := ipt.DeleteChain("nat", chain)
if err != nil {
return errors.Wrapf(err, "stale chain cleanup: failed to delete chain %s", chain)
}
}
}
}
return nil
}

func (n *linuxNetwork) updateHostIptablesRules(vpcCIDRs []string, primaryMAC string, primaryAddr *net.IP, v4Enabled bool,
v6Enabled bool) error {
primaryIntf, err := findPrimaryInterfaceName(primaryMAC)
Expand Down Expand Up @@ -434,15 +480,13 @@ func (n *linuxNetwork) buildIptablesSNATRules(vpcCIDRs []string, primaryAddr *ne
log.Debugf("Total CIDRs to program - %d", len(allCIDRs))
// build IPTABLES chain for SNAT of non-VPC outbound traffic and excluded CIDRs
var chains []string
for i := 0; i <= len(allCIDRs); i++ {
chain := fmt.Sprintf("AWS-SNAT-CHAIN-%d", i)
log.Debugf("Setup Host Network: iptables -N %s -t nat", chain)
if err := ipt.NewChain("nat", chain); err != nil && !containChainExistErr(err) {
log.Errorf("ipt.NewChain error for chain [%s]: %v", chain, err)
return []iptablesRule{}, errors.Wrapf(err, "host network setup: failed to add chain")
}
chains = append(chains, chain)
chain := "AWS-SNAT-CHAIN-0"
log.Debugf("Setup Host Network: iptables -N %s -t nat", chain)
if err := ipt.NewChain("nat", chain); err != nil && !containChainExistErr(err) {
log.Errorf("ipt.NewChain error for chain [%s]: %v", chain, err)
return []iptablesRule{}, errors.Wrapf(err, "host network setup: failed to add chain")
}
chains = append(chains, chain)

// build SNAT rules for outbound non-VPC traffic
var iptableRules []iptablesRule
Expand All @@ -456,23 +500,20 @@ func (n *linuxNetwork) buildIptablesSNATRules(vpcCIDRs []string, primaryAddr *ne
"-m", "comment", "--comment", "AWS SNAT CHAIN", "-j", "AWS-SNAT-CHAIN-0",
}})

for i, cidr := range allCIDRs {
curChain := chains[i]
curName := fmt.Sprintf("[%d] AWS-SNAT-CHAIN", i)
nextChain := chains[i+1]
for _, cidr := range allCIDRs {
comment := "AWS SNAT CHAIN"
if cidr.isExclusion {
comment += " EXCLUSION"
}
log.Debugf("Setup Host Network: iptables -A %s ! -d %s -t nat -j %s", curChain, cidr, nextChain)
log.Debugf("Setup Host Network: iptables -A %s -d %s -t nat -j %s", chain, cidr, "RETURN")

iptableRules = append(iptableRules, iptablesRule{
name: curName,
name: chain,
shouldExist: !n.useExternalSNAT,
table: "nat",
chain: curChain,
chain: chain,
rule: []string{
"!", "-d", cidr.cidr, "-m", "comment", "--comment", comment, "-j", nextChain,
"-d", cidr.cidr, "-m", "comment", "--comment", comment, "-j", "RETURN",
}})
}

Expand All @@ -494,22 +535,21 @@ func (n *linuxNetwork) buildIptablesSNATRules(vpcCIDRs []string, primaryAddr *ne
}
}

lastChain := chains[len(chains)-1]
iptableRules = append(iptableRules, iptablesRule{
name: "last SNAT rule for non-VPC outbound traffic",
shouldExist: !n.useExternalSNAT,
table: "nat",
chain: lastChain,
rule: snatRule,
})

snatStaleRules, err := computeStaleIptablesRules(ipt, "nat", "AWS-SNAT-CHAIN", iptableRules, chains)
if err != nil {
return []iptablesRule{}, err
}

iptableRules = append(iptableRules, snatStaleRules...)

iptableRules = append(iptableRules, iptablesRule{
name: "last SNAT rule for non-VPC outbound traffic",
shouldExist: !n.useExternalSNAT,
table: "nat",
chain: chain,
rule: snatRule,
})

iptableRules = append(iptableRules, iptablesRule{
name: "connmark for primary ENI",
shouldExist: n.nodePortSupportEnabled,
Expand Down Expand Up @@ -556,16 +596,15 @@ func (n *linuxNetwork) buildIptablesConnmarkRules(vpcCIDRs []string, ipt iptable
excludeCIDRs := sets.NewString(n.excludeSNATCIDRs...)

log.Debugf("Total CIDRs to exempt from connmark rules - %d", len(allCIDRs))

var chains []string
for i := 0; i <= len(allCIDRs); i++ {
chain := fmt.Sprintf("AWS-CONNMARK-CHAIN-%d", i)
log.Debugf("Setup Host Network: iptables -N %s -t nat", chain)
if err := ipt.NewChain("nat", chain); err != nil && !containChainExistErr(err) {
log.Errorf("ipt.NewChain error for chain [%s]: %v", chain, err)
return []iptablesRule{}, errors.Wrapf(err, "host network setup: failed to add chain")
}
chains = append(chains, chain)
chain := "AWS-CONNMARK-CHAIN-0"
log.Debugf("Setup Host Network: iptables -N %s -t nat", chain)
if err := ipt.NewChain("nat", chain); err != nil && !containChainExistErr(err) {
log.Errorf("ipt.NewChain error for chain [%s]: %v", chain, err)
return []iptablesRule{}, errors.Wrapf(err, "host network setup: failed to add chain")
}
chains = append(chains, chain)

var iptableRules []iptablesRule
log.Debugf("Setup Host Network: iptables -t nat -A PREROUTING -i %s+ -m comment --comment \"AWS, outbound connections\" -j AWS-CONNMARK-CHAIN-0", n.vethPrefix)
Expand All @@ -590,37 +629,23 @@ func (n *linuxNetwork) buildIptablesConnmarkRules(vpcCIDRs []string, ipt iptable
"-j", "AWS-CONNMARK-CHAIN-0",
}})

for i, cidr := range allCIDRs {
curChain := chains[i]
curName := fmt.Sprintf("[%d] AWS-SNAT-CHAIN", i)
nextChain := chains[i+1]
for _, cidr := range allCIDRs {
comment := "AWS CONNMARK CHAIN, VPC CIDR"
if excludeCIDRs.Has(cidr) {
comment = "AWS CONNMARK CHAIN, EXCLUDED CIDR"
}
log.Debugf("Setup Host Network: iptables -A %s ! -d %s -t nat -j %s", curChain, cidr, nextChain)
log.Debugf("Setup Host Network: iptables -A %s -d %s -t nat -j %s", chain, cidr, "RETURN")

iptableRules = append(iptableRules, iptablesRule{
name: curName,
name: chain,
shouldExist: !n.useExternalSNAT,
table: "nat",
chain: curChain,
chain: chain,
rule: []string{
"!", "-d", cidr, "-m", "comment", "--comment", comment, "-j", nextChain,
"-d", cidr, "-m", "comment", "--comment", comment, "-j", "RETURN",
}})
}

iptableRules = append(iptableRules, iptablesRule{
name: "connmark rule for external outbound traffic",
shouldExist: !n.useExternalSNAT,
table: "nat",
chain: chains[len(chains)-1],
rule: []string{
"-m", "comment", "--comment", "AWS, CONNMARK", "-j", "CONNMARK",
"--set-xmark", fmt.Sprintf("%#x/%#x", n.mainENIMark, n.mainENIMark),
},
})

// Force delete existing restore mark rule so that the subsequent rule gets added to the end
iptableRules = append(iptableRules, iptablesRule{
name: "connmark to fwmark copy",
Expand Down Expand Up @@ -652,14 +677,24 @@ func (n *linuxNetwork) buildIptablesConnmarkRules(vpcCIDRs []string, ipt iptable
}
iptableRules = append(iptableRules, connmarkStaleRules...)

iptableRules = append(iptableRules, iptablesRule{
name: "connmark rule for external outbound traffic",
shouldExist: !n.useExternalSNAT,
table: "nat",
chain: chain,
rule: []string{
"-m", "comment", "--comment", "AWS, CONNMARK", "-j", "CONNMARK",
"--set-xmark", fmt.Sprintf("%#x/%#x", n.mainENIMark, n.mainENIMark),
},
})

log.Debugf("iptableRules: %v", iptableRules)
return iptableRules, nil
}

func (n *linuxNetwork) updateIptablesRules(iptableRules []iptablesRule, ipt iptableswrapper.IPTablesIface) error {
for _, rule := range iptableRules {
log.Debugf("execute iptable rule : %s", rule.name)

exists, err := ipt.Exists(rule.table, rule.chain, rule.rule...)
log.Debugf("rule %v exists %v, err %v", rule, exists, err)
if err != nil {
Expand All @@ -668,10 +703,19 @@ func (n *linuxNetwork) updateIptablesRules(iptableRules []iptablesRule, ipt ipta
}

if !exists && rule.shouldExist {
err = ipt.Append(rule.table, rule.chain, rule.rule...)
if err != nil {
log.Errorf("host network setup: failed to add %v, %v", rule, err)
return errors.Wrapf(err, "host network setup: failed to add %v", rule)
if rule.name == "AWS-CONNMARK-CHAIN-0" || rule.name == "AWS-SNAT-CHAIN-0" {
// All CIDR rules must go before the SNAT/Mark rule
err = ipt.Insert(rule.table, rule.chain, 1, rule.rule...)
if err != nil {
log.Errorf("host network setup: failed to insert %v, %v", rule, err)
return errors.Wrapf(err, "host network setup: failed to add %v", rule)
}
} else {
err = ipt.Append(rule.table, rule.chain, rule.rule...)
if err != nil {
log.Errorf("host network setup: failed to add %v, %v", rule, err)
return errors.Wrapf(err, "host network setup: failed to add %v", rule)
}
}
} else if exists && !rule.shouldExist {
err = ipt.Delete(rule.table, rule.chain, rule.rule...)
Expand Down Expand Up @@ -726,7 +770,7 @@ func computeStaleIptablesRules(ipt iptableswrapper.IPTablesIface, table, chainPr
return []iptablesRule{}, errors.Wrapf(err, "host network setup: failed to list rules from table %s with chain prefix %s", table, chainPrefix)
}
activeChains := sets.NewString(chains...)
log.Debugf("Setup Host Network: computing stale iptables rules for %s table with chain prefix %s")
log.Debugf("Setup Host Network: computing stale iptables rules for %s table with chain prefix %s", table, chainPrefix)
for _, staleRule := range existingRules {
if len(staleRule.rule) == 0 && activeChains.Has(staleRule.chain) {
log.Debugf("Setup Host Network: active chain found: %s", staleRule.chain)
Expand Down
Loading