Skip to content

Commit cf27920

Browse files
GODRIVER-3284 Allow valid SRV hostnames with less than 3 parts. (#1898) [release/2.0] (#1949)
Co-authored-by: Qingyang Hu <[email protected]>
1 parent 71e025b commit cf27920

File tree

2 files changed

+129
-4
lines changed

2 files changed

+129
-4
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
// Copyright (C) MongoDB, Inc. 2024-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
package connstring
8+
9+
import (
10+
"fmt"
11+
"net"
12+
"testing"
13+
14+
"go.mongodb.org/mongo-driver/v2/internal/assert"
15+
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/dns"
16+
)
17+
18+
func TestInitialDNSSeedlistDiscoveryProse(t *testing.T) {
19+
newTestParser := func(record string) *parser {
20+
return &parser{&dns.Resolver{
21+
LookupSRV: func(_, _, _ string) (string, []*net.SRV, error) {
22+
return "", []*net.SRV{
23+
{
24+
Target: record,
25+
Port: 27017,
26+
},
27+
}, nil
28+
},
29+
LookupTXT: func(string) ([]string, error) {
30+
return nil, nil
31+
},
32+
}}
33+
}
34+
35+
t.Run("1. Allow SRVs with fewer than 3 . separated parts", func(t *testing.T) {
36+
t.Parallel()
37+
38+
cases := []struct {
39+
record string
40+
uri string
41+
}{
42+
{"test_1.localhost", "mongodb+srv://localhost"},
43+
{"test_1.mongo.local", "mongodb+srv://mongo.local"},
44+
}
45+
for _, c := range cases {
46+
c := c
47+
t.Run(c.uri, func(t *testing.T) {
48+
t.Parallel()
49+
50+
_, err := newTestParser(c.record).parse(c.uri)
51+
assert.NoError(t, err, "expected no URI parsing error, got %v", err)
52+
})
53+
}
54+
})
55+
t.Run("2. Throw when return address does not end with SRV domain", func(t *testing.T) {
56+
t.Parallel()
57+
58+
cases := []struct {
59+
record string
60+
uri string
61+
}{
62+
{"localhost.mongodb", "mongodb+srv://localhost"},
63+
{"test_1.evil.local", "mongodb+srv://mongo.local"},
64+
{"blogs.evil.com", "mongodb+srv://blogs.mongodb.com"},
65+
}
66+
for _, c := range cases {
67+
c := c
68+
t.Run(c.uri, func(t *testing.T) {
69+
t.Parallel()
70+
71+
_, err := newTestParser(c.record).parse(c.uri)
72+
assert.ErrorContains(t, err, "Domain suffix from SRV record not matched input domain")
73+
})
74+
}
75+
})
76+
t.Run("3. Throw when return address is identical to SRV hostname", func(t *testing.T) {
77+
t.Parallel()
78+
79+
cases := []struct {
80+
record string
81+
uri string
82+
labels int
83+
}{
84+
{"localhost", "mongodb+srv://localhost", 1},
85+
{"mongo.local", "mongodb+srv://mongo.local", 2},
86+
}
87+
for _, c := range cases {
88+
c := c
89+
t.Run(c.uri, func(t *testing.T) {
90+
t.Parallel()
91+
92+
_, err := newTestParser(c.record).parse(c.uri)
93+
expected := fmt.Sprintf(
94+
"Server record (%d levels) should have more domain levels than parent URI (%d levels)",
95+
c.labels, c.labels,
96+
)
97+
assert.ErrorContains(t, err, expected)
98+
})
99+
}
100+
})
101+
t.Run("4. Throw when return address does not contain . separating shared part of domain", func(t *testing.T) {
102+
t.Parallel()
103+
104+
cases := []struct {
105+
record string
106+
uri string
107+
}{
108+
{"test_1.cluster_1localhost", "mongodb+srv://localhost"},
109+
{"test_1.my_hostmongo.local", "mongodb+srv://mongo.local"},
110+
{"cluster.testmongodb.com", "mongodb+srv://blogs.mongodb.com"},
111+
}
112+
for _, c := range cases {
113+
c := c
114+
t.Run(c.uri, func(t *testing.T) {
115+
t.Parallel()
116+
117+
_, err := newTestParser(c.record).parse(c.uri)
118+
assert.ErrorContains(t, err, "Domain suffix from SRV record not matched input domain")
119+
})
120+
}
121+
})
122+
}

x/mongo/driver/dns/dns.go

+7-4
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,18 @@ func (r *Resolver) fetchSeedlistFromSRV(host string, srvName string, stopOnErr b
113113
func validateSRVResult(recordFromSRV, inputHostName string) error {
114114
separatedInputDomain := strings.Split(strings.ToLower(inputHostName), ".")
115115
separatedRecord := strings.Split(strings.ToLower(recordFromSRV), ".")
116-
if len(separatedRecord) < 2 {
117-
return errors.New("DNS name must contain at least 2 labels")
116+
if l := len(separatedInputDomain); l < 3 && len(separatedRecord) <= l {
117+
return fmt.Errorf("Server record (%d levels) should have more domain levels than parent URI (%d levels)", l, len(separatedRecord))
118118
}
119119
if len(separatedRecord) < len(separatedInputDomain) {
120120
return errors.New("Domain suffix from SRV record not matched input domain")
121121
}
122122

123-
inputDomainSuffix := separatedInputDomain[1:]
124-
domainSuffixOffset := len(separatedRecord) - (len(separatedInputDomain) - 1)
123+
inputDomainSuffix := separatedInputDomain
124+
if len(inputDomainSuffix) > 2 {
125+
inputDomainSuffix = inputDomainSuffix[1:]
126+
}
127+
domainSuffixOffset := len(separatedRecord) - len(inputDomainSuffix)
125128

126129
recordDomainSuffix := separatedRecord[domainSuffixOffset:]
127130
for ix, label := range inputDomainSuffix {

0 commit comments

Comments
 (0)