Skip to content

Commit 48b6b11

Browse files
authored
credentials/tls: reject connections with ALPN disabled (#7184)
1 parent 0a0abfa commit 48b6b11

File tree

3 files changed

+198
-1
lines changed

3 files changed

+198
-1
lines changed

credentials/tls.go

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,13 @@ import (
2727
"net/url"
2828
"os"
2929

30+
"google.golang.org/grpc/grpclog"
3031
credinternal "google.golang.org/grpc/internal/credentials"
32+
"google.golang.org/grpc/internal/envconfig"
3133
)
3234

35+
var logger = grpclog.Component("credentials")
36+
3337
// TLSInfo contains the auth information for a TLS authenticated connection.
3438
// It implements the AuthInfo interface.
3539
type TLSInfo struct {
@@ -112,6 +116,22 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawCon
112116
conn.Close()
113117
return nil, nil, ctx.Err()
114118
}
119+
120+
// The negotiated protocol can be either of the following:
121+
// 1. h2: When the server supports ALPN. Only HTTP/2 can be negotiated since
122+
// it is the only protocol advertised by the client during the handshake.
123+
// The tls library ensures that the server chooses a protocol advertised
124+
// by the client.
125+
// 2. "" (empty string): If the server doesn't support ALPN. ALPN is a requirement
126+
// for using HTTP/2 over TLS. We can terminate the connection immediately.
127+
np := conn.ConnectionState().NegotiatedProtocol
128+
if np == "" {
129+
if envconfig.EnforceALPNEnabled {
130+
conn.Close()
131+
return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property")
132+
}
133+
logger.Warningf("Allowing TLS connection to server %q with ALPN disabled. TLS connections to servers with ALPN disabled will be disallowed in future grpc-go releases", cfg.ServerName)
134+
}
115135
tlsInfo := TLSInfo{
116136
State: conn.ConnectionState(),
117137
CommonAuthInfo: CommonAuthInfo{
@@ -131,8 +151,20 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
131151
conn.Close()
132152
return nil, nil, err
133153
}
154+
cs := conn.ConnectionState()
155+
// The negotiated application protocol can be empty only if the client doesn't
156+
// support ALPN. In such cases, we can close the connection since ALPN is required
157+
// for using HTTP/2 over TLS.
158+
if cs.NegotiatedProtocol == "" {
159+
if envconfig.EnforceALPNEnabled {
160+
conn.Close()
161+
return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property")
162+
} else if logger.V(2) {
163+
logger.Info("Allowing TLS connection from client with ALPN disabled. TLS connections with ALPN disabled will be disallowed in future grpc-go releases")
164+
}
165+
}
134166
tlsInfo := TLSInfo{
135-
State: conn.ConnectionState(),
167+
State: cs,
136168
CommonAuthInfo: CommonAuthInfo{
137169
SecurityLevel: PrivacyAndIntegrity,
138170
},

credentials/tls_ext_test.go

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"crypto/tls"
2424
"crypto/x509"
2525
"fmt"
26+
"net"
2627
"os"
2728
"strings"
2829
"testing"
@@ -31,6 +32,7 @@ import (
3132
"google.golang.org/grpc"
3233
"google.golang.org/grpc/codes"
3334
"google.golang.org/grpc/credentials"
35+
"google.golang.org/grpc/internal/envconfig"
3436
"google.golang.org/grpc/internal/grpctest"
3537
"google.golang.org/grpc/internal/stubserver"
3638
"google.golang.org/grpc/status"
@@ -236,3 +238,160 @@ func (s) TestTLS_CipherSuitesOverridable(t *testing.T) {
236238
t.Fatalf("EmptyCall err = %v; want <nil>", err)
237239
}
238240
}
241+
242+
// TestTLS_DisabledALPNClient tests the behaviour of TransportCredentials when
243+
// connecting to a server that doesn't support ALPN.
244+
func (s) TestTLS_DisabledALPNClient(t *testing.T) {
245+
initialVal := envconfig.EnforceALPNEnabled
246+
defer func() {
247+
envconfig.EnforceALPNEnabled = initialVal
248+
}()
249+
250+
tests := []struct {
251+
name string
252+
alpnEnforced bool
253+
wantErr bool
254+
}{
255+
{
256+
name: "enforced",
257+
alpnEnforced: true,
258+
wantErr: true,
259+
},
260+
{
261+
name: "not_enforced",
262+
},
263+
}
264+
265+
for _, tc := range tests {
266+
t.Run(tc.name, func(t *testing.T) {
267+
envconfig.EnforceALPNEnabled = tc.alpnEnforced
268+
269+
listener, err := tls.Listen("tcp", "localhost:0", &tls.Config{
270+
Certificates: []tls.Certificate{serverCert},
271+
NextProtos: []string{}, // Empty list indicates ALPN is disabled.
272+
})
273+
if err != nil {
274+
t.Fatalf("Error starting TLS server: %v", err)
275+
}
276+
277+
errCh := make(chan error, 1)
278+
go func() {
279+
conn, err := listener.Accept()
280+
if err != nil {
281+
errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
282+
} else {
283+
// The first write to the TLS listener initiates the TLS handshake.
284+
conn.Write([]byte("Hello, World!"))
285+
conn.Close()
286+
}
287+
close(errCh)
288+
}()
289+
290+
serverAddr := listener.Addr().String()
291+
conn, err := net.Dial("tcp", serverAddr)
292+
if err != nil {
293+
t.Fatalf("net.Dial(%s) failed: %v", serverAddr, err)
294+
}
295+
defer conn.Close()
296+
297+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
298+
defer cancel()
299+
300+
clientCfg := tls.Config{
301+
ServerName: serverName,
302+
RootCAs: certPool,
303+
NextProtos: []string{"h2"},
304+
}
305+
_, _, err = credentials.NewTLS(&clientCfg).ClientHandshake(ctx, serverName, conn)
306+
307+
if gotErr := (err != nil); gotErr != tc.wantErr {
308+
t.Errorf("ClientHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr)
309+
}
310+
311+
select {
312+
case err := <-errCh:
313+
if err != nil {
314+
t.Fatalf("Unexpected error received from server: %v", err)
315+
}
316+
case <-ctx.Done():
317+
t.Fatalf("Timeout waiting for error from server")
318+
}
319+
})
320+
}
321+
}
322+
323+
// TestTLS_DisabledALPNServer tests the behaviour of TransportCredentials when
324+
// accepting a request from a client that doesn't support ALPN.
325+
func (s) TestTLS_DisabledALPNServer(t *testing.T) {
326+
initialVal := envconfig.EnforceALPNEnabled
327+
defer func() {
328+
envconfig.EnforceALPNEnabled = initialVal
329+
}()
330+
331+
tests := []struct {
332+
name string
333+
alpnEnforced bool
334+
wantErr bool
335+
}{
336+
{
337+
name: "enforced",
338+
alpnEnforced: true,
339+
wantErr: true,
340+
},
341+
{
342+
name: "not_enforced",
343+
},
344+
}
345+
346+
for _, tc := range tests {
347+
t.Run(tc.name, func(t *testing.T) {
348+
envconfig.EnforceALPNEnabled = tc.alpnEnforced
349+
350+
listener, err := net.Listen("tcp", "localhost:0")
351+
if err != nil {
352+
t.Fatalf("Error starting server: %v", err)
353+
}
354+
355+
errCh := make(chan error, 1)
356+
go func() {
357+
conn, err := listener.Accept()
358+
if err != nil {
359+
errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
360+
return
361+
}
362+
defer conn.Close()
363+
serverCfg := tls.Config{
364+
Certificates: []tls.Certificate{serverCert},
365+
NextProtos: []string{"h2"},
366+
}
367+
_, _, err = credentials.NewTLS(&serverCfg).ServerHandshake(conn)
368+
if gotErr := (err != nil); gotErr != tc.wantErr {
369+
t.Errorf("ServerHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr)
370+
}
371+
close(errCh)
372+
}()
373+
374+
serverAddr := listener.Addr().String()
375+
clientCfg := &tls.Config{
376+
Certificates: []tls.Certificate{serverCert},
377+
NextProtos: []string{}, // Empty list indicates ALPN is disabled.
378+
RootCAs: certPool,
379+
ServerName: serverName,
380+
}
381+
conn, err := tls.Dial("tcp", serverAddr, clientCfg)
382+
if err != nil {
383+
t.Fatalf("tls.Dial(%s) failed: %v", serverAddr, err)
384+
}
385+
defer conn.Close()
386+
387+
select {
388+
case <-time.After(defaultTestTimeout):
389+
t.Fatal("Timed out waiting for completion")
390+
case err := <-errCh:
391+
if err != nil {
392+
t.Fatalf("Unexpected server error: %v", err)
393+
}
394+
}
395+
})
396+
}
397+
}

internal/envconfig/envconfig.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ var (
4040
// ALTSMaxConcurrentHandshakes is the maximum number of concurrent ALTS
4141
// handshakes that can be performed.
4242
ALTSMaxConcurrentHandshakes = uint64FromEnv("GRPC_ALTS_MAX_CONCURRENT_HANDSHAKES", 100, 1, 100)
43+
// EnforceALPNEnabled is set if TLS connections to servers with ALPN disabled
44+
// should be rejected. The HTTP/2 protocol requires ALPN to be enabled, this
45+
// option is present for backward compatibility. This option may be overridden
46+
// by setting the environment variable "GRPC_ENFORCE_ALPN_ENABLED" to "true"
47+
// or "false".
48+
EnforceALPNEnabled = boolFromEnv("GRPC_ENFORCE_ALPN_ENABLED", false)
4349
)
4450

4551
func boolFromEnv(envVar string, def bool) bool {

0 commit comments

Comments
 (0)