Skip to content

Commit ff3a502

Browse files
add functional tests (#38)
* add functional tests * fix incorrect loopback address * fix assertion * lint: ignore funlen for NewRootCommand * address linter warnings
1 parent 0c7fd0a commit ff3a502

File tree

3 files changed

+214
-14
lines changed

3 files changed

+214
-14
lines changed

.golangci.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ issues:
88
linters:
99
- funlen
1010

11+
- source: "^func NewRootCommand"
12+
linters:
13+
- funlen
14+
1115
linters:
1216
disable-all: true
1317
enable:

cmd/root.go

+21-14
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@ const (
2121
)
2222

2323
var (
24-
version string
25-
24+
version = "dev"
2625
debug bool
2726
json bool
2827
noColor bool
28+
server string
29+
qtype string
30+
)
2931

30-
server string
31-
qtype string
32-
33-
rootCmd = &cobra.Command{
32+
func NewRootCommand() *cobra.Command {
33+
cmd := &cobra.Command{
3434
Use: "zns",
3535
Short: "zns is a command-line utility for querying DNS records and displaying them in human- or machine-readable formats.",
3636
Long: "zns is a command-line utility for querying DNS records, displaying them in a human-readable, colored format that includes type, name, TTL, and value. It supports various DNS record types, concurrent queries for improved performance, JSON output format for machine-readable results, and options to write output to a file or query a specific DNS server.",
@@ -138,7 +138,13 @@ var (
138138
}
139139
}
140140

141-
querier := query.NewQueryClient(fmt.Sprintf("%s:53", server), new(dns.Client), logger)
141+
// If the server address does not already include a port,
142+
// append the default DNS port (53) to it.
143+
if !strings.Contains(server, ":") {
144+
server = fmt.Sprintf("%s:53", server)
145+
}
146+
147+
querier := query.NewQueryClient(server, new(dns.Client), logger)
142148

143149
logger.Debug("Creating querier", "server", server, "qtype", qtype, "domain", args[0])
144150

@@ -181,17 +187,18 @@ var (
181187
return nil
182188
},
183189
}
184-
)
185190

186-
func init() {
187-
rootCmd.CompletionOptions.DisableDefaultCmd = true
188-
rootCmd.Flags().StringVarP(&server, "server", "s", "", "DNS server to query")
189-
rootCmd.Flags().StringVarP(&qtype, "query-type", "q", "", "DNS query type")
190-
rootCmd.Flags().BoolVar(&debug, "debug", false, "If set, debug output is printed")
191-
rootCmd.Flags().BoolVar(&json, "json", false, "If set, output is printed in JSON format.")
191+
cmd.CompletionOptions.DisableDefaultCmd = true
192+
cmd.Flags().StringVarP(&server, "server", "s", "", "DNS server to query")
193+
cmd.Flags().StringVarP(&qtype, "query-type", "q", "", "DNS query type")
194+
cmd.Flags().BoolVar(&debug, "debug", false, "Enable debug output")
195+
cmd.Flags().BoolVar(&json, "json", false, "Output in JSON format")
196+
197+
return cmd
192198
}
193199

194200
func Execute() {
201+
rootCmd := NewRootCommand()
195202
if err := rootCmd.Execute(); err != nil {
196203
if merr, ok := err.(*multierror.Error); ok {
197204
for _, e := range merr.Errors {

cmd/root_test.go

+189
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
package cmd
2+
3+
import (
4+
"fmt"
5+
"log"
6+
"net"
7+
"os"
8+
"testing"
9+
10+
"github.com/miekg/dns"
11+
"github.com/stretchr/testify/assert"
12+
)
13+
14+
const (
15+
DNSServerPort = 53535
16+
)
17+
18+
func TestMain(m *testing.M) {
19+
go startDNSServer()
20+
21+
code := m.Run()
22+
os.Exit(code)
23+
}
24+
25+
func startDNSServer() {
26+
dns.HandleFunc(".", dnsHandler)
27+
28+
go func() {
29+
err := dns.ListenAndServe(fmt.Sprintf(":%d", DNSServerPort), "udp", nil)
30+
if err != nil {
31+
log.Fatalf("Failed to start DNS server: %v", err)
32+
}
33+
}()
34+
}
35+
36+
func dnsHandler(w dns.ResponseWriter, r *dns.Msg) {
37+
msg := dns.Msg{}
38+
msg.SetReply(r)
39+
40+
// Simulate an A record response for "example.com"
41+
if len(r.Question) > 0 {
42+
q := r.Question[0]
43+
if q.Name == "example.com." && q.Qtype == dns.TypeA {
44+
// Example A record response
45+
a := &dns.A{
46+
Hdr: dns.RR_Header{
47+
Name: "example.com.",
48+
Rrtype: dns.TypeA,
49+
Class: dns.ClassINET,
50+
Ttl: 60,
51+
},
52+
A: net.ParseIP("93.184.216.34"),
53+
}
54+
msg.Answer = append(msg.Answer, a)
55+
}
56+
// Simulate a CNAME record response for "example.com"
57+
if q.Name == "example.com." && q.Qtype == dns.TypeCNAME {
58+
// Example CNAME record response
59+
cname := &dns.CNAME{
60+
Hdr: dns.RR_Header{
61+
Name: "example.com.",
62+
Rrtype: dns.TypeCNAME,
63+
Class: dns.ClassINET,
64+
Ttl: 60,
65+
},
66+
Target: "example.org.",
67+
}
68+
msg.Answer = append(msg.Answer, cname)
69+
}
70+
}
71+
72+
_ = w.WriteMsg(&msg)
73+
}
74+
75+
func Test_Cmd(t *testing.T) {
76+
t.Setenv("NO_COLOR", "1") // Disable color codes for easier testing
77+
78+
rootCmd := NewRootCommand()
79+
rootCmd.SetArgs([]string{"example.com", "--server", fmt.Sprintf("127.0.0.1:%d", DNSServerPort)})
80+
81+
err := rootCmd.Execute()
82+
83+
assert.NoError(t, err)
84+
}
85+
86+
func Test_Cmd_Error(t *testing.T) {
87+
t.Setenv("NO_COLOR", "1") // Disable color codes for easier testing
88+
89+
rootCmd := NewRootCommand()
90+
rootCmd.SetArgs([]string{"--server", fmt.Sprintf("127.0.0.1:%d", DNSServerPort)})
91+
92+
err := rootCmd.Execute()
93+
94+
assert.Error(t, err)
95+
assert.Equal(t, "error: domain name is required", err.Error())
96+
}
97+
98+
func Test_Cmd_JSON(t *testing.T) {
99+
t.Setenv("NO_COLOR", "1") // Disable color codes for easier testing
100+
101+
rootCmd := NewRootCommand()
102+
rootCmd.SetArgs([]string{"example.com", "--json", "--server", fmt.Sprintf("127.0.0.1:%d", DNSServerPort)})
103+
104+
err := rootCmd.Execute()
105+
106+
assert.NoError(t, err)
107+
}
108+
109+
func Test_Cmd_QueryType(t *testing.T) {
110+
t.Setenv("NO_COLOR", "1") // Disable color codes for easier testing
111+
112+
rootCmd := NewRootCommand()
113+
rootCmd.SetArgs([]string{"example.com", "--server", fmt.Sprintf("127.0.0.1:%d", DNSServerPort), "--query-type", "A"})
114+
115+
err := rootCmd.Execute()
116+
117+
assert.NoError(t, err)
118+
}
119+
120+
func Test_Cmd_Debug(t *testing.T) {
121+
t.Setenv("NO_COLOR", "1") // Disable color codes for easier testing
122+
123+
rootCmd := NewRootCommand()
124+
rootCmd.SetArgs([]string{"example.com", "--debug", "--server", fmt.Sprintf("127.0.0.1:%d", DNSServerPort), "--query-type", "A"})
125+
126+
err := rootCmd.Execute()
127+
128+
assert.NoError(t, err)
129+
}
130+
131+
func Test_Cmd_LogFile(t *testing.T) {
132+
t.Setenv("NO_COLOR", "1") // Disable color codes for easier testing
133+
134+
file, err := os.CreateTemp(t.TempDir(), "zns")
135+
if err != nil {
136+
t.Fatal(err)
137+
}
138+
defer os.Remove(file.Name())
139+
140+
t.Setenv("ZNS_LOG_FILE", file.Name())
141+
142+
rootCmd := NewRootCommand()
143+
rootCmd.SetArgs([]string{"example.com", "--server", fmt.Sprintf("127.0.0.1:%d", DNSServerPort)})
144+
145+
err = rootCmd.Execute()
146+
assert.NoError(t, err)
147+
148+
assert.FileExists(t, file.Name())
149+
150+
logFile, err := os.ReadFile(file.Name())
151+
if err != nil {
152+
t.Fatal(err)
153+
}
154+
155+
assert.Contains(t, string(logFile), "A example.com. 01m00s 93.184.216.34")
156+
assert.Contains(t, string(logFile), "CNAME example.com. 01m00s example.org.")
157+
}
158+
159+
func Test_Cmd_LogFile_Debug(t *testing.T) {
160+
t.Setenv("NO_COLOR", "1") // Disable color codes for easier testing
161+
162+
file, err := os.CreateTemp(t.TempDir(), "zns")
163+
if err != nil {
164+
t.Fatal(err)
165+
}
166+
defer os.Remove(file.Name())
167+
168+
t.Setenv("ZNS_LOG_FILE", file.Name())
169+
170+
rootCmd := NewRootCommand()
171+
rootCmd.SetArgs([]string{"example.com", "--debug", "--server", fmt.Sprintf("127.0.0.1:%d", DNSServerPort)})
172+
173+
err = rootCmd.Execute()
174+
assert.NoError(t, err)
175+
176+
assert.FileExists(t, file.Name())
177+
178+
logFile, err := os.ReadFile(file.Name())
179+
if err != nil {
180+
t.Fatal(err)
181+
}
182+
183+
assert.Contains(t, string(logFile), "Querying DNS server: @domain=example.com server=127.0.0.1:53535 domain=example.com qtype=A")
184+
assert.Contains(t, string(logFile), "Querying DNS server: @domain=example.com server=127.0.0.1:53535 domain=example.com qtype=CNAME")
185+
assert.Contains(t, string(logFile), "Received DNS response: @domain=example.com server=127.0.0.1:53535 domain=example.com qtype=A rcode=NOERROR")
186+
assert.Contains(t, string(logFile), "Received DNS response: @domain=example.com server=127.0.0.1:53535 domain=example.com qtype=CNAME rcode=NOERROR")
187+
assert.Contains(t, string(logFile), "A |example.com. |01m00s |93.184.216.34")
188+
assert.Contains(t, string(logFile), "CNAME |example.com. |01m00s |example.org.")
189+
}

0 commit comments

Comments
 (0)