Skip to content

Commit 2960f30

Browse files
query: add tests (#36)
* improve comments * query: add tests * query: type assert the multierror
1 parent 2cf5d85 commit 2960f30

File tree

3 files changed

+177
-11
lines changed

3 files changed

+177
-11
lines changed

cmd/root.go

+2-6
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,9 @@ var (
106106
JSONFormat: json,
107107
}).With("@domain", args[0])
108108

109-
// Log the debug state and current log level.
110109
logger.Debug("Debug logging enabled", "debug", debug)
111110
logger.Debug("Log level", "level", logger.GetLevel())
112111

113-
// Log the arguments and flags
114112
logger.Debug("Args", "args", args)
115113
logger.Debug("Flags", "server", server, "qtype", qtype, "debug", debug)
116114

@@ -140,18 +138,17 @@ var (
140138
}
141139
}
142140

143-
// Create a new querier.
144141
querier := query.NewQueryClient(fmt.Sprintf("%s:53", server), logger)
145142

146143
logger.Debug("Creating querier", "server", server, "qtype", qtype, "domain", args[0])
147144

148-
// Prepare query types.
145+
// Create a slice of supported query types to query.
149146
qtypes := make([]uint16, 0, len(query.QueryTypes))
150147
for _, qtype := range query.QueryTypes {
151148
qtypes = append(qtypes, qtype)
152149
}
153150

154-
// Set specific query type if provided.
151+
// Filter down to the specified query type, if provided.
155152
if qtype != "" {
156153
qtypeInt, ok := query.QueryTypes[strings.ToUpper(qtype)]
157154
if !ok {
@@ -160,7 +157,6 @@ var (
160157
qtypes = []uint16{qtypeInt}
161158
}
162159

163-
// Execute the queries.
164160
messages, err := querier.MultiQuery(args[0], qtypes)
165161
if err != nil {
166162
if merr, ok := err.(*multierror.Error); ok {

internal/query/query.go

+10-5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package query
22

33
import (
44
"sync"
5+
"time"
56

67
"github.com/hashicorp/go-hclog"
78
"github.com/hashicorp/go-multierror"
@@ -22,8 +23,13 @@ var (
2223
}
2324
)
2425

26+
type DNSClient interface {
27+
Exchange(*dns.Msg, string) (*dns.Msg, time.Duration, error)
28+
}
29+
2530
type QueryClient struct {
2631
Server string
32+
Client DNSClient
2733
hclog.Logger
2834
}
2935

@@ -43,7 +49,7 @@ func (q *QueryClient) MultiQuery(domain string, qtypes []uint16) ([]*dns.Msg, er
4349
wg.Add(1)
4450
go func(i int, qtype uint16) {
4551
defer wg.Done()
46-
msg, err := q.Query(domain, qtype)
52+
msg, err := q.query(domain, qtype)
4753
mu.Lock()
4854
messages[i] = msg
4955
errors = multierror.Append(errors, err)
@@ -56,15 +62,14 @@ func (q *QueryClient) MultiQuery(domain string, qtypes []uint16) ([]*dns.Msg, er
5662
return messages, errors.ErrorOrNil()
5763
}
5864

59-
// Query performs the DNS query and returns the response and any error encountered.
60-
func (q *QueryClient) Query(domain string, qtype uint16) (*dns.Msg, error) {
65+
// query performs the DNS query and returns the response and any error encountered.
66+
func (q *QueryClient) query(domain string, qtype uint16) (*dns.Msg, error) {
6167
msg := new(dns.Msg)
6268
msg.SetQuestion(dns.Fqdn(domain), qtype)
6369

6470
q.Logger.Debug("Querying DNS server", "server", q.Server, "domain", domain, "qtype", dns.TypeToString[qtype])
6571

66-
client := new(dns.Client)
67-
resp, rtt, err := client.Exchange(msg, q.Server)
72+
resp, rtt, err := q.Client.Exchange(msg, q.Server)
6873
if err != nil {
6974
return nil, err
7075
}

internal/query/query_test.go

+165
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
package query
2+
3+
import (
4+
"fmt"
5+
"testing"
6+
"time"
7+
8+
"github.com/hashicorp/go-hclog"
9+
"github.com/hashicorp/go-multierror"
10+
"github.com/miekg/dns"
11+
"github.com/stretchr/testify/assert"
12+
)
13+
14+
// MockDNSClient is a mock DNS client used for testing purposes.
15+
// It is used to override the Exchange method to capture and introspect the DNS query.
16+
type MockDNSClient struct {
17+
// ReceivedDomain stores the domain name extracted from the request.
18+
// This is used to verify that the correct domain name is passed to the underlying DNS client.
19+
ReceivedDomain string
20+
21+
// QueryType stores the DNS query type (e.g., A, MX) extracted from the request.
22+
// This is used to verify that the correct query type is passed to the underlying DNS client.
23+
QueryType uint16
24+
}
25+
26+
func (m *MockDNSClient) Exchange(req *dns.Msg, addr string) (*dns.Msg, time.Duration, error) {
27+
// If the request contains a question, capture the domain name and query type.
28+
if len(req.Question) > 0 {
29+
m.ReceivedDomain = req.Question[0].Name
30+
m.QueryType = req.Question[0].Qtype
31+
}
32+
33+
// Return a mock response with a fixed round-trip time and no error.
34+
return &dns.Msg{}, time.Microsecond * 42, nil
35+
}
36+
37+
// MockDNSClientWithError is a mock DNS client used for testing purposes.
38+
// It is used to override the Exchange method to return an error.
39+
type MockDNSClientWithError struct{}
40+
41+
func (m *MockDNSClientWithError) Exchange(req *dns.Msg, addr string) (*dns.Msg, time.Duration, error) {
42+
return &dns.Msg{}, time.Microsecond * 42, fmt.Errorf("it's always DNS")
43+
}
44+
45+
func TestQueryClient_Query(t *testing.T) {
46+
// Use a null logger to suppress log output during testing.
47+
client := NewQueryClient("8.8.8.8", hclog.NewNullLogger())
48+
49+
mockDNSClient := &MockDNSClient{}
50+
client.Client = mockDNSClient
51+
52+
_, err := client.query("example.com", dns.TypeA)
53+
54+
assert.Nil(t, err)
55+
56+
assert.Equal(t, "example.com.", mockDNSClient.ReceivedDomain)
57+
assert.Equal(t, dns.TypeA, mockDNSClient.QueryType)
58+
}
59+
60+
func TestQueryClient_Query_Domain(t *testing.T) {
61+
// Use a null logger to suppress log output during testing.
62+
client := NewQueryClient("1.1.1.1", hclog.NewNullLogger())
63+
64+
mockDNSClient := &MockDNSClient{}
65+
client.Client = mockDNSClient
66+
67+
_, err := client.query("abc.xyz", dns.TypeA)
68+
69+
assert.Nil(t, err)
70+
71+
assert.Equal(t, "abc.xyz.", mockDNSClient.ReceivedDomain)
72+
}
73+
74+
func TestQueryClient_Query_QueryType(t *testing.T) {
75+
// Use a null logger to suppress log output during testing.
76+
client := NewQueryClient("1.1.1.1", hclog.NewNullLogger())
77+
78+
mockDNSClient := &MockDNSClient{}
79+
client.Client = mockDNSClient
80+
81+
_, err := client.query("abc.xyz", dns.TypeCNAME)
82+
83+
assert.Nil(t, err)
84+
85+
assert.Equal(t, dns.TypeCNAME, mockDNSClient.QueryType)
86+
}
87+
88+
func TestQueryClient_Query_Error(t *testing.T) {
89+
// Use a null logger to suppress log output during testing.
90+
client := NewQueryClient("8.8.8.8", hclog.NewNullLogger())
91+
92+
mockDNSClientWithError := &MockDNSClientWithError{}
93+
client.Client = mockDNSClientWithError
94+
95+
_, err := client.query("example.com", dns.TypeA)
96+
97+
assert.NotNil(t, err)
98+
assert.Equal(t, "it's always DNS", err.Error())
99+
}
100+
101+
func TestQueryClient_MultiQuery(t *testing.T) {
102+
// Use a null logger to suppress log output during testing.
103+
client := NewQueryClient("8.8.8.8", hclog.NewNullLogger())
104+
105+
mockDNSClient := &MockDNSClient{}
106+
client.Client = mockDNSClient
107+
108+
resp, err := client.MultiQuery("example.com", []uint16{dns.TypeA, dns.TypeMX})
109+
110+
assert.Nil(t, err)
111+
112+
assert.Equal(t, "example.com.", mockDNSClient.ReceivedDomain)
113+
assert.Equal(t, 2, len(resp))
114+
}
115+
116+
func TestQueryClient_MultiQuery_Domain(t *testing.T) {
117+
// Use a null logger to suppress log output during testing.
118+
client := NewQueryClient("1.1.1.1", hclog.NewNullLogger())
119+
120+
mockDNSClient := &MockDNSClient{}
121+
client.Client = mockDNSClient
122+
123+
_, err := client.MultiQuery("abc.xyz", []uint16{dns.TypeA, dns.TypeMX})
124+
125+
assert.Nil(t, err)
126+
127+
assert.Equal(t, "abc.xyz.", mockDNSClient.ReceivedDomain)
128+
}
129+
130+
func TestQueryClient_MultiQuery_Error(t *testing.T) {
131+
// Use a null logger to suppress log output during testing.
132+
client := NewQueryClient("1.1.1.1", hclog.NewNullLogger())
133+
134+
mockDNSClientWithError := &MockDNSClientWithError{}
135+
client.Client = mockDNSClientWithError
136+
137+
_, err := client.MultiQuery("1", []uint16{dns.TypeA, dns.TypeMX})
138+
139+
assert.NotNil(t, err)
140+
}
141+
142+
func TestQueryClient_MultiQuery_TypeAssert_MultiError(t *testing.T) {
143+
// Use a null logger to suppress log output during testing.
144+
client := NewQueryClient("1.1.1.1", hclog.NewNullLogger())
145+
146+
mockDNSClientWithError := &MockDNSClientWithError{}
147+
client.Client = mockDNSClientWithError
148+
149+
_, err := client.MultiQuery("1", []uint16{dns.TypeA, dns.TypeMX})
150+
151+
assert.NotNil(t, err)
152+
153+
// Because MultiQuery returns a multierror.Error, we assert that the error is of that type.
154+
assert.IsType(t, &multierror.Error{}, err)
155+
156+
// We can then type assert the error to a *multierror.Error and introspect the individual errors.
157+
if err, ok := err.(*multierror.Error); ok {
158+
// Assert that two errors are returned (one for each query type).
159+
assert.Equal(t, 2, len(err.Errors))
160+
161+
for _, e := range err.Errors {
162+
assert.Equal(t, "it's always DNS", e.Error())
163+
}
164+
}
165+
}

0 commit comments

Comments
 (0)