Skip to content

Commit c0a4cb2

Browse files
fix: properly format IPv6 DNS server addresses when appending port (#51)
chore: add test for EnsureDNSAddress helper function Signed-off-by: Rui Chen <[email protected]> Co-authored-by: Bruno Schaatsbergen <[email protected]>
1 parent e0643d9 commit c0a4cb2

File tree

2 files changed

+35
-5
lines changed

2 files changed

+35
-5
lines changed

cmd/root.go

+16-5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package cmd
22

33
import (
44
"fmt"
5+
"net"
56
"os"
67
"runtime"
78
"sort"
@@ -29,6 +30,20 @@ var (
2930
qtype string
3031
)
3132

33+
// EnsureDNSAddress formats the DNS server address properly.
34+
func EnsureDNSAddress(server string) string {
35+
if strings.Contains(server, "]") || strings.Contains(server, ":") && net.ParseIP(server) == nil {
36+
return server
37+
}
38+
39+
ip := net.ParseIP(server)
40+
if ip != nil && ip.To4() == nil { // It's IPv6 (and not IPv4)
41+
return "[" + server + "]:53"
42+
}
43+
// Otherwise, assume IPv4 or hostname, so append port normally.
44+
return server + ":53"
45+
}
46+
3247
func NewRootCommand() *cobra.Command {
3348
cmd := &cobra.Command{
3449
Use: "zns",
@@ -138,11 +153,7 @@ func NewRootCommand() *cobra.Command {
138153
}
139154
}
140155

141-
// If the server address does not already include a port,
142-
// append the default DNS port (53) to it.
143-
if !strings.Contains(server, ":") {
144-
server = fmt.Sprintf("%s:53", server)
145-
}
156+
server = EnsureDNSAddress(server)
146157

147158
querier := query.NewQueryClient(server, new(dns.Client), logger)
148159

cmd/root_test.go

+19
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,22 @@ func Test_Cmd_LogFile_Debug(t *testing.T) {
187187
assert.Contains(t, string(logFile), "A |example.com. |01m00s |93.184.216.34")
188188
assert.Contains(t, string(logFile), "CNAME |example.com. |01m00s |example.org.")
189189
}
190+
191+
func TestEnsureDNSAddress(t *testing.T) {
192+
testCases := []struct {
193+
input string
194+
expected string
195+
}{
196+
{"127.0.0.1", "127.0.0.1:53"},
197+
{"2001:558:feed::1", "[2001:558:feed::1]:53"},
198+
{"[2001:558:feed::1]:53", "[2001:558:feed::1]:53"},
199+
{"example.com", "example.com:53"},
200+
}
201+
202+
for _, tc := range testCases {
203+
t.Run(fmt.Sprintf("input=%s", tc.input), func(t *testing.T) {
204+
result := EnsureDNSAddress(tc.input)
205+
assert.Equal(t, tc.expected, result)
206+
})
207+
}
208+
}

0 commit comments

Comments
 (0)