Skip to content

Commit

Permalink
net/http: fix cancelation of requests with a readTrackingBody wrapper
Browse files Browse the repository at this point in the history
Use the original *Request in the reqCanceler map, not the transient
wrapper created to handle body rewinding.

Change the key of reqCanceler to a struct{*Request}, to make it more
difficult to accidentally use the wrong request as the key.

Fixes #40453.

Change-Id: I4e61ee9ff2c794fb4c920a3a66c9a0458693d757
Reviewed-on: https://go-review.googlesource.com/c/go/+/245357
Run-TryBot: Damien Neil <[email protected]>
TryBot-Result: Gobot Gobot <[email protected]>
Reviewed-by: Russ Cox <[email protected]>
  • Loading branch information
neild committed Aug 4, 2020
1 parent f923374 commit f235275
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 31 deletions.
73 changes: 42 additions & 31 deletions src/net/http/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ type Transport struct {
idleLRU connLRU

reqMu sync.Mutex
reqCanceler map[*Request]func(error)
reqCanceler map[cancelKey]func(error)

altMu sync.Mutex // guards changing altProto only
altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme
Expand Down Expand Up @@ -273,6 +273,13 @@ type Transport struct {
ForceAttemptHTTP2 bool
}

// A cancelKey is the key of the reqCanceler map.
// We wrap the *Request in this type since we want to use the original request,
// not any transient one created by roundTrip.
type cancelKey struct {
req *Request
}

func (t *Transport) writeBufferSize() int {
if t.WriteBufferSize > 0 {
return t.WriteBufferSize
Expand Down Expand Up @@ -433,9 +440,10 @@ func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, error) {
// optional extra headers to write and stores any error to return
// from roundTrip.
type transportRequest struct {
*Request // original request, not to be mutated
extra Header // extra headers to write, or nil
trace *httptrace.ClientTrace // optional
*Request // original request, not to be mutated
extra Header // extra headers to write, or nil
trace *httptrace.ClientTrace // optional
cancelKey cancelKey

mu sync.Mutex // guards err
err error // first setError value for mapRoundTripError to consider
Expand Down Expand Up @@ -512,6 +520,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
}

origReq := req
cancelKey := cancelKey{origReq}
req = setupRewindBody(req)

if altRT := t.alternateRoundTripper(req); altRT != nil {
Expand Down Expand Up @@ -546,7 +555,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
}

// treq gets modified by roundTrip, so we need to recreate for each retry.
treq := &transportRequest{Request: req, trace: trace}
treq := &transportRequest{Request: req, trace: trace, cancelKey: cancelKey}
cm, err := t.connectMethodForRequest(treq)
if err != nil {
req.closeBody()
Expand All @@ -559,15 +568,15 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
// to send it requests.
pconn, err := t.getConn(treq, cm)
if err != nil {
t.setReqCanceler(req, nil)
t.setReqCanceler(cancelKey, nil)
req.closeBody()
return nil, err
}

var resp *Response
if pconn.alt != nil {
// HTTP/2 path.
t.setReqCanceler(req, nil) // not cancelable with CancelRequest
t.setReqCanceler(cancelKey, nil) // not cancelable with CancelRequest
resp, err = pconn.alt.RoundTrip(req)
} else {
resp, err = pconn.roundTrip(treq)
Expand Down Expand Up @@ -753,14 +762,14 @@ func (t *Transport) CloseIdleConnections() {
// cancelable context instead. CancelRequest cannot cancel HTTP/2
// requests.
func (t *Transport) CancelRequest(req *Request) {
t.cancelRequest(req, errRequestCanceled)
t.cancelRequest(cancelKey{req}, errRequestCanceled)
}

// Cancel an in-flight request, recording the error value.
func (t *Transport) cancelRequest(req *Request, err error) {
func (t *Transport) cancelRequest(key cancelKey, err error) {
t.reqMu.Lock()
cancel := t.reqCanceler[req]
delete(t.reqCanceler, req)
cancel := t.reqCanceler[key]
delete(t.reqCanceler, key)
t.reqMu.Unlock()
if cancel != nil {
cancel(err)
Expand Down Expand Up @@ -1093,34 +1102,34 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) bool {
return removed
}

func (t *Transport) setReqCanceler(r *Request, fn func(error)) {
func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) {
t.reqMu.Lock()
defer t.reqMu.Unlock()
if t.reqCanceler == nil {
t.reqCanceler = make(map[*Request]func(error))
t.reqCanceler = make(map[cancelKey]func(error))
}
if fn != nil {
t.reqCanceler[r] = fn
t.reqCanceler[key] = fn
} else {
delete(t.reqCanceler, r)
delete(t.reqCanceler, key)
}
}

// replaceReqCanceler replaces an existing cancel function. If there is no cancel function
// for the request, we don't set the function and return false.
// Since CancelRequest will clear the canceler, we can use the return value to detect if
// the request was canceled since the last setReqCancel call.
func (t *Transport) replaceReqCanceler(r *Request, fn func(error)) bool {
func (t *Transport) replaceReqCanceler(key cancelKey, fn func(error)) bool {
t.reqMu.Lock()
defer t.reqMu.Unlock()
_, ok := t.reqCanceler[r]
_, ok := t.reqCanceler[key]
if !ok {
return false
}
if fn != nil {
t.reqCanceler[r] = fn
t.reqCanceler[key] = fn
} else {
delete(t.reqCanceler, r)
delete(t.reqCanceler, key)
}
return true
}
Expand Down Expand Up @@ -1324,12 +1333,12 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi
// set request canceler to some non-nil function so we
// can detect whether it was cleared between now and when
// we enter roundTrip
t.setReqCanceler(req, func(error) {})
t.setReqCanceler(treq.cancelKey, func(error) {})
return pc, nil
}

cancelc := make(chan error, 1)
t.setReqCanceler(req, func(err error) { cancelc <- err })
t.setReqCanceler(treq.cancelKey, func(err error) { cancelc <- err })

// Queue for permission to dial.
t.queueForDial(w)
Expand Down Expand Up @@ -2078,7 +2087,7 @@ func (pc *persistConn) readLoop() {
}

if !hasBody || bodyWritable {
pc.t.setReqCanceler(rc.req, nil)
pc.t.setReqCanceler(rc.cancelKey, nil)

// Put the idle conn back into the pool before we send the response
// so if they process it quickly and make another request, they'll
Expand Down Expand Up @@ -2151,7 +2160,7 @@ func (pc *persistConn) readLoop() {
// reading the response body. (or for cancellation or death)
select {
case bodyEOF := <-waitForBodyRead:
pc.t.setReqCanceler(rc.req, nil) // before pc might return to idle pool
pc.t.setReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool
alive = alive &&
bodyEOF &&
!pc.sawEOF &&
Expand All @@ -2165,7 +2174,7 @@ func (pc *persistConn) readLoop() {
pc.t.CancelRequest(rc.req)
case <-rc.req.Context().Done():
alive = false
pc.t.cancelRequest(rc.req, rc.req.Context().Err())
pc.t.cancelRequest(rc.cancelKey, rc.req.Context().Err())
case <-pc.closech:
alive = false
}
Expand Down Expand Up @@ -2408,9 +2417,10 @@ type responseAndError struct {
}

type requestAndChan struct {
_ incomparable
req *Request
ch chan responseAndError // unbuffered; always send in select on callerGone
_ incomparable
req *Request
cancelKey cancelKey
ch chan responseAndError // unbuffered; always send in select on callerGone

// whether the Transport (as opposed to the user client code)
// added the Accept-Encoding gzip header. If the Transport
Expand Down Expand Up @@ -2472,7 +2482,7 @@ var (

func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
testHookEnterRoundTrip()
if !pc.t.replaceReqCanceler(req.Request, pc.cancelRequest) {
if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) {
pc.t.putOrCloseIdleConn(pc)
return nil, errRequestCanceled
}
Expand Down Expand Up @@ -2524,7 +2534,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err

defer func() {
if err != nil {
pc.t.setReqCanceler(req.Request, nil)
pc.t.setReqCanceler(req.cancelKey, nil)
}
}()

Expand All @@ -2540,6 +2550,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
resc := make(chan responseAndError)
pc.reqch <- requestAndChan{
req: req.Request,
cancelKey: req.cancelKey,
ch: resc,
addedGzip: requestedGzip,
continueCh: continueCh,
Expand Down Expand Up @@ -2591,10 +2602,10 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
}
return re.res, nil
case <-cancelChan:
pc.t.CancelRequest(req.Request)
pc.t.cancelRequest(req.cancelKey, errRequestCanceled)
cancelChan = nil
case <-ctxDoneChan:
pc.t.cancelRequest(req.Request, req.Context().Err())
pc.t.cancelRequest(req.cancelKey, req.Context().Err())
cancelChan = nil
ctxDoneChan = nil
}
Expand Down
44 changes: 44 additions & 0 deletions src/net/http/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2364,6 +2364,50 @@ func TestTransportCancelRequest(t *testing.T) {
}
}

func testTransportCancelRequestInDo(t *testing.T, body io.Reader) {
setParallel(t)
defer afterTest(t)
if testing.Short() {
t.Skip("skipping test in -short mode")
}
unblockc := make(chan bool)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
<-unblockc
}))
defer ts.Close()
defer close(unblockc)

c := ts.Client()
tr := c.Transport.(*Transport)

donec := make(chan bool)
req, _ := NewRequest("GET", ts.URL, body)
go func() {
defer close(donec)
c.Do(req)
}()
start := time.Now()
timeout := 10 * time.Second
for time.Since(start) < timeout {
time.Sleep(100 * time.Millisecond)
tr.CancelRequest(req)
select {
case <-donec:
return
default:
}
}
t.Errorf("Do of canceled request has not returned after %v", timeout)
}

func TestTransportCancelRequestInDo(t *testing.T) {
testTransportCancelRequestInDo(t, nil)
}

func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
testTransportCancelRequestInDo(t, bytes.NewBuffer([]byte{0}))
}

func TestTransportCancelRequestInDial(t *testing.T) {
defer afterTest(t)
if testing.Short() {
Expand Down

0 comments on commit f235275

Please sign in to comment.