From a96a9eddc031c85f22378ef1e37e3fd7e9c482ef Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Sat, 22 Jul 2023 20:07:20 -0700 Subject: [PATCH] net/http: add enforcement hook to Transport.RoundTrip, like our net ones Updates tailscale/go#55 Updates tailscale/corp#12702 Signed-off-by: Brad Fitzpatrick --- api/go1.99999.txt | 1 + src/net/http/tailscale.go | 21 +++++++++++++++++++++ src/net/http/transport.go | 5 +++++ 3 files changed, 27 insertions(+) create mode 100644 src/net/http/tailscale.go diff --git a/api/go1.99999.txt b/api/go1.99999.txt index a4b8591391989..3c59e814c5ca5 100644 --- a/api/go1.99999.txt +++ b/api/go1.99999.txt @@ -1,5 +1,6 @@ pkg net, func SetDialEnforcer(func(context.Context, []Addr) error) #55 pkg net, func SetResolveEnforcer(func(context.Context, string, string, string, Addr) error) #55 +pkg net/http, func SetRoundTripEnforcer(func(*Request) error) #55 pkg net, func WithSockTrace(context.Context, *SockTrace) context.Context #58 pkg net, func ContextSockTrace(context.Context) *SockTrace #58 pkg net, type SockTrace struct #58 diff --git a/src/net/http/tailscale.go b/src/net/http/tailscale.go new file mode 100644 index 0000000000000..b2a13893d3eaa --- /dev/null +++ b/src/net/http/tailscale.go @@ -0,0 +1,21 @@ +package http + +var roundTripEnforcer func(*Request) error + +// SetRoundTripEnforcer set a program-global resolver enforcer that can cause +// RoundTrip calls to fail based on the request and its context. +// +// f must be non-nil. +// +// SetRoundTripEnforcer can only be called once, and must not be called +// concurrent with any RoundTrip call; it's expected to be registered during +// init. +func SetRoundTripEnforcer(f func(*Request) error) { + if f == nil { + panic("nil func") + } + if roundTripEnforcer != nil { + panic("already called") + } + roundTripEnforcer = f +} diff --git a/src/net/http/transport.go b/src/net/http/transport.go index c07352b018e1e..4d9efe270840b 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -511,6 +511,11 @@ func (t *Transport) alternateRoundTripper(req *Request) RoundTripper { // roundTrip implements a RoundTripper over HTTP. func (t *Transport) roundTrip(req *Request) (*Response, error) { + if roundTripEnforcer != nil { + if err := roundTripEnforcer(req); err != nil { + return nil, err + } + } t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) ctx := req.Context() trace := httptrace.ContextClientTrace(ctx)