diff --git a/src/net/http/transport.go b/src/net/http/transport.go index a41e732d983da4..d37b52b13d0600 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -100,7 +100,7 @@ type Transport struct { idleLRU connLRU reqMu sync.Mutex - reqCanceler map[*Request]func(error) + reqCanceler map[cancelKey]func(error) altMu sync.Mutex // guards changing altProto only altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme @@ -273,6 +273,13 @@ type Transport struct { ForceAttemptHTTP2 bool } +// A cancelKey is the key of the reqCanceler map. +// We wrap the *Request in this type since we want to use the original request, +// not any transient one created by roundTrip. +type cancelKey struct { + req *Request +} + func (t *Transport) writeBufferSize() int { if t.WriteBufferSize > 0 { return t.WriteBufferSize @@ -433,9 +440,10 @@ func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, error) { // optional extra headers to write and stores any error to return // from roundTrip. type transportRequest struct { - *Request // original request, not to be mutated - extra Header // extra headers to write, or nil - trace *httptrace.ClientTrace // optional + *Request // original request, not to be mutated + extra Header // extra headers to write, or nil + trace *httptrace.ClientTrace // optional + cancelKey cancelKey mu sync.Mutex // guards err err error // first setError value for mapRoundTripError to consider @@ -512,6 +520,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { } origReq := req + cancelKey := cancelKey{origReq} req = setupRewindBody(req) if altRT := t.alternateRoundTripper(req); altRT != nil { @@ -546,7 +555,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { } // treq gets modified by roundTrip, so we need to recreate for each retry. - treq := &transportRequest{Request: req, trace: trace} + treq := &transportRequest{Request: req, trace: trace, cancelKey: cancelKey} cm, err := t.connectMethodForRequest(treq) if err != nil { req.closeBody() @@ -559,7 +568,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { // to send it requests. pconn, err := t.getConn(treq, cm) if err != nil { - t.setReqCanceler(req, nil) + t.setReqCanceler(cancelKey, nil) req.closeBody() return nil, err } @@ -567,7 +576,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { var resp *Response if pconn.alt != nil { // HTTP/2 path. - t.setReqCanceler(req, nil) // not cancelable with CancelRequest + t.setReqCanceler(cancelKey, nil) // not cancelable with CancelRequest resp, err = pconn.alt.RoundTrip(req) } else { resp, err = pconn.roundTrip(treq) @@ -753,14 +762,14 @@ func (t *Transport) CloseIdleConnections() { // cancelable context instead. CancelRequest cannot cancel HTTP/2 // requests. func (t *Transport) CancelRequest(req *Request) { - t.cancelRequest(req, errRequestCanceled) + t.cancelRequest(cancelKey{req}, errRequestCanceled) } // Cancel an in-flight request, recording the error value. -func (t *Transport) cancelRequest(req *Request, err error) { +func (t *Transport) cancelRequest(key cancelKey, err error) { t.reqMu.Lock() - cancel := t.reqCanceler[req] - delete(t.reqCanceler, req) + cancel := t.reqCanceler[key] + delete(t.reqCanceler, key) t.reqMu.Unlock() if cancel != nil { cancel(err) @@ -1093,16 +1102,16 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) bool { return removed } -func (t *Transport) setReqCanceler(r *Request, fn func(error)) { +func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) { t.reqMu.Lock() defer t.reqMu.Unlock() if t.reqCanceler == nil { - t.reqCanceler = make(map[*Request]func(error)) + t.reqCanceler = make(map[cancelKey]func(error)) } if fn != nil { - t.reqCanceler[r] = fn + t.reqCanceler[key] = fn } else { - delete(t.reqCanceler, r) + delete(t.reqCanceler, key) } } @@ -1110,17 +1119,17 @@ func (t *Transport) setReqCanceler(r *Request, fn func(error)) { // for the request, we don't set the function and return false. // Since CancelRequest will clear the canceler, we can use the return value to detect if // the request was canceled since the last setReqCancel call. -func (t *Transport) replaceReqCanceler(r *Request, fn func(error)) bool { +func (t *Transport) replaceReqCanceler(key cancelKey, fn func(error)) bool { t.reqMu.Lock() defer t.reqMu.Unlock() - _, ok := t.reqCanceler[r] + _, ok := t.reqCanceler[key] if !ok { return false } if fn != nil { - t.reqCanceler[r] = fn + t.reqCanceler[key] = fn } else { - delete(t.reqCanceler, r) + delete(t.reqCanceler, key) } return true } @@ -1324,12 +1333,12 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi // set request canceler to some non-nil function so we // can detect whether it was cleared between now and when // we enter roundTrip - t.setReqCanceler(req, func(error) {}) + t.setReqCanceler(treq.cancelKey, func(error) {}) return pc, nil } cancelc := make(chan error, 1) - t.setReqCanceler(req, func(err error) { cancelc <- err }) + t.setReqCanceler(treq.cancelKey, func(err error) { cancelc <- err }) // Queue for permission to dial. t.queueForDial(w) @@ -2078,7 +2087,7 @@ func (pc *persistConn) readLoop() { } if !hasBody || bodyWritable { - pc.t.setReqCanceler(rc.req, nil) + pc.t.setReqCanceler(rc.cancelKey, nil) // Put the idle conn back into the pool before we send the response // so if they process it quickly and make another request, they'll @@ -2151,7 +2160,7 @@ func (pc *persistConn) readLoop() { // reading the response body. (or for cancellation or death) select { case bodyEOF := <-waitForBodyRead: - pc.t.setReqCanceler(rc.req, nil) // before pc might return to idle pool + pc.t.setReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool alive = alive && bodyEOF && !pc.sawEOF && @@ -2165,7 +2174,7 @@ func (pc *persistConn) readLoop() { pc.t.CancelRequest(rc.req) case <-rc.req.Context().Done(): alive = false - pc.t.cancelRequest(rc.req, rc.req.Context().Err()) + pc.t.cancelRequest(rc.cancelKey, rc.req.Context().Err()) case <-pc.closech: alive = false } @@ -2408,9 +2417,10 @@ type responseAndError struct { } type requestAndChan struct { - _ incomparable - req *Request - ch chan responseAndError // unbuffered; always send in select on callerGone + _ incomparable + req *Request + cancelKey cancelKey + ch chan responseAndError // unbuffered; always send in select on callerGone // whether the Transport (as opposed to the user client code) // added the Accept-Encoding gzip header. If the Transport @@ -2472,7 +2482,7 @@ var ( func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { testHookEnterRoundTrip() - if !pc.t.replaceReqCanceler(req.Request, pc.cancelRequest) { + if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) { pc.t.putOrCloseIdleConn(pc) return nil, errRequestCanceled } @@ -2524,7 +2534,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err defer func() { if err != nil { - pc.t.setReqCanceler(req.Request, nil) + pc.t.setReqCanceler(req.cancelKey, nil) } }() @@ -2540,6 +2550,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err resc := make(chan responseAndError) pc.reqch <- requestAndChan{ req: req.Request, + cancelKey: req.cancelKey, ch: resc, addedGzip: requestedGzip, continueCh: continueCh, @@ -2591,10 +2602,10 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err } return re.res, nil case <-cancelChan: - pc.t.CancelRequest(req.Request) + pc.t.cancelRequest(req.cancelKey, errRequestCanceled) cancelChan = nil case <-ctxDoneChan: - pc.t.cancelRequest(req.Request, req.Context().Err()) + pc.t.cancelRequest(req.cancelKey, req.Context().Err()) cancelChan = nil ctxDoneChan = nil } diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index 31a41f53511523..0a47687d9afbac 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -2364,6 +2364,50 @@ func TestTransportCancelRequest(t *testing.T) { } } +func testTransportCancelRequestInDo(t *testing.T, body io.Reader) { + setParallel(t) + defer afterTest(t) + if testing.Short() { + t.Skip("skipping test in -short mode") + } + unblockc := make(chan bool) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + <-unblockc + })) + defer ts.Close() + defer close(unblockc) + + c := ts.Client() + tr := c.Transport.(*Transport) + + donec := make(chan bool) + req, _ := NewRequest("GET", ts.URL, body) + go func() { + defer close(donec) + c.Do(req) + }() + start := time.Now() + timeout := 10 * time.Second + for time.Since(start) < timeout { + time.Sleep(100 * time.Millisecond) + tr.CancelRequest(req) + select { + case <-donec: + return + default: + } + } + t.Errorf("Do of canceled request has not returned after %v", timeout) +} + +func TestTransportCancelRequestInDo(t *testing.T) { + testTransportCancelRequestInDo(t, nil) +} + +func TestTransportCancelRequestWithBodyInDo(t *testing.T) { + testTransportCancelRequestInDo(t, bytes.NewBuffer([]byte{0})) +} + func TestTransportCancelRequestInDial(t *testing.T) { defer afterTest(t) if testing.Short() {