@@ -23,6 +23,7 @@ import (
23
23
"crypto/tls"
24
24
"crypto/x509"
25
25
"fmt"
26
+ "net"
26
27
"os"
27
28
"strings"
28
29
"testing"
@@ -31,6 +32,7 @@ import (
31
32
"google.golang.org/grpc"
32
33
"google.golang.org/grpc/codes"
33
34
"google.golang.org/grpc/credentials"
35
+ "google.golang.org/grpc/internal/envconfig"
34
36
"google.golang.org/grpc/internal/grpctest"
35
37
"google.golang.org/grpc/internal/stubserver"
36
38
"google.golang.org/grpc/status"
@@ -236,3 +238,160 @@ func (s) TestTLS_CipherSuitesOverridable(t *testing.T) {
236
238
t .Fatalf ("EmptyCall err = %v; want <nil>" , err )
237
239
}
238
240
}
241
+
242
+ // TestTLS_DisabledALPNClient tests the behaviour of TransportCredentials when
243
+ // connecting to a server that doesn't support ALPN.
244
+ func (s ) TestTLS_DisabledALPNClient (t * testing.T ) {
245
+ initialVal := envconfig .EnforceALPNEnabled
246
+ defer func () {
247
+ envconfig .EnforceALPNEnabled = initialVal
248
+ }()
249
+
250
+ tests := []struct {
251
+ name string
252
+ alpnEnforced bool
253
+ wantErr bool
254
+ }{
255
+ {
256
+ name : "enforced" ,
257
+ alpnEnforced : true ,
258
+ wantErr : true ,
259
+ },
260
+ {
261
+ name : "not_enforced" ,
262
+ },
263
+ }
264
+
265
+ for _ , tc := range tests {
266
+ t .Run (tc .name , func (t * testing.T ) {
267
+ envconfig .EnforceALPNEnabled = tc .alpnEnforced
268
+
269
+ listener , err := tls .Listen ("tcp" , "localhost:0" , & tls.Config {
270
+ Certificates : []tls.Certificate {serverCert },
271
+ NextProtos : []string {}, // Empty list indicates ALPN is disabled.
272
+ })
273
+ if err != nil {
274
+ t .Fatalf ("Error starting TLS server: %v" , err )
275
+ }
276
+
277
+ errCh := make (chan error , 1 )
278
+ go func () {
279
+ conn , err := listener .Accept ()
280
+ if err != nil {
281
+ errCh <- fmt .Errorf ("listener.Accept returned error: %v" , err )
282
+ } else {
283
+ // The first write to the TLS listener initiates the TLS handshake.
284
+ conn .Write ([]byte ("Hello, World!" ))
285
+ conn .Close ()
286
+ }
287
+ close (errCh )
288
+ }()
289
+
290
+ serverAddr := listener .Addr ().String ()
291
+ conn , err := net .Dial ("tcp" , serverAddr )
292
+ if err != nil {
293
+ t .Fatalf ("net.Dial(%s) failed: %v" , serverAddr , err )
294
+ }
295
+ defer conn .Close ()
296
+
297
+ ctx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
298
+ defer cancel ()
299
+
300
+ clientCfg := tls.Config {
301
+ ServerName : serverName ,
302
+ RootCAs : certPool ,
303
+ NextProtos : []string {"h2" },
304
+ }
305
+ _ , _ , err = credentials .NewTLS (& clientCfg ).ClientHandshake (ctx , serverName , conn )
306
+
307
+ if gotErr := (err != nil ); gotErr != tc .wantErr {
308
+ t .Errorf ("ClientHandshake returned unexpected error: got=%v, want=%t" , err , tc .wantErr )
309
+ }
310
+
311
+ select {
312
+ case err := <- errCh :
313
+ if err != nil {
314
+ t .Fatalf ("Unexpected error received from server: %v" , err )
315
+ }
316
+ case <- ctx .Done ():
317
+ t .Fatalf ("Timeout waiting for error from server" )
318
+ }
319
+ })
320
+ }
321
+ }
322
+
323
+ // TestTLS_DisabledALPNServer tests the behaviour of TransportCredentials when
324
+ // accepting a request from a client that doesn't support ALPN.
325
+ func (s ) TestTLS_DisabledALPNServer (t * testing.T ) {
326
+ initialVal := envconfig .EnforceALPNEnabled
327
+ defer func () {
328
+ envconfig .EnforceALPNEnabled = initialVal
329
+ }()
330
+
331
+ tests := []struct {
332
+ name string
333
+ alpnEnforced bool
334
+ wantErr bool
335
+ }{
336
+ {
337
+ name : "enforced" ,
338
+ alpnEnforced : true ,
339
+ wantErr : true ,
340
+ },
341
+ {
342
+ name : "not_enforced" ,
343
+ },
344
+ }
345
+
346
+ for _ , tc := range tests {
347
+ t .Run (tc .name , func (t * testing.T ) {
348
+ envconfig .EnforceALPNEnabled = tc .alpnEnforced
349
+
350
+ listener , err := net .Listen ("tcp" , "localhost:0" )
351
+ if err != nil {
352
+ t .Fatalf ("Error starting server: %v" , err )
353
+ }
354
+
355
+ errCh := make (chan error , 1 )
356
+ go func () {
357
+ conn , err := listener .Accept ()
358
+ if err != nil {
359
+ errCh <- fmt .Errorf ("listener.Accept returned error: %v" , err )
360
+ return
361
+ }
362
+ defer conn .Close ()
363
+ serverCfg := tls.Config {
364
+ Certificates : []tls.Certificate {serverCert },
365
+ NextProtos : []string {"h2" },
366
+ }
367
+ _ , _ , err = credentials .NewTLS (& serverCfg ).ServerHandshake (conn )
368
+ if gotErr := (err != nil ); gotErr != tc .wantErr {
369
+ t .Errorf ("ServerHandshake returned unexpected error: got=%v, want=%t" , err , tc .wantErr )
370
+ }
371
+ close (errCh )
372
+ }()
373
+
374
+ serverAddr := listener .Addr ().String ()
375
+ clientCfg := & tls.Config {
376
+ Certificates : []tls.Certificate {serverCert },
377
+ NextProtos : []string {}, // Empty list indicates ALPN is disabled.
378
+ RootCAs : certPool ,
379
+ ServerName : serverName ,
380
+ }
381
+ conn , err := tls .Dial ("tcp" , serverAddr , clientCfg )
382
+ if err != nil {
383
+ t .Fatalf ("tls.Dial(%s) failed: %v" , serverAddr , err )
384
+ }
385
+ defer conn .Close ()
386
+
387
+ select {
388
+ case <- time .After (defaultTestTimeout ):
389
+ t .Fatal ("Timed out waiting for completion" )
390
+ case err := <- errCh :
391
+ if err != nil {
392
+ t .Fatalf ("Unexpected server error: %v" , err )
393
+ }
394
+ }
395
+ })
396
+ }
397
+ }
0 commit comments