@@ -277,6 +277,24 @@ func parseNetwork(ctx context.Context, network string, needsProto bool) (afnet s
277
277
return "" , 0 , UnknownNetworkError (network )
278
278
}
279
279
280
+ // SetResolveEnforcer set a program-global resolver enforcer that can cause resolvers to
281
+ // fail based on the context and/or other arguments.
282
+ //
283
+ // f must be non-nil, it can only be called once, and must not be called
284
+ // concurrent with any dial/resolve.
285
+ func SetResolveEnforcer (f func (ctx context.Context , op , network , addr string , hint Addr ) error ) {
286
+ if f == nil {
287
+ panic ("nil func" )
288
+ }
289
+ if resolveEnforcer != nil {
290
+ panic ("already called" )
291
+ }
292
+ resolveEnforcer = f
293
+ }
294
+
295
+ // resolveEnforcer, if non-nil, is the installed hook from SetResolveEnforcer.
296
+ var resolveEnforcer func (ctx context.Context , op , network , addr string , hint Addr ) error
297
+
280
298
// resolveAddrList resolves addr using hint and returns a list of
281
299
// addresses. The result contains at least one address when error is
282
300
// nil.
@@ -299,6 +317,13 @@ func (r *Resolver) resolveAddrList(ctx context.Context, op, network, addr string
299
317
}
300
318
return addrList {addr }, nil
301
319
}
320
+
321
+ if resolveEnforcer != nil {
322
+ if err := resolveEnforcer (ctx , op , network , addr , hint ); err != nil {
323
+ return nil , err
324
+ }
325
+ }
326
+
302
327
addrs , err := r .internetAddrList (ctx , afnet , addr )
303
328
if err != nil || op != "dial" || hint == nil {
304
329
return addrs , err
@@ -603,9 +628,32 @@ func (sd *sysDialer) dialParallel(ctx context.Context, primaries, fallbacks addr
603
628
}
604
629
}
605
630
631
+ // SetDialEnforcer set a program-global dial enforcer that can cause dials to
632
+ // fail based on the context and/or Addr(s).
633
+ //
634
+ // f must be non-nil, it can only be called once, and must not be called
635
+ // concurrent with any dial.
636
+ func SetDialEnforcer (f func (context.Context , []Addr ) error ) {
637
+ if f == nil {
638
+ panic ("nil func" )
639
+ }
640
+ if dialEnforcer != nil {
641
+ panic ("already called" )
642
+ }
643
+ dialEnforcer = f
644
+ }
645
+
646
+ // dialEnforce, if non-nil, is any installed hook from SetDialEnforcer.
647
+ var dialEnforcer func (context.Context , []Addr ) error
648
+
606
649
// dialSerial connects to a list of addresses in sequence, returning
607
650
// either the first successful connection, or the first error.
608
651
func (sd * sysDialer ) dialSerial (ctx context.Context , ras addrList ) (Conn , error ) {
652
+ if dialEnforcer != nil {
653
+ if err := dialEnforcer (ctx , ras ); err != nil {
654
+ return nil , err
655
+ }
656
+ }
609
657
var firstErr error // The error from the first address is most relevant.
610
658
611
659
for i , ra := range ras {
0 commit comments