diff --git a/internal/component/postgresql/dbaccess.go b/internal/component/postgresql/interfaces/interfaces.go similarity index 56% rename from internal/component/postgresql/dbaccess.go rename to internal/component/postgresql/interfaces/interfaces.go index d9e73e0853..fd5d23fc79 100644 --- a/internal/component/postgresql/dbaccess.go +++ b/internal/component/postgresql/interfaces/interfaces.go @@ -11,33 +11,30 @@ See the License for the specific language governing permissions and limitations under the License. */ -package postgresql +package pginterfaces import ( "context" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" - - "github.com/dapr/components-contrib/state" ) -// dbAccess is a private interface which enables unit testing of PostgreSQL. -type dbAccess interface { - Init(ctx context.Context, metadata state.Metadata) error - Set(ctx context.Context, req *state.SetRequest) error - Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) - BulkGet(ctx context.Context, req []state.GetRequest) ([]state.BulkGetResponse, error) - Delete(ctx context.Context, req *state.DeleteRequest) error - ExecuteMulti(ctx context.Context, req *state.TransactionalStateRequest) error - Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) - Close() error // io.Closer -} - // Interface that contains methods for querying. // Applies to *pgx.Conn, *pgxpool.Pool, and pgx.Tx -type dbquerier interface { +type DBQuerier interface { Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) Query(context.Context, string, ...interface{}) (pgx.Rows, error) QueryRow(context.Context, string, ...interface{}) pgx.Row } + +// Interface that applies to *pgxpool.Pool. +// We need this to be able to mock the connection in tests. +type PGXPoolConn interface { + DBQuerier + + Begin(context.Context) (pgx.Tx, error) + BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) + Ping(context.Context) error + Close() +} diff --git a/internal/component/postgresql/metadata.go b/internal/component/postgresql/metadata.go index a4dd06f076..891d6d65b5 100644 --- a/internal/component/postgresql/metadata.go +++ b/internal/component/postgresql/metadata.go @@ -33,7 +33,7 @@ const ( defaultTimeout = 20 // Default timeout for network requests, in seconds ) -type postgresMetadataStruct struct { +type pgMetadata struct { pgauth.PostgresAuthMetadata `mapstructure:",squash"` TableName string `mapstructure:"tableName"` // Could be in the format "schema.table" or just "table" @@ -42,7 +42,7 @@ type postgresMetadataStruct struct { CleanupInterval *time.Duration `mapstructure:"cleanupIntervalInSeconds"` } -func (m *postgresMetadataStruct) InitWithMetadata(meta state.Metadata, azureADEnabled bool) error { +func (m *pgMetadata) InitWithMetadata(meta state.Metadata, azureADEnabled bool) error { // Reset the object m.PostgresAuthMetadata.Reset() m.TableName = defaultTableName diff --git a/internal/component/postgresql/metadata_test.go b/internal/component/postgresql/metadata_test.go index bb4b612f18..b87be22d3a 100644 --- a/internal/component/postgresql/metadata_test.go +++ b/internal/component/postgresql/metadata_test.go @@ -25,7 +25,7 @@ import ( func TestMetadata(t *testing.T) { t.Run("missing connection string", func(t *testing.T) { - m := postgresMetadataStruct{} + m := pgMetadata{} props := map[string]string{} err := m.InitWithMetadata(state.Metadata{Base: metadata.Base{Properties: props}}, false) @@ -34,7 +34,7 @@ func TestMetadata(t *testing.T) { }) t.Run("has connection string", func(t *testing.T) { - m := postgresMetadataStruct{} + m := pgMetadata{} props := map[string]string{ "connectionString": "foo", } @@ -44,7 +44,7 @@ func TestMetadata(t *testing.T) { }) t.Run("default table name", func(t *testing.T) { - m := postgresMetadataStruct{} + m := pgMetadata{} props := map[string]string{ "connectionString": "foo", } @@ -55,7 +55,7 @@ func TestMetadata(t *testing.T) { }) t.Run("custom table name", func(t *testing.T) { - m := postgresMetadataStruct{} + m := pgMetadata{} props := map[string]string{ "connectionString": "foo", "tableName": "mytable", @@ -67,7 +67,7 @@ func TestMetadata(t *testing.T) { }) t.Run("default timeout", func(t *testing.T) { - m := postgresMetadataStruct{} + m := pgMetadata{} props := map[string]string{ "connectionString": "foo", } @@ -78,7 +78,7 @@ func TestMetadata(t *testing.T) { }) t.Run("invalid timeout", func(t *testing.T) { - m := postgresMetadataStruct{} + m := pgMetadata{} props := map[string]string{ "connectionString": "foo", "timeoutInSeconds": "NaN", @@ -89,7 +89,7 @@ func TestMetadata(t *testing.T) { }) t.Run("positive timeout", func(t *testing.T) { - m := postgresMetadataStruct{} + m := pgMetadata{} props := map[string]string{ "connectionString": "foo", "timeoutInSeconds": "42", @@ -101,7 +101,7 @@ func TestMetadata(t *testing.T) { }) t.Run("zero timeout", func(t *testing.T) { - m := postgresMetadataStruct{} + m := pgMetadata{} props := map[string]string{ "connectionString": "foo", "timeoutInSeconds": "0", @@ -112,7 +112,7 @@ func TestMetadata(t *testing.T) { }) t.Run("default cleanupIntervalInSeconds", func(t *testing.T) { - m := postgresMetadataStruct{} + m := pgMetadata{} props := map[string]string{ "connectionString": "foo", } @@ -124,7 +124,7 @@ func TestMetadata(t *testing.T) { }) t.Run("invalid cleanupIntervalInSeconds", func(t *testing.T) { - m := postgresMetadataStruct{} + m := pgMetadata{} props := map[string]string{ "connectionString": "foo", "cleanupIntervalInSeconds": "NaN", @@ -135,7 +135,7 @@ func TestMetadata(t *testing.T) { }) t.Run("positive cleanupIntervalInSeconds", func(t *testing.T) { - m := postgresMetadataStruct{} + m := pgMetadata{} props := map[string]string{ "connectionString": "foo", "cleanupIntervalInSeconds": "42", @@ -148,7 +148,7 @@ func TestMetadata(t *testing.T) { }) t.Run("zero cleanupIntervalInSeconds", func(t *testing.T) { - m := postgresMetadataStruct{} + m := pgMetadata{} props := map[string]string{ "connectionString": "foo", "cleanupIntervalInSeconds": "0", diff --git a/internal/component/postgresql/postgresdbaccess.go b/internal/component/postgresql/postgresdbaccess.go deleted file mode 100644 index 431279569c..0000000000 --- a/internal/component/postgresql/postgresdbaccess.go +++ /dev/null @@ -1,527 +0,0 @@ -/* -Copyright 2021 The Dapr Authors -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package postgresql - -import ( - "context" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "strconv" - "time" - - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgxpool" - - internalsql "github.com/dapr/components-contrib/internal/component/sql" - "github.com/dapr/components-contrib/state" - "github.com/dapr/components-contrib/state/query" - stateutils "github.com/dapr/components-contrib/state/utils" - "github.com/dapr/kit/logger" - "github.com/dapr/kit/ptr" -) - -// Interface that applies to *pgxpool.Pool. -// We need this to be able to mock the connection in tests. -type PGXPoolConn interface { - Begin(context.Context) (pgx.Tx, error) - BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) - Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) - Query(context.Context, string, ...interface{}) (pgx.Rows, error) - QueryRow(context.Context, string, ...interface{}) pgx.Row - Ping(context.Context) error - Close() -} - -// PostgresDBAccess implements dbaccess. -type PostgresDBAccess struct { - logger logger.Logger - metadata postgresMetadataStruct - db PGXPoolConn - - gc internalsql.GarbageCollector - - migrateFn func(context.Context, PGXPoolConn, MigrateOptions) error - setQueryFn func(*state.SetRequest, SetQueryOptions) string - etagColumn string - enableAzureAD bool -} - -// newPostgresDBAccess creates a new instance of postgresAccess. -func newPostgresDBAccess(logger logger.Logger, opts Options) *PostgresDBAccess { - logger.Debug("Instantiating new Postgres state store") - - return &PostgresDBAccess{ - logger: logger, - migrateFn: opts.MigrateFn, - setQueryFn: opts.SetQueryFn, - etagColumn: opts.ETagColumn, - enableAzureAD: opts.EnableAzureAD, - } -} - -// Init sets up Postgres connection and ensures that the state table exists. -func (p *PostgresDBAccess) Init(ctx context.Context, meta state.Metadata) error { - p.logger.Debug("Initializing Postgres state store") - - err := p.metadata.InitWithMetadata(meta, p.enableAzureAD) - if err != nil { - p.logger.Errorf("Failed to parse metadata: %v", err) - return err - } - - config, err := p.metadata.GetPgxPoolConfig() - if err != nil { - p.logger.Error(err) - return err - } - - connCtx, connCancel := context.WithTimeout(ctx, p.metadata.Timeout) - p.db, err = pgxpool.NewWithConfig(connCtx, config) - connCancel() - if err != nil { - err = fmt.Errorf("failed to connect to the database: %w", err) - p.logger.Error(err) - return err - } - - pingCtx, pingCancel := context.WithTimeout(ctx, p.metadata.Timeout) - err = p.db.Ping(pingCtx) - pingCancel() - if err != nil { - err = fmt.Errorf("failed to ping the database: %w", err) - p.logger.Error(err) - return err - } - - err = p.migrateFn(ctx, p.db, MigrateOptions{ - Logger: p.logger, - StateTableName: p.metadata.TableName, - MetadataTableName: p.metadata.MetadataTableName, - }) - if err != nil { - return err - } - - if p.metadata.CleanupInterval != nil { - gc, err := internalsql.ScheduleGarbageCollector(internalsql.GCOptions{ - Logger: p.logger, - UpdateLastCleanupQuery: fmt.Sprintf( - `INSERT INTO %[1]s (key, value) - VALUES ('last-cleanup', CURRENT_TIMESTAMP::text) - ON CONFLICT (key) - DO UPDATE SET value = CURRENT_TIMESTAMP::text - WHERE (EXTRACT('epoch' FROM CURRENT_TIMESTAMP - %[1]s.value::timestamp with time zone) * 1000)::bigint > $1`, - p.metadata.MetadataTableName, - ), - DeleteExpiredValuesQuery: fmt.Sprintf( - `DELETE FROM %s WHERE expiredate IS NOT NULL AND expiredate < CURRENT_TIMESTAMP`, - p.metadata.TableName, - ), - CleanupInterval: *p.metadata.CleanupInterval, - DBPgx: p.db, - }) - if err != nil { - return err - } - p.gc = gc - } - - return nil -} - -func (p *PostgresDBAccess) GetDB() *pgxpool.Pool { - // We can safely cast to *pgxpool.Pool because this method is never used in unit tests where we mock the DB - return p.db.(*pgxpool.Pool) -} - -// Set makes an insert or update to the database. -func (p *PostgresDBAccess) Set(ctx context.Context, req *state.SetRequest) error { - return p.doSet(ctx, p.db, req) -} - -func (p *PostgresDBAccess) doSet(parentCtx context.Context, db dbquerier, req *state.SetRequest) error { - err := state.CheckRequestOptions(req.Options) - if err != nil { - return err - } - - if req.Key == "" { - return errors.New("missing key in set operation") - } - - v := req.Value - byteArray, isBinary := req.Value.([]uint8) - if isBinary { - v = base64.StdEncoding.EncodeToString(byteArray) - } - - // Convert to json string - bt, _ := stateutils.Marshal(v, json.Marshal) - value := string(bt) - - // TTL - var ttlSeconds int - ttl, ttlerr := stateutils.ParseTTL(req.Metadata) - if ttlerr != nil { - return fmt.Errorf("error parsing TTL: %w", ttlerr) - } - if ttl != nil { - ttlSeconds = *ttl - } - - var ( - queryExpiredate string - params []any - ) - - if !req.HasETag() { - params = []any{req.Key, value, isBinary} - } else { - var etag64 uint64 - etag64, err = strconv.ParseUint(*req.ETag, 10, 32) - if err != nil { - return state.NewETagError(state.ETagInvalid, err) - } - params = []any{req.Key, value, isBinary, uint32(etag64)} - } - - if ttlSeconds > 0 { - queryExpiredate = "CURRENT_TIMESTAMP + interval '" + strconv.Itoa(ttlSeconds) + " seconds'" - } else { - queryExpiredate = "NULL" - } - - query := p.setQueryFn(req, SetQueryOptions{ - TableName: p.metadata.TableName, - ExpireDateValue: queryExpiredate, - }) - - result, err := db.Exec(parentCtx, query, params...) - if err != nil { - return err - } - if result.RowsAffected() != 1 { - if req.HasETag() { - return state.NewETagError(state.ETagMismatch, nil) - } - return errors.New("no item was updated") - } - - return nil -} - -// Get returns data from the database. If data does not exist for the key an empty state.GetResponse will be returned. -func (p *PostgresDBAccess) Get(parentCtx context.Context, req *state.GetRequest) (*state.GetResponse, error) { - if req.Key == "" { - return nil, errors.New("missing key in get operation") - } - - query := `SELECT - key, value, isbinary, ` + p.etagColumn + ` AS etag, expiredate - FROM ` + p.metadata.TableName + ` - WHERE - key = $1 - AND (expiredate IS NULL OR expiredate >= CURRENT_TIMESTAMP)` - ctx, cancel := context.WithTimeout(parentCtx, p.metadata.Timeout) - defer cancel() - row := p.db.QueryRow(ctx, query, req.Key) - _, value, etag, expireTime, err := readRow(row) - if err != nil { - // If no rows exist, return an empty response, otherwise return the error. - if errors.Is(err, pgx.ErrNoRows) { - return &state.GetResponse{}, nil - } - return nil, err - } - - resp := &state.GetResponse{ - Data: value, - ETag: etag, - } - - if expireTime != nil { - resp.Metadata = map[string]string{ - state.GetRespMetaKeyTTLExpireTime: expireTime.UTC().Format(time.RFC3339), - } - } - - return resp, nil -} - -func (p *PostgresDBAccess) BulkGet(parentCtx context.Context, req []state.GetRequest) ([]state.BulkGetResponse, error) { - if len(req) == 0 { - return []state.BulkGetResponse{}, nil - } - - // Get all keys - keys := make([]string, len(req)) - for i, r := range req { - keys[i] = r.Key - } - - // Execute the query - query := `SELECT - key, value, isbinary, ` + p.etagColumn + ` AS etag, expiredate - FROM ` + p.metadata.TableName + ` - WHERE - key = ANY($1) - AND (expiredate IS NULL OR expiredate >= CURRENT_TIMESTAMP)` - ctx, cancel := context.WithTimeout(parentCtx, p.metadata.Timeout) - defer cancel() - rows, err := p.db.Query(ctx, query, keys) - if err != nil { - // If no rows exist, return an empty response, otherwise return the error. - if errors.Is(err, pgx.ErrNoRows) { - return []state.BulkGetResponse{}, nil - } - return nil, err - } - - // Scan all rows - var n int - res := make([]state.BulkGetResponse, len(req)) - foundKeys := make(map[string]struct{}, len(req)) - for ; rows.Next(); n++ { - if n >= len(req) { - // Sanity check to prevent panics, which should never happen - return nil, fmt.Errorf("query returned more records than expected (expected %d)", len(req)) - } - - r := state.BulkGetResponse{} - var expireTime *time.Time - r.Key, r.Data, r.ETag, expireTime, err = readRow(rows) - if err != nil { - r.Error = err.Error() - } - if expireTime != nil { - r.Metadata = map[string]string{ - state.GetRespMetaKeyTTLExpireTime: expireTime.UTC().Format(time.RFC3339), - } - } - res[n] = r - foundKeys[r.Key] = struct{}{} - } - - // Populate missing keys with empty values - // This is to ensure consistency with the other state stores that implement BulkGet as a loop over Get, and with the Get method - if len(foundKeys) < len(req) { - var ok bool - for _, r := range req { - _, ok = foundKeys[r.Key] - if !ok { - if n >= len(req) { - // Sanity check to prevent panics, which should never happen - return nil, fmt.Errorf("query returned more records than expected (expected %d)", len(req)) - } - res[n] = state.BulkGetResponse{ - Key: r.Key, - } - n++ - } - } - } - - return res[:n], nil -} - -func readRow(row pgx.Row) (key string, value []byte, etagS *string, expireTime *time.Time, err error) { - var ( - isBinary bool - etag pgtype.Int8 - expT pgtype.Timestamp - ) - err = row.Scan(&key, &value, &isBinary, &etag, &expT) - if err != nil { - return key, nil, nil, nil, err - } - - if etag.Valid { - etagS = ptr.Of(strconv.FormatInt(etag.Int64, 10)) - } - - if expT.Valid { - expireTime = &expT.Time - } - - if isBinary { - var ( - s string - data []byte - ) - - err = json.Unmarshal(value, &s) - if err != nil { - return key, nil, nil, nil, fmt.Errorf("failed to unmarshal JSON data: %w", err) - } - - data, err = base64.StdEncoding.DecodeString(s) - if err != nil { - return key, nil, nil, nil, fmt.Errorf("failed to decode base64 data: %w", err) - } - - return key, data, etagS, expireTime, nil - } - - return key, value, etagS, expireTime, nil -} - -// Delete removes an item from the state store. -func (p *PostgresDBAccess) Delete(ctx context.Context, req *state.DeleteRequest) (err error) { - return p.doDelete(ctx, p.db, req) -} - -func (p *PostgresDBAccess) doDelete(parentCtx context.Context, db dbquerier, req *state.DeleteRequest) (err error) { - if req.Key == "" { - return errors.New("missing key in delete operation") - } - - ctx, cancel := context.WithTimeout(parentCtx, p.metadata.Timeout) - defer cancel() - var result pgconn.CommandTag - if !req.HasETag() { - result, err = db.Exec(ctx, "DELETE FROM "+p.metadata.TableName+" WHERE key = $1", req.Key) - } else { - // Convert req.ETag to uint32 for postgres XID compatibility - var etag64 uint64 - etag64, err = strconv.ParseUint(*req.ETag, 10, 32) - if err != nil { - return state.NewETagError(state.ETagInvalid, err) - } - - result, err = db.Exec(ctx, "DELETE FROM "+p.metadata.TableName+" WHERE key = $1 AND $2 = "+p.etagColumn, req.Key, uint32(etag64)) - } - if err != nil { - return err - } - - rows := result.RowsAffected() - if rows != 1 && req.ETag != nil && *req.ETag != "" { - return state.NewETagError(state.ETagMismatch, nil) - } - - return nil -} - -func (p *PostgresDBAccess) ExecuteMulti(parentCtx context.Context, request *state.TransactionalStateRequest) error { - tx, err := p.beginTx(parentCtx) - if err != nil { - return err - } - defer p.rollbackTx(parentCtx, tx, "ExecMulti") - - for _, o := range request.Operations { - switch x := o.(type) { - case state.SetRequest: - err = p.doSet(parentCtx, tx, &x) - if err != nil { - return err - } - - case state.DeleteRequest: - err = p.doDelete(parentCtx, tx, &x) - if err != nil { - return err - } - - default: - return fmt.Errorf("unsupported operation: %s", o.Operation()) - } - } - - ctx, cancel := context.WithTimeout(parentCtx, p.metadata.Timeout) - err = tx.Commit(ctx) - cancel() - if err != nil { - return fmt.Errorf("failed to commit transaction: %w", err) - } - - return nil -} - -// Query executes a query against store. -func (p *PostgresDBAccess) Query(parentCtx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) { - q := &Query{ - query: "", - params: []any{}, - tableName: p.metadata.TableName, - etagColumn: p.etagColumn, - } - qbuilder := query.NewQueryBuilder(q) - if err := qbuilder.BuildQuery(&req.Query); err != nil { - return &state.QueryResponse{}, err - } - data, token, err := q.execute(parentCtx, p.logger, p.db) - if err != nil { - return &state.QueryResponse{}, err - } - - return &state.QueryResponse{ - Results: data, - Token: token, - }, nil -} - -func (p *PostgresDBAccess) CleanupExpired() error { - if p.gc != nil { - return p.gc.CleanupExpired() - } - return nil -} - -// Close implements io.Close. -func (p *PostgresDBAccess) Close() error { - if p.db != nil { - p.db.Close() - p.db = nil - } - - if p.gc != nil { - return p.gc.Close() - } - - return nil -} - -// GetCleanupInterval returns the cleanupInterval property. -// This is primarily used for tests. -func (p *PostgresDBAccess) GetCleanupInterval() *time.Duration { - return p.metadata.CleanupInterval -} - -// Internal function that begins a transaction. -func (p *PostgresDBAccess) beginTx(parentCtx context.Context) (pgx.Tx, error) { - ctx, cancel := context.WithTimeout(parentCtx, p.metadata.Timeout) - tx, err := p.db.Begin(ctx) - cancel() - if err != nil { - return nil, fmt.Errorf("failed to begin transaction: %w", err) - } - return tx, nil -} - -// Internal function that rolls back a transaction. -// Normally called as a deferred function in methods that use transactions. -// In case of errors, they are logged but not actioned upon. -func (p *PostgresDBAccess) rollbackTx(parentCtx context.Context, tx pgx.Tx, methodName string) { - rollbackCtx, rollbackCancel := context.WithTimeout(parentCtx, p.metadata.Timeout) - rollbackErr := tx.Rollback(rollbackCtx) - rollbackCancel() - if rollbackErr != nil && !errors.Is(rollbackErr, pgx.ErrTxClosed) { - p.logger.Errorf("Failed to rollback transaction in %s: %v", methodName, rollbackErr) - } -} diff --git a/internal/component/postgresql/postgresdbaccess_test.go b/internal/component/postgresql/postgresdbaccess_test.go deleted file mode 100644 index dfccd3313e..0000000000 --- a/internal/component/postgresql/postgresdbaccess_test.go +++ /dev/null @@ -1,234 +0,0 @@ -/* -Copyright 2021 The Dapr Authors -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package postgresql - -import ( - "context" - "encoding/json" - "testing" - "time" - - "github.com/google/uuid" - pgxmock "github.com/pashagolub/pgxmock/v2" - "github.com/stretchr/testify/assert" - - "github.com/dapr/components-contrib/state" - "github.com/dapr/kit/logger" -) - -type mocks struct { - db pgxmock.PgxPoolIface - pgDba *PostgresDBAccess -} - -type fakeItem struct { - Color string -} - -func TestMultiWithNoRequests(t *testing.T) { - // Arrange - m, _ := mockDatabase(t) - defer m.db.Close() - - m.db.ExpectBegin() - m.db.ExpectCommit() - // There's also a rollback called after a commit, which is expected and will not have effect - m.db.ExpectRollback() - - var operations []state.TransactionalStateOperation - - // Act - err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{ - Operations: operations, - }) - - // Assert - assert.NoError(t, err) -} - -func TestValidSetRequest(t *testing.T) { - // Arrange - m, _ := mockDatabase(t) - defer m.db.Close() - - setReq := createSetRequest() - operations := []state.TransactionalStateOperation{setReq} - val, _ := json.Marshal(setReq.Value) - - m.db.ExpectBegin() - m.db.ExpectExec("INSERT INTO"). - WithArgs(setReq.Key, string(val), false). - WillReturnResult(pgxmock.NewResult("INSERT", 1)) - m.db.ExpectCommit() - // There's also a rollback called after a commit, which is expected and will not have effect - m.db.ExpectRollback() - - // Act - err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{ - Operations: operations, - }) - - // Assert - assert.NoError(t, err) -} - -func TestInvalidMultiSetRequestNoKey(t *testing.T) { - // Arrange - m, _ := mockDatabase(t) - defer m.db.Close() - - m.db.ExpectBegin() - m.db.ExpectRollback() - - operations := []state.TransactionalStateOperation{ - state.SetRequest{Value: "value1"}, // Set request without key is not valid for Upsert operation - } - - // Act - err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{ - Operations: operations, - }) - - // Assert - assert.Error(t, err) -} - -func TestValidMultiDeleteRequest(t *testing.T) { - // Arrange - m, _ := mockDatabase(t) - defer m.db.Close() - - deleteReq := createDeleteRequest() - operations := []state.TransactionalStateOperation{deleteReq} - - m.db.ExpectBegin() - m.db.ExpectExec("DELETE FROM"). - WithArgs(deleteReq.Key). - WillReturnResult(pgxmock.NewResult("DELETE", 1)) - m.db.ExpectCommit() - // There's also a rollback called after a commit, which is expected and will not have effect - m.db.ExpectRollback() - - // Act - err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{ - Operations: operations, - }) - - // Assert - assert.NoError(t, err) -} - -func TestInvalidMultiDeleteRequestNoKey(t *testing.T) { - // Arrange - m, _ := mockDatabase(t) - defer m.db.Close() - - m.db.ExpectBegin() - m.db.ExpectRollback() - - operations := []state.TransactionalStateOperation{state.DeleteRequest{}} // Delete request without key is not valid for Delete operation - - // Act - err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{ - Operations: operations, - }) - - // Assert - assert.Error(t, err) -} - -func TestMultiOperationOrder(t *testing.T) { - // Arrange - m, _ := mockDatabase(t) - defer m.db.Close() - - operations := []state.TransactionalStateOperation{ - state.SetRequest{Key: "key1", Value: "value1"}, - state.DeleteRequest{Key: "key1"}, - } - - m.db.ExpectBegin() - m.db.ExpectExec("INSERT INTO"). - WithArgs("key1", `"value1"`, false). - WillReturnResult(pgxmock.NewResult("INSERT", 1)) - m.db.ExpectExec("DELETE FROM"). - WithArgs("key1"). - WillReturnResult(pgxmock.NewResult("DELETE", 1)) - m.db.ExpectCommit() - // There's also a rollback called after a commit, which is expected and will not have effect - m.db.ExpectRollback() - - // Act - err := m.pgDba.ExecuteMulti(context.Background(), &state.TransactionalStateRequest{ - Operations: operations, - }) - - // Assert - assert.NoError(t, err) -} - -func createSetRequest() state.SetRequest { - return state.SetRequest{ - Key: randomKey(), - Value: randomJSON(), - } -} - -func createDeleteRequest() state.DeleteRequest { - return state.DeleteRequest{ - Key: randomKey(), - } -} - -func mockDatabase(t *testing.T) (*mocks, error) { - logger := logger.NewLogger("test") - - db, err := pgxmock.NewPool() - if err != nil { - t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) - } - - dba := &PostgresDBAccess{ - metadata: postgresMetadataStruct{ - TableName: "state", - Timeout: 30 * time.Second, - }, - logger: logger, - db: db, - migrateFn: func(context.Context, PGXPoolConn, MigrateOptions) error { - return nil - }, - setQueryFn: func(*state.SetRequest, SetQueryOptions) string { - return `INSERT INTO state - (key, value, isbinary, expiredate) - VALUES - ($1, $2, $3, NULL)` - }, - } - - return &mocks{ - db: db, - pgDba: dba, - }, err -} - -func randomKey() string { - return uuid.New().String() -} - -func randomJSON() *fakeItem { - return &fakeItem{Color: randomKey()} -} diff --git a/internal/component/postgresql/postgresql.go b/internal/component/postgresql/postgresql.go index d7d0c3375f..fcebaee044 100644 --- a/internal/component/postgresql/postgresql.go +++ b/internal/component/postgresql/postgresql.go @@ -15,11 +15,27 @@ package postgresql import ( "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" "reflect" + "strconv" + "time" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" + + pginterfaces "github.com/dapr/components-contrib/internal/component/postgresql/interfaces" + internalsql "github.com/dapr/components-contrib/internal/component/sql" "github.com/dapr/components-contrib/metadata" "github.com/dapr/components-contrib/state" + "github.com/dapr/components-contrib/state/query" + stateutils "github.com/dapr/components-contrib/state/utils" "github.com/dapr/kit/logger" + "github.com/dapr/kit/ptr" ) // PostgreSQL state store. @@ -27,11 +43,19 @@ type PostgreSQL struct { state.BulkStore logger logger.Logger - dbaccess dbAccess + metadata pgMetadata + db pginterfaces.PGXPoolConn + + gc internalsql.GarbageCollector + + migrateFn func(context.Context, pginterfaces.PGXPoolConn, MigrateOptions) error + setQueryFn func(*state.SetRequest, SetQueryOptions) string + etagColumn string + enableAzureAD bool } type Options struct { - MigrateFn func(context.Context, PGXPoolConn, MigrateOptions) error + MigrateFn func(context.Context, pginterfaces.PGXPoolConn, MigrateOptions) error SetQueryFn func(*state.SetRequest, SetQueryOptions) string ETagColumn string EnableAzureAD bool @@ -50,24 +74,90 @@ type SetQueryOptions struct { // NewPostgreSQLStateStore creates a new instance of PostgreSQL state store. func NewPostgreSQLStateStore(logger logger.Logger, opts Options) state.Store { - dba := newPostgresDBAccess(logger, opts) - s := newPostgreSQLStateStore(logger, dba) + s := newPostgreSQLStateStore(logger, opts) s.BulkStore = state.NewDefaultBulkStore(s) return s } // newPostgreSQLStateStore creates a newPostgreSQLStateStore instance of a PostgreSQL state store. -// This unexported constructor allows injecting a dbAccess instance for unit testing. -func newPostgreSQLStateStore(logger logger.Logger, dba dbAccess) *PostgreSQL { +func newPostgreSQLStateStore(logger logger.Logger, opts Options) *PostgreSQL { return &PostgreSQL{ - logger: logger, - dbaccess: dba, + logger: logger, + migrateFn: opts.MigrateFn, + setQueryFn: opts.SetQueryFn, + etagColumn: opts.ETagColumn, + enableAzureAD: opts.EnableAzureAD, } } -// Init initializes the SQL server state store. -func (p *PostgreSQL) Init(ctx context.Context, metadata state.Metadata) error { - return p.dbaccess.Init(ctx, metadata) +// Init sets up Postgres connection and performs migrations. +func (p *PostgreSQL) Init(ctx context.Context, meta state.Metadata) error { + err := p.metadata.InitWithMetadata(meta, p.enableAzureAD) + if err != nil { + p.logger.Errorf("Failed to parse metadata: %v", err) + return err + } + + config, err := p.metadata.GetPgxPoolConfig() + if err != nil { + p.logger.Error(err) + return err + } + + connCtx, connCancel := context.WithTimeout(ctx, p.metadata.Timeout) + p.db, err = pgxpool.NewWithConfig(connCtx, config) + connCancel() + if err != nil { + err = fmt.Errorf("failed to connect to the database: %w", err) + p.logger.Error(err) + return err + } + + pingCtx, pingCancel := context.WithTimeout(ctx, p.metadata.Timeout) + err = p.db.Ping(pingCtx) + pingCancel() + if err != nil { + err = fmt.Errorf("failed to ping the database: %w", err) + p.logger.Error(err) + return err + } + + err = p.migrateFn(ctx, p.db, MigrateOptions{ + Logger: p.logger, + StateTableName: p.metadata.TableName, + MetadataTableName: p.metadata.MetadataTableName, + }) + if err != nil { + return err + } + + if p.metadata.CleanupInterval != nil { + gc, err := internalsql.ScheduleGarbageCollector(internalsql.GCOptions{ + Logger: p.logger, + UpdateLastCleanupQuery: func(arg any) (string, any) { + return fmt.Sprintf( + `INSERT INTO %[1]s (key, value) + VALUES ('last-cleanup', CURRENT_TIMESTAMP::text) + ON CONFLICT (key) + DO UPDATE SET value = CURRENT_TIMESTAMP::text + WHERE (EXTRACT('epoch' FROM CURRENT_TIMESTAMP - %[1]s.value::timestamp with time zone) * 1000)::bigint > $1`, + p.metadata.MetadataTableName, + ), arg + }, + DeleteExpiredValuesQuery: fmt.Sprintf( + `DELETE FROM %s WHERE expiredate IS NOT NULL AND expiredate < CURRENT_TIMESTAMP`, + p.metadata.TableName, + ), + CleanupInterval: *p.metadata.CleanupInterval, + DB: internalsql.AdaptPgxConn(p.db), + }) + if err != nil { + return err + } + p.gc = gc + } + + return nil } // Features returns the features available in this state store. @@ -80,53 +170,390 @@ func (p *PostgreSQL) Features() []state.Feature { } } -// Delete removes an entity from the store. -func (p *PostgreSQL) Delete(ctx context.Context, req *state.DeleteRequest) error { - return p.dbaccess.Delete(ctx, req) +func (p *PostgreSQL) GetDB() *pgxpool.Pool { + // We can safely cast to *pgxpool.Pool because this method is never used in unit tests where we mock the DB + return p.db.(*pgxpool.Pool) } -// Get returns an entity from store. -func (p *PostgreSQL) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { - return p.dbaccess.Get(ctx, req) +// Set makes an insert or update to the database. +func (p *PostgreSQL) Set(ctx context.Context, req *state.SetRequest) error { + return p.doSet(ctx, p.db, req) } -// BulkGet performs a bulks get operations. -// Options are ignored because this component requests all values in a single query. -func (p *PostgreSQL) BulkGet(ctx context.Context, req []state.GetRequest, _ state.BulkGetOpts) ([]state.BulkGetResponse, error) { - return p.dbaccess.BulkGet(ctx, req) +func (p *PostgreSQL) doSet(parentCtx context.Context, db pginterfaces.DBQuerier, req *state.SetRequest) error { + err := state.CheckRequestOptions(req.Options) + if err != nil { + return err + } + + if req.Key == "" { + return errors.New("missing key in set operation") + } + + v := req.Value + byteArray, isBinary := req.Value.([]uint8) + if isBinary { + v = base64.StdEncoding.EncodeToString(byteArray) + } + + // Convert to json string + bt, _ := stateutils.Marshal(v, json.Marshal) + value := string(bt) + + // TTL + var ttlSeconds int + ttl, ttlerr := stateutils.ParseTTL(req.Metadata) + if ttlerr != nil { + return fmt.Errorf("error parsing TTL: %w", ttlerr) + } + if ttl != nil { + ttlSeconds = *ttl + } + + var ( + queryExpiredate string + params []any + ) + + if !req.HasETag() { + params = []any{req.Key, value, isBinary} + } else { + var etag64 uint64 + etag64, err = strconv.ParseUint(*req.ETag, 10, 32) + if err != nil { + return state.NewETagError(state.ETagInvalid, err) + } + params = []any{req.Key, value, isBinary, uint32(etag64)} + } + + if ttlSeconds > 0 { + queryExpiredate = "CURRENT_TIMESTAMP + interval '" + strconv.Itoa(ttlSeconds) + " seconds'" + } else { + queryExpiredate = "NULL" + } + + query := p.setQueryFn(req, SetQueryOptions{ + TableName: p.metadata.TableName, + ExpireDateValue: queryExpiredate, + }) + + result, err := db.Exec(parentCtx, query, params...) + if err != nil { + return err + } + if result.RowsAffected() != 1 { + if req.HasETag() { + return state.NewETagError(state.ETagMismatch, nil) + } + return errors.New("no item was updated") + } + + return nil } -// Set adds/updates an entity on store. -func (p *PostgreSQL) Set(ctx context.Context, req *state.SetRequest) error { - return p.dbaccess.Set(ctx, req) +// Get returns data from the database. If data does not exist for the key an empty state.GetResponse will be returned. +func (p *PostgreSQL) Get(parentCtx context.Context, req *state.GetRequest) (*state.GetResponse, error) { + if req.Key == "" { + return nil, errors.New("missing key in get operation") + } + + query := `SELECT + key, value, isbinary, ` + p.etagColumn + ` AS etag, expiredate + FROM ` + p.metadata.TableName + ` + WHERE + key = $1 + AND (expiredate IS NULL OR expiredate >= CURRENT_TIMESTAMP)` + ctx, cancel := context.WithTimeout(parentCtx, p.metadata.Timeout) + defer cancel() + row := p.db.QueryRow(ctx, query, req.Key) + _, value, etag, expireTime, err := readRow(row) + if err != nil { + // If no rows exist, return an empty response, otherwise return the error. + if errors.Is(err, pgx.ErrNoRows) { + return &state.GetResponse{}, nil + } + return nil, err + } + + resp := &state.GetResponse{ + Data: value, + ETag: etag, + } + + if expireTime != nil { + resp.Metadata = map[string]string{ + state.GetRespMetaKeyTTLExpireTime: expireTime.UTC().Format(time.RFC3339), + } + } + + return resp, nil } -// Multi handles multiple transactions. Implements TransactionalStore. -func (p *PostgreSQL) Multi(ctx context.Context, request *state.TransactionalStateRequest) error { - return p.dbaccess.ExecuteMulti(ctx, request) +func (p *PostgreSQL) BulkGet(parentCtx context.Context, req []state.GetRequest, _ state.BulkGetOpts) ([]state.BulkGetResponse, error) { + if len(req) == 0 { + return []state.BulkGetResponse{}, nil + } + + // Get all keys + keys := make([]string, len(req)) + for i, r := range req { + keys[i] = r.Key + } + + // Execute the query + query := `SELECT + key, value, isbinary, ` + p.etagColumn + ` AS etag, expiredate + FROM ` + p.metadata.TableName + ` + WHERE + key = ANY($1) + AND (expiredate IS NULL OR expiredate >= CURRENT_TIMESTAMP)` + ctx, cancel := context.WithTimeout(parentCtx, p.metadata.Timeout) + defer cancel() + rows, err := p.db.Query(ctx, query, keys) + if err != nil { + // If no rows exist, return an empty response, otherwise return the error. + if errors.Is(err, pgx.ErrNoRows) { + return []state.BulkGetResponse{}, nil + } + return nil, err + } + + // Scan all rows + var n int + res := make([]state.BulkGetResponse, len(req)) + foundKeys := make(map[string]struct{}, len(req)) + for ; rows.Next(); n++ { + if n >= len(req) { + // Sanity check to prevent panics, which should never happen + return nil, fmt.Errorf("query returned more records than expected (expected %d)", len(req)) + } + + r := state.BulkGetResponse{} + var expireTime *time.Time + r.Key, r.Data, r.ETag, expireTime, err = readRow(rows) + if err != nil { + r.Error = err.Error() + } + if expireTime != nil { + r.Metadata = map[string]string{ + state.GetRespMetaKeyTTLExpireTime: expireTime.UTC().Format(time.RFC3339), + } + } + res[n] = r + foundKeys[r.Key] = struct{}{} + } + + // Populate missing keys with empty values + // This is to ensure consistency with the other state stores that implement BulkGet as a loop over Get, and with the Get method + if len(foundKeys) < len(req) { + var ok bool + for _, r := range req { + _, ok = foundKeys[r.Key] + if !ok { + if n >= len(req) { + // Sanity check to prevent panics, which should never happen + return nil, fmt.Errorf("query returned more records than expected (expected %d)", len(req)) + } + res[n] = state.BulkGetResponse{ + Key: r.Key, + } + n++ + } + } + } + + return res[:n], nil +} + +func readRow(row pgx.Row) (key string, value []byte, etagS *string, expireTime *time.Time, err error) { + var ( + isBinary bool + etag pgtype.Int8 + expT pgtype.Timestamp + ) + err = row.Scan(&key, &value, &isBinary, &etag, &expT) + if err != nil { + return key, nil, nil, nil, err + } + + if etag.Valid { + etagS = ptr.Of(strconv.FormatInt(etag.Int64, 10)) + } + + if expT.Valid { + expireTime = &expT.Time + } + + if isBinary { + var ( + s string + data []byte + ) + + err = json.Unmarshal(value, &s) + if err != nil { + return key, nil, nil, nil, fmt.Errorf("failed to unmarshal JSON data: %w", err) + } + + data, err = base64.StdEncoding.DecodeString(s) + if err != nil { + return key, nil, nil, nil, fmt.Errorf("failed to decode base64 data: %w", err) + } + + return key, data, etagS, expireTime, nil + } + + return key, value, etagS, expireTime, nil +} + +// Delete removes an item from the state store. +func (p *PostgreSQL) Delete(ctx context.Context, req *state.DeleteRequest) (err error) { + return p.doDelete(ctx, p.db, req) +} + +func (p *PostgreSQL) doDelete(parentCtx context.Context, db pginterfaces.DBQuerier, req *state.DeleteRequest) (err error) { + if req.Key == "" { + return errors.New("missing key in delete operation") + } + + ctx, cancel := context.WithTimeout(parentCtx, p.metadata.Timeout) + defer cancel() + var result pgconn.CommandTag + if !req.HasETag() { + result, err = db.Exec(ctx, "DELETE FROM "+p.metadata.TableName+" WHERE key = $1", req.Key) + } else { + // Convert req.ETag to uint32 for postgres XID compatibility + var etag64 uint64 + etag64, err = strconv.ParseUint(*req.ETag, 10, 32) + if err != nil { + return state.NewETagError(state.ETagInvalid, err) + } + + result, err = db.Exec(ctx, "DELETE FROM "+p.metadata.TableName+" WHERE key = $1 AND $2 = "+p.etagColumn, req.Key, uint32(etag64)) + } + if err != nil { + return err + } + + rows := result.RowsAffected() + if rows != 1 && req.ETag != nil && *req.ETag != "" { + return state.NewETagError(state.ETagMismatch, nil) + } + + return nil +} + +func (p *PostgreSQL) Multi(parentCtx context.Context, request *state.TransactionalStateRequest) error { + tx, err := p.beginTx(parentCtx) + if err != nil { + return err + } + defer p.rollbackTx(parentCtx, tx, "ExecMulti") + + for _, o := range request.Operations { + switch x := o.(type) { + case state.SetRequest: + err = p.doSet(parentCtx, tx, &x) + if err != nil { + return err + } + + case state.DeleteRequest: + err = p.doDelete(parentCtx, tx, &x) + if err != nil { + return err + } + + default: + return fmt.Errorf("unsupported operation: %s", o.Operation()) + } + } + + ctx, cancel := context.WithTimeout(parentCtx, p.metadata.Timeout) + err = tx.Commit(ctx) + cancel() + if err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + + return nil } // Query executes a query against store. -func (p *PostgreSQL) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) { - return p.dbaccess.Query(ctx, req) +func (p *PostgreSQL) Query(parentCtx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) { + q := &Query{ + query: "", + params: []any{}, + tableName: p.metadata.TableName, + etagColumn: p.etagColumn, + } + qbuilder := query.NewQueryBuilder(q) + if err := qbuilder.BuildQuery(&req.Query); err != nil { + return &state.QueryResponse{}, err + } + data, token, err := q.execute(parentCtx, p.logger, p.db) + if err != nil { + return &state.QueryResponse{}, err + } + + return &state.QueryResponse{ + Results: data, + Token: token, + }, nil } -// Close implements io.Closer. +func (p *PostgreSQL) CleanupExpired() error { + if p.gc != nil { + return p.gc.CleanupExpired() + } + return nil +} + +// Close implements io.Close. func (p *PostgreSQL) Close() error { - if p.dbaccess != nil { - return p.dbaccess.Close() + if p.db != nil { + p.db.Close() + p.db = nil } + + if p.gc != nil { + return p.gc.Close() + } + return nil } -// Returns the dbaccess property. -// This method is used in tests. -func (p *PostgreSQL) GetDBAccess() dbAccess { - return p.dbaccess +// GetCleanupInterval returns the cleanupInterval property. +// This is primarily used for tests. +func (p *PostgreSQL) GetCleanupInterval() *time.Duration { + return p.metadata.CleanupInterval +} + +// Internal function that begins a transaction. +func (p *PostgreSQL) beginTx(parentCtx context.Context) (pgx.Tx, error) { + ctx, cancel := context.WithTimeout(parentCtx, p.metadata.Timeout) + tx, err := p.db.Begin(ctx) + cancel() + if err != nil { + return nil, fmt.Errorf("failed to begin transaction: %w", err) + } + return tx, nil +} + +// Internal function that rolls back a transaction. +// Normally called as a deferred function in methods that use transactions. +// In case of errors, they are logged but not actioned upon. +func (p *PostgreSQL) rollbackTx(parentCtx context.Context, tx pgx.Tx, methodName string) { + rollbackCtx, rollbackCancel := context.WithTimeout(parentCtx, p.metadata.Timeout) + rollbackErr := tx.Rollback(rollbackCtx) + rollbackCancel() + if rollbackErr != nil && !errors.Is(rollbackErr, pgx.ErrTxClosed) { + p.logger.Errorf("Failed to rollback transaction in %s: %v", methodName, rollbackErr) + } } func (p *PostgreSQL) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { - metadataStruct := postgresMetadataStruct{} + metadataStruct := pgMetadata{} metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.StateStoreType) return } diff --git a/internal/component/postgresql/postgresql_query.go b/internal/component/postgresql/postgresql_query.go index 0308b26fd3..abfb165ab6 100644 --- a/internal/component/postgresql/postgresql_query.go +++ b/internal/component/postgresql/postgresql_query.go @@ -20,6 +20,7 @@ import ( "strconv" "strings" + pginterfaces "github.com/dapr/components-contrib/internal/component/postgresql/interfaces" "github.com/dapr/components-contrib/state" "github.com/dapr/components-contrib/state/query" "github.com/dapr/kit/logger" @@ -140,7 +141,7 @@ func (q *Query) Finalize(filters string, qq *query.Query) error { return nil } -func (q *Query) execute(ctx context.Context, logger logger.Logger, db dbquerier) ([]state.QueryItem, string, error) { +func (q *Query) execute(ctx context.Context, logger logger.Logger, db pginterfaces.DBQuerier) ([]state.QueryItem, string, error) { rows, err := db.Query(ctx, q.query, q.params...) if err != nil { return nil, "", err diff --git a/internal/component/postgresql/postgresql_test.go b/internal/component/postgresql/postgresql_test.go index 971b933349..f2170b4203 100644 --- a/internal/component/postgresql/postgresql_test.go +++ b/internal/component/postgresql/postgresql_test.go @@ -12,104 +12,224 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + package postgresql import ( "context" + "encoding/json" "testing" + "time" + "github.com/google/uuid" + pgxmock "github.com/pashagolub/pgxmock/v2" "github.com/stretchr/testify/assert" - "github.com/dapr/components-contrib/metadata" + pginterfaces "github.com/dapr/components-contrib/internal/component/postgresql/interfaces" "github.com/dapr/components-contrib/state" "github.com/dapr/kit/logger" ) -const ( - fakeConnectionString = "not a real connection" -) - -// Fake implementation of interface postgressql.dbaccess. -type fakeDBaccess struct { - logger logger.Logger - initExecuted bool - setExecuted bool - getExecuted bool - deleteExecuted bool +type mocks struct { + db pgxmock.PgxPoolIface + pg *PostgreSQL } -func (m *fakeDBaccess) Init(ctx context.Context, metadata state.Metadata) error { - m.initExecuted = true - - return nil +type fakeItem struct { + Color string } -func (m *fakeDBaccess) Set(ctx context.Context, req *state.SetRequest) error { - m.setExecuted = true +func TestMultiWithNoRequests(t *testing.T) { + // Arrange + m, _ := mockDatabase(t) + defer m.db.Close() - return nil -} + m.db.ExpectBegin() + m.db.ExpectCommit() + // There's also a rollback called after a commit, which is expected and will not have effect + m.db.ExpectRollback() + + var operations []state.TransactionalStateOperation -func (m *fakeDBaccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { - m.getExecuted = true + // Act + err := m.pg.Multi(context.Background(), &state.TransactionalStateRequest{ + Operations: operations, + }) - return nil, nil + // Assert + assert.NoError(t, err) } -func (m *fakeDBaccess) BulkGet(ctx context.Context, req []state.GetRequest) ([]state.BulkGetResponse, error) { - return nil, nil +func TestValidSetRequest(t *testing.T) { + // Arrange + m, _ := mockDatabase(t) + defer m.db.Close() + + setReq := createSetRequest() + operations := []state.TransactionalStateOperation{setReq} + val, _ := json.Marshal(setReq.Value) + + m.db.ExpectBegin() + m.db.ExpectExec("INSERT INTO"). + WithArgs(setReq.Key, string(val), false). + WillReturnResult(pgxmock.NewResult("INSERT", 1)) + m.db.ExpectCommit() + // There's also a rollback called after a commit, which is expected and will not have effect + m.db.ExpectRollback() + + // Act + err := m.pg.Multi(context.Background(), &state.TransactionalStateRequest{ + Operations: operations, + }) + + // Assert + assert.NoError(t, err) } -func (m *fakeDBaccess) Delete(ctx context.Context, req *state.DeleteRequest) error { - m.deleteExecuted = true +func TestInvalidMultiSetRequestNoKey(t *testing.T) { + // Arrange + m, _ := mockDatabase(t) + defer m.db.Close() - return nil -} + m.db.ExpectBegin() + m.db.ExpectRollback() -func (m *fakeDBaccess) ExecuteMulti(ctx context.Context, req *state.TransactionalStateRequest) error { - return nil + operations := []state.TransactionalStateOperation{ + state.SetRequest{Value: "value1"}, // Set request without key is not valid for Upsert operation + } + + // Act + err := m.pg.Multi(context.Background(), &state.TransactionalStateRequest{ + Operations: operations, + }) + + // Assert + assert.Error(t, err) } -func (m *fakeDBaccess) Query(ctx context.Context, req *state.QueryRequest) (*state.QueryResponse, error) { - return nil, nil +func TestValidMultiDeleteRequest(t *testing.T) { + // Arrange + m, _ := mockDatabase(t) + defer m.db.Close() + + deleteReq := createDeleteRequest() + operations := []state.TransactionalStateOperation{deleteReq} + + m.db.ExpectBegin() + m.db.ExpectExec("DELETE FROM"). + WithArgs(deleteReq.Key). + WillReturnResult(pgxmock.NewResult("DELETE", 1)) + m.db.ExpectCommit() + // There's also a rollback called after a commit, which is expected and will not have effect + m.db.ExpectRollback() + + // Act + err := m.pg.Multi(context.Background(), &state.TransactionalStateRequest{ + Operations: operations, + }) + + // Assert + assert.NoError(t, err) } -func (m *fakeDBaccess) Close() error { - return nil +func TestInvalidMultiDeleteRequestNoKey(t *testing.T) { + // Arrange + m, _ := mockDatabase(t) + defer m.db.Close() + + m.db.ExpectBegin() + m.db.ExpectRollback() + + operations := []state.TransactionalStateOperation{state.DeleteRequest{}} // Delete request without key is not valid for Delete operation + + // Act + err := m.pg.Multi(context.Background(), &state.TransactionalStateRequest{ + Operations: operations, + }) + + // Assert + assert.Error(t, err) } -// Proves that the Init method runs the init method. -func TestInitRunsDBAccessInit(t *testing.T) { - t.Parallel() - _, fake := createPostgreSQLWithFake(t) - assert.True(t, fake.initExecuted) +func TestMultiOperationOrder(t *testing.T) { + // Arrange + m, _ := mockDatabase(t) + defer m.db.Close() + + operations := []state.TransactionalStateOperation{ + state.SetRequest{Key: "key1", Value: "value1"}, + state.DeleteRequest{Key: "key1"}, + } + + m.db.ExpectBegin() + m.db.ExpectExec("INSERT INTO"). + WithArgs("key1", `"value1"`, false). + WillReturnResult(pgxmock.NewResult("INSERT", 1)) + m.db.ExpectExec("DELETE FROM"). + WithArgs("key1"). + WillReturnResult(pgxmock.NewResult("DELETE", 1)) + m.db.ExpectCommit() + // There's also a rollback called after a commit, which is expected and will not have effect + m.db.ExpectRollback() + + // Act + err := m.pg.Multi(context.Background(), &state.TransactionalStateRequest{ + Operations: operations, + }) + + // Assert + assert.NoError(t, err) } -func createPostgreSQLWithFake(t *testing.T) (*PostgreSQL, *fakeDBaccess) { - pgs := createPostgreSQL(t) - fake := pgs.dbaccess.(*fakeDBaccess) +func createSetRequest() state.SetRequest { + return state.SetRequest{ + Key: randomKey(), + Value: randomJSON(), + } +} - return pgs, fake +func createDeleteRequest() state.DeleteRequest { + return state.DeleteRequest{ + Key: randomKey(), + } } -func createPostgreSQL(t *testing.T) *PostgreSQL { +func mockDatabase(t *testing.T) (*mocks, error) { logger := logger.NewLogger("test") - dba := &fakeDBaccess{ - logger: logger, + db, err := pgxmock.NewPool() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } - pgs := newPostgreSQLStateStore(logger, dba) - assert.NotNil(t, pgs) - - metadata := &state.Metadata{ - Base: metadata.Base{Properties: map[string]string{"connectionString": fakeConnectionString}}, + dba := &PostgreSQL{ + metadata: pgMetadata{ + TableName: "state", + Timeout: 30 * time.Second, + }, + logger: logger, + db: db, + migrateFn: func(context.Context, pginterfaces.PGXPoolConn, MigrateOptions) error { + return nil + }, + setQueryFn: func(*state.SetRequest, SetQueryOptions) string { + return `INSERT INTO state + (key, value, isbinary, expiredate) + VALUES + ($1, $2, $3, NULL)` + }, } - err := pgs.Init(context.Background(), *metadata) + return &mocks{ + db: db, + pg: dba, + }, err +} - assert.Nil(t, err) - assert.NotNil(t, pgs.dbaccess) +func randomKey() string { + return uuid.New().String() +} - return pgs +func randomJSON() *fakeItem { + return &fakeItem{Color: randomKey()} } diff --git a/internal/component/sql/adapter.go b/internal/component/sql/adapter.go new file mode 100644 index 0000000000..7bdc000b5f --- /dev/null +++ b/internal/component/sql/adapter.go @@ -0,0 +1,181 @@ +/* +Copyright 2023 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sql + +import ( + "context" + "database/sql" + "errors" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +// AdaptDatabaseSQLConn returns a databaseConn based on a database/sql connection. +// +// Note: when using transactions with database/sql, the context bassed to Begin impacts the entire transaction. +// Canceling the context automatically rolls back the transaction. +func AdaptDatabaseSQLConn(db DatabaseSQLConn) DatabaseConn { + return &DatabaseSQLAdapter{db} +} + +// AdaptPgxConn returns a databaseConn based on a pgx connection. +// +// Note: when using transactions with pgx, the context bassed to Begin impacts the creation of the transaction only. +func AdaptPgxConn(db PgxConn) DatabaseConn { + return &PgxAdapter{db} +} + +// DatabaseSQLConn is the interface for connections that use database/sql. +type DatabaseSQLConn interface { + BeginTx(context.Context, *sql.TxOptions) (*sql.Tx, error) + QueryContext(context.Context, string, ...any) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...any) *sql.Row + ExecContext(context.Context, string, ...any) (sql.Result, error) +} + +// PgxConn is the interface for connections that use pgx. +type PgxConn interface { + Begin(context.Context) (pgx.Tx, error) + Query(context.Context, string, ...any) (pgx.Rows, error) + QueryRow(context.Context, string, ...any) pgx.Row + Exec(context.Context, string, ...any) (pgconn.CommandTag, error) +} + +// DatabaseConn is the interface matched by all adapters. +type DatabaseConn interface { + Begin(context.Context) (databaseConnTx, error) + QueryRow(context.Context, string, ...any) databaseConnRow + Exec(context.Context, string, ...any) (int64, error) + IsNoRowsError(err error) bool +} + +type databaseConnRow interface { + Scan(...any) error +} + +type databaseConnTx interface { + Commit(context.Context) error + Rollback(context.Context) error + QueryRow(context.Context, string, ...any) databaseConnRow + Exec(context.Context, string, ...any) (int64, error) +} + +// DatabaseSQLAdapter is an adapter for database/sql connections. +type DatabaseSQLAdapter struct { + conn DatabaseSQLConn +} + +func (sqla *DatabaseSQLAdapter) Begin(ctx context.Context) (databaseConnTx, error) { + tx, err := sqla.conn.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + + return &databaseSQLTxAdapter{tx}, nil +} + +func (sqla *DatabaseSQLAdapter) Exec(ctx context.Context, query string, args ...any) (int64, error) { + res, err := sqla.conn.ExecContext(ctx, query, args...) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func (sqla *DatabaseSQLAdapter) QueryRow(ctx context.Context, query string, args ...any) databaseConnRow { + return sqla.conn.QueryRowContext(ctx, query, args...) +} + +func (sqla *DatabaseSQLAdapter) IsNoRowsError(err error) bool { + return errors.Is(err, sql.ErrNoRows) +} + +type databaseSQLTxAdapter struct { + tx *sql.Tx +} + +func (sqltx *databaseSQLTxAdapter) Rollback(_ context.Context) error { + return sqltx.tx.Rollback() +} + +func (sqltx *databaseSQLTxAdapter) Commit(_ context.Context) error { + return sqltx.tx.Commit() +} + +func (sqltx *databaseSQLTxAdapter) Exec(ctx context.Context, query string, args ...any) (int64, error) { + res, err := sqltx.tx.ExecContext(ctx, query, args...) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func (sqltx *databaseSQLTxAdapter) QueryRow(ctx context.Context, query string, args ...any) databaseConnRow { + return sqltx.tx.QueryRowContext(ctx, query, args...) +} + +// PgxAdapter is an adapter for pgx connections. +type PgxAdapter struct { + conn PgxConn +} + +func (pga *PgxAdapter) Begin(ctx context.Context) (databaseConnTx, error) { + tx, err := pga.conn.Begin(ctx) + if err != nil { + return nil, err + } + + return &pgxTxAdapter{tx}, nil +} + +func (pga *PgxAdapter) Exec(ctx context.Context, query string, args ...any) (int64, error) { + res, err := pga.conn.Exec(ctx, query, args...) + if err != nil { + return 0, err + } + return res.RowsAffected(), nil +} + +func (pga *PgxAdapter) QueryRow(ctx context.Context, query string, args ...any) databaseConnRow { + return pga.conn.QueryRow(ctx, query, args...) +} + +func (pga *PgxAdapter) IsNoRowsError(err error) bool { + return errors.Is(err, pgx.ErrNoRows) +} + +type pgxTxAdapter struct { + tx pgx.Tx +} + +func (pgtx *pgxTxAdapter) Rollback(ctx context.Context) error { + return pgtx.tx.Rollback(ctx) +} + +func (pgtx *pgxTxAdapter) Commit(ctx context.Context) error { + return pgtx.tx.Commit(ctx) +} + +func (pgtx *pgxTxAdapter) Exec(ctx context.Context, query string, args ...any) (int64, error) { + res, err := pgtx.tx.Exec(ctx, query, args...) + if err != nil { + return 0, err + } + return res.RowsAffected(), nil +} + +func (pgtx *pgxTxAdapter) QueryRow(ctx context.Context, query string, args ...any) databaseConnRow { + return pgtx.tx.QueryRow(ctx, query, args...) +} diff --git a/internal/component/sql/cleanup.go b/internal/component/sql/cleanup.go index ab57f4321f..8f66d2a4da 100644 --- a/internal/component/sql/cleanup.go +++ b/internal/component/sql/cleanup.go @@ -15,7 +15,6 @@ package sql import ( "context" - "database/sql" "errors" "fmt" "io" @@ -23,9 +22,6 @@ import ( "sync/atomic" "time" - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgconn" - "github.com/dapr/kit/logger" ) @@ -40,12 +36,8 @@ type GCOptions struct { // Query that must atomically update the "last cleanup time" in the metadata table, but only if the garbage collector hasn't run already. // The caller will check the nuber of affected rows. If zero, it assumes that the GC has ran too recently, and will not proceed to delete expired records. // The query receives one parameter that is the last cleanup interval, in milliseconds. - UpdateLastCleanupQuery string - - // Name of the parameter passed to the UpdateLeastCleanupQuery query, to use named parameters (via `sql.Named`). The parameter is the time interval. - // If empty, assumes the parameter is positional and not named. - // This is ignored when using the pgx querier. - UpdateLastCleanupQueryParameterName string + // The function must return both the query and the argument. + UpdateLastCleanupQuery func(arg any) (string, any) // Query that performs the cleanup of all expired rows. DeleteExpiredValuesQuery string @@ -53,32 +45,17 @@ type GCOptions struct { // Interval to perfm the cleanup. CleanupInterval time.Duration - // Database connection when using pgx. - DBPgx PgxConn - // Database connection when using database/sql. - DBSql DatabaseSQLConn -} - -// Interface for connections that use pgx. -type PgxConn interface { - Begin(context.Context) (pgx.Tx, error) - Exec(context.Context, string, ...any) (pgconn.CommandTag, error) -} - -// Interface for connections that use database/sql. -type DatabaseSQLConn interface { - BeginTx(context.Context, *sql.TxOptions) (*sql.Tx, error) - ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + // Database connection. + // Must be adapted using AdaptDatabaseSQLConn or AdaptPgxConn. + DB DatabaseConn } type gc struct { log logger.Logger - updateLastCleanupQuery string - ulcqParamName string + updateLastCleanupQuery func(arg any) (string, any) deleteExpiredValuesQuery string cleanupInterval time.Duration - dbPgx PgxConn - dbSQL DatabaseSQLConn + db DatabaseConn closed atomic.Bool closedCh chan struct{} @@ -90,21 +67,16 @@ func ScheduleGarbageCollector(opts GCOptions) (GarbageCollector, error) { return new(gcNoOp), nil } - if opts.DBPgx == nil && opts.DBSql == nil { - return nil, errors.New("either DBPgx or DBSql must be provided") - } - if opts.DBPgx != nil && opts.DBSql != nil { - return nil, errors.New("only one of DBPgx or DBSql must be provided") + if opts.DB == nil { + return nil, errors.New("property DB must be provided") } gc := &gc{ log: opts.Logger, updateLastCleanupQuery: opts.UpdateLastCleanupQuery, - ulcqParamName: opts.UpdateLastCleanupQueryParameterName, deleteExpiredValuesQuery: opts.DeleteExpiredValuesQuery, cleanupInterval: opts.CleanupInterval, - dbPgx: opts.DBPgx, - dbSQL: opts.DBSql, + db: opts.DB, closedCh: make(chan struct{}), } @@ -167,51 +139,19 @@ func (g *gc) CleanupExpired() error { return nil } - var ( - tx pgx.Tx - txwc *sql.Tx - ) - - if g.dbPgx != nil { - tx, err = g.dbPgx.Begin(ctx) - if err != nil { - return fmt.Errorf("failed to start transaction: %w", err) - } - defer tx.Rollback(ctx) - } else { - txwc, err = g.dbSQL.BeginTx(ctx, nil) - if err != nil { - return fmt.Errorf("failed to start transaction: %w", err) - } - defer txwc.Rollback() + tx, err := g.db.Begin(ctx) + if err != nil { + return fmt.Errorf("failed to start transaction: %w", err) } + defer tx.Rollback(ctx) - var rowsAffected int64 - if tx != nil { - var res pgconn.CommandTag - res, err = tx.Exec(ctx, g.deleteExpiredValuesQuery) - if err != nil { - return fmt.Errorf("failed to execute query: %w", err) - } - rowsAffected = res.RowsAffected() - } else { - var res sql.Result - res, err = txwc.ExecContext(ctx, g.deleteExpiredValuesQuery) - if err != nil { - return fmt.Errorf("failed to execute query: %w", err) - } - rowsAffected, err = res.RowsAffected() - if err != nil { - return fmt.Errorf("failed to get rows affected: %w", err) - } + rowsAffected, err := tx.Exec(ctx, g.deleteExpiredValuesQuery) + if err != nil { + return fmt.Errorf("failed to execute query: %w", err) } // Commit - if tx != nil { - err = tx.Commit(ctx) - } else { - err = txwc.Commit() - } + err = tx.Commit(ctx) if err != nil { return fmt.Errorf("failed to commit transaction: %w", err) } @@ -223,31 +163,13 @@ func (g *gc) CleanupExpired() error { // updateLastCleanup sets the 'last-cleanup' value only if it's less than cleanupInterval. // Returns true if the row was updated, which means that the cleanup can proceed. func (g *gc) updateLastCleanup(ctx context.Context) (bool, error) { - var n int64 // Query parameter: interval in ms // Subtract 100ms for some buffer - var param any = (g.cleanupInterval.Milliseconds() - 100) + query, param := g.updateLastCleanupQuery(g.cleanupInterval.Milliseconds() - 100) - if g.dbPgx != nil { - res, err := g.dbPgx.Exec(ctx, g.updateLastCleanupQuery, param) - if err != nil { - return false, fmt.Errorf("error updating last cleanup time: %w", err) - } - n = res.RowsAffected() - } else { - // Use named parameters if we need to - if g.ulcqParamName != "" { - param = sql.Named(g.ulcqParamName, param) - } - res, err := g.dbSQL.ExecContext(ctx, g.updateLastCleanupQuery, param) - if err != nil { - return false, fmt.Errorf("error updating last cleanup time: %w", err) - } - - n, err = res.RowsAffected() - if err != nil { - return false, fmt.Errorf("failed to retrieve affected row count: %w", err) - } + n, err := g.db.Exec(ctx, query, param) + if err != nil { + return false, fmt.Errorf("error updating last cleanup time: %w", err) } return n > 0, nil diff --git a/internal/component/sql/migrations.go b/internal/component/sql/migrations.go new file mode 100644 index 0000000000..161b9ae660 --- /dev/null +++ b/internal/component/sql/migrations.go @@ -0,0 +1,100 @@ +/* +Copyright 2023 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sql + +import ( + "context" + "fmt" + "strconv" + "time" + + "github.com/dapr/kit/logger" +) + +// MigrationOptions contains options for the Migrate function. +type MigrationOptions struct { + // Logger + Logger logger.Logger + + // List of migrations to execute. + // Each item is a function that receives a context and the database connection, and can execute queries. + Migrations []MigrationFn + + // EnsureMetadataTable ensures that the metadata table exists. + EnsureMetadataTable func(ctx context.Context) error + + // GetVersionQuery is the query to execute to load the latest migration version. + GetVersionQuery string + + // UpdateVersionQuery is a function that returns the query to update the migration version, and the arg. + UpdateVersionQuery func(version string) (string, any) +} + +type ( + MigrationFn = func(ctx context.Context) error + MigrationTeardownFn = func() error +) + +// Migrate performs database migrations. +func Migrate(ctx context.Context, db DatabaseConn, opts MigrationOptions) error { + opts.Logger.Debug("Migrate: start") + + // Ensure that the metadata table exists + opts.Logger.Debug("Migrate: ensure metadata table exists") + err := opts.EnsureMetadataTable(ctx) + if err != nil { + return fmt.Errorf("failed to ensure metadata table exists: %w", err) + } + + // Select the migration level + opts.Logger.Debug("Migrate: load current migration level") + var ( + migrationLevelStr string + migrationLevel int + ) + queryCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + err = db.QueryRow(queryCtx, opts.GetVersionQuery).Scan(&migrationLevelStr) + cancel() + if db.IsNoRowsError(err) { + // If there's no row... + migrationLevel = 0 + } else if err != nil { + return fmt.Errorf("failed to read migration level: %w", err) + } else { + migrationLevel, err = strconv.Atoi(migrationLevelStr) + if err != nil || migrationLevel < 0 { + return fmt.Errorf("invalid migration level found in metadata table: %s", migrationLevelStr) + } + } + opts.Logger.Debug("Migrate: current migration level: %d", migrationLevel) + + // Perform the migrations + for i := migrationLevel; i < len(opts.Migrations); i++ { + opts.Logger.Infof("Performing migration %d", i+1) + err = opts.Migrations[i](ctx) + if err != nil { + return fmt.Errorf("failed to perform migration %d: %w", i, err) + } + + query, arg := opts.UpdateVersionQuery(strconv.Itoa(i + 1)) + queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second) + _, err = db.Exec(queryCtx, query, arg) + cancel() + if err != nil { + return fmt.Errorf("failed to update migration level in metadata table: %w", err) + } + } + + return nil +} diff --git a/internal/component/sql/migrations/postgres/postgres_migrations.go b/internal/component/sql/migrations/postgres/postgres_migrations.go new file mode 100644 index 0000000000..3eec1428ab --- /dev/null +++ b/internal/component/sql/migrations/postgres/postgres_migrations.go @@ -0,0 +1,165 @@ +/* +Copyright 2021 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package pgmigrations + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/jackc/pgx/v5" + + pginterfaces "github.com/dapr/components-contrib/internal/component/postgresql/interfaces" + sqlinternal "github.com/dapr/components-contrib/internal/component/sql" + "github.com/dapr/kit/logger" +) + +// Migrations performs migrations for the database schema +type Migrations struct { + DB pginterfaces.PGXPoolConn + Logger logger.Logger + MetadataTableName string + MetadataKey string +} + +// Perform the required migrations +func (m Migrations) Perform(ctx context.Context, migrationFns []sqlinternal.MigrationFn) error { + // Use an advisory lock (with an arbitrary number) to ensure that no one else is performing migrations at the same time + // This is the only way to also ensure we are not running multiple "CREATE TABLE IF NOT EXISTS" at the exact same time + // See: https://www.postgresql.org/message-id/CA+TgmoZAdYVtwBfp1FL2sMZbiHCWT4UPrzRLNnX1Nb30Ku3-gg@mail.gmail.com + const lockID = 42 + + // Long timeout here as this query may block + m.Logger.Debug("Acquiring advisory lock pre-migration") + queryCtx, cancel := context.WithTimeout(ctx, time.Minute) + _, err := m.DB.Exec(queryCtx, "SELECT pg_advisory_lock($1)", lockID) + cancel() + if err != nil { + return fmt.Errorf("faild to acquire advisory lock: %w", err) + } + m.Logger.Debug("Successfully acquired advisory lock") + + // Release the lock + defer func() { + m.Logger.Debug("Releasing advisory lock") + queryCtx, cancel = context.WithTimeout(ctx, time.Minute) + _, err = m.DB.Exec(queryCtx, "SELECT pg_advisory_unlock($1)", lockID) + cancel() + if err != nil { + // Panicking here, as this forcibly closes the session and thus ensures we are not leaving locks hanging around + m.Logger.Fatalf("Failed to release advisory lock: %v", err) + } + }() + + return sqlinternal.Migrate(ctx, sqlinternal.AdaptPgxConn(m.DB), sqlinternal.MigrationOptions{ + Logger: m.Logger, + // Yes, we are using fmt.Sprintf for adding a value in a query. + // This comes from a constant hardcoded at development-time, and cannot be influenced by users. So, no risk of SQL injections here. + GetVersionQuery: fmt.Sprintf(`SELECT value FROM %s WHERE key = '%s'`, m.MetadataTableName, m.MetadataKey), + UpdateVersionQuery: func(version string) (string, any) { + return fmt.Sprintf(`INSERT INTO %s (key, value) VALUES ('%s', $1) ON CONFLICT (key) DO UPDATE SET value = $1`, m.MetadataTableName, m.MetadataKey), + version + }, + EnsureMetadataTable: func(ctx context.Context) error { + // Check if the metadata table exists, which we also use to store the migration level + queryCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + exists, _, _, err := m.TableExists(queryCtx, m.MetadataTableName) + cancel() + if err != nil { + return err + } + + // If the table doesn't exist, create it + if !exists { + queryCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + err = m.CreateMetadataTable(queryCtx) + cancel() + if err != nil { + return err + } + } + + return nil + }, + Migrations: migrationFns, + }) +} + +func (m Migrations) CreateMetadataTable(ctx context.Context) error { + m.Logger.Infof("Creating metadata table '%s'", m.MetadataTableName) + // Add an "IF NOT EXISTS" in case another Dapr sidecar is creating the same table at the same time + // In the next step we'll acquire a lock so there won't be issues with concurrency + _, err := m.DB.Exec(ctx, fmt.Sprintf( + `CREATE TABLE IF NOT EXISTS %s ( + key text NOT NULL PRIMARY KEY, + value text NOT NULL + )`, + m.MetadataTableName, + )) + if err != nil { + return fmt.Errorf("failed to create metadata table: %w", err) + } + return nil +} + +// TableExists checks if the table exists, and returns true and the name of the table and schema. +func (m Migrations) TableExists(ctx context.Context, tableName string) (exists bool, schema string, table string, err error) { + table, schema, err = m.TableSchemaName(tableName) + if err != nil { + return false, "", "", err + } + + if schema == "" { + err = m.DB.QueryRow( + ctx, + `SELECT table_name, table_schema + FROM information_schema.tables + WHERE table_name = $1`, + table, + ). + Scan(&table, &schema) + } else { + err = m.DB.QueryRow( + ctx, + `SELECT table_name, table_schema + FROM information_schema.tables + WHERE table_schema = $1 AND table_name = $2`, + schema, table, + ). + Scan(&table, &schema) + } + + if err != nil && errors.Is(err, pgx.ErrNoRows) { + return false, "", "", nil + } else if err != nil { + return false, "", "", fmt.Errorf("failed to check if table '%s' exists: %w", tableName, err) + } + return true, schema, table, nil +} + +// TableSchemaName parses the table name. +// If the table name includes a schema (e.g. `schema.table`, returns the two parts separately). +func (m Migrations) TableSchemaName(tableName string) (table string, schema string, err error) { + parts := strings.Split(tableName, ".") + switch len(parts) { + case 1: + return parts[0], "", nil + case 2: + return parts[1], parts[0], nil + default: + return "", "", errors.New("invalid table name: must be in the format 'table' or 'schema.table'") + } +} diff --git a/internal/component/sql/migrations/sqlite/sqlite_migrations.go b/internal/component/sql/migrations/sqlite/sqlite_migrations.go new file mode 100644 index 0000000000..f98c87b260 --- /dev/null +++ b/internal/component/sql/migrations/sqlite/sqlite_migrations.go @@ -0,0 +1,156 @@ +/* +Copyright 2023 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sqlitemigrations + +import ( + "context" + "database/sql" + "fmt" + "time" + + sqlinternal "github.com/dapr/components-contrib/internal/component/sql" + "github.com/dapr/kit/logger" +) + +// Migrations performs migrations for the database schema +type Migrations struct { + Pool *sql.DB + Logger logger.Logger + MetadataTableName string + MetadataKey string + + conn *sql.Conn +} + +// Perform the required migrations +func (m *Migrations) Perform(ctx context.Context, migrationFns []sqlinternal.MigrationFn) (err error) { + // Get a connection so we can create a transaction + m.conn, err = m.Pool.Conn(ctx) + if err != nil { + return fmt.Errorf("failed to get a connection from the pool: %w", err) + } + defer m.conn.Close() + + // Begin an exclusive transaction + // We can't use Begin because that doesn't allow us setting the level of transaction + queryCtx, cancel := context.WithTimeout(ctx, time.Minute) + _, err = m.conn.ExecContext(queryCtx, "BEGIN EXCLUSIVE TRANSACTION") + cancel() + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + + // Rollback the transaction in a deferred statement to catch errors + success := false + defer func() { + if success { + return + } + queryCtx, cancel = context.WithTimeout(ctx, time.Minute) + _, err = m.conn.ExecContext(queryCtx, "ROLLBACK TRANSACTION") + cancel() + if err != nil { + // Panicking here, as this forcibly closes the session and thus ensures we are not leaving locks hanging around + m.Logger.Fatalf("Failed to rollback transaction: %v", err) + } + }() + + // Perform the migrations + err = sqlinternal.Migrate(ctx, sqlinternal.AdaptDatabaseSQLConn(m.conn), sqlinternal.MigrationOptions{ + Logger: m.Logger, + // Yes, we are using fmt.Sprintf for adding a value in a query. + // This comes from a constant hardcoded at development-time, and cannot be influenced by users. So, no risk of SQL injections here. + GetVersionQuery: fmt.Sprintf(`SELECT value FROM %s WHERE key = '%s'`, m.MetadataTableName, m.MetadataKey), + UpdateVersionQuery: func(version string) (string, any) { + return fmt.Sprintf(`REPLACE INTO %s (key, value) VALUES ('%s', ?)`, m.MetadataTableName, m.MetadataKey), + version + }, + EnsureMetadataTable: func(ctx context.Context) error { + // Check if the metadata table exists, which we also use to store the migration level + queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second) + var exists bool + exists, err = m.tableExists(queryCtx, m.conn, m.MetadataTableName) + cancel() + if err != nil { + return fmt.Errorf("failed to check if the metadata table exists: %w", err) + } + + // If the table doesn't exist, create it + if !exists { + queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second) + err = m.createMetadataTable(queryCtx, m.conn) + cancel() + if err != nil { + return fmt.Errorf("failed to create metadata table: %w", err) + } + } + + return nil + }, + Migrations: migrationFns, + }) + if err != nil { + return err + } + + // Commit the transaction + queryCtx, cancel = context.WithTimeout(ctx, time.Minute) + _, err = m.conn.ExecContext(queryCtx, "COMMIT TRANSACTION") + cancel() + if err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + + // Set success to true so we don't also run a rollback + success = true + + return nil +} + +// GetConn returns the active connection. +func (m *Migrations) GetConn() *sql.Conn { + return m.conn +} + +// Returns true if a table exists +func (m Migrations) tableExists(parentCtx context.Context, db sqlinternal.DatabaseSQLConn, tableName string) (bool, error) { + ctx, cancel := context.WithTimeout(parentCtx, 30*time.Second) + defer cancel() + + var exists string + // Returns 1 or 0 as a string if the table exists or not. + const q = `SELECT EXISTS ( + SELECT name FROM sqlite_master WHERE type='table' AND name = ? + ) AS 'exists'` + err := db.QueryRowContext(ctx, q, m.MetadataTableName). + Scan(&exists) + return exists == "1", err +} + +func (m Migrations) createMetadataTable(ctx context.Context, db sqlinternal.DatabaseSQLConn) error { + m.Logger.Infof("Creating metadata table '%s' if it doesn't exist", m.MetadataTableName) + // Add an "IF NOT EXISTS" in case another Dapr sidecar is creating the same table at the same time + // In the next step we'll acquire a lock so there won't be issues with concurrency + _, err := db.ExecContext(ctx, fmt.Sprintf( + `CREATE TABLE IF NOT EXISTS %s ( + key text NOT NULL PRIMARY KEY, + value text NOT NULL + )`, + m.MetadataTableName, + )) + if err != nil { + return fmt.Errorf("failed to create metadata table: %w", err) + } + return nil +} diff --git a/state/cockroachdb/cockroachdb.go b/state/cockroachdb/cockroachdb.go index 4cd6cfa09e..6c7291c2bf 100644 --- a/state/cockroachdb/cockroachdb.go +++ b/state/cockroachdb/cockroachdb.go @@ -18,6 +18,7 @@ import ( "fmt" "github.com/dapr/components-contrib/internal/component/postgresql" + pginterfaces "github.com/dapr/components-contrib/internal/component/postgresql/interfaces" "github.com/dapr/components-contrib/state" "github.com/dapr/kit/logger" ) @@ -68,7 +69,7 @@ WHERE }) } -func ensureTables(ctx context.Context, db postgresql.PGXPoolConn, opts postgresql.MigrateOptions) error { +func ensureTables(ctx context.Context, db pginterfaces.PGXPoolConn, opts postgresql.MigrateOptions) error { exists, err := tableExists(ctx, db, opts.StateTableName) if err != nil { return err @@ -122,7 +123,7 @@ func ensureTables(ctx context.Context, db postgresql.PGXPoolConn, opts postgresq return nil } -func tableExists(ctx context.Context, db postgresql.PGXPoolConn, tableName string) (bool, error) { +func tableExists(ctx context.Context, db pginterfaces.PGXPoolConn, tableName string) (bool, error) { exists := false err := db.QueryRow(ctx, "SELECT EXISTS (SELECT * FROM pg_tables where tablename = $1)", tableName).Scan(&exists) return exists, err diff --git a/state/cockroachdb/cockroachdb_integration_test.go b/state/cockroachdb/cockroachdb_integration_test.go index 7fb82365d5..3c08e2f745 100644 --- a/state/cockroachdb/cockroachdb_integration_test.go +++ b/state/cockroachdb/cockroachdb_integration_test.go @@ -72,10 +72,7 @@ func TestCockroachDBIntegration(t *testing.T) { t.Run("Create table succeeds", func(t *testing.T) { t.Parallel() - dbAccess, ok := pgs.GetDBAccess().(*postgresql.PostgresDBAccess) - assert.True(t, ok) - - testCreateTable(t, dbAccess) + testCreateTable(t, pgs.GetDB()) }) t.Run("Get Set Delete one item", func(t *testing.T) { @@ -180,20 +177,17 @@ func setGetUpdateDeleteOneItem(t *testing.T, pgs *postgresql.PostgreSQL) { } // testCreateTable tests the ability to create the state table. -func testCreateTable(t *testing.T, dba *postgresql.PostgresDBAccess) { +func testCreateTable(t *testing.T, db *pgxpool.Pool) { t.Helper() + const tableName = "test_state" ctx := context.Background() - tableName := "test_state" - - db := dba.GetDB() - // Drop the table if it already exists. exists, err := tableExists(ctx, db, tableName) require.NoError(t, err) if exists { - dropTable(t, dba.GetDB(), tableName) + dropTable(t, db, tableName) } // Create the state table and test for its existence. diff --git a/state/mysql/mysql.go b/state/mysql/mysql.go index 4554a5415e..7773a97d3b 100644 --- a/state/mysql/mysql.go +++ b/state/mysql/mysql.go @@ -27,7 +27,7 @@ import ( "github.com/google/uuid" - sqlCleanup "github.com/dapr/components-contrib/internal/component/sql" + internalsql "github.com/dapr/components-contrib/internal/component/sql" "github.com/dapr/components-contrib/metadata" "github.com/dapr/components-contrib/state" "github.com/dapr/components-contrib/state/utils" @@ -89,7 +89,7 @@ type MySQL struct { logger logger.Logger factory iMySQLFactory - gc sqlCleanup.GarbageCollector + gc internalsql.GarbageCollector } type mySQLMetadata struct { @@ -274,19 +274,21 @@ func (m *MySQL) finishInit(ctx context.Context, db *sql.DB) error { } if m.cleanupInterval != nil { - gc, err := sqlCleanup.ScheduleGarbageCollector(sqlCleanup.GCOptions{ + gc, err := internalsql.ScheduleGarbageCollector(internalsql.GCOptions{ Logger: m.logger, - UpdateLastCleanupQuery: fmt.Sprintf(`INSERT INTO %[1]s (id, value) - VALUES ('last-cleanup', CURRENT_TIMESTAMP) - ON DUPLICATE KEY UPDATE - value = IF(CURRENT_TIMESTAMP > DATE_ADD(value, INTERVAL ?*1000 MICROSECOND), CURRENT_TIMESTAMP, value)`, - m.metadataTableName), + UpdateLastCleanupQuery: func(arg any) (string, any) { + return fmt.Sprintf(`INSERT INTO %[1]s (id, value) + VALUES ('last-cleanup', CURRENT_TIMESTAMP) + ON DUPLICATE KEY UPDATE + value = IF(CURRENT_TIMESTAMP > DATE_ADD(value, INTERVAL ?*1000 MICROSECOND), CURRENT_TIMESTAMP, value)`, + m.metadataTableName), arg + }, DeleteExpiredValuesQuery: fmt.Sprintf( `DELETE FROM %s WHERE expiredate IS NOT NULL AND expiredate <= CURRENT_TIMESTAMP`, m.tableName, ), CleanupInterval: *m.cleanupInterval, - DBSql: m.db, + DB: internalsql.AdaptDatabaseSQLConn(m.db), }) if err != nil { return err diff --git a/state/mysql/mysql_test.go b/state/mysql/mysql_test.go index af5cbd31e8..f2dc0515f4 100644 --- a/state/mysql/mysql_test.go +++ b/state/mysql/mysql_test.go @@ -159,7 +159,7 @@ func TestClosingDatabaseTwiceReturnsNil(t *testing.T) { assert.Nil(t, err, "error returned") } -func TestExecuteMultiCannotBeginTransaction(t *testing.T) { +func TestMultiCannotBeginTransaction(t *testing.T) { // Arrange m, _ := mockDatabase(t) defer m.mySQL.Close() @@ -174,7 +174,7 @@ func TestExecuteMultiCannotBeginTransaction(t *testing.T) { assert.Equal(t, "beginError", err.Error(), "wrong error returned") } -func TestExecuteMultiCommitSetsAndDeletes(t *testing.T) { +func TestMultiCommitSetsAndDeletes(t *testing.T) { // Arrange m, _ := mockDatabase(t) defer m.mySQL.Close() diff --git a/state/postgresql/migrations.go b/state/postgresql/migrations.go index 0fef56acc4..ca9ed58bef 100644 --- a/state/postgresql/migrations.go +++ b/state/postgresql/migrations.go @@ -15,218 +15,59 @@ package postgresql import ( "context" - "errors" "fmt" - "strconv" - "strings" - "time" - - "github.com/jackc/pgx/v5" "github.com/dapr/components-contrib/internal/component/postgresql" - "github.com/dapr/kit/logger" + pginterfaces "github.com/dapr/components-contrib/internal/component/postgresql/interfaces" + sqlinternal "github.com/dapr/components-contrib/internal/component/sql" + pgmigrations "github.com/dapr/components-contrib/internal/component/sql/migrations/postgres" ) -// Performs migrations for the database schema -type migrations struct { - logger logger.Logger - stateTableName string - metadataTableName string -} - -// performMigration the required migrations -func performMigration(ctx context.Context, db postgresql.PGXPoolConn, opts postgresql.MigrateOptions) error { - m := &migrations{ - logger: opts.Logger, - stateTableName: opts.StateTableName, - metadataTableName: opts.MetadataTableName, - } - - // Use an advisory lock (with an arbitrary number) to ensure that no one else is performing migrations at the same time - // This is the only way to also ensure we are not running multiple "CREATE TABLE IF NOT EXISTS" at the exact same time - // See: https://www.postgresql.org/message-id/CA+TgmoZAdYVtwBfp1FL2sMZbiHCWT4UPrzRLNnX1Nb30Ku3-gg@mail.gmail.com - const lockID = 42 - - // Long timeout here as this query may block - queryCtx, cancel := context.WithTimeout(ctx, time.Minute) - _, err := db.Exec(queryCtx, "SELECT pg_advisory_lock($1)", lockID) - cancel() - if err != nil { - return fmt.Errorf("faild to acquire advisory lock: %w", err) - } - - // Release the lock - defer func() { - queryCtx, cancel = context.WithTimeout(ctx, time.Minute) - _, err = db.Exec(queryCtx, "SELECT pg_advisory_unlock($1)", lockID) - cancel() - if err != nil { - // Panicking here, as this forcibly closes the session and thus ensures we are not leaving locks hanging around - m.logger.Fatalf("Failed to release advisory lock: %v", err) - } - }() - - // Check if the metadata table exists, which we also use to store the migration level - queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second) - exists, _, _, err := m.tableExists(queryCtx, db, m.metadataTableName) - cancel() - if err != nil { - return err +// Performs the required migrations +func performMigrations(ctx context.Context, db pginterfaces.PGXPoolConn, opts postgresql.MigrateOptions) error { + m := pgmigrations.Migrations{ + DB: db, + Logger: opts.Logger, + MetadataTableName: opts.MetadataTableName, + MetadataKey: "migrations", } - // If the table doesn't exist, create it - if !exists { - queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second) - err = m.createMetadataTable(queryCtx, db) - cancel() - if err != nil { - return err - } - } - - // Select the migration level - var ( - migrationLevelStr string - migrationLevel int - ) - queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second) - err = db.QueryRow(queryCtx, - fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, m.metadataTableName), - ).Scan(&migrationLevelStr) - cancel() - if errors.Is(err, pgx.ErrNoRows) { - // If there's no row... - migrationLevel = 0 - } else if err != nil { - return fmt.Errorf("failed to read migration level: %w", err) - } else { - migrationLevel, err = strconv.Atoi(migrationLevelStr) - if err != nil || migrationLevel < 0 { - return fmt.Errorf("invalid migration level found in metadata table: %s", migrationLevelStr) - } - } - - // Perform the migrations - for i := migrationLevel; i < len(allMigrations); i++ { - m.logger.Infof("Performing migration %d", i) - err = allMigrations[i](ctx, db, m) - if err != nil { - return fmt.Errorf("failed to perform migration %d: %w", i, err) - } - - queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second) - _, err = db.Exec(queryCtx, - fmt.Sprintf(`INSERT INTO %s (key, value) VALUES ('migrations', $1) ON CONFLICT (key) DO UPDATE SET value = $1`, m.metadataTableName), - strconv.Itoa(i+1), - ) - cancel() - if err != nil { - return fmt.Errorf("failed to update migration level in metadata table: %w", err) - } - } - - return nil -} - -func (m migrations) createMetadataTable(ctx context.Context, db postgresql.PGXPoolConn) error { - m.logger.Infof("Creating metadata table '%s'", m.metadataTableName) - // Add an "IF NOT EXISTS" in case another Dapr sidecar is creating the same table at the same time - // In the next step we'll acquire a lock so there won't be issues with concurrency - _, err := db.Exec(ctx, fmt.Sprintf( - `CREATE TABLE IF NOT EXISTS %s ( - key text NOT NULL PRIMARY KEY, - value text NOT NULL - )`, - m.metadataTableName, - )) - if err != nil { - return fmt.Errorf("failed to create metadata table: %w", err) - } - return nil -} - -// If the table exists, returns true and the name of the table and schema -func (m migrations) tableExists(ctx context.Context, db postgresql.PGXPoolConn, tableName string) (exists bool, schema string, table string, err error) { - table, schema, err = m.tableSchemaName(tableName) - if err != nil { - return false, "", "", err - } - - if schema == "" { - err = db.QueryRow( - ctx, - `SELECT table_name, table_schema - FROM information_schema.tables - WHERE table_name = $1`, - table, - ). - Scan(&table, &schema) - } else { - err = db.QueryRow( - ctx, - `SELECT table_name, table_schema - FROM information_schema.tables - WHERE table_schema = $1 AND table_name = $2`, - schema, table, - ). - Scan(&table, &schema) - } - - if err != nil && errors.Is(err, pgx.ErrNoRows) { - return false, "", "", nil - } else if err != nil { - return false, "", "", fmt.Errorf("failed to check if table '%s' exists: %w", tableName, err) - } - return true, schema, table, nil -} - -// If the table name includes a schema (e.g. `schema.table`, returns the two parts separately) -func (m migrations) tableSchemaName(tableName string) (table string, schema string, err error) { - parts := strings.Split(tableName, ".") - switch len(parts) { - case 1: - return parts[0], "", nil - case 2: - return parts[1], parts[0], nil - default: - return "", "", errors.New("invalid table name: must be in the format 'table' or 'schema.table'") - } -} - -var allMigrations = [2]func(ctx context.Context, db postgresql.PGXPoolConn, m *migrations) error{ - // Migration 0: create the state table - func(ctx context.Context, db postgresql.PGXPoolConn, m *migrations) error { - // We need to add an "IF NOT EXISTS" because we may be migrating from when we did not use a metadata table - m.logger.Infof("Creating state table '%s'", m.stateTableName) - _, err := db.Exec( - ctx, - fmt.Sprintf( - `CREATE TABLE IF NOT EXISTS %s ( - key text NOT NULL PRIMARY KEY, - value jsonb NOT NULL, - isbinary boolean NOT NULL, - insertdate TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), - updatedate TIMESTAMP WITH TIME ZONE NULL - )`, - m.stateTableName, - ), - ) - if err != nil { - return fmt.Errorf("failed to create state table: %w", err) - } - return nil - }, - - // Migration 1: add the "expiredate" column - func(ctx context.Context, db postgresql.PGXPoolConn, m *migrations) error { - m.logger.Infof("Adding expiredate column to state table '%s'", m.stateTableName) - _, err := db.Exec(ctx, fmt.Sprintf( - `ALTER TABLE %s ADD expiredate TIMESTAMP WITH TIME ZONE`, - m.stateTableName, - )) - if err != nil { - return fmt.Errorf("failed to update state table: %w", err) - } - return nil + return m.Perform(ctx, []sqlinternal.MigrationFn{ + // Migration 0: create the state table + func(ctx context.Context) error { + // We need to add an "IF NOT EXISTS" because we may be migrating from when we did not use a metadata table + opts.Logger.Infof("Creating state table '%s'", opts.StateTableName) + _, err := db.Exec( + ctx, + fmt.Sprintf( + `CREATE TABLE IF NOT EXISTS %s ( + key text NOT NULL PRIMARY KEY, + value jsonb NOT NULL, + isbinary boolean NOT NULL, + insertdate TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + updatedate TIMESTAMP WITH TIME ZONE NULL + )`, + opts.StateTableName, + ), + ) + if err != nil { + return fmt.Errorf("failed to create state table: %w", err) + } + return nil + }, + + // Migration 1: add the "expiredate" column + func(ctx context.Context) error { + opts.Logger.Infof("Adding expiredate column to state table '%s'", opts.StateTableName) + _, err := db.Exec(ctx, fmt.Sprintf( + `ALTER TABLE %s ADD expiredate TIMESTAMP WITH TIME ZONE`, + opts.StateTableName, + )) + if err != nil { + return fmt.Errorf("failed to update state table: %w", err) + } + return nil + }, }, + ) } diff --git a/state/postgresql/postgresql.go b/state/postgresql/postgresql.go index 84563242ff..0d533ab8bf 100644 --- a/state/postgresql/postgresql.go +++ b/state/postgresql/postgresql.go @@ -24,7 +24,7 @@ func NewPostgreSQLStateStore(logger logger.Logger) state.Store { return postgresql.NewPostgreSQLStateStore(logger, postgresql.Options{ ETagColumn: "xmin", EnableAzureAD: true, - MigrateFn: performMigration, + MigrateFn: performMigrations, SetQueryFn: func(req *state.SetRequest, opts postgresql.SetQueryOptions) string { // Sprintf is required for table name because the driver does not substitute parameters for table names. if !req.HasETag() { diff --git a/state/sqlite/sqlite_dbaccess.go b/state/sqlite/sqlite_dbaccess.go index d1b0b852ac..e05b5c5032 100644 --- a/state/sqlite/sqlite_dbaccess.go +++ b/state/sqlite/sqlite_dbaccess.go @@ -97,26 +97,25 @@ func (a *sqliteDBAccess) Init(ctx context.Context, md state.Metadata) error { } // Performs migrations - migrate := &migrations{ - Logger: a.logger, - Conn: a.db, - MetadataTableName: a.metadata.MetadataTableName, + err = performMigrations(ctx, a.db, a.logger, migrationOptions{ StateTableName: a.metadata.TableName, - } - err = migrate.Perform(ctx) + MetadataTableName: a.metadata.MetadataTableName, + }) if err != nil { return fmt.Errorf("failed to perform migrations: %w", err) } gc, err := internalsql.ScheduleGarbageCollector(internalsql.GCOptions{ Logger: a.logger, - UpdateLastCleanupQuery: fmt.Sprintf(`INSERT INTO %s (key, value) - VALUES ('last-cleanup', CURRENT_TIMESTAMP) - ON CONFLICT (key) - DO UPDATE SET value = CURRENT_TIMESTAMP - WHERE (unixepoch(CURRENT_TIMESTAMP) - unixepoch(value)) * 1000 > ?;`, - a.metadata.MetadataTableName, - ), + UpdateLastCleanupQuery: func(arg any) (string, any) { + return fmt.Sprintf(`INSERT INTO %s (key, value) + VALUES ('last-cleanup', CURRENT_TIMESTAMP) + ON CONFLICT (key) + DO UPDATE SET value = CURRENT_TIMESTAMP + WHERE (unixepoch(CURRENT_TIMESTAMP) - unixepoch(value)) * 1000 > ?;`, + a.metadata.MetadataTableName, + ), arg + }, DeleteExpiredValuesQuery: fmt.Sprintf(`DELETE FROM %s WHERE expiration_time IS NOT NULL @@ -124,7 +123,7 @@ func (a *sqliteDBAccess) Init(ctx context.Context, md state.Metadata) error { a.metadata.TableName, ), CleanupInterval: a.metadata.CleanupInterval, - DBSql: a.db, + DB: internalsql.AdaptDatabaseSQLConn(a.db), }) if err != nil { return err diff --git a/state/sqlite/sqlite_migrations.go b/state/sqlite/sqlite_migrations.go index ef0e10a7e7..b0b2b76ede 100644 --- a/state/sqlite/sqlite_migrations.go +++ b/state/sqlite/sqlite_migrations.go @@ -16,175 +16,50 @@ package sqlite import ( "context" "database/sql" - "errors" "fmt" - "strconv" - "time" + sqlinternal "github.com/dapr/components-contrib/internal/component/sql" + sqlitemigrations "github.com/dapr/components-contrib/internal/component/sql/migrations/sqlite" "github.com/dapr/kit/logger" ) -// Performs migrations for the database schema -type migrations struct { - Logger logger.Logger - Conn *sql.DB +type migrationOptions struct { StateTableName string MetadataTableName string } // Perform the required migrations -func (m *migrations) Perform(ctx context.Context) error { - // Begin an exclusive transaction - // We can't use Begin because that doesn't allow us setting the level of transaction - queryCtx, cancel := context.WithTimeout(ctx, time.Minute) - _, err := m.Conn.ExecContext(queryCtx, "BEGIN EXCLUSIVE TRANSACTION") - cancel() - if err != nil { - return fmt.Errorf("faild to begin transaction: %w", err) +func performMigrations(ctx context.Context, db *sql.DB, logger logger.Logger, opts migrationOptions) error { + m := sqlitemigrations.Migrations{ + Pool: db, + Logger: logger, + MetadataTableName: opts.MetadataTableName, + MetadataKey: "migrations", } - // Rollback the transaction in a deferred statement to catch errors - success := false - defer func() { - if success { - return - } - queryCtx, cancel = context.WithTimeout(ctx, time.Minute) - _, err = m.Conn.ExecContext(queryCtx, "ROLLBACK TRANSACTION") - cancel() - if err != nil { - // Panicking here, as this forcibly closes the session and thus ensures we are not leaving locks hanging around - m.Logger.Fatalf("Failed to rollback transaction: %v", err) - } - }() - - // Check if the metadata table exists, which we also use to store the migration level - queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second) - exists, err := m.tableExists(queryCtx, m.Conn, m.MetadataTableName) - cancel() - if err != nil { - return fmt.Errorf("failed to check if the metadata table exists: %w", err) - } - - // If the table doesn't exist, create it - if !exists { - queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second) - err = m.createMetadataTable(queryCtx, m.Conn) - cancel() - if err != nil { - return fmt.Errorf("failed to create metadata table: %w", err) - } - } - - // Select the migration level - var ( - migrationLevelStr string - migrationLevel int - ) - queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second) - err = m.Conn.QueryRowContext(queryCtx, - fmt.Sprintf(`SELECT value FROM %s WHERE key = 'migrations'`, m.MetadataTableName), - ).Scan(&migrationLevelStr) - cancel() - if errors.Is(err, sql.ErrNoRows) { - // If there's no row... - migrationLevel = 0 - } else if err != nil { - return fmt.Errorf("failed to read migration level: %w", err) - } else { - migrationLevel, err = strconv.Atoi(migrationLevelStr) - if err != nil || migrationLevel < 0 { - return fmt.Errorf("invalid migration level found in metadata table: %s", migrationLevelStr) - } - } - - // Perform the migrations - for i := migrationLevel; i < len(allMigrations); i++ { - m.Logger.Infof("Performing migration %d", i+1) - err = allMigrations[i](ctx, m.Conn, m) - if err != nil { - return fmt.Errorf("failed to perform migration %d: %w", i+1, err) - } - - queryCtx, cancel = context.WithTimeout(ctx, 30*time.Second) - _, err = m.Conn.ExecContext(queryCtx, - fmt.Sprintf(`REPLACE INTO %s (key, value) VALUES ('migrations', ?)`, m.MetadataTableName), - strconv.Itoa(i+1), - ) - cancel() - if err != nil { - return fmt.Errorf("failed to update migration level in metadata table: %w", err) - } - } - - // Commit the transaction - queryCtx, cancel = context.WithTimeout(ctx, time.Minute) - _, err = m.Conn.ExecContext(queryCtx, "COMMIT TRANSACTION") - cancel() - if err != nil { - return fmt.Errorf("failed to commit transaction") - } - - // Set success to true so we don't also run a rollback - success = true - - return nil -} - -// Returns true if a table exists -func (m migrations) tableExists(parentCtx context.Context, db querier, tableName string) (bool, error) { - ctx, cancel := context.WithTimeout(parentCtx, 30*time.Second) - defer cancel() - - var exists string - // Returns 1 or 0 as a string if the table exists or not. - const q = `SELECT EXISTS ( - SELECT name FROM sqlite_master WHERE type='table' AND name = ? - ) AS 'exists'` - err := db.QueryRowContext(ctx, q, m.MetadataTableName). - Scan(&exists) - return exists == "1", err -} - -func (m migrations) createMetadataTable(ctx context.Context, db querier) error { - m.Logger.Infof("Creating metadata table '%s' if it doesn't exist", m.MetadataTableName) - // Add an "IF NOT EXISTS" in case another Dapr sidecar is creating the same table at the same time - // In the next step we'll acquire a lock so there won't be issues with concurrency - _, err := db.ExecContext(ctx, fmt.Sprintf( - `CREATE TABLE IF NOT EXISTS %s ( - key text NOT NULL PRIMARY KEY, - value text NOT NULL - )`, - m.MetadataTableName, - )) - if err != nil { - return fmt.Errorf("failed to create metadata table: %w", err) - } - return nil -} - -var allMigrations = [1]func(ctx context.Context, db querier, m *migrations) error{ - // Migration 0: create the state table - func(ctx context.Context, db querier, m *migrations) error { - // We need to add an "IF NOT EXISTS" because we may be migrating from when we did not use a metadata table - m.Logger.Infof("Creating state table '%s'", m.StateTableName) - _, err := db.ExecContext( - ctx, - fmt.Sprintf( - `CREATE TABLE IF NOT EXISTS %s ( - key TEXT NOT NULL PRIMARY KEY, - value TEXT NOT NULL, - is_binary BOOLEAN NOT NULL, - etag TEXT NOT NULL, - expiration_time TIMESTAMP DEFAULT NULL, - update_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP - )`, - m.StateTableName, - ), - ) - if err != nil { - return fmt.Errorf("failed to create state table: %w", err) - } - return nil - }, + return m.Perform(ctx, []sqlinternal.MigrationFn{ + // Migration 0: create the state table + func(ctx context.Context) error { + // We need to add an "IF NOT EXISTS" because we may be migrating from when we did not use a metadata table + logger.Infof("Creating state table '%s'", opts.StateTableName) + _, err := m.GetConn().ExecContext( + ctx, + fmt.Sprintf( + `CREATE TABLE IF NOT EXISTS %s ( + key TEXT NOT NULL PRIMARY KEY, + value TEXT NOT NULL, + is_binary BOOLEAN NOT NULL, + etag TEXT NOT NULL, + expiration_time TIMESTAMP DEFAULT NULL, + update_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + )`, + opts.StateTableName, + ), + ) + if err != nil { + return fmt.Errorf("failed to create state table: %w", err) + } + return nil + }, + }) } diff --git a/state/sqlserver/sqlserver.go b/state/sqlserver/sqlserver.go index 77623d6152..1cce506040 100644 --- a/state/sqlserver/sqlserver.go +++ b/state/sqlserver/sqlserver.go @@ -143,21 +143,22 @@ func (s *SQLServer) Init(ctx context.Context, metadata state.Metadata) error { func (s *SQLServer) startGC() error { gc, err := internalsql.ScheduleGarbageCollector(internalsql.GCOptions{ Logger: s.logger, - UpdateLastCleanupQuery: fmt.Sprintf(`BEGIN TRANSACTION; + UpdateLastCleanupQuery: func(arg any) (string, any) { + return fmt.Sprintf(`BEGIN TRANSACTION; BEGIN TRY INSERT INTO [%[1]s].[%[2]s] ([Key], [Value]) VALUES ('last-cleanup', CONVERT(nvarchar(MAX), GETDATE(), 21)); END TRY BEGIN CATCH UPDATE [%[1]s].[%[2]s] SET [Value] = CONVERT(nvarchar(MAX), GETDATE(), 21) WHERE [Key] = 'last-cleanup' AND Datediff_big(MS, [Value], GETUTCDATE()) > @Interval END CATCH -COMMIT TRANSACTION;`, s.metadata.Schema, s.metadata.MetadataTableName), - UpdateLastCleanupQueryParameterName: "Interval", +COMMIT TRANSACTION;`, s.metadata.Schema, s.metadata.MetadataTableName), sql.Named("Interval", arg) + }, DeleteExpiredValuesQuery: fmt.Sprintf( `DELETE FROM [%s].[%s] WHERE [ExpireDate] IS NOT NULL AND [ExpireDate] < GETDATE()`, s.metadata.Schema, s.metadata.TableName, ), CleanupInterval: *s.metadata.CleanupInterval, - DBSql: s.db, + DB: internalsql.AdaptDatabaseSQLConn(s.db), }) if err != nil { return err diff --git a/tests/certification/state/cockroachdb/cockroachdb_test.go b/tests/certification/state/cockroachdb/cockroachdb_test.go index ab85d2c3c4..31f67ec8a3 100644 --- a/tests/certification/state/cockroachdb/cockroachdb_test.go +++ b/tests/certification/state/cockroachdb/cockroachdb_test.go @@ -284,10 +284,7 @@ func TestCockroach(t *testing.T) { require.NoError(t, err, "failed to init") defer storeObj.Close() - dbAccess := storeObj.GetDBAccess().(*postgresql.PostgresDBAccess) - require.NotNil(t, dbAccess) - - cleanupInterval := dbAccess.GetCleanupInterval() + cleanupInterval := storeObj.GetCleanupInterval() require.NotNil(t, cleanupInterval) assert.Equal(t, time.Duration(1*time.Hour), *cleanupInterval) }) @@ -301,10 +298,7 @@ func TestCockroach(t *testing.T) { require.NoError(t, err, "failed to init") defer storeObj.Close() - dbAccess := storeObj.GetDBAccess().(*postgresql.PostgresDBAccess) - require.NotNil(t, dbAccess) - - cleanupInterval := dbAccess.GetCleanupInterval() + cleanupInterval := storeObj.GetCleanupInterval() require.NotNil(t, cleanupInterval) assert.Equal(t, time.Duration(10*time.Second), *cleanupInterval) }) @@ -318,10 +312,7 @@ func TestCockroach(t *testing.T) { require.NoError(t, err, "failed to init") defer storeObj.Close() - dbAccess := storeObj.GetDBAccess().(*postgresql.PostgresDBAccess) - require.NotNil(t, dbAccess) - - cleanupInterval := dbAccess.GetCleanupInterval() + cleanupInterval := storeObj.GetCleanupInterval() _ = assert.Nil(t, cleanupInterval) }) }) @@ -408,9 +399,7 @@ func TestCockroach(t *testing.T) { require.NotEmpty(t, lastCleanupValueOrig) // Trigger the background cleanup, which should do nothing because the last cleanup was < 3600s - dbAccess := storeObj.GetDBAccess().(*postgresql.PostgresDBAccess) - require.NotNil(t, dbAccess) - err = dbAccess.CleanupExpired() + err = storeObj.CleanupExpired() require.NoError(t, err, "CleanupExpired returned an error") // Validate that 20 records are still present diff --git a/tests/certification/state/postgresql/postgresql_test.go b/tests/certification/state/postgresql/postgresql_test.go index c0d8345cdf..d62cae85cd 100644 --- a/tests/certification/state/postgresql/postgresql_test.go +++ b/tests/certification/state/postgresql/postgresql_test.go @@ -498,10 +498,7 @@ func TestPostgreSQL(t *testing.T) { require.NoError(t, err, "failed to init") defer storeObj.Close() - dbAccess := storeObj.GetDBAccess().(*state_postgres.PostgresDBAccess) - require.NotNil(t, dbAccess) - - cleanupInterval := dbAccess.GetCleanupInterval() + cleanupInterval := storeObj.GetCleanupInterval() _ = assert.NotNil(t, cleanupInterval) && assert.Equal(t, time.Duration(1*time.Hour), *cleanupInterval) }) @@ -515,10 +512,7 @@ func TestPostgreSQL(t *testing.T) { require.NoError(t, err, "failed to init") defer storeObj.Close() - dbAccess := storeObj.GetDBAccess().(*state_postgres.PostgresDBAccess) - require.NotNil(t, dbAccess) - - cleanupInterval := dbAccess.GetCleanupInterval() + cleanupInterval := storeObj.GetCleanupInterval() _ = assert.NotNil(t, cleanupInterval) && assert.Equal(t, time.Duration(10*time.Second), *cleanupInterval) }) @@ -532,10 +526,7 @@ func TestPostgreSQL(t *testing.T) { require.NoError(t, err, "failed to init") defer storeObj.Close() - dbAccess := storeObj.GetDBAccess().(*state_postgres.PostgresDBAccess) - require.NotNil(t, dbAccess) - - cleanupInterval := dbAccess.GetCleanupInterval() + cleanupInterval := storeObj.GetCleanupInterval() _ = assert.Nil(t, cleanupInterval) }) @@ -599,9 +590,6 @@ func TestPostgreSQL(t *testing.T) { require.NoError(t, err, "failed to init") defer storeObj.Close() - dbAccess := storeObj.GetDBAccess().(*state_postgres.PostgresDBAccess) - require.NotNil(t, dbAccess) - // Seed the database with some records err = populateTTLRecords(ctx, dbClient) require.NoError(t, err, "failed to seed records") @@ -624,7 +612,7 @@ func TestPostgreSQL(t *testing.T) { require.NotEmpty(t, lastCleanupValueOrig) // Trigger the background cleanup, which should do nothing because the last cleanup was < 3600s - err = dbAccess.CleanupExpired() + err = storeObj.CleanupExpired() require.NoError(t, err, "CleanupExpired returned an error") // Validate that 20 records are still present