@@ -24,22 +24,26 @@ import (
24
24
"io/ioutil"
25
25
"log"
26
26
"os"
27
+ "runtime"
27
28
"strings"
28
29
29
30
"github.com/containernetworking/cni/pkg/types"
30
31
"github.com/containernetworking/cni/pkg/utils"
31
32
"github.com/containernetworking/cni/pkg/version"
33
+ "github.com/vishvananda/netns"
34
+ "golang.org/x/sys/unix"
32
35
)
33
36
34
37
// CmdArgs captures all the arguments passed in to the plugin
35
38
// via both env vars and stdin
36
39
type CmdArgs struct {
37
- ContainerID string
38
- Netns string
39
- IfName string
40
- Args string
41
- Path string
42
- StdinData []byte
40
+ ContainerID string
41
+ Netns string
42
+ IfName string
43
+ Args string
44
+ Path string
45
+ NetnsOverride string
46
+ StdinData []byte
43
47
}
44
48
45
49
type dispatcher struct {
@@ -55,7 +59,7 @@ type dispatcher struct {
55
59
type reqForCmdEntry map [string ]bool
56
60
57
61
func (t * dispatcher ) getCmdArgsFromEnv () (string , * CmdArgs , * types.Error ) {
58
- var cmd , contID , netns , ifName , args , path string
62
+ var cmd , contID , netns , ifName , args , path , netnsOverride string
59
63
60
64
vars := []struct {
61
65
name string
@@ -116,6 +120,15 @@ func (t *dispatcher) getCmdArgsFromEnv() (string, *CmdArgs, *types.Error) {
116
120
"DEL" : true ,
117
121
},
118
122
},
123
+ {
124
+ "CNI_NETNS_OVERRIDE" ,
125
+ & netnsOverride ,
126
+ reqForCmdEntry {
127
+ "ADD" : false ,
128
+ "CHECK" : false ,
129
+ "DEL" : false ,
130
+ },
131
+ },
119
132
}
120
133
121
134
argsMissing := make ([]string , 0 )
@@ -143,12 +156,13 @@ func (t *dispatcher) getCmdArgsFromEnv() (string, *CmdArgs, *types.Error) {
143
156
}
144
157
145
158
cmdArgs := & CmdArgs {
146
- ContainerID : contID ,
147
- Netns : netns ,
148
- IfName : ifName ,
149
- Args : args ,
150
- Path : path ,
151
- StdinData : stdinData ,
159
+ ContainerID : contID ,
160
+ Netns : netns ,
161
+ IfName : ifName ,
162
+ Args : args ,
163
+ Path : path ,
164
+ StdinData : stdinData ,
165
+ NetnsOverride : netnsOverride ,
152
166
}
153
167
return cmd , cmdArgs , nil
154
168
}
@@ -190,6 +204,39 @@ func validateConfig(jsonBytes []byte) *types.Error {
190
204
return nil
191
205
}
192
206
207
+ // Returns an object representing the current OS thread's network namespace
208
+ func getCurrentNS () (netns.NsHandle , error ) {
209
+ // Lock the thread in case other goroutine executes in it and changes its
210
+ // network namespace after getCurrentThreadNetNSPath(), otherwise it might
211
+ // return an unexpected network namespace.
212
+ runtime .LockOSThread ()
213
+ defer runtime .UnlockOSThread ()
214
+ return netns .GetFromPath (getCurrentThreadNetNSPath ())
215
+ }
216
+
217
+ func getCurrentThreadNetNSPath () string {
218
+ // /proc/self/ns/net returns the namespace of the main thread, not
219
+ // of whatever thread this goroutine is running on. Make sure we
220
+ // use the thread's net namespace since the thread is switching around
221
+ return fmt .Sprintf ("/proc/%d/task/%d/ns/net" , os .Getpid (), unix .Gettid ())
222
+ }
223
+
224
+ func checkNetNS (nsPath string ) (bool , * types.Error ) {
225
+ ns , err := netns .GetFromPath (nsPath )
226
+ if err != nil {
227
+ return false , nil
228
+ }
229
+ defer ns .Close ()
230
+
231
+ pluginNS , err := getCurrentNS ()
232
+ if err != nil {
233
+ return false , types .NewError (types .ErrInvalidNetNS , "get plugin's netns failed" , "" )
234
+ }
235
+ defer pluginNS .Close ()
236
+
237
+ return pluginNS .Equal (ns ), nil
238
+ }
239
+
193
240
func (t * dispatcher ) pluginMain (cmdAdd , cmdCheck , cmdDel func (_ * CmdArgs ) error , versionInfo version.PluginInfo , about string ) * types.Error {
194
241
cmd , cmdArgs , err := t .getCmdArgsFromEnv ()
195
242
if err != nil {
@@ -217,6 +264,17 @@ func (t *dispatcher) pluginMain(cmdAdd, cmdCheck, cmdDel func(_ *CmdArgs) error,
217
264
switch cmd {
218
265
case "ADD" :
219
266
err = t .checkVersionAndCall (cmdArgs , versionInfo , cmdAdd )
267
+ if err != nil {
268
+ return err
269
+ }
270
+ if strings .ToUpper (cmdArgs .NetnsOverride ) != "TRUE" || cmdArgs .NetnsOverride != "1" {
271
+ isPluginNetNS , checkErr := checkNetNS (cmdArgs .Netns )
272
+ if checkErr != nil {
273
+ return checkErr
274
+ } else if isPluginNetNS {
275
+ return types .NewError (types .ErrInvalidNetNS , "plugin's netns and netns from CNI_NETNS should not be the same" , "" )
276
+ }
277
+ }
220
278
case "CHECK" :
221
279
configVersion , err := t .ConfVersionDecoder .Decode (cmdArgs .StdinData )
222
280
if err != nil {
@@ -241,6 +299,17 @@ func (t *dispatcher) pluginMain(cmdAdd, cmdCheck, cmdDel func(_ *CmdArgs) error,
241
299
return types .NewError (types .ErrIncompatibleCNIVersion , "plugin version does not allow CHECK" , "" )
242
300
case "DEL" :
243
301
err = t .checkVersionAndCall (cmdArgs , versionInfo , cmdDel )
302
+ if err != nil {
303
+ return err
304
+ }
305
+ if strings .ToUpper (cmdArgs .NetnsOverride ) != "TRUE" || cmdArgs .NetnsOverride != "1" {
306
+ isPluginNetNS , checkErr := checkNetNS (cmdArgs .Netns )
307
+ if checkErr != nil {
308
+ return checkErr
309
+ } else if isPluginNetNS {
310
+ return types .NewError (types .ErrInvalidNetNS , "plugin's netns and netns from CNI_NETNS should not be the same" , "" )
311
+ }
312
+ }
244
313
case "VERSION" :
245
314
if err := versionInfo .Encode (t .Stdout ); err != nil {
246
315
return types .NewError (types .ErrIOFailure , err .Error (), "" )
0 commit comments