Skip to content

Commit 4638620

Browse files
authored
fix: validate iptable rule exists after calling insert or append iptable rule (#3602)
* add validation as part of the iptables insert command * add logic to append * add iptables package unit tests moves platform exec client to new client from RunCmd which shouldn't change anything as NewExecClient just instantiates a struct enables passing in to iptables client custom/mock functionality for running an os command for testing iptables.Client is never created without NewClient() outside of testing so adding the platform exec field should not cause nil pointers-- Client.pl is auto populated when calling NewClient
1 parent a848065 commit 4638620

File tree

2 files changed

+288
-7
lines changed

2 files changed

+288
-7
lines changed

iptables/iptables.go

+28-7
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,18 @@ package iptables
33
// This package contains wrapper functions to program iptables rules
44

55
import (
6+
"errors"
67
"fmt"
78

89
"github.com/Azure/azure-container-networking/cni/log"
910
"github.com/Azure/azure-container-networking/platform"
1011
"go.uber.org/zap"
1112
)
1213

13-
var logger = log.CNILogger.With(zap.String("component", "cni-iptables"))
14+
var (
15+
logger = log.CNILogger.With(zap.String("component", "cni-iptables"))
16+
errCouldNotValidateRuleExists = errors.New("could not validate iptable rule exists after insertion")
17+
)
1418

1519
// cni iptable chains
1620
const (
@@ -87,17 +91,20 @@ type IPTableEntry struct {
8791
Params string
8892
}
8993

90-
type Client struct{}
94+
type Client struct {
95+
pl platform.ExecClient
96+
}
9197

9298
func NewClient() *Client {
93-
return &Client{}
99+
return &Client{
100+
pl: platform.NewExecClient(logger),
101+
}
94102
}
95103

96104
// Run iptables command
97105
func (c *Client) RunCmd(version, params string) error {
98106
var cmd string
99107

100-
p := platform.NewExecClient(logger)
101108
iptCmd := iptables
102109
if version == V6 {
103110
iptCmd = ip6tables
@@ -109,7 +116,7 @@ func (c *Client) RunCmd(version, params string) error {
109116
cmd = fmt.Sprintf("%s -w %d %s", iptCmd, lockTimeout, params)
110117
}
111118

112-
if _, err := p.ExecuteRawCommand(cmd); err != nil {
119+
if _, err := c.pl.ExecuteRawCommand(cmd); err != nil {
113120
return err
114121
}
115122

@@ -171,7 +178,14 @@ func (c *Client) InsertIptableRule(version, tableName, chainName, match, target
171178
}
172179

173180
cmd := c.GetInsertIptableRuleCmd(version, tableName, chainName, match, target)
174-
return c.RunCmd(version, cmd.Params)
181+
err := c.RunCmd(version, cmd.Params)
182+
if err != nil {
183+
return err
184+
}
185+
if !c.RuleExists(version, tableName, chainName, match, target) {
186+
return errCouldNotValidateRuleExists
187+
}
188+
return nil
175189
}
176190

177191
func (c *Client) GetAppendIptableRuleCmd(version, tableName, chainName, match, target string) IPTableEntry {
@@ -189,7 +203,14 @@ func (c *Client) AppendIptableRule(version, tableName, chainName, match, target
189203
}
190204

191205
cmd := c.GetAppendIptableRuleCmd(version, tableName, chainName, match, target)
192-
return c.RunCmd(version, cmd.Params)
206+
err := c.RunCmd(version, cmd.Params)
207+
if err != nil {
208+
return err
209+
}
210+
if !c.RuleExists(version, tableName, chainName, match, target) {
211+
return errCouldNotValidateRuleExists
212+
}
213+
return nil
193214
}
194215

195216
// Delete matched iptable rule

iptables/iptables_test.go

+260
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
package iptables
2+
3+
import (
4+
"errors"
5+
"testing"
6+
7+
"github.com/Azure/azure-container-networking/platform"
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
type validationCase struct {
13+
cmd string
14+
doErr bool
15+
}
16+
17+
var (
18+
errMockPlatform = errors.New("mock pl error")
19+
errExtraneousCalls = errors.New("function called too many times")
20+
)
21+
22+
// GenerateValidateFunc takes in a slice of expected calls and intended responses for each time the returned function is called
23+
// For example, if expectedCmds has one validationCase, the first call of the func returned will check the command
24+
// passed in matches the first validationCase's command (fails test if not), and return an error if the first validationCase has doErr as true
25+
// The second call will use the second validation case in the slice to check against the cmd passed in and so on
26+
// If we call this function more times than the number of elements in expectedCmds, errExtraneousCalls is returned
27+
func GenerateValidationFunc(t *testing.T, expectedCmds []validationCase) func(cmd string) (string, error) {
28+
curr := 0
29+
30+
ret := func(cmd string) (string, error) {
31+
if curr >= len(expectedCmds) {
32+
return "", errExtraneousCalls
33+
}
34+
expected := expectedCmds[curr]
35+
curr++
36+
37+
require.Equal(t, expected.cmd, cmd, "command run does not match expected")
38+
39+
if expected.doErr {
40+
return "", errMockPlatform
41+
}
42+
return "", nil
43+
}
44+
45+
return ret
46+
}
47+
48+
func TestGenerateValidationFunc(t *testing.T) {
49+
mockPL := platform.NewMockExecClient(false)
50+
fn := GenerateValidationFunc(t, []validationCase{
51+
{
52+
cmd: "echo hello",
53+
doErr: true,
54+
},
55+
})
56+
mockPL.SetExecRawCommand(fn)
57+
58+
_, err := mockPL.ExecuteRawCommand("echo hello")
59+
require.Error(t, err)
60+
61+
_, err = mockPL.ExecuteRawCommand("echo hello")
62+
require.ErrorIs(t, err, errExtraneousCalls)
63+
}
64+
65+
func TestRunCmd(t *testing.T) {
66+
mockPL := platform.NewMockExecClient(false)
67+
client := &Client{
68+
pl: mockPL,
69+
}
70+
mockPL.SetExecRawCommand(
71+
GenerateValidationFunc(t, []validationCase{
72+
{
73+
cmd: "iptables -w 60 -L",
74+
doErr: false,
75+
},
76+
}),
77+
)
78+
79+
err := client.RunCmd(V4, "-L")
80+
require.NoError(t, err)
81+
}
82+
83+
func TestCreateChain(t *testing.T) {
84+
mockPL := platform.NewMockExecClient(false)
85+
client := &Client{
86+
pl: mockPL,
87+
}
88+
mockPL.SetExecRawCommand(
89+
GenerateValidationFunc(t, []validationCase{
90+
{
91+
cmd: "iptables -w 60 -t filter -nL AZURECNIINPUT",
92+
doErr: true,
93+
},
94+
{
95+
cmd: "iptables -w 60 -t filter -N AZURECNIINPUT",
96+
doErr: false,
97+
},
98+
}),
99+
)
100+
101+
err := client.CreateChain(V4, Filter, CNIInputChain)
102+
require.NoError(t, err)
103+
}
104+
105+
func TestInsertIptableRule(t *testing.T) {
106+
mockPL := platform.NewMockExecClient(false)
107+
client := &Client{
108+
pl: mockPL,
109+
}
110+
111+
mockPL.SetExecRawCommand(
112+
GenerateValidationFunc(t, []validationCase{
113+
// iptables succeeds
114+
{
115+
cmd: "iptables -w 60 -t filter -C AZURECNIINPUT -p tcp --dport 70 -j ACCEPT",
116+
doErr: true,
117+
},
118+
{
119+
cmd: "iptables -w 60 -t filter -I AZURECNIINPUT 1 -p tcp --dport 70 -j ACCEPT",
120+
doErr: false,
121+
},
122+
{
123+
cmd: "iptables -w 60 -t filter -C AZURECNIINPUT -p tcp --dport 70 -j ACCEPT",
124+
doErr: false,
125+
},
126+
// iptables fails silently
127+
{
128+
cmd: "iptables -w 60 -t filter -C AZURECNIINPUT -p tcp --dport 80 -j ACCEPT",
129+
doErr: true,
130+
},
131+
{
132+
cmd: "iptables -w 60 -t filter -I AZURECNIINPUT 1 -p tcp --dport 80 -j ACCEPT",
133+
doErr: false,
134+
},
135+
{
136+
cmd: "iptables -w 60 -t filter -C AZURECNIINPUT -p tcp --dport 80 -j ACCEPT",
137+
doErr: true,
138+
},
139+
// iptables finds rule already
140+
{
141+
cmd: "iptables -w 60 -t filter -C AZURECNIINPUT -p tcp --dport 90 -j ACCEPT",
142+
doErr: false,
143+
},
144+
}),
145+
)
146+
// iptables succeeds
147+
err := client.InsertIptableRule(V4, Filter, CNIInputChain, "-p tcp --dport 70", Accept)
148+
require.NoError(t, err)
149+
// iptables fails silently
150+
err = client.InsertIptableRule(V4, Filter, CNIInputChain, "-p tcp --dport 80", Accept)
151+
require.ErrorIs(t, err, errCouldNotValidateRuleExists)
152+
// iptables finds rule already
153+
err = client.InsertIptableRule(V4, Filter, CNIInputChain, "-p tcp --dport 90", Accept)
154+
require.NoError(t, err)
155+
}
156+
157+
func TestAppendIptableRule(t *testing.T) {
158+
mockPL := platform.NewMockExecClient(false)
159+
client := &Client{
160+
pl: mockPL,
161+
}
162+
mockPL.SetExecRawCommand(
163+
GenerateValidationFunc(t, []validationCase{
164+
// iptables succeeds
165+
{
166+
cmd: "iptables -w 60 -t filter -C AZURECNIINPUT -p tcp --dport 70 -j ACCEPT",
167+
doErr: true,
168+
},
169+
{
170+
cmd: "iptables -w 60 -t filter -A AZURECNIINPUT -p tcp --dport 70 -j ACCEPT",
171+
doErr: false,
172+
},
173+
{
174+
cmd: "iptables -w 60 -t filter -C AZURECNIINPUT -p tcp --dport 70 -j ACCEPT",
175+
doErr: false,
176+
},
177+
// iptables fails silently
178+
{
179+
cmd: "iptables -w 60 -t filter -C AZURECNIINPUT -p tcp --dport 80 -j ACCEPT",
180+
doErr: true,
181+
},
182+
{
183+
cmd: "iptables -w 60 -t filter -A AZURECNIINPUT -p tcp --dport 80 -j ACCEPT",
184+
doErr: false,
185+
},
186+
{
187+
cmd: "iptables -w 60 -t filter -C AZURECNIINPUT -p tcp --dport 80 -j ACCEPT",
188+
doErr: true,
189+
},
190+
// iptables finds rule already
191+
{
192+
cmd: "iptables -w 60 -t filter -C AZURECNIINPUT -p tcp --dport 90 -j ACCEPT",
193+
doErr: false,
194+
},
195+
}),
196+
)
197+
// iptables succeeds
198+
err := client.AppendIptableRule(V4, Filter, CNIInputChain, "-p tcp --dport 70", Accept)
199+
require.NoError(t, err)
200+
// iptables fails silently
201+
err = client.AppendIptableRule(V4, Filter, CNIInputChain, "-p tcp --dport 80", Accept)
202+
require.ErrorIs(t, errCouldNotValidateRuleExists, err)
203+
// iptables finds rule already
204+
err = client.AppendIptableRule(V4, Filter, CNIInputChain, "-p tcp --dport 90", Accept)
205+
require.NoError(t, err)
206+
}
207+
208+
func TestDeleteIptableRule(t *testing.T) {
209+
mockPL := platform.NewMockExecClient(false)
210+
client := &Client{
211+
pl: mockPL,
212+
}
213+
mockPL.SetExecRawCommand(
214+
GenerateValidationFunc(t, []validationCase{
215+
{
216+
cmd: "iptables -w 60 -t filter -D AZURECNIINPUT -p tcp --dport 80 -j ACCEPT",
217+
doErr: false,
218+
},
219+
}),
220+
)
221+
222+
err := client.DeleteIptableRule(V4, Filter, CNIInputChain, "-p tcp --dport 80", Accept)
223+
require.NoError(t, err)
224+
}
225+
226+
func TestChainExists(t *testing.T) {
227+
mockPL := platform.NewMockExecClient(false)
228+
client := &Client{
229+
pl: mockPL,
230+
}
231+
mockPL.SetExecRawCommand(
232+
GenerateValidationFunc(t, []validationCase{
233+
{
234+
cmd: "iptables -w 60 -t filter -nL AZURECNIINPUT",
235+
doErr: true,
236+
},
237+
}),
238+
)
239+
240+
result := client.ChainExists(V4, Filter, CNIInputChain)
241+
assert.False(t, result)
242+
}
243+
244+
func TestRuleExists(t *testing.T) {
245+
mockPL := platform.NewMockExecClient(false)
246+
client := &Client{
247+
pl: mockPL,
248+
}
249+
mockPL.SetExecRawCommand(
250+
GenerateValidationFunc(t, []validationCase{
251+
{
252+
cmd: "iptables -w 60 -t filter -C AZURECNIINPUT -p tcp --dport 80 -j ACCEPT",
253+
doErr: true,
254+
},
255+
}),
256+
)
257+
258+
result := client.RuleExists(V4, Filter, CNIInputChain, "-p tcp --dport 80", Accept)
259+
assert.False(t, result)
260+
}

0 commit comments

Comments
 (0)