Skip to content

Commit

Permalink
http2: remove afterReqBodyWriteError wrapper
Browse files Browse the repository at this point in the history
There was a case where we forgot to undo this wrapper. Instead of fixing
that case, I moved the implementation of ClientConn.RoundTrip into an
unexported method that returns the same info as a bool.

Fixes golang/go#22136

Change-Id: I7e5fc467f9c26fb74b9b83f2b3b7f8882645e34c
Reviewed-on: https://go-review.googlesource.com/75252
Reviewed-by: Brad Fitzpatrick <[email protected]>
Run-TryBot: Brad Fitzpatrick <[email protected]>
TryBot-Result: Gobot Gobot <[email protected]>
  • Loading branch information
tombergan committed Nov 1, 2017
1 parent c73622c commit ea0da6f
Showing 1 changed file with 27 additions and 36 deletions.
63 changes: 27 additions & 36 deletions http2/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,13 @@ func (cs *clientStream) checkResetOrDone() error {
}
}

func (cs *clientStream) getStartedWrite() bool {
cc := cs.cc
cc.mu.Lock()
defer cc.mu.Unlock()
return cs.startedWrite
}

func (cs *clientStream) abortRequestBodyWrite(err error) {
if err == nil {
panic("nil error")
Expand Down Expand Up @@ -349,14 +356,9 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res
return nil, err
}
traceGotConn(req, cc)
res, err := cc.RoundTrip(req)
res, gotErrAfterReqBodyWrite, err := cc.roundTrip(req)
if err != nil && retry <= 6 {
afterBodyWrite := false
if e, ok := err.(afterReqBodyWriteError); ok {
err = e
afterBodyWrite = true
}
if req, err = shouldRetryRequest(req, err, afterBodyWrite); err == nil {
if req, err = shouldRetryRequest(req, err, gotErrAfterReqBodyWrite); err == nil {
// After the first retry, do exponential backoff with 10% jitter.
if retry == 0 {
continue
Expand Down Expand Up @@ -394,16 +396,6 @@ var (
errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY")
)

// afterReqBodyWriteError is a wrapper around errors returned by ClientConn.RoundTrip.
// It is used to signal that err happened after part of Request.Body was sent to the server.
type afterReqBodyWriteError struct {
err error
}

func (e afterReqBodyWriteError) Error() string {
return e.err.Error() + "; some request body already written"
}

// shouldRetryRequest is called by RoundTrip when a request fails to get
// response headers. It is always called with a non-nil error.
// It returns either a request to retry (either the same request, or a
Expand Down Expand Up @@ -752,23 +744,28 @@ func actualContentLength(req *http.Request) int64 {
}

func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
resp, _, err := cc.roundTrip(req)
return resp, err
}

func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAfterReqBodyWrite bool, err error) {
if err := checkConnHeaders(req); err != nil {
return nil, err
return nil, false, err
}
if cc.idleTimer != nil {
cc.idleTimer.Stop()
}

trailers, err := commaSeparatedTrailers(req)
if err != nil {
return nil, err
return nil, false, err
}
hasTrailers := trailers != ""

cc.mu.Lock()
if err := cc.awaitOpenSlotForRequest(req); err != nil {
cc.mu.Unlock()
return nil, err
return nil, false, err
}

body := req.Body
Expand Down Expand Up @@ -802,7 +799,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
hdrs, err := cc.encodeHeaders(req, requestedGzip, trailers, contentLen)
if err != nil {
cc.mu.Unlock()
return nil, err
return nil, false, err
}

cs := cc.newStream()
Expand All @@ -828,7 +825,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
// Don't bother sending a RST_STREAM (our write already failed;
// no need to keep writing)
traceWroteRequest(cs.trace, werr)
return nil, werr
return nil, false, werr
}

var respHeaderTimer <-chan time.Time
Expand All @@ -847,7 +844,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
bodyWritten := false
ctx := reqContext(req)

handleReadLoopResponse := func(re resAndError) (*http.Response, error) {
handleReadLoopResponse := func(re resAndError) (*http.Response, bool, error) {
res := re.res
if re.err != nil || res.StatusCode > 299 {
// On error or status code 3xx, 4xx, 5xx, etc abort any
Expand All @@ -863,18 +860,12 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
cs.abortRequestBodyWrite(errStopReqBodyWrite)
}
if re.err != nil {
cc.mu.Lock()
afterBodyWrite := cs.startedWrite
cc.mu.Unlock()
cc.forgetStreamID(cs.ID)
if afterBodyWrite {
return nil, afterReqBodyWriteError{re.err}
}
return nil, re.err
return nil, cs.getStartedWrite(), re.err
}
res.Request = req
res.TLS = cc.tlsState
return res, nil
return res, false, nil
}

for {
Expand All @@ -889,7 +880,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
}
cc.forgetStreamID(cs.ID)
return nil, errTimeout
return nil, cs.getStartedWrite(), errTimeout
case <-ctx.Done():
if !hasBody || bodyWritten {
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
Expand All @@ -898,7 +889,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
}
cc.forgetStreamID(cs.ID)
return nil, ctx.Err()
return nil, cs.getStartedWrite(), ctx.Err()
case <-req.Cancel:
if !hasBody || bodyWritten {
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
Expand All @@ -907,12 +898,12 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
}
cc.forgetStreamID(cs.ID)
return nil, errRequestCanceled
return nil, cs.getStartedWrite(), errRequestCanceled
case <-cs.peerReset:
// processResetStream already removed the
// stream from the streams map; no need for
// forgetStreamID.
return nil, cs.resetErr
return nil, cs.getStartedWrite(), cs.resetErr
case err := <-bodyWriter.resc:
// Prefer the read loop's response, if available. Issue 16102.
select {
Expand All @@ -921,7 +912,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
default:
}
if err != nil {
return nil, err
return nil, cs.getStartedWrite(), err
}
bodyWritten = true
if d := cc.responseHeaderTimeout(); d != 0 {
Expand Down

0 comments on commit ea0da6f

Please sign in to comment.