Skip to content

Commit d0cc7fe

Browse files
authored
Merge pull request #1019 from govau/newmockbrokerlistener
Add NewMockBrokerListener() so that it's possible to test TLS connections
2 parents 08ccbbb + f933fb4 commit d0cc7fe

File tree

2 files changed

+216
-4
lines changed

2 files changed

+216
-4
lines changed

client_tls_test.go

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
package sarama
2+
3+
import (
4+
"math/big"
5+
"net"
6+
"testing"
7+
"time"
8+
9+
"crypto/rand"
10+
"crypto/rsa"
11+
"crypto/tls"
12+
"crypto/x509"
13+
"crypto/x509/pkix"
14+
)
15+
16+
func TestTLS(t *testing.T) {
17+
cakey, err := rsa.GenerateKey(rand.Reader, 2048)
18+
if err != nil {
19+
t.Fatal(err)
20+
}
21+
22+
clientkey, err := rsa.GenerateKey(rand.Reader, 2048)
23+
if err != nil {
24+
t.Fatal(err)
25+
}
26+
27+
hostkey, err := rsa.GenerateKey(rand.Reader, 2048)
28+
if err != nil {
29+
t.Fatal(err)
30+
}
31+
32+
nvb := time.Now().Add(-1 * time.Hour)
33+
nva := time.Now().Add(1 * time.Hour)
34+
35+
caTemplate := &x509.Certificate{
36+
Subject: pkix.Name{CommonName: "ca"},
37+
Issuer: pkix.Name{CommonName: "ca"},
38+
SerialNumber: big.NewInt(0),
39+
NotAfter: nva,
40+
NotBefore: nvb,
41+
IsCA: true,
42+
BasicConstraintsValid: true,
43+
KeyUsage: x509.KeyUsageCertSign,
44+
}
45+
caDer, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &cakey.PublicKey, cakey)
46+
if err != nil {
47+
t.Fatal(err)
48+
}
49+
caFinalCert, err := x509.ParseCertificate(caDer)
50+
if err != nil {
51+
t.Fatal(err)
52+
}
53+
54+
hostDer, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{
55+
Subject: pkix.Name{CommonName: "host"},
56+
Issuer: pkix.Name{CommonName: "ca"},
57+
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)},
58+
SerialNumber: big.NewInt(0),
59+
NotAfter: nva,
60+
NotBefore: nvb,
61+
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
62+
}, caFinalCert, &hostkey.PublicKey, cakey)
63+
if err != nil {
64+
t.Fatal(err)
65+
}
66+
67+
clientDer, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{
68+
Subject: pkix.Name{CommonName: "client"},
69+
Issuer: pkix.Name{CommonName: "ca"},
70+
SerialNumber: big.NewInt(0),
71+
NotAfter: nva,
72+
NotBefore: nvb,
73+
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
74+
}, caFinalCert, &clientkey.PublicKey, cakey)
75+
if err != nil {
76+
t.Fatal(err)
77+
}
78+
79+
pool := x509.NewCertPool()
80+
pool.AddCert(caFinalCert)
81+
82+
systemCerts, err := x509.SystemCertPool()
83+
if err != nil {
84+
t.Fatal(err)
85+
}
86+
87+
// Keep server the same - it's the client that we're testing
88+
serverTLSConfig := &tls.Config{
89+
Certificates: []tls.Certificate{tls.Certificate{
90+
Certificate: [][]byte{hostDer},
91+
PrivateKey: hostkey,
92+
}},
93+
ClientAuth: tls.RequireAndVerifyClientCert,
94+
ClientCAs: pool,
95+
}
96+
97+
for _, tc := range []struct {
98+
Succeed bool
99+
Server, Client *tls.Config
100+
}{
101+
{ // Verify client fails if wrong CA cert pool is specified
102+
Succeed: false,
103+
Server: serverTLSConfig,
104+
Client: &tls.Config{
105+
RootCAs: systemCerts,
106+
Certificates: []tls.Certificate{tls.Certificate{
107+
Certificate: [][]byte{clientDer},
108+
PrivateKey: clientkey,
109+
}},
110+
},
111+
},
112+
{ // Verify client fails if wrong key is specified
113+
Succeed: false,
114+
Server: serverTLSConfig,
115+
Client: &tls.Config{
116+
RootCAs: pool,
117+
Certificates: []tls.Certificate{tls.Certificate{
118+
Certificate: [][]byte{clientDer},
119+
PrivateKey: hostkey,
120+
}},
121+
},
122+
},
123+
{ // Verify client fails if wrong cert is specified
124+
Succeed: false,
125+
Server: serverTLSConfig,
126+
Client: &tls.Config{
127+
RootCAs: pool,
128+
Certificates: []tls.Certificate{tls.Certificate{
129+
Certificate: [][]byte{hostDer},
130+
PrivateKey: clientkey,
131+
}},
132+
},
133+
},
134+
{ // Verify client fails if no CAs are specified
135+
Succeed: false,
136+
Server: serverTLSConfig,
137+
Client: &tls.Config{
138+
Certificates: []tls.Certificate{tls.Certificate{
139+
Certificate: [][]byte{clientDer},
140+
PrivateKey: clientkey,
141+
}},
142+
},
143+
},
144+
{ // Verify client fails if no keys are specified
145+
Succeed: false,
146+
Server: serverTLSConfig,
147+
Client: &tls.Config{
148+
RootCAs: pool,
149+
},
150+
},
151+
{ // Finally, verify it all works happily with client and server cert in place
152+
Succeed: true,
153+
Server: serverTLSConfig,
154+
Client: &tls.Config{
155+
RootCAs: pool,
156+
Certificates: []tls.Certificate{tls.Certificate{
157+
Certificate: [][]byte{clientDer},
158+
PrivateKey: clientkey,
159+
}},
160+
},
161+
},
162+
} {
163+
doListenerTLSTest(t, tc.Succeed, tc.Server, tc.Client)
164+
}
165+
}
166+
167+
func doListenerTLSTest(t *testing.T, expectSuccess bool, serverConfig, clientConfig *tls.Config) {
168+
serverConfig.BuildNameToCertificate()
169+
clientConfig.BuildNameToCertificate()
170+
171+
seedListener, err := tls.Listen("tcp", "127.0.0.1:0", serverConfig)
172+
if err != nil {
173+
t.Fatal("cannot open listener", err)
174+
}
175+
176+
var childT *testing.T
177+
if expectSuccess {
178+
childT = t
179+
} else {
180+
childT = &testing.T{} // we want to swallow errors
181+
}
182+
183+
seedBroker := NewMockBrokerListener(childT, 1, seedListener)
184+
defer seedBroker.Close()
185+
186+
seedBroker.Returns(new(MetadataResponse))
187+
188+
config := NewConfig()
189+
config.Net.TLS.Enable = true
190+
config.Net.TLS.Config = clientConfig
191+
192+
client, err := NewClient([]string{seedBroker.Addr()}, config)
193+
if err == nil {
194+
safeClose(t, client)
195+
}
196+
197+
if expectSuccess {
198+
if err != nil {
199+
t.Fatal(err)
200+
}
201+
} else {
202+
if err == nil {
203+
t.Fatal("expected failure")
204+
}
205+
}
206+
}

mockbroker.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,15 @@ func NewMockBroker(t TestReporter, brokerID int32) *MockBroker {
288288
// NewMockBrokerAddr behaves like newMockBroker but listens on the address you give
289289
// it rather than just some ephemeral port.
290290
func NewMockBrokerAddr(t TestReporter, brokerID int32, addr string) *MockBroker {
291+
listener, err := net.Listen("tcp", addr)
292+
if err != nil {
293+
t.Fatal(err)
294+
}
295+
return NewMockBrokerListener(t, brokerID, listener)
296+
}
297+
298+
// NewMockBrokerListener behaves like newMockBrokerAddr but accepts connections on the listener specified.
299+
func NewMockBrokerListener(t TestReporter, brokerID int32, listener net.Listener) *MockBroker {
291300
var err error
292301

293302
broker := &MockBroker{
@@ -296,13 +305,10 @@ func NewMockBrokerAddr(t TestReporter, brokerID int32, addr string) *MockBroker
296305
t: t,
297306
brokerID: brokerID,
298307
expectations: make(chan encoder, 512),
308+
listener: listener,
299309
}
300310
broker.handler = broker.defaultRequestHandler
301311

302-
broker.listener, err = net.Listen("tcp", addr)
303-
if err != nil {
304-
t.Fatal(err)
305-
}
306312
Logger.Printf("*** mockbroker/%d listening on %s\n", brokerID, broker.listener.Addr().String())
307313
_, portStr, err := net.SplitHostPort(broker.listener.Addr().String())
308314
if err != nil {

0 commit comments

Comments
 (0)