Skip to content

Commit e064f32

Browse files
seans3AlexVulaj
authored andcommitted
Implements HTTPS proxy functionality
1 parent 5e00238 commit e064f32

File tree

3 files changed

+1036
-44
lines changed

3 files changed

+1036
-44
lines changed

client.go

Lines changed: 124 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,34 @@ func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufS
5151
//
5252
// It is safe to call Dialer's methods concurrently.
5353
type Dialer struct {
54+
// The following custom dial functions can be set to establish
55+
// connections to either the backend server or the proxy (if it
56+
// exists). The scheme of the dialed entity (either backend or
57+
// proxy) determines which custom dial function is selected:
58+
// either NetDialTLSContext for HTTPS or NetDialContext/NetDial
59+
// for HTTP. Since the "Proxy" function can determine the scheme
60+
// dynamically, it can make sense to set multiple custom dial
61+
// functions simultaneously.
62+
//
5463
// NetDial specifies the dial function for creating TCP connections. If
5564
// NetDial is nil, net.Dialer DialContext is used.
65+
// If "Proxy" field is also set, this function dials the proxy--not
66+
// the backend server.
5667
NetDial func(network, addr string) (net.Conn, error)
5768

5869
// NetDialContext specifies the dial function for creating TCP connections. If
5970
// NetDialContext is nil, NetDial is used.
71+
// If "Proxy" field is also set, this function dials the proxy--not
72+
// the backend server.
6073
NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
6174

6275
// NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If
6376
// NetDialTLSContext is nil, NetDialContext is used.
6477
// If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and
6578
// TLSClientConfig is ignored.
79+
// If "Proxy" field is also set, this function dials the proxy (and performs
80+
// the TLS handshake with the proxy, ignoring TLSClientConfig). In this TLS proxy
81+
// dialing case the TLSClientConfig could still be necessary for TLS to the backend server.
6682
NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
6783

6884
// Proxy specifies a function to return a proxy for a given
@@ -73,7 +89,7 @@ type Dialer struct {
7389

7490
// TLSClientConfig specifies the TLS configuration to use with tls.Client.
7591
// If nil, the default configuration is used.
76-
// If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake
92+
// If NetDialTLSContext is set, Dial assumes the TLS handshake
7793
// is done there and TLSClientConfig is ignored.
7894
TLSClientConfig *tls.Config
7995

@@ -244,49 +260,16 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
244260
defer cancel()
245261
}
246262

247-
var netDial netDialerFunc
248-
switch {
249-
case u.Scheme == "https" && d.NetDialTLSContext != nil:
250-
netDial = d.NetDialTLSContext
251-
case d.NetDialContext != nil:
252-
netDial = d.NetDialContext
253-
case d.NetDial != nil:
254-
netDial = func(ctx context.Context, net, addr string) (net.Conn, error) {
255-
return d.NetDial(net, addr)
256-
}
257-
default:
258-
netDial = (&net.Dialer{}).DialContext
259-
}
260-
261-
// If needed, wrap the dial function to set the connection deadline.
262-
if deadline, ok := ctx.Deadline(); ok {
263-
forwardDial := netDial
264-
netDial = func(ctx context.Context, network, addr string) (net.Conn, error) {
265-
c, err := forwardDial(ctx, network, addr)
266-
if err != nil {
267-
return nil, err
268-
}
269-
err = c.SetDeadline(deadline)
270-
if err != nil {
271-
c.Close()
272-
return nil, err
273-
}
274-
return c, nil
275-
}
276-
}
277-
278-
// If needed, wrap the dial function to connect through a proxy.
263+
var proxyURL *url.URL
279264
if d.Proxy != nil {
280-
proxyURL, err := d.Proxy(req)
265+
proxyURL, err = d.Proxy(req)
281266
if err != nil {
282267
return nil, nil, err
283268
}
284-
if proxyURL != nil {
285-
netDial, err = proxyFromURL(proxyURL, netDial)
286-
if err != nil {
287-
return nil, nil, err
288-
}
289-
}
269+
}
270+
netDial, err := d.netDialFn(ctx, proxyURL, u)
271+
if err != nil {
272+
return nil, nil, err
290273
}
291274

292275
hostPort, hostNoPort := hostPortNoPort(u)
@@ -317,8 +300,8 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
317300
}
318301
}()
319302

320-
if u.Scheme == "https" && d.NetDialTLSContext == nil {
321-
// If NetDialTLSContext is set, assume that the TLS handshake has already been done
303+
// Do TLS handshake over established connection if a proxy exists.
304+
if proxyURL != nil && u.Scheme == "https" {
322305

323306
cfg := cloneTLSConfig(d.TLSClientConfig)
324307
if cfg.ServerName == "" {
@@ -415,6 +398,105 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
415398
return conn, resp, nil
416399
}
417400

401+
// Returns the dial function to establish the connection to either the backend
402+
// server or the proxy (if it exists). If the dialed entity is HTTPS, then the
403+
// returned dial function *also* performs the TLS handshake to the dialed entity.
404+
// NOTE: If a proxy exists, it is possible for a second TLS handshake to be
405+
// necessary over the established connection.
406+
func (d *Dialer) netDialFn(ctx context.Context, proxyURL *url.URL, backendURL *url.URL) (netDialerFunc, error) {
407+
var netDial netDialerFunc
408+
if proxyURL != nil {
409+
netDial = d.netDialFromURL(proxyURL)
410+
} else {
411+
netDial = d.netDialFromURL(backendURL)
412+
}
413+
// If needed, wrap the dial function to set the connection deadline.
414+
if deadline, ok := ctx.Deadline(); ok {
415+
netDial = netDialWithDeadline(netDial, deadline)
416+
}
417+
// Proxy dialing is wrapped to implement CONNECT method and possibly proxy auth.
418+
if proxyURL != nil {
419+
return proxyFromURL(proxyURL, netDial)
420+
}
421+
return netDial, nil
422+
}
423+
424+
// Returns function to create the connection depending on the Dialer's
425+
// custom dialing functions and the passed URL of entity connecting to.
426+
func (d *Dialer) netDialFromURL(u *url.URL) netDialerFunc {
427+
var netDial netDialerFunc
428+
switch {
429+
case d.NetDialContext != nil:
430+
netDial = d.NetDialContext
431+
case d.NetDial != nil:
432+
netDial = func(ctx context.Context, net, addr string) (net.Conn, error) {
433+
return d.NetDial(net, addr)
434+
}
435+
default:
436+
netDial = (&net.Dialer{}).DialContext
437+
}
438+
// If dialed entity is HTTPS, then either use custom TLS dialing function (if exists)
439+
// or wrap the previously computed "netDial" to use TLS config for handshake.
440+
if u.Scheme == "https" {
441+
if d.NetDialTLSContext != nil {
442+
netDial = d.NetDialTLSContext
443+
} else {
444+
netDial = netDialWithTLSHandshake(netDial, d.TLSClientConfig, u)
445+
}
446+
}
447+
return netDial
448+
}
449+
450+
// Returns wrapped "netDial" function, performing TLS handshake after connecting.
451+
func netDialWithTLSHandshake(netDial netDialerFunc, tlsConfig *tls.Config, u *url.URL) netDialerFunc {
452+
return func(ctx context.Context, unused, addr string) (net.Conn, error) {
453+
hostPort, hostNoPort := hostPortNoPort(u)
454+
trace := httptrace.ContextClientTrace(ctx)
455+
if trace != nil && trace.GetConn != nil {
456+
trace.GetConn(hostPort)
457+
}
458+
// Creates TCP connection to addr using passed "netDial" function.
459+
conn, err := netDial(ctx, "tcp", addr)
460+
if err != nil {
461+
return nil, err
462+
}
463+
cfg := cloneTLSConfig(tlsConfig)
464+
if cfg.ServerName == "" {
465+
cfg.ServerName = hostNoPort
466+
}
467+
tlsConn := tls.Client(conn, cfg)
468+
// Do the TLS handshake using TLSConfig over the wrapped connection.
469+
if trace != nil && trace.TLSHandshakeStart != nil {
470+
trace.TLSHandshakeStart()
471+
}
472+
err = doHandshake(ctx, tlsConn, cfg)
473+
if trace != nil && trace.TLSHandshakeDone != nil {
474+
trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
475+
}
476+
if err != nil {
477+
tlsConn.Close()
478+
return nil, err
479+
}
480+
return tlsConn, nil
481+
}
482+
}
483+
484+
// Returns wrapped "netDial" function, setting passed deadline.
485+
func netDialWithDeadline(netDial netDialerFunc, deadline time.Time) netDialerFunc {
486+
return func(ctx context.Context, network, addr string) (net.Conn, error) {
487+
c, err := netDial(ctx, network, addr)
488+
if err != nil {
489+
return nil, err
490+
}
491+
err = c.SetDeadline(deadline)
492+
if err != nil {
493+
c.Close()
494+
return nil, err
495+
}
496+
return c, nil
497+
}
498+
}
499+
418500
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
419501
if cfg == nil {
420502
return &tls.Config{}

0 commit comments

Comments
 (0)