Skip to content

Commit ab59d02

Browse files
phireworkbradfitz
authored andcommitted
[tailscale] net, net/http: add enforcement hooks
Updates #55 Updates tailscale/corp#8944 Updates tailscale/corp#12702 Signed-off-by: Jenny Zhang <[email protected]> Signed-off-by: Brad Fitzpatrick <[email protected]> (Cherry-picked from 13373ca) (cherry picked from commit 043e09a) (cherry picked from commit 8df9488)
1 parent 1e42045 commit ab59d02

File tree

4 files changed

+78
-0
lines changed

4 files changed

+78
-0
lines changed

api/go1.99999.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
pkg net, func SetDialEnforcer(func(context.Context, []Addr) error) #55
2+
pkg net, func SetResolveEnforcer(func(context.Context, string, string, string, Addr) error) #55
3+
pkg net/http, func SetRoundTripEnforcer(func(*Request) error) #55

src/net/dial.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,24 @@ func parseNetwork(ctx context.Context, network string, needsProto bool) (afnet s
277277
return "", 0, UnknownNetworkError(network)
278278
}
279279

280+
// SetResolveEnforcer set a program-global resolver enforcer that can cause resolvers to
281+
// fail based on the context and/or other arguments.
282+
//
283+
// f must be non-nil, it can only be called once, and must not be called
284+
// concurrent with any dial/resolve.
285+
func SetResolveEnforcer(f func(ctx context.Context, op, network, addr string, hint Addr) error) {
286+
if f == nil {
287+
panic("nil func")
288+
}
289+
if resolveEnforcer != nil {
290+
panic("already called")
291+
}
292+
resolveEnforcer = f
293+
}
294+
295+
// resolveEnforcer, if non-nil, is the installed hook from SetResolveEnforcer.
296+
var resolveEnforcer func(ctx context.Context, op, network, addr string, hint Addr) error
297+
280298
// resolveAddrList resolves addr using hint and returns a list of
281299
// addresses. The result contains at least one address when error is
282300
// nil.
@@ -299,6 +317,13 @@ func (r *Resolver) resolveAddrList(ctx context.Context, op, network, addr string
299317
}
300318
return addrList{addr}, nil
301319
}
320+
321+
if resolveEnforcer != nil {
322+
if err := resolveEnforcer(ctx, op, network, addr, hint); err != nil {
323+
return nil, err
324+
}
325+
}
326+
302327
addrs, err := r.internetAddrList(ctx, afnet, addr)
303328
if err != nil || op != "dial" || hint == nil {
304329
return addrs, err
@@ -603,9 +628,32 @@ func (sd *sysDialer) dialParallel(ctx context.Context, primaries, fallbacks addr
603628
}
604629
}
605630

631+
// SetDialEnforcer set a program-global dial enforcer that can cause dials to
632+
// fail based on the context and/or Addr(s).
633+
//
634+
// f must be non-nil, it can only be called once, and must not be called
635+
// concurrent with any dial.
636+
func SetDialEnforcer(f func(context.Context, []Addr) error) {
637+
if f == nil {
638+
panic("nil func")
639+
}
640+
if dialEnforcer != nil {
641+
panic("already called")
642+
}
643+
dialEnforcer = f
644+
}
645+
646+
// dialEnforce, if non-nil, is any installed hook from SetDialEnforcer.
647+
var dialEnforcer func(context.Context, []Addr) error
648+
606649
// dialSerial connects to a list of addresses in sequence, returning
607650
// either the first successful connection, or the first error.
608651
func (sd *sysDialer) dialSerial(ctx context.Context, ras addrList) (Conn, error) {
652+
if dialEnforcer != nil {
653+
if err := dialEnforcer(ctx, ras); err != nil {
654+
return nil, err
655+
}
656+
}
609657
var firstErr error // The error from the first address is most relevant.
610658

611659
for i, ra := range ras {

src/net/http/tailscale.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package http
2+
3+
var roundTripEnforcer func(*Request) error
4+
5+
// SetRoundTripEnforcer set a program-global resolver enforcer that can cause
6+
// RoundTrip calls to fail based on the request and its context.
7+
//
8+
// f must be non-nil.
9+
//
10+
// SetRoundTripEnforcer can only be called once, and must not be called
11+
// concurrent with any RoundTrip call; it's expected to be registered during
12+
// init.
13+
func SetRoundTripEnforcer(f func(*Request) error) {
14+
if f == nil {
15+
panic("nil func")
16+
}
17+
if roundTripEnforcer != nil {
18+
panic("already called")
19+
}
20+
roundTripEnforcer = f
21+
}

src/net/http/transport.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,12 @@ func validateHeaders(hdrs Header) string {
528528

529529
// roundTrip implements a RoundTripper over HTTP.
530530
func (t *Transport) roundTrip(req *Request) (_ *Response, err error) {
531+
if roundTripEnforcer != nil {
532+
if err := roundTripEnforcer(req); err != nil {
533+
return nil, err
534+
}
535+
}
536+
531537
t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
532538
ctx := req.Context()
533539
trace := httptrace.ContextClientTrace(ctx)

0 commit comments

Comments
 (0)