diff --git a/api/go1.99999.txt b/api/go1.99999.txt new file mode 100644 index 0000000000000..dd6cb2c37d580 --- /dev/null +++ b/api/go1.99999.txt @@ -0,0 +1,2 @@ +pkg net, func SetDialEnforcer(func(context.Context, []Addr) error) #55 +pkg net, func SetResolveEnforcer(func(context.Context, string, string, string, Addr) error) #55 diff --git a/src/net/dial.go b/src/net/dial.go index a6565c3ce5d13..2e8239da0831f 100644 --- a/src/net/dial.go +++ b/src/net/dial.go @@ -258,6 +258,24 @@ func parseNetwork(ctx context.Context, network string, needsProto bool) (afnet s return "", 0, UnknownNetworkError(network) } +// SetResolveEnforcer set a program-global resolver enforcer that can cause resolvers to +// fail based on the context and/or other arguments. +// +// f must be non-nil, it can only be called once, and must not be called +// concurrent with any dial/resolve. +func SetResolveEnforcer(f func(ctx context.Context, op, network, addr string, hint Addr) error) { + if f == nil { + panic("nil func") + } + if resolveEnforcer != nil { + panic("already called") + } + resolveEnforcer = f +} + +// resolveEnforcer, if non-nil, is the installed hook from SetResolveEnforcer. +var resolveEnforcer func(ctx context.Context, op, network, addr string, hint Addr) error + // resolveAddrList resolves addr using hint and returns a list of // addresses. The result contains at least one address when error is // nil. @@ -280,6 +298,13 @@ func (r *Resolver) resolveAddrList(ctx context.Context, op, network, addr string } return addrList{addr}, nil } + + if resolveEnforcer != nil { + if err := resolveEnforcer(ctx, op, network, addr, hint); err != nil { + return nil, err + } + } + addrs, err := r.internetAddrList(ctx, afnet, addr) if err != nil || op != "dial" || hint == nil { return addrs, err @@ -584,9 +609,32 @@ func (sd *sysDialer) dialParallel(ctx context.Context, primaries, fallbacks addr } } +// SetDialEnforcer set a program-global dial enforcer that can cause dials to +// fail based on the context and/or Addr(s). +// +// f must be non-nil, it can only be called once, and must not be called +// concurrent with any dial. +func SetDialEnforcer(f func(context.Context, []Addr) error) { + if f == nil { + panic("nil func") + } + if dialEnforcer != nil { + panic("already called") + } + dialEnforcer = f +} + +// dialEnforce, if non-nil, is any installed hook from SetDialEnforcer. +var dialEnforcer func(context.Context, []Addr) error + // dialSerial connects to a list of addresses in sequence, returning // either the first successful connection, or the first error. func (sd *sysDialer) dialSerial(ctx context.Context, ras addrList) (Conn, error) { + if dialEnforcer != nil { + if err := dialEnforcer(ctx, ras); err != nil { + return nil, err + } + } var firstErr error // The error from the first address is most relevant. for i, ra := range ras {