diff --git a/services/horizon/CHANGELOG.md b/services/horizon/CHANGELOG.md index 7426381add..0aa07fa9fc 100644 --- a/services/horizon/CHANGELOG.md +++ b/services/horizon/CHANGELOG.md @@ -7,6 +7,8 @@ file. This project adheres to [Semantic Versioning](http://semver.org/). ### Changes * Return inner and outer result codes for fee bump transactions ([4081](https://github.com/stellar/go/pull/4081)) +* Generate Http Status code of 499 for Client Disconnects, should propagate into `horizon_http_requests_duration_seconds_count` + metric key with status=499 label. ([4098](horizon_http_requests_duration_seconds_count)) ## v2.11.0 diff --git a/services/horizon/internal/actions/submit_transaction.go b/services/horizon/internal/actions/submit_transaction.go index 97b3dd3580..703412c801 100644 --- a/services/horizon/internal/actions/submit_transaction.go +++ b/services/horizon/internal/actions/submit_transaction.go @@ -1,6 +1,7 @@ package actions import ( + "context" "encoding/hex" "mime" "net/http" @@ -16,8 +17,16 @@ import ( "github.com/stellar/go/xdr" ) +type NetworkSubmitter interface { + Submit( + ctx context.Context, + rawTx string, + envelope xdr.TransactionEnvelope, + hash string) <-chan txsub.Result +} + type SubmitTransactionHandler struct { - Submitter *txsub.System + Submitter NetworkSubmitter NetworkPassphrase string CoreStateGetter } @@ -78,7 +87,7 @@ func (handler SubmitTransactionHandler) response(r *http.Request, info envelopeI } if result.Err == txsub.ErrCanceled { - return nil, &hProblem.Timeout + return nil, &hProblem.ClientDisconnected } switch err := result.Err.(type) { @@ -153,6 +162,9 @@ func (handler SubmitTransactionHandler) GetResource(w HeaderWriter, r *http.Requ case result := <-submission: return handler.response(r, info, result) case <-r.Context().Done(): - return nil, &hProblem.Timeout + if r.Context().Err() == context.Canceled { + return nil, hProblem.ClientDisconnected + } + return nil, hProblem.Timeout } } diff --git a/services/horizon/internal/actions/submit_transaction_test.go b/services/horizon/internal/actions/submit_transaction_test.go index 129f9de6f5..72ccb5a297 100644 --- a/services/horizon/internal/actions/submit_transaction_test.go +++ b/services/horizon/internal/actions/submit_transaction_test.go @@ -1,15 +1,20 @@ package actions import ( + "context" "net/http" "net/http/httptest" "net/url" "strings" "testing" + "time" "github.com/stellar/go/network" "github.com/stellar/go/services/horizon/internal/corestate" + hProblem "github.com/stellar/go/services/horizon/internal/render/problem" + "github.com/stellar/go/services/horizon/internal/txsub" "github.com/stellar/go/support/render/problem" + "github.com/stellar/go/xdr" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -35,6 +40,19 @@ func (m *coreStateGetterMock) GetCoreState() corestate.State { return a.Get(0).(corestate.State) } +type networkSubmitterMock struct { + mock.Mock +} + +func (m *networkSubmitterMock) Submit( + ctx context.Context, + rawTx string, + envelope xdr.TransactionEnvelope, + hash string) <-chan txsub.Result { + a := m.Called() + return a.Get(0).(chan txsub.Result) +} + func TestStellarCoreNotSynced(t *testing.T) { mock := &coreStateGetterMock{} mock.On("GetCoreState").Return(corestate.State{ @@ -64,3 +82,78 @@ func TestStellarCoreNotSynced(t *testing.T) { assert.Equal(t, "stale_history", err.(problem.P).Type) assert.Equal(t, "Historical DB Is Too Stale", err.(problem.P).Title) } + +func TestTimeoutSubmission(t *testing.T) { + mockSubmitChannel := make(chan txsub.Result) + + mock := &coreStateGetterMock{} + mock.On("GetCoreState").Return(corestate.State{ + Synced: true, + }) + + mockSubmitter := &networkSubmitterMock{} + mockSubmitter.On("Submit").Return(mockSubmitChannel) + + handler := SubmitTransactionHandler{ + Submitter: mockSubmitter, + NetworkPassphrase: network.PublicNetworkPassphrase, + CoreStateGetter: mock, + } + + form := url.Values{} + form.Set("tx", "AAAAAAGUcmKO5465JxTSLQOQljwk2SfqAJmZSG6JH6wtqpwhAAABLAAAAAAAAAABAAAAAAAAAAEAAAALaGVsbG8gd29ybGQAAAAAAwAAAAAAAAAAAAAAABbxCy3mLg3hiTqX4VUEEp60pFOrJNxYM1JtxXTwXhY2AAAAAAvrwgAAAAAAAAAAAQAAAAAW8Qst5i4N4Yk6l+FVBBKetKRTqyTcWDNSbcV08F4WNgAAAAAN4Lazj4x61AAAAAAAAAAFAAAAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABLaqcIQAAAEBKwqWy3TaOxoGnfm9eUjfTRBvPf34dvDA0Nf+B8z4zBob90UXtuCqmQqwMCyH+okOI3c05br3khkH0yP4kCwcE") + + request, err := http.NewRequest( + "POST", + "https://horizon.stellar.org/transactions", + strings.NewReader(form.Encode()), + ) + + require.NoError(t, err) + request.Header.Add("Content-Type", "application/x-www-form-urlencoded") + ctx, cancel := context.WithTimeout(request.Context(), time.Duration(0)) + defer cancel() + request = request.WithContext(ctx) + + w := httptest.NewRecorder() + _, err = handler.GetResource(w, request) + assert.Error(t, err) + assert.Equal(t, hProblem.Timeout, err) +} + +func TestClientDisconnectSubmission(t *testing.T) { + mockSubmitChannel := make(chan txsub.Result) + + mock := &coreStateGetterMock{} + mock.On("GetCoreState").Return(corestate.State{ + Synced: true, + }) + + mockSubmitter := &networkSubmitterMock{} + mockSubmitter.On("Submit").Return(mockSubmitChannel) + + handler := SubmitTransactionHandler{ + Submitter: mockSubmitter, + NetworkPassphrase: network.PublicNetworkPassphrase, + CoreStateGetter: mock, + } + + form := url.Values{} + form.Set("tx", "AAAAAAGUcmKO5465JxTSLQOQljwk2SfqAJmZSG6JH6wtqpwhAAABLAAAAAAAAAABAAAAAAAAAAEAAAALaGVsbG8gd29ybGQAAAAAAwAAAAAAAAAAAAAAABbxCy3mLg3hiTqX4VUEEp60pFOrJNxYM1JtxXTwXhY2AAAAAAvrwgAAAAAAAAAAAQAAAAAW8Qst5i4N4Yk6l+FVBBKetKRTqyTcWDNSbcV08F4WNgAAAAAN4Lazj4x61AAAAAAAAAAFAAAAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABLaqcIQAAAEBKwqWy3TaOxoGnfm9eUjfTRBvPf34dvDA0Nf+B8z4zBob90UXtuCqmQqwMCyH+okOI3c05br3khkH0yP4kCwcE") + + request, err := http.NewRequest( + "POST", + "https://horizon.stellar.org/transactions", + strings.NewReader(form.Encode()), + ) + + require.NoError(t, err) + request.Header.Add("Content-Type", "application/x-www-form-urlencoded") + ctx, cancel := context.WithCancel(request.Context()) + cancel() + request = request.WithContext(ctx) + + w := httptest.NewRecorder() + _, err = handler.GetResource(w, request) + assert.Equal(t, hProblem.ClientDisconnected, err) +} diff --git a/services/horizon/internal/httpx/server.go b/services/horizon/internal/httpx/server.go index a5a51da5eb..c3f6983c2c 100644 --- a/services/horizon/internal/httpx/server.go +++ b/services/horizon/internal/httpx/server.go @@ -55,8 +55,9 @@ func init() { problem.RegisterError(db2.ErrInvalidOrder, problem.BadRequest) problem.RegisterError(sse.ErrRateLimited, hProblem.RateLimitExceeded) problem.RegisterError(context.DeadlineExceeded, hProblem.Timeout) - problem.RegisterError(context.Canceled, hProblem.ServiceUnavailable) - problem.RegisterError(db.ErrCancelled, hProblem.ServiceUnavailable) + problem.RegisterError(context.Canceled, hProblem.ClientDisconnected) + problem.RegisterError(db.ErrCancelled, hProblem.ClientDisconnected) + problem.RegisterError(db.ErrTimeout, hProblem.ServiceUnavailable) problem.RegisterError(db.ErrConflictWithRecovery, hProblem.ServiceUnavailable) problem.RegisterError(db.ErrBadConnection, hProblem.ServiceUnavailable) } diff --git a/services/horizon/internal/middleware_test.go b/services/horizon/internal/middleware_test.go index b411b23e3d..c6b617c2b3 100644 --- a/services/horizon/internal/middleware_test.go +++ b/services/horizon/internal/middleware_test.go @@ -306,6 +306,36 @@ func TestStateMiddleware(t *testing.T) { } } +func TestClientDisconnect(t *testing.T) { + tt := test.Start(t) + defer tt.Finish() + test.ResetHorizonDB(t, tt.HorizonDB) + + request, err := http.NewRequest("GET", "http://localhost/", nil) + tt.Assert.NoError(err) + + endpoint := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + } + + stateMiddleware := &httpx.StateMiddleware{ + HorizonSession: tt.HorizonSession(), + NoStateVerification: true, + } + handler := chi.NewRouter() + handler.With(stateMiddleware.Wrap).MethodFunc("GET", "/", endpoint) + w := httptest.NewRecorder() + + ctx, cancel := context.WithCancel(request.Context()) + defer cancel() + request = request.WithContext(ctx) + // cancel invocation simulates client disconnect in the context + cancel() + + handler.ServeHTTP(w, request) + tt.Assert.Equal(499, w.Code) +} + func TestCheckHistoryStaleMiddleware(t *testing.T) { tt := test.Start(t) defer tt.Finish() diff --git a/services/horizon/internal/render/problem/problem.go b/services/horizon/internal/render/problem/problem.go index 046c3b6604..0d206f255c 100644 --- a/services/horizon/internal/render/problem/problem.go +++ b/services/horizon/internal/render/problem/problem.go @@ -8,6 +8,17 @@ import ( // Well-known and reused problems below: var ( + + // ClientDisconnected, represented by a non-standard HTTP status code of 499, which was introduced by + // nginix.org(https://www.nginx.com/resources/wiki/extending/api/http/) as a way to capture this state. Use it as a shortcut + // in your actions. + ClientDisconnected = problem.P{ + Type: "client_disconnected", + Title: "Client Disconnected", + Status: 499, + Detail: "The client has closed the connection.", + } + // ServiceUnavailable is a well-known problem type. Use it as a shortcut // in your actions. ServiceUnavailable = problem.P{ diff --git a/services/horizon/internal/render/problem/problem_test.go b/services/horizon/internal/render/problem/problem_test.go index d193079f41..707c75b658 100644 --- a/services/horizon/internal/render/problem/problem_test.go +++ b/services/horizon/internal/render/problem/problem_test.go @@ -25,6 +25,7 @@ func TestCommonProblems(t *testing.T) { }{ {"NotFound", problem.NotFound, 404}, {"RateLimitExceeded", RateLimitExceeded, 429}, + {"ClientDisconneted", ClientDisconnected, 499}, } for _, tc := range testCases { diff --git a/services/horizon/internal/txsub/system.go b/services/horizon/internal/txsub/system.go index 52c344dfcc..9b0557d095 100644 --- a/services/horizon/internal/txsub/system.go +++ b/services/horizon/internal/txsub/system.go @@ -177,9 +177,8 @@ func (sys *System) Submit( sys.finish(ctx, hash, response, Result{Err: sr.Err}) return } - - if sys.waitUntilAccountSequence(ctx, db, sourceAddress, uint64(envelope.SeqNum())) { - sys.finish(ctx, hash, response, Result{Err: ErrCanceled}) + if err = sys.waitUntilAccountSequence(ctx, db, sourceAddress, uint64(envelope.SeqNum())); err != nil { + sys.finish(ctx, hash, response, Result{Err: err}) return } @@ -194,7 +193,7 @@ func (sys *System) Submit( } case <-ctx.Done(): - sys.finish(ctx, hash, response, Result{Err: ErrCanceled}) + sys.finish(ctx, hash, response, Result{Err: sys.deriveTxSubError(ctx)}) } return @@ -202,7 +201,7 @@ func (sys *System) Submit( // waitUntilAccountSequence blocks until either the context times out or the sequence number of the // given source account is greater than or equal to `seq` -func (sys *System) waitUntilAccountSequence(ctx context.Context, db HorizonDB, sourceAddress string, seq uint64) bool { +func (sys *System) waitUntilAccountSequence(ctx context.Context, db HorizonDB, sourceAddress string, seq uint64) error { timer := time.NewTimer(sys.accountSeqPollInterval) defer timer.Stop() @@ -222,19 +221,26 @@ func (sys *System) waitUntilAccountSequence(ctx context.Context, db HorizonDB, s Warn("missing sequence number for account") } if num >= seq { - return false + return nil } } select { case <-ctx.Done(): - return true + return sys.deriveTxSubError(ctx) case <-timer.C: timer.Reset(sys.accountSeqPollInterval) } } } +func (sys *System) deriveTxSubError(ctx context.Context) error { + if ctx.Err() == context.Canceled { + return ErrCanceled + } + return ErrTimeout +} + // Submit submits the provided base64 encoded transaction envelope to the // network using this submission system. func (sys *System) submitOnce(ctx context.Context, env string) SubmissionResult { diff --git a/services/horizon/internal/txsub/system_test.go b/services/horizon/internal/txsub/system_test.go index b6026b5adb..6f01134f39 100644 --- a/services/horizon/internal/txsub/system_test.go +++ b/services/horizon/internal/txsub/system_test.go @@ -148,6 +148,69 @@ func (suite *SystemTestSuite) TestSubmit_Basic() { assert.False(suite.T(), suite.submitter.WasSubmittedTo) } +func (suite *SystemTestSuite) TestTimeoutDuringSequnceLoop() { + var cancel context.CancelFunc + suite.ctx, cancel = context.WithTimeout(suite.ctx, time.Duration(0)) + defer cancel() + + suite.submitter.R = suite.badSeq + suite.db.On("BeginTx", &sql.TxOptions{ + Isolation: sql.LevelRepeatableRead, + ReadOnly: true, + }).Return(nil).Once() + suite.db.On("Rollback").Return(nil).Once() + suite.db.On("TransactionByHash", suite.ctx, mock.Anything, suite.successTx.Transaction.TransactionHash). + Return(sql.ErrNoRows).Once() + suite.db.On("NoRows", sql.ErrNoRows).Return(true).Once() + suite.db.On("GetSequenceNumbers", suite.ctx, []string{suite.unmuxedSource.Address()}). + Return(map[string]uint64{suite.unmuxedSource.Address(): 0}, nil) + + r := <-suite.system.Submit( + suite.ctx, + suite.successTx.Transaction.TxEnvelope, + suite.successXDR, + suite.successTx.Transaction.TransactionHash, + ) + + assert.NotNil(suite.T(), r.Err) + assert.Equal(suite.T(), ErrTimeout, r.Err) +} + +func (suite *SystemTestSuite) TestClientDisconnectedDuringSequnceLoop() { + var cancel context.CancelFunc + suite.ctx, cancel = context.WithCancel(suite.ctx) + + suite.submitter.R = suite.badSeq + suite.db.On("BeginTx", &sql.TxOptions{ + Isolation: sql.LevelRepeatableRead, + ReadOnly: true, + }).Return(nil).Once() + suite.db.On("Rollback").Return(nil).Once() + suite.db.On("TransactionByHash", suite.ctx, mock.Anything, suite.successTx.Transaction.TransactionHash). + Return(sql.ErrNoRows).Once() + suite.db.On("NoRows", sql.ErrNoRows).Return(true).Once() + suite.db.On("GetSequenceNumbers", suite.ctx, []string{suite.unmuxedSource.Address()}). + Return(map[string]uint64{suite.unmuxedSource.Address(): 0}, nil). + Run(func(args mock.Arguments) { + // simulate client disconnecting while looping on sequnce number check + cancel() + suite.ctx.Deadline() + }). + Once() + suite.db.On("GetSequenceNumbers", suite.ctx, []string{suite.unmuxedSource.Address()}). + Return(map[string]uint64{suite.unmuxedSource.Address(): 0}, nil) + + r := <-suite.system.Submit( + suite.ctx, + suite.successTx.Transaction.TxEnvelope, + suite.successXDR, + suite.successTx.Transaction.TransactionHash, + ) + + assert.NotNil(suite.T(), r.Err) + assert.Equal(suite.T(), ErrCanceled, r.Err) +} + func getMetricValue(metric prometheus.Metric) *dto.Metric { value := &dto.Metric{} err := metric.Write(value) diff --git a/support/db/main.go b/support/db/main.go index 2d0eaec21e..ed6e65285d 100644 --- a/support/db/main.go +++ b/support/db/main.go @@ -31,6 +31,9 @@ const ( ) var ( + // ErrTimeout is an error returned by Session methods when request has + // taken longer than context's deadline max duration + ErrTimeout = errors.New("canceling statement due to lack of response within timeout period") // ErrCancelled is an error returned by Session methods when request has // been cancelled (ex. context cancelled). ErrCancelled = errors.New("canceling statement due to user request") diff --git a/support/db/session.go b/support/db/session.go index 43e7121be0..4bc0218f90 100644 --- a/support/db/session.go +++ b/support/db/session.go @@ -23,7 +23,7 @@ func (s *Session) Begin() error { tx, err := s.DB.BeginTxx(context.Background(), nil) if err != nil { - if knownErr := s.replaceWithKnownError(err); knownErr != nil { + if knownErr := s.replaceWithKnownError(err, context.Background()); knownErr != nil { return knownErr } @@ -44,7 +44,7 @@ func (s *Session) BeginTx(opts *sql.TxOptions) error { tx, err := s.DB.BeginTxx(context.Background(), opts) if err != nil { - if knownErr := s.replaceWithKnownError(err); knownErr != nil { + if knownErr := s.replaceWithKnownError(err, context.Background()); knownErr != nil { return knownErr } @@ -142,7 +142,7 @@ func (s *Session) GetRaw(ctx context.Context, dest interface{}, query string, ar return nil } - if knownErr := s.replaceWithKnownError(err); knownErr != nil { + if knownErr := s.replaceWithKnownError(err, ctx); knownErr != nil { return knownErr } @@ -211,7 +211,7 @@ func (s *Session) ExecRaw(ctx context.Context, query string, args ...interface{} return result, nil } - if knownErr := s.replaceWithKnownError(err); knownErr != nil { + if knownErr := s.replaceWithKnownError(err, ctx); knownErr != nil { return nil, knownErr } @@ -230,10 +230,15 @@ func (s *Session) NoRows(err error) bool { // replaceWithKnownError tries to replace Postgres error with package error. // Returns a new error if the err is known. -func (s *Session) replaceWithKnownError(err error) error { +func (s *Session) replaceWithKnownError(err error, ctx context.Context) error { switch { - case strings.Contains(err.Error(), "pq: canceling statement due to user request"): + case ctx.Err() == context.Canceled: return ErrCancelled + case ctx.Err() == context.DeadlineExceeded: + // if libpq waits too long to obtain conn from pool, can get ctx timeout before server trip + return ErrTimeout + case strings.Contains(err.Error(), "pq: canceling statement due to user request"): + return ErrTimeout case strings.Contains(err.Error(), "pq: canceling statement due to conflict with recovery"): return ErrConflictWithRecovery case strings.Contains(err.Error(), "driver: bad connection"): @@ -267,7 +272,7 @@ func (s *Session) QueryRaw(ctx context.Context, query string, args ...interface{ return result, nil } - if knownErr := s.replaceWithKnownError(err); knownErr != nil { + if knownErr := s.replaceWithKnownError(err, ctx); knownErr != nil { return nil, knownErr } @@ -341,7 +346,7 @@ func (s *Session) SelectRaw( return nil } - if knownErr := s.replaceWithKnownError(err); knownErr != nil { + if knownErr := s.replaceWithKnownError(err, ctx); knownErr != nil { return knownErr } diff --git a/support/db/session_test.go b/support/db/session_test.go index 1cbccb9301..742167bee0 100644 --- a/support/db/session_test.go +++ b/support/db/session_test.go @@ -3,12 +3,50 @@ package db import ( "context" "testing" + "time" "github.com/stellar/go/support/db/dbtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func TestServerTimeout(t *testing.T) { + db := dbtest.Postgres(t).Load(testSchema) + defer db.Close() + + var cancel context.CancelFunc + ctx := context.Background() + ctx, cancel = context.WithTimeout(ctx, time.Duration(1)) + assert := assert.New(t) + + sess := &Session{DB: db.Open()} + defer sess.DB.Close() + defer cancel() + + var count int + err := sess.GetRaw(ctx, &count, "SELECT pg_sleep(2), COUNT(*) FROM people") + assert.ErrorIs(err, ErrTimeout, "long running db server operation past context timeout, should return timeout") +} + +func TestUserCancel(t *testing.T) { + db := dbtest.Postgres(t).Load(testSchema) + defer db.Close() + + var cancel context.CancelFunc + ctx := context.Background() + ctx, cancel = context.WithCancel(ctx) + assert := assert.New(t) + + sess := &Session{DB: db.Open()} + defer sess.DB.Close() + defer cancel() + + var count int + cancel() + err := sess.GetRaw(ctx, &count, "SELECT pg_sleep(2), COUNT(*) FROM people") + assert.ErrorIs(err, ErrCancelled, "any ongoing db server operation should return error immediately after user cancel") +} + func TestSession(t *testing.T) { db := dbtest.Postgres(t).Load(testSchema) defer db.Close()