Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tailscale] net: add enforcement hooks #56

Merged
merged 1 commit into from
Feb 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions api/go1.99999.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pkg net, func SetDialEnforcer(func(context.Context, []Addr) error) #55
pkg net, func SetResolveEnforcer(func(context.Context, string, string, string, Addr) error) #55
48 changes: 48 additions & 0 deletions src/net/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,24 @@ func parseNetwork(ctx context.Context, network string, needsProto bool) (afnet s
return "", 0, UnknownNetworkError(network)
}

// SetResolveEnforcer set a program-global resolver enforcer that can cause resolvers to
// fail based on the context and/or other arguments.
//
// f must be non-nil, it can only be called once, and must not be called
bradfitz marked this conversation as resolved.
Show resolved Hide resolved
// concurrent with any dial/resolve.
func SetResolveEnforcer(f func(ctx context.Context, op, network, addr string, hint Addr) error) {
if f == nil {
panic("nil func")
}
if resolveEnforcer != nil {
panic("already called")
}
resolveEnforcer = f
}

// resolveEnforcer, if non-nil, is the installed hook from SetResolveEnforcer.
var resolveEnforcer func(ctx context.Context, op, network, addr string, hint Addr) error

// resolveAddrList resolves addr using hint and returns a list of
// addresses. The result contains at least one address when error is
// nil.
Expand All @@ -231,6 +249,13 @@ func (r *Resolver) resolveAddrList(ctx context.Context, op, network, addr string
}
return addrList{addr}, nil
}

if resolveEnforcer != nil {
if err := resolveEnforcer(ctx, op, network, addr, hint); err != nil {
return nil, err
}
}

addrs, err := r.internetAddrList(ctx, afnet, addr)
if err != nil || op != "dial" || hint == nil {
return addrs, err
Expand Down Expand Up @@ -516,9 +541,32 @@ func (sd *sysDialer) dialParallel(ctx context.Context, primaries, fallbacks addr
}
}

// SetDialEnforcer set a program-global dial enforcer that can cause dials to
// fail based on the context and/or Addr(s).
//
// f must be non-nil, it can only be called once, and must not be called
// concurrent with any dial.
func SetDialEnforcer(f func(context.Context, []Addr) error) {
if f == nil {
panic("nil func")
}
if dialEnforcer != nil {
panic("already called")
}
dialEnforcer = f
}

// dialEnforce, if non-nil, is any installed hook from SetDialEnforcer.
var dialEnforcer func(context.Context, []Addr) error

// dialSerial connects to a list of addresses in sequence, returning
// either the first successful connection, or the first error.
func (sd *sysDialer) dialSerial(ctx context.Context, ras addrList) (Conn, error) {
if dialEnforcer != nil {
if err := dialEnforcer(ctx, ras); err != nil {
return nil, err
}
}
var firstErr error // The error from the first address is most relevant.

for i, ra := range ras {
Expand Down