diff --git a/src/net/http/tailscale.go b/src/net/http/tailscale.go new file mode 100644 index 0000000000000..b2a13893d3eaa --- /dev/null +++ b/src/net/http/tailscale.go @@ -0,0 +1,21 @@ +package http + +var roundTripEnforcer func(*Request) error + +// SetRoundTripEnforcer set a program-global resolver enforcer that can cause +// RoundTrip calls to fail based on the request and its context. +// +// f must be non-nil. +// +// SetRoundTripEnforcer can only be called once, and must not be called +// concurrent with any RoundTrip call; it's expected to be registered during +// init. +func SetRoundTripEnforcer(f func(*Request) error) { + if f == nil { + panic("nil func") + } + if roundTripEnforcer != nil { + panic("already called") + } + roundTripEnforcer = f +} diff --git a/src/net/http/transport.go b/src/net/http/transport.go index da9163a27ae6f..d9ec8bb157a82 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -528,6 +528,12 @@ func validateHeaders(hdrs Header) string { // roundTrip implements a RoundTripper over HTTP. func (t *Transport) roundTrip(req *Request) (_ *Response, err error) { + if roundTripEnforcer != nil { + if err := roundTripEnforcer(req); err != nil { + return nil, err + } + } + t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) ctx := req.Context() trace := httptrace.ContextClientTrace(ctx)