diff --git a/services/horizon/internal/app.go b/services/horizon/internal/app.go index b1cd7a1c85..27ca8db6ef 100644 --- a/services/horizon/internal/app.go +++ b/services/horizon/internal/app.go @@ -494,7 +494,7 @@ func (a *App) init() error { a.UpdateStellarCoreInfo(a.ctx) // horizon-db and core-db - mustInitHorizonDB(a) + dbServerSideTimeout := mustInitHorizonDB(a) if a.config.Ingest { // ingester @@ -532,6 +532,7 @@ func (a *App) init() error { SSEUpdateFrequency: a.config.SSEUpdateFrequency, StaleThreshold: a.config.StaleThreshold, ConnectionTimeout: a.config.ConnectionTimeout, + DBServerSideTimeout: dbServerSideTimeout, MaxHTTPRequestSize: a.config.MaxHTTPRequestSize, NetworkPassphrase: a.config.NetworkPassphrase, MaxPathLength: a.config.MaxPathLength, diff --git a/services/horizon/internal/httpx/middleware.go b/services/horizon/internal/httpx/middleware.go index cdcd7f4e3c..9a75db9a7c 100644 --- a/services/horizon/internal/httpx/middleware.go +++ b/services/horizon/internal/httpx/middleware.go @@ -188,7 +188,7 @@ func recoverMiddleware(h http.Handler) http.Handler { // NewHistoryMiddleware adds session to the request context and ensures Horizon // is not in a stale state, which is when the difference between latest core // ledger and latest history ledger is higher than the given threshold -func NewHistoryMiddleware(ledgerState *ledger.State, staleThreshold int32, session db.SessionInterface) func(http.Handler) http.Handler { +func NewHistoryMiddleware(ledgerState *ledger.State, staleThreshold int32, session db.SessionInterface, contextDBTimeout time.Duration) func(http.Handler) http.Handler { return func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -196,6 +196,7 @@ func NewHistoryMiddleware(ledgerState *ledger.State, staleThreshold int32, sessi if routePattern := supportHttp.GetChiRoutePattern(r); routePattern != "" { ctx = context.WithValue(ctx, &db.RouteContextKey, routePattern) } + ctx = setContextDBTimeout(contextDBTimeout, ctx) if staleThreshold > 0 { ls := ledgerState.CurrentStatus() isStale := (ls.CoreLatest - ls.HistoryLatest) > int32(staleThreshold) @@ -229,6 +230,7 @@ func NewHistoryMiddleware(ledgerState *ledger.State, staleThreshold int32, sessi // returning invalid data to the user) type StateMiddleware struct { HorizonSession db.SessionInterface + ContextDBTimeout time.Duration NoStateVerification bool } @@ -267,6 +269,7 @@ func (m *StateMiddleware) WrapFunc(h http.HandlerFunc) http.HandlerFunc { if routePattern := supportHttp.GetChiRoutePattern(r); routePattern != "" { ctx = context.WithValue(ctx, &db.RouteContextKey, routePattern) } + ctx = setContextDBTimeout(m.ContextDBTimeout, ctx) session := m.HorizonSession.Clone() q := &history.Q{session} sseRequest := render.Negotiate(r) == render.MimeEventStream @@ -335,6 +338,13 @@ func (m *StateMiddleware) WrapFunc(h http.HandlerFunc) http.HandlerFunc { } } +func setContextDBTimeout(timeout time.Duration, ctx context.Context) context.Context { + if timeout > 0 { + ctx = context.WithValue(ctx, &db.DeadlineCtxKey, time.Now().Add(timeout)) + } + return ctx +} + // WrapFunc executes the middleware on a given HTTP handler function func (m *StateMiddleware) Wrap(h http.Handler) http.Handler { return m.WrapFunc(h.ServeHTTP) diff --git a/services/horizon/internal/httpx/router.go b/services/horizon/internal/httpx/router.go index 8fa57d0379..8e2c9de078 100644 --- a/services/horizon/internal/httpx/router.go +++ b/services/horizon/internal/httpx/router.go @@ -38,6 +38,7 @@ type RouterConfig struct { SSEUpdateFrequency time.Duration StaleThreshold uint ConnectionTimeout time.Duration + DBServerSideTimeout bool MaxHTTPRequestSize uint NetworkPassphrase string MaxPathLength uint @@ -137,8 +138,13 @@ func (r *Router) addMiddleware(config *RouterConfig, } func (r *Router) addRoutes(config *RouterConfig, rateLimiter *throttled.HTTPRateLimiter, ledgerState *ledger.State) { + var contextDBTimeout time.Duration + if config.DBServerSideTimeout { + contextDBTimeout = config.ConnectionTimeout * 15 + } stateMiddleware := StateMiddleware{ - HorizonSession: config.DBSession, + HorizonSession: config.DBSession, + ContextDBTimeout: contextDBTimeout, } r.Method(http.MethodGet, "/health", config.HealthCheck) @@ -156,7 +162,7 @@ func (r *Router) addRoutes(config *RouterConfig, rateLimiter *throttled.HTTPRate LedgerSourceFactory: historyLedgerSourceFactory{ledgerState: ledgerState, updateFrequency: config.SSEUpdateFrequency}, } - historyMiddleware := NewHistoryMiddleware(ledgerState, int32(config.StaleThreshold), config.DBSession) + historyMiddleware := NewHistoryMiddleware(ledgerState, int32(config.StaleThreshold), config.DBSession, contextDBTimeout) // State endpoints behind stateMiddleware r.Group(func(r chi.Router) { r.Route("/accounts", func(r chi.Router) { diff --git a/services/horizon/internal/init.go b/services/horizon/internal/init.go index 0c0fe2c2cd..d7fedaf32c 100644 --- a/services/horizon/internal/init.go +++ b/services/horizon/internal/init.go @@ -30,8 +30,9 @@ func mustNewDBSession(subservice db.Subservice, databaseURL string, maxIdle, max return db.RegisterMetrics(session, "horizon", subservice, registry) } -func mustInitHorizonDB(app *App) { +func mustInitHorizonDB(app *App) bool { log.Infof("Initializing database...") + var dbServerSideTimeout bool maxIdle := app.config.HorizonDBMaxIdleConnections maxOpen := app.config.HorizonDBMaxOpenConnections @@ -55,6 +56,7 @@ func mustInitHorizonDB(app *App) { db.StatementTimeout(app.config.ConnectionTimeout), db.IdleTransactionTimeout(app.config.ConnectionTimeout), ) + dbServerSideTimeout = true } app.historyQ = &history.Q{mustNewDBSession( db.HistorySubservice, @@ -70,6 +72,7 @@ func mustInitHorizonDB(app *App) { db.StatementTimeout(app.config.ConnectionTimeout), db.IdleTransactionTimeout(app.config.ConnectionTimeout), } + dbServerSideTimeout = true app.historyQ = &history.Q{mustNewDBSession( db.HistorySubservice, app.config.RoDatabaseURL, @@ -87,6 +90,8 @@ func mustInitHorizonDB(app *App) { app.prometheusRegistry, )} } + + return dbServerSideTimeout } func initIngester(app *App) { diff --git a/services/horizon/internal/middleware_test.go b/services/horizon/internal/middleware_test.go index 40a74ea69e..269269bf8e 100644 --- a/services/horizon/internal/middleware_test.go +++ b/services/horizon/internal/middleware_test.go @@ -402,7 +402,7 @@ func TestCheckHistoryStaleMiddleware(t *testing.T) { } ledgerState := &ledger.State{} ledgerState.SetStatus(state) - historyMiddleware := httpx.NewHistoryMiddleware(ledgerState, testCase.staleThreshold, tt.HorizonSession()) + historyMiddleware := httpx.NewHistoryMiddleware(ledgerState, testCase.staleThreshold, tt.HorizonSession(), 0) handler := chi.NewRouter() handler.With(historyMiddleware).MethodFunc("GET", "/", endpoint) w := httptest.NewRecorder() diff --git a/support/db/main.go b/support/db/main.go index dca23526ee..e3851cbd18 100644 --- a/support/db/main.go +++ b/support/db/main.go @@ -21,6 +21,7 @@ import ( "github.com/Masterminds/squirrel" "github.com/jmoiron/sqlx" + "github.com/stellar/go/support/errors" // Enable postgres @@ -119,6 +120,7 @@ type Session struct { DB *sqlx.DB tx *sqlx.Tx + txCancel context.CancelFunc txOptions *sql.TxOptions errorHandlers []ErrorHandlerFunc } diff --git a/support/db/session.go b/support/db/session.go index 472fc40a37..90af80e93f 100644 --- a/support/db/session.go +++ b/support/db/session.go @@ -12,28 +12,65 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/jmoiron/sqlx" "github.com/lib/pq" + "github.com/stellar/go/support/db/sqlutils" "github.com/stellar/go/support/errors" "github.com/stellar/go/support/log" ) +var DeadlineCtxKey = CtxKey("deadline") + +func noop() {} + +// context() checks if there is a override on the context timeout which is configured using DeadlineCtxKey. +// If the override exists, we return a new context with the desired deadline. Otherwise, we return the +// original context. +// Note that the override will not be applied if requestCtx has already been terminated. +func (s *Session) context(requestCtx context.Context) (context.Context, context.CancelFunc, error) { + deadline, ok := requestCtx.Value(&DeadlineCtxKey).(time.Time) + if !ok { + return requestCtx, noop, nil + } + + // if requestCtx is already terminated don't proceed with the db statement + switch { + case requestCtx.Err() == context.Canceled: + return requestCtx, noop, ErrCancelled + case requestCtx.Err() == context.DeadlineExceeded: + return requestCtx, noop, ErrTimeout + case requestCtx.Err() != nil: + return requestCtx, noop, requestCtx.Err() + } + + ctx, cancel := context.WithDeadline(context.Background(), deadline) + return ctx, cancel, nil +} + // Begin binds this session to a new transaction. func (s *Session) Begin(ctx context.Context) error { if s.tx != nil { return errors.New("already in transaction") } + ctx, cancel, err := s.context(ctx) + if err != nil { + cancel() + return err + } tx, err := s.DB.BeginTxx(ctx, nil) if err != nil { if knownErr := s.handleError(err, ctx); knownErr != nil { + cancel() return knownErr } + cancel() return errors.Wrap(err, "beginx failed") } log.Debug("sql: begin") s.tx = tx s.txOptions = nil + s.txCancel = cancel return nil } @@ -43,19 +80,27 @@ func (s *Session) BeginTx(ctx context.Context, opts *sql.TxOptions) error { if s.tx != nil { return errors.New("already in transaction") } + ctx, cancel, err := s.context(ctx) + if err != nil { + cancel() + return err + } tx, err := s.DB.BeginTxx(ctx, opts) if err != nil { if knownErr := s.handleError(err, ctx); knownErr != nil { + cancel() return knownErr } + cancel() return errors.Wrap(err, "beginTx failed") } log.Debug("sql: begin") s.tx = tx s.txOptions = opts + s.txCancel = cancel return nil } @@ -93,6 +138,8 @@ func (s *Session) Commit() error { log.Debug("sql: commit") s.tx = nil s.txOptions = nil + s.txCancel() + s.txCancel = nil if knownErr := s.handleError(err, context.Background()); knownErr != nil { return knownErr @@ -135,7 +182,13 @@ func (s *Session) Get(ctx context.Context, dest interface{}, query sq.Sqlizer) e // GetRaw runs `query` with `args`, setting the first result found on // `dest`, if any. func (s *Session) GetRaw(ctx context.Context, dest interface{}, query string, args ...interface{}) error { - query, err := s.ReplacePlaceholders(query) + ctx, cancel, err := s.context(ctx) + defer cancel() + if err != nil { + return err + } + + query, err = s.ReplacePlaceholders(query) if err != nil { return errors.Wrap(err, "replace placeholders failed") } @@ -204,7 +257,13 @@ func (s *Session) ExecAll(ctx context.Context, script string) error { // ExecRaw runs `query` with `args` func (s *Session) ExecRaw(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { - query, err := s.ReplacePlaceholders(query) + ctx, cancel, err := s.context(ctx) + defer cancel() + if err != nil { + return nil, err + } + + query, err = s.ReplacePlaceholders(query) if err != nil { return nil, errors.Wrap(err, "replace placeholders failed") } @@ -304,7 +363,13 @@ func (s *Session) Query(ctx context.Context, query sq.Sqlizer) (*sqlx.Rows, erro // QueryRaw runs `query` with `args` func (s *Session) QueryRaw(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) { - query, err := s.ReplacePlaceholders(query) + ctx, cancel, err := s.context(ctx) + defer cancel() + if err != nil { + return nil, err + } + + query, err = s.ReplacePlaceholders(query) if err != nil { return nil, errors.Wrap(err, "replace placeholders failed") } @@ -350,6 +415,8 @@ func (s *Session) Rollback() error { log.Debug("sql: rollback") s.tx = nil s.txOptions = nil + s.txCancel() + s.txCancel = nil if knownErr := s.handleError(err, context.Background()); knownErr != nil { return knownErr @@ -381,8 +448,14 @@ func (s *Session) SelectRaw( query string, args ...interface{}, ) error { + ctx, cancel, err := s.context(ctx) + defer cancel() + if err != nil { + return err + } + s.clearSliceIfPossible(dest) - query, err := s.ReplacePlaceholders(query) + query, err = s.ReplacePlaceholders(query) if err != nil { return errors.Wrap(err, "replace placeholders failed") } diff --git a/support/db/session_test.go b/support/db/session_test.go index 1fd2a3902b..00718c9b2e 100644 --- a/support/db/session_test.go +++ b/support/db/session_test.go @@ -6,12 +6,12 @@ import ( "testing" "time" - //"github.com/lib/pq" "github.com/lib/pq" "github.com/prometheus/client_golang/prometheus" - "github.com/stellar/go/support/db/dbtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/stellar/go/support/db/dbtest" ) func TestContextTimeoutDuringSql(t *testing.T) { @@ -140,6 +140,38 @@ func TestStatementTimeout(t *testing.T) { assertDbErrorMetrics(reg, "n/a", "57014", "statement_timeout", assert) } +func TestDeadlineOverride(t *testing.T) { + db := dbtest.Postgres(t).Load(testSchema) + defer db.Close() + + sess := &Session{DB: db.Open()} + defer sess.DB.Close() + + resultCtx, _, err := sess.context(context.Background()) + assert.NoError(t, err) + _, ok := resultCtx.Deadline() + assert.False(t, ok) + + deadline := time.Now().Add(time.Hour) + requestCtx := context.WithValue(context.Background(), &DeadlineCtxKey, deadline) + resultCtx, _, err = sess.context(requestCtx) + assert.NoError(t, err) + d, ok := resultCtx.Deadline() + assert.True(t, ok) + assert.Equal(t, deadline, d) + + requestCtx, cancel := context.WithDeadline(requestCtx, time.Now().Add(time.Minute*30)) + resultCtx, _, err = sess.context(requestCtx) + assert.NoError(t, err) + d, ok = resultCtx.Deadline() + assert.True(t, ok) + assert.Equal(t, deadline, d) + + cancel() + _, _, err = sess.context(requestCtx) + assert.EqualError(t, err, "canceling statement due to user request") +} + func TestSession(t *testing.T) { db := dbtest.Postgres(t).Load(testSchema) defer db.Close()