Skip to content

Commit cecf21d

Browse files
rpothiercornfeedhobo
authored andcommitted
Add IPNetSlice and unit tests (spf13#170)
1 parent 6971c29 commit cecf21d

File tree

2 files changed

+383
-0
lines changed

2 files changed

+383
-0
lines changed

ipnet_slice.go

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
package pflag
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"net"
7+
"strings"
8+
)
9+
10+
// -- ipNetSlice Value
11+
type ipNetSliceValue struct {
12+
value *[]net.IPNet
13+
changed bool
14+
}
15+
16+
func newIPNetSliceValue(val []net.IPNet, p *[]net.IPNet) *ipNetSliceValue {
17+
ipnsv := new(ipNetSliceValue)
18+
ipnsv.value = p
19+
*ipnsv.value = val
20+
return ipnsv
21+
}
22+
23+
// Set converts, and assigns, the comma-separated IPNet argument string representation as the []net.IPNet value of this flag.
24+
// If Set is called on a flag that already has a []net.IPNet assigned, the newly converted values will be appended.
25+
func (s *ipNetSliceValue) Set(val string) error {
26+
27+
// remove all quote characters
28+
rmQuote := strings.NewReplacer(`"`, "", `'`, "", "`", "")
29+
30+
// read flag arguments with CSV parser
31+
ipNetStrSlice, err := readAsCSV(rmQuote.Replace(val))
32+
if err != nil && err != io.EOF {
33+
return err
34+
}
35+
36+
// parse ip values into slice
37+
out := make([]net.IPNet, 0, len(ipNetStrSlice))
38+
for _, ipNetStr := range ipNetStrSlice {
39+
_, n, err := net.ParseCIDR(strings.TrimSpace(ipNetStr))
40+
if err != nil {
41+
return fmt.Errorf("invalid string being converted to CIDR: %s", ipNetStr)
42+
}
43+
out = append(out, *n)
44+
}
45+
46+
if !s.changed {
47+
*s.value = out
48+
} else {
49+
*s.value = append(*s.value, out...)
50+
}
51+
52+
s.changed = true
53+
54+
return nil
55+
}
56+
57+
// Type returns a string that uniquely represents this flag's type.
58+
func (s *ipNetSliceValue) Type() string {
59+
return "ipNetSlice"
60+
}
61+
62+
// String defines a "native" format for this net.IPNet slice flag value.
63+
func (s *ipNetSliceValue) String() string {
64+
65+
ipNetStrSlice := make([]string, len(*s.value))
66+
for i, n := range *s.value {
67+
ipNetStrSlice[i] = n.String()
68+
}
69+
70+
out, _ := writeAsCSV(ipNetStrSlice)
71+
return "[" + out + "]"
72+
}
73+
74+
func ipNetSliceConv(val string) (interface{}, error) {
75+
val = strings.Trim(val, "[]")
76+
// Emtpy string would cause a slice with one (empty) entry
77+
if len(val) == 0 {
78+
return []net.IPNet{}, nil
79+
}
80+
ss := strings.Split(val, ",")
81+
out := make([]net.IPNet, len(ss))
82+
for i, sval := range ss {
83+
_, n, err := net.ParseCIDR(strings.TrimSpace(sval))
84+
if err != nil {
85+
return nil, fmt.Errorf("invalid string being converted to CIDR: %s", sval)
86+
}
87+
out[i] = *n
88+
}
89+
return out, nil
90+
}
91+
92+
// GetIPNetSlice returns the []net.IPNet value of a flag with the given name
93+
func (f *FlagSet) GetIPNetSlice(name string) ([]net.IPNet, error) {
94+
val, err := f.getFlagType(name, "ipNetSlice", ipNetSliceConv)
95+
if err != nil {
96+
return []net.IPNet{}, err
97+
}
98+
return val.([]net.IPNet), nil
99+
}
100+
101+
// IPNetSliceVar defines a ipNetSlice flag with specified name, default value, and usage string.
102+
// The argument p points to a []net.IPNet variable in which to store the value of the flag.
103+
func (f *FlagSet) IPNetSliceVar(p *[]net.IPNet, name string, value []net.IPNet, usage string) {
104+
f.VarP(newIPNetSliceValue(value, p), name, "", usage)
105+
}
106+
107+
// IPNetSliceVarP is like IPNetSliceVar, but accepts a shorthand letter that can be used after a single dash.
108+
func (f *FlagSet) IPNetSliceVarP(p *[]net.IPNet, name, shorthand string, value []net.IPNet, usage string) {
109+
f.VarP(newIPNetSliceValue(value, p), name, shorthand, usage)
110+
}
111+
112+
// IPNetSliceVar defines a []net.IPNet flag with specified name, default value, and usage string.
113+
// The argument p points to a []net.IPNet variable in which to store the value of the flag.
114+
func IPNetSliceVar(p *[]net.IPNet, name string, value []net.IPNet, usage string) {
115+
CommandLine.VarP(newIPNetSliceValue(value, p), name, "", usage)
116+
}
117+
118+
// IPNetSliceVarP is like IPNetSliceVar, but accepts a shorthand letter that can be used after a single dash.
119+
func IPNetSliceVarP(p *[]net.IPNet, name, shorthand string, value []net.IPNet, usage string) {
120+
CommandLine.VarP(newIPNetSliceValue(value, p), name, shorthand, usage)
121+
}
122+
123+
// IPNetSlice defines a []net.IPNet flag with specified name, default value, and usage string.
124+
// The return value is the address of a []net.IPNet variable that stores the value of that flag.
125+
func (f *FlagSet) IPNetSlice(name string, value []net.IPNet, usage string) *[]net.IPNet {
126+
p := []net.IPNet{}
127+
f.IPNetSliceVarP(&p, name, "", value, usage)
128+
return &p
129+
}
130+
131+
// IPNetSliceP is like IPNetSlice, but accepts a shorthand letter that can be used after a single dash.
132+
func (f *FlagSet) IPNetSliceP(name, shorthand string, value []net.IPNet, usage string) *[]net.IPNet {
133+
p := []net.IPNet{}
134+
f.IPNetSliceVarP(&p, name, shorthand, value, usage)
135+
return &p
136+
}
137+
138+
// IPNetSlice defines a []net.IPNet flag with specified name, default value, and usage string.
139+
// The return value is the address of a []net.IP variable that stores the value of the flag.
140+
func IPNetSlice(name string, value []net.IPNet, usage string) *[]net.IPNet {
141+
return CommandLine.IPNetSliceP(name, "", value, usage)
142+
}
143+
144+
// IPNetSliceP is like IPNetSlice, but accepts a shorthand letter that can be used after a single dash.
145+
func IPNetSliceP(name, shorthand string, value []net.IPNet, usage string) *[]net.IPNet {
146+
return CommandLine.IPNetSliceP(name, shorthand, value, usage)
147+
}

ipnet_slice_test.go

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
package pflag
2+
3+
import (
4+
"fmt"
5+
"net"
6+
"strings"
7+
"testing"
8+
)
9+
10+
// Helper function to set static slices
11+
func getCIDR(ip net.IP, cidr *net.IPNet, err error) net.IPNet {
12+
return *cidr
13+
}
14+
15+
func equalCIDR(c1 net.IPNet, c2 net.IPNet) bool {
16+
return c1.String() == c2.String()
17+
}
18+
19+
func setUpIPNetFlagSet(ipsp *[]net.IPNet) *FlagSet {
20+
f := NewFlagSet("test", ContinueOnError)
21+
f.IPNetSliceVar(ipsp, "cidrs", []net.IPNet{}, "Command separated list!")
22+
return f
23+
}
24+
25+
func setUpIPNetFlagSetWithDefault(ipsp *[]net.IPNet) *FlagSet {
26+
f := NewFlagSet("test", ContinueOnError)
27+
f.IPNetSliceVar(ipsp, "cidrs",
28+
[]net.IPNet{
29+
getCIDR(net.ParseCIDR("192.168.1.1/16")),
30+
getCIDR(net.ParseCIDR("fd00::/64")),
31+
},
32+
"Command separated list!")
33+
return f
34+
}
35+
36+
func TestEmptyIPNet(t *testing.T) {
37+
var cidrs []net.IPNet
38+
f := setUpIPNetFlagSet(&cidrs)
39+
err := f.Parse([]string{})
40+
if err != nil {
41+
t.Fatal("expected no error; got", err)
42+
}
43+
44+
getIPNet, err := f.GetIPNetSlice("cidrs")
45+
if err != nil {
46+
t.Fatal("got an error from GetIPNetSlice():", err)
47+
}
48+
if len(getIPNet) != 0 {
49+
t.Fatalf("got ips %v with len=%d but expected length=0", getIPNet, len(getIPNet))
50+
}
51+
}
52+
53+
func TestIPNets(t *testing.T) {
54+
var ips []net.IPNet
55+
f := setUpIPNetFlagSet(&ips)
56+
57+
vals := []string{"192.168.1.1/24", "10.0.0.1/16", "fd00:0:0:0:0:0:0:2/64"}
58+
arg := fmt.Sprintf("--cidrs=%s", strings.Join(vals, ","))
59+
err := f.Parse([]string{arg})
60+
if err != nil {
61+
t.Fatal("expected no error; got", err)
62+
}
63+
for i, v := range ips {
64+
if _, cidr, _ := net.ParseCIDR(vals[i]); cidr == nil {
65+
t.Fatalf("invalid string being converted to CIDR: %s", vals[i])
66+
} else if !equalCIDR(*cidr, v) {
67+
t.Fatalf("expected ips[%d] to be %s but got: %s from GetIPSlice", i, vals[i], v)
68+
}
69+
}
70+
}
71+
72+
func TestIPNetDefault(t *testing.T) {
73+
var cidrs []net.IPNet
74+
f := setUpIPNetFlagSetWithDefault(&cidrs)
75+
76+
vals := []string{"192.168.1.1/16", "fd00::/64"}
77+
err := f.Parse([]string{})
78+
if err != nil {
79+
t.Fatal("expected no error; got", err)
80+
}
81+
for i, v := range cidrs {
82+
if _, cidr, _ := net.ParseCIDR(vals[i]); cidr == nil {
83+
t.Fatalf("invalid string being converted to CIDR: %s", vals[i])
84+
} else if !equalCIDR(*cidr, v) {
85+
t.Fatalf("expected cidrs[%d] to be %s but got: %s", i, vals[i], v)
86+
}
87+
}
88+
89+
getIPNet, err := f.GetIPNetSlice("cidrs")
90+
if err != nil {
91+
t.Fatal("got an error from GetIPNetSlice")
92+
}
93+
for i, v := range getIPNet {
94+
if _, cidr, _ := net.ParseCIDR(vals[i]); cidr == nil {
95+
t.Fatalf("invalid string being converted to CIDR: %s", vals[i])
96+
} else if !equalCIDR(*cidr, v) {
97+
t.Fatalf("expected cidrs[%d] to be %s but got: %s", i, vals[i], v)
98+
}
99+
}
100+
}
101+
102+
func TestIPNetWithDefault(t *testing.T) {
103+
var cidrs []net.IPNet
104+
f := setUpIPNetFlagSetWithDefault(&cidrs)
105+
106+
vals := []string{"192.168.1.1/16", "fd00::/64"}
107+
arg := fmt.Sprintf("--cidrs=%s", strings.Join(vals, ","))
108+
err := f.Parse([]string{arg})
109+
if err != nil {
110+
t.Fatal("expected no error; got", err)
111+
}
112+
for i, v := range cidrs {
113+
if _, cidr, _ := net.ParseCIDR(vals[i]); cidr == nil {
114+
t.Fatalf("invalid string being converted to CIDR: %s", vals[i])
115+
} else if !equalCIDR(*cidr, v) {
116+
t.Fatalf("expected cidrs[%d] to be %s but got: %s", i, vals[i], v)
117+
}
118+
}
119+
120+
getIPNet, err := f.GetIPNetSlice("cidrs")
121+
if err != nil {
122+
t.Fatal("got an error from GetIPNetSlice")
123+
}
124+
for i, v := range getIPNet {
125+
if _, cidr, _ := net.ParseCIDR(vals[i]); cidr == nil {
126+
t.Fatalf("invalid string being converted to CIDR: %s", vals[i])
127+
} else if !equalCIDR(*cidr, v) {
128+
t.Fatalf("expected cidrs[%d] to be %s but got: %s", i, vals[i], v)
129+
}
130+
}
131+
}
132+
133+
func TestIPNetCalledTwice(t *testing.T) {
134+
var cidrs []net.IPNet
135+
f := setUpIPNetFlagSet(&cidrs)
136+
137+
in := []string{"192.168.1.2/16,fd00::/64", "10.0.0.1/24"}
138+
139+
expected := []net.IPNet{
140+
getCIDR(net.ParseCIDR("192.168.1.2/16")),
141+
getCIDR(net.ParseCIDR("fd00::/64")),
142+
getCIDR(net.ParseCIDR("10.0.0.1/24")),
143+
}
144+
argfmt := "--cidrs=%s"
145+
arg1 := fmt.Sprintf(argfmt, in[0])
146+
arg2 := fmt.Sprintf(argfmt, in[1])
147+
err := f.Parse([]string{arg1, arg2})
148+
if err != nil {
149+
t.Fatal("expected no error; got", err)
150+
}
151+
for i, v := range cidrs {
152+
if !equalCIDR(expected[i], v) {
153+
t.Fatalf("expected cidrs[%d] to be %s but got: %s", i, expected[i], v)
154+
}
155+
}
156+
}
157+
158+
func TestIPNetBadQuoting(t *testing.T) {
159+
160+
tests := []struct {
161+
Want []net.IPNet
162+
FlagArg []string
163+
}{
164+
{
165+
Want: []net.IPNet{
166+
getCIDR(net.ParseCIDR("a4ab:61d:f03e:5d7d:fad7:d4c2:a1a5:568/128")),
167+
getCIDR(net.ParseCIDR("203.107.49.208/32")),
168+
getCIDR(net.ParseCIDR("14.57.204.90/32")),
169+
},
170+
FlagArg: []string{
171+
"a4ab:61d:f03e:5d7d:fad7:d4c2:a1a5:568/128",
172+
"203.107.49.208/32",
173+
"14.57.204.90/32",
174+
},
175+
},
176+
{
177+
Want: []net.IPNet{
178+
getCIDR(net.ParseCIDR("204.228.73.195/32")),
179+
getCIDR(net.ParseCIDR("86.141.15.94/32")),
180+
},
181+
FlagArg: []string{
182+
"204.228.73.195/32",
183+
"86.141.15.94/32",
184+
},
185+
},
186+
{
187+
Want: []net.IPNet{
188+
getCIDR(net.ParseCIDR("c70c:db36:3001:890f:c6ea:3f9b:7a39:cc3f/128")),
189+
getCIDR(net.ParseCIDR("4d17:1d6e:e699:bd7a:88c5:5e7e:ac6a:4472/128")),
190+
},
191+
FlagArg: []string{
192+
"c70c:db36:3001:890f:c6ea:3f9b:7a39:cc3f/128",
193+
"4d17:1d6e:e699:bd7a:88c5:5e7e:ac6a:4472/128",
194+
},
195+
},
196+
{
197+
Want: []net.IPNet{
198+
getCIDR(net.ParseCIDR("5170:f971:cfac:7be3:512a:af37:952c:bc33/128")),
199+
getCIDR(net.ParseCIDR("93.21.145.140/32")),
200+
getCIDR(net.ParseCIDR("2cac:61d3:c5ff:6caf:73e0:1b1a:c336:c1ca/128")),
201+
},
202+
FlagArg: []string{
203+
" 5170:f971:cfac:7be3:512a:af37:952c:bc33/128 , 93.21.145.140/32 ",
204+
"2cac:61d3:c5ff:6caf:73e0:1b1a:c336:c1ca/128",
205+
},
206+
},
207+
{
208+
Want: []net.IPNet{
209+
getCIDR(net.ParseCIDR("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128")),
210+
getCIDR(net.ParseCIDR("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128")),
211+
getCIDR(net.ParseCIDR("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128")),
212+
getCIDR(net.ParseCIDR("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128")),
213+
},
214+
FlagArg: []string{
215+
`"2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128, 2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128,2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128 "`,
216+
" 2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128"},
217+
},
218+
}
219+
220+
for i, test := range tests {
221+
222+
var cidrs []net.IPNet
223+
f := setUpIPNetFlagSet(&cidrs)
224+
225+
if err := f.Parse([]string{fmt.Sprintf("--cidrs=%s", strings.Join(test.FlagArg, ","))}); err != nil {
226+
t.Fatalf("flag parsing failed with error: %s\nparsing:\t%#v\nwant:\t\t%s",
227+
err, test.FlagArg, test.Want[i])
228+
}
229+
230+
for j, b := range cidrs {
231+
if !equalCIDR(b, test.Want[j]) {
232+
t.Fatalf("bad value parsed for test %d on net.IP %d:\nwant:\t%s\ngot:\t%s", i, j, test.Want[j], b)
233+
}
234+
}
235+
}
236+
}

0 commit comments

Comments
 (0)