diff --git a/mocks/authmocks/plugin.go b/mocks/authmocks/plugin.go index 4ff0804..85e7814 100644 --- a/mocks/authmocks/plugin.go +++ b/mocks/authmocks/plugin.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.40.1. DO NOT EDIT. +// Code generated by mockery v2.40.2. DO NOT EDIT. package authmocks diff --git a/mocks/crudmocks/crud.go b/mocks/crudmocks/crud.go index 5961966..3352a41 100644 --- a/mocks/crudmocks/crud.go +++ b/mocks/crudmocks/crud.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.40.1. DO NOT EDIT. +// Code generated by mockery v2.40.2. DO NOT EDIT. package crudmocks diff --git a/mocks/dbmigratemocks/driver.go b/mocks/dbmigratemocks/driver.go index a61c374..d5a25f2 100644 --- a/mocks/dbmigratemocks/driver.go +++ b/mocks/dbmigratemocks/driver.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.40.1. DO NOT EDIT. +// Code generated by mockery v2.40.2. DO NOT EDIT. package dbmigratemocks diff --git a/mocks/httpservermocks/go_http_server.go b/mocks/httpservermocks/go_http_server.go index f4a73df..31e13d3 100644 --- a/mocks/httpservermocks/go_http_server.go +++ b/mocks/httpservermocks/go_http_server.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.40.1. DO NOT EDIT. +// Code generated by mockery v2.40.2. DO NOT EDIT. package httpservermocks diff --git a/mocks/wsservermocks/protocol.go b/mocks/wsservermocks/protocol.go index 10863ae..999545d 100644 --- a/mocks/wsservermocks/protocol.go +++ b/mocks/wsservermocks/protocol.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.40.1. DO NOT EDIT. +// Code generated by mockery v2.40.2. DO NOT EDIT. package wsservermocks diff --git a/mocks/wsservermocks/web_socket_server.go b/mocks/wsservermocks/web_socket_server.go index 45c6562..a1183c6 100644 --- a/mocks/wsservermocks/web_socket_server.go +++ b/mocks/wsservermocks/web_socket_server.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.40.1. DO NOT EDIT. +// Code generated by mockery v2.40.2. DO NOT EDIT. package wsservermocks diff --git a/pkg/config/cobracmd.go b/pkg/config/cobracmd.go index fa2d81d..595373e 100644 --- a/pkg/config/cobracmd.go +++ b/pkg/config/cobracmd.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Kaleido, Inc. +// Copyright © 2024 Kaleido, Inc. // // SPDX-License-Identifier: Apache-2.0 // @@ -27,7 +27,7 @@ func ShowConfigCommand(initConf func() error) *cobra.Command { Use: "showconfig", Aliases: []string{"showconf"}, Short: "List out the configuration options", - RunE: func(cmd *cobra.Command, args []string) error { + RunE: func(_ *cobra.Command, _ []string) error { if err := initConf(); err != nil { return err } diff --git a/pkg/dbsql/crud.go b/pkg/dbsql/crud.go index 2c951d3..2e885d5 100644 --- a/pkg/dbsql/crud.go +++ b/pkg/dbsql/crud.go @@ -19,6 +19,7 @@ package dbsql import ( "context" "database/sql" + "database/sql/driver" "fmt" "reflect" "strings" @@ -50,6 +51,7 @@ const ( UpsertOptimizationSkip UpsertOptimization = iota UpsertOptimizationNew UpsertOptimizationExisting + UpsertOptimizationDB // only supported if the DB layer support ON CONFLICT semantics ) type GetOption int @@ -129,6 +131,7 @@ type CrudBase[T Resource] struct { TimesDisabled bool // no management of the time columns PatchDisabled bool // allows non-pointer fields, but prevents UpdateSparse function ImmutableColumns []string + IDField string // override default ID field NameField string // If supporting name semantics QueryFactory ffapi.QueryFactory // Must be set when name is set DefaultSort func() []interface{} // optionally override the default sort - array of *ffapi.SortField or string @@ -144,6 +147,7 @@ type CrudBase[T Resource] struct { ReadTableAlias string ReadOnlyColumns []string ReadQueryModifier QueryModifier + AfterLoad func(ctx context.Context, inst T) error // perform final validation/formatting after an instance is loaded from db } func (c *CrudBase[T]) Scoped(scope sq.Eq) CRUD[T] { @@ -159,6 +163,13 @@ func (c *CrudBase[T]) TableAlias() string { return c.Table } +func (c *CrudBase[T]) GetIDField() string { + if c.IDField != "" { + return c.IDField + } + return ColumnID +} + func (c *CrudBase[T]) GetQueryFactory() ffapi.QueryFactory { return c.QueryFactory } @@ -212,7 +223,7 @@ func (c *CrudBase[T]) Validate() { ptrs := map[string]interface{}{} fieldMap := map[string]bool{ // Mandatory column checks - ColumnID: false, + c.GetIDField(): false, } if !c.TimesDisabled { fieldMap[ColumnCreated] = false @@ -260,25 +271,30 @@ func (c *CrudBase[T]) idFilter(id string) sq.Eq { filter = sq.Eq{} } if c.ReadTableAlias != "" { - filter[fmt.Sprintf("%s.id", c.ReadTableAlias)] = id + filter[fmt.Sprintf("%s.%s", c.ReadTableAlias, c.GetIDField())] = id } else { filter["id"] = id } return filter } +func (c *CrudBase[T]) isImmutable(col string) bool { + for _, immutable := range append(c.ImmutableColumns, c.GetIDField(), ColumnCreated, ColumnUpdated, c.DB.sequenceColumn) { + if col == immutable { + return true + } + } + return false +} + func (c *CrudBase[T]) buildUpdateList(_ context.Context, update sq.UpdateBuilder, inst T, includeNil bool) sq.UpdateBuilder { -colLoop: for _, col := range c.Columns { - for _, immutable := range append(c.ImmutableColumns, ColumnID, ColumnCreated, ColumnUpdated, c.DB.sequenceColumn) { - if col == immutable { - continue colLoop + if !c.isImmutable(col) { + value := c.getFieldValue(inst, col) + if includeNil || !isNil(value) { + update = update.Set(col, value) } } - value := c.getFieldValue(inst, col) - if includeNil || !isNil(value) { - update = update.Set(col, value) - } } if !c.TimesDisabled { update = update.Set(ColumnUpdated, fftypes.Now()) @@ -323,9 +339,8 @@ func (c *CrudBase[T]) getFieldValue(inst T, col string) interface{} { return val } -func (c *CrudBase[T]) setInsertTimestamps(inst T) { +func (c *CrudBase[T]) setInsertTimestamps(inst T, now *fftypes.FFTime) { if !c.TimesDisabled { - now := fftypes.Now() inst.SetCreated(now) inst.SetUpdated(now) } @@ -344,7 +359,7 @@ func (c *CrudBase[T]) attemptInsert(ctx context.Context, tx *TXWrapper, inst T, } } - c.setInsertTimestamps(inst) + c.setInsertTimestamps(inst, fftypes.Now()) insert := sq.Insert(c.Table).Columns(c.Columns...) values := make([]interface{}, len(c.Columns)) for i, col := range c.Columns { @@ -374,11 +389,18 @@ func (c *CrudBase[T]) Upsert(ctx context.Context, inst T, optimization UpsertOpt // The expectation is that the optimization will hit almost all of the time, // as only recovery paths require us to go down the un-optimized route. optimized := false - if optimization == UpsertOptimizationNew { + switch { + case optimization == UpsertOptimizationDB && c.DB.features.DBOptimizedUpsertBuilder != nil: + optimized = true // the DB does the work here, so any failure is a straight failure + created, err = c.dbOptimizedUpsert(ctx, tx, inst) + if err != nil { + return false, err + } + case optimization == UpsertOptimizationNew: opErr := c.attemptInsert(ctx, tx, inst, true /* we want a failure here we can progress past */) optimized = opErr == nil created = optimized - } else if optimization == UpsertOptimizationExisting { + default: // UpsertOptimizationExisting, or fallback if DB optimization requested rowsAffected, opErr := c.updateFromInstance(ctx, tx, inst, true /* full replace */) optimized = opErr == nil && rowsAffected == 1 } @@ -417,6 +439,47 @@ func (c *CrudBase[T]) Upsert(ctx context.Context, inst T, optimization UpsertOpt return created, c.DB.CommitTx(ctx, tx, autoCommit) } +func (c *CrudBase[T]) dbOptimizedUpsert(ctx context.Context, tx *TXWrapper, inst T) (created bool, err error) { + + // Caller responsible for checking this is available before driving this path + optimizedInsertBuilder := c.DB.provider.Features().DBOptimizedUpsertBuilder + + if c.IDValidator != nil { + if err := c.IDValidator(ctx, inst.GetID()); err != nil { + return false, err + } + } + now := fftypes.Now() + c.setInsertTimestamps(inst, now) + + values := make(map[string]driver.Value) + updateCols := make([]string, 0, len(c.Columns)) + for _, col := range c.Columns { + values[col] = c.getFieldValue(inst, col) + if !c.isImmutable(col) { + updateCols = append(updateCols, col) + } + } + var rows *sql.Rows + query, err := optimizedInsertBuilder(ctx, c.Table, c.GetIDField(), c.Columns, updateCols, ColumnCreated, values) + if err == nil { + rows, _, err = c.DB.RunAsQueryTx(ctx, c.Table, tx, query.PlaceholderFormat(c.DB.features.PlaceholderFormat)) + } + if err != nil { + return false, err + } + defer rows.Close() + if rows.Next() { + var createTime fftypes.FFTime + if err = rows.Scan(&createTime); err != nil { + return false, i18n.NewError(ctx, i18n.MsgDBReadInsertTSFailed, err) + } + created = !createTime.Time().Before(*now.Time()) + } + return created, nil + +} + func (c *CrudBase[T]) InsertMany(ctx context.Context, instances []T, allowPartialSuccess bool, hooks ...PostCompletionHook) (err error) { ctx, tx, autoCommit, err := c.DB.BeginOrUseTx(ctx) @@ -427,7 +490,7 @@ func (c *CrudBase[T]) InsertMany(ctx context.Context, instances []T, allowPartia if c.DB.Features().MultiRowInsert { insert := sq.Insert(c.Table).Columns(c.Columns...) for _, inst := range instances { - c.setInsertTimestamps(inst) + c.setInsertTimestamps(inst, fftypes.Now()) values := make([]interface{}, len(c.Columns)) for i, col := range c.Columns { values[i] = c.getFieldValue(inst, col) @@ -626,6 +689,9 @@ func (c *CrudBase[T]) GetByID(ctx context.Context, id string, getOpts ...GetOpti if err != nil { return c.NilValue(), err } + if c.AfterLoad != nil { + return inst, c.AfterLoad(ctx, inst) + } return inst, nil } @@ -728,6 +794,12 @@ func (c *CrudBase[T]) getManyScoped(ctx context.Context, tableFrom string, fi *f if err != nil { return nil, nil, err } + if c.AfterLoad != nil { + err = c.AfterLoad(ctx, inst) + if err != nil { + return nil, nil, err + } + } instances = append(instances, inst) } log.L(ctx).Debugf("SQL<- GetMany(%s): %d", c.Table, len(instances)) diff --git a/pkg/dbsql/crud_test.go b/pkg/dbsql/crud_test.go index a07998f..43ed423 100644 --- a/pkg/dbsql/crud_test.go +++ b/pkg/dbsql/crud_test.go @@ -23,6 +23,7 @@ import ( "fmt" "os" "testing" + "time" "github.com/DATA-DOG/go-sqlmock" "github.com/Masterminds/squirrel" @@ -377,10 +378,19 @@ func TestCRUDWithDBEnd2End(t *testing.T) { assert.Equal(t, Created, collection.events[0]) collection.events = nil + // Install an AfterLoad handler + afterLoadCalled := false + collection.AfterLoad = func(ctx context.Context, inst *TestCRUDable) error { + afterLoadCalled = true + return nil + } + // Check we get it back c1copy, err := iCrud.GetByID(ctx, c1.ID.String()) assert.NoError(t, err) checkEqualExceptTimes(t, *c1, *c1copy) + assert.True(t, afterLoadCalled) + collection.AfterLoad = nil // Check we get it back by name c1copy, err = iCrud.GetByName(ctx, *c1.Name) @@ -412,6 +422,14 @@ func TestCRUDWithDBEnd2End(t *testing.T) { assert.NoError(t, err) checkEqualExceptTimes(t, *c1, *c1copy) + // Check AfterLoad error behavior + collection.AfterLoad = func(ctx context.Context, inst *TestCRUDable) error { + return fmt.Errorf("pop") + } + _, _, err = iCrud.GetMany(ctx, CRUDableQueryFactory.NewFilter(ctx).And()) + assert.EqualError(t, err, "pop") + collection.AfterLoad = nil + // Upsert the existing row optimized c1copy.Field1 = ptrTo("hello again - 1") created, err := iCrud.Upsert(ctx, c1copy, UpsertOptimizationExisting) @@ -759,6 +777,96 @@ func TestUpsertFailUpdate(t *testing.T) { assert.NoError(t, mock.ExpectationsWereMet()) } +func TestUpsertPSQLOptimizedCreated(t *testing.T) { + after := (fftypes.FFTime)(fftypes.Now().Time().Add(1 * time.Hour)) + mp := NewMockProvider() + mp.FakePSQLUpsertOptimization = true + db, mock := mp.UTInit() + tc := newCRUDCollection(&db.Database, "ns1") + mock.ExpectBegin() + mock.ExpectQuery("INSERT INTO crudables.*ON CONFLICT .* DO UPDATE SET.*RETURNING created").WillReturnRows( + sqlmock.NewRows([]string{"created"}).AddRow(after.String()), + ) + mock.ExpectCommit() + created, err := tc.Upsert(context.Background(), &TestCRUDable{ + ResourceBase: ResourceBase{ + ID: fftypes.NewUUID(), + }, + }, UpsertOptimizationDB) + assert.NoError(t, err) + assert.True(t, created) + assert.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUpsertPSQLOptimizedUpdated(t *testing.T) { + before := fftypes.Now() + mp := NewMockProvider() + mp.FakePSQLUpsertOptimization = true + db, mock := mp.UTInit() + tc := newCRUDCollection(&db.Database, "ns1") + mock.ExpectBegin() + mock.ExpectQuery("INSERT INTO crudables.*ON CONFLICT .* DO UPDATE SET.*RETURNING created").WillReturnRows( + sqlmock.NewRows([]string{"created"}).AddRow(before.String()), + ) + mock.ExpectCommit() + created, err := tc.Upsert(context.Background(), &TestCRUDable{ + ResourceBase: ResourceBase{ + ID: fftypes.NewUUID(), + }, + }, UpsertOptimizationDB) + assert.NoError(t, err) + assert.False(t, created) + assert.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUpsertPSQLOptimizedBadID(t *testing.T) { + mp := NewMockProvider() + mp.FakePSQLUpsertOptimization = true + db, mock := mp.UTInit() + tc := newCRUDCollection(&db.Database, "ns1") + mock.ExpectBegin() + mock.ExpectRollback() + _, err := tc.Upsert(context.Background(), &TestCRUDable{}, UpsertOptimizationDB) + assert.Regexp(t, "FF00138", err) + assert.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUpsertPSQLOptimizedQueryFail(t *testing.T) { + mp := NewMockProvider() + mp.FakePSQLUpsertOptimization = true + db, mock := mp.UTInit() + tc := newCRUDCollection(&db.Database, "ns1") + mock.ExpectBegin() + mock.ExpectQuery("INSERT INTO crudables.*ON CONFLICT .* DO UPDATE SET.*RETURNING created").WillReturnError(fmt.Errorf("pop")) + mock.ExpectRollback() + _, err := tc.Upsert(context.Background(), &TestCRUDable{ + ResourceBase: ResourceBase{ + ID: fftypes.NewUUID(), + }, + }, UpsertOptimizationDB) + assert.Regexp(t, "FF00176.*pop", err) + assert.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUpsertPSQLOptimizedBadTimeReturn(t *testing.T) { + mp := NewMockProvider() + mp.FakePSQLUpsertOptimization = true + db, mock := mp.UTInit() + tc := newCRUDCollection(&db.Database, "ns1") + mock.ExpectBegin() + mock.ExpectQuery("INSERT INTO crudables.*ON CONFLICT .* DO UPDATE SET.*RETURNING created").WillReturnRows( + sqlmock.NewRows([]string{"created"}).AddRow("!!!this is not a time!!!"), + ) + mock.ExpectRollback() + _, err := tc.Upsert(context.Background(), &TestCRUDable{ + ResourceBase: ResourceBase{ + ID: fftypes.NewUUID(), + }, + }, UpsertOptimizationDB) + assert.Regexp(t, "FF00248.*FF00136", err) + assert.NoError(t, mock.ExpectationsWereMet()) +} + func TestInsertManyBeginFail(t *testing.T) { db, mock := NewMockProvider().UTInit() tc := newCRUDCollection(&db.Database, "ns1") @@ -1294,3 +1402,24 @@ func TestValidateNameSemanticsWithoutQueryFactory(t *testing.T) { tc.Validate() }) } + +func TestCustomIDColumn(t *testing.T) { + db, _ := NewMockProvider().UTInit() + tc := &CrudBase[*TestCRUDable]{ + DB: &db.Database, + NewInstance: func() *TestCRUDable { return &TestCRUDable{} }, + NilValue: func() *TestCRUDable { return nil }, + IDField: "f1", + Columns: []string{"f1"}, + TimesDisabled: true, + PatchDisabled: true, + GetFieldPtr: func(inst *TestCRUDable, col string) interface{} { + if col == "id" { + var t *string + return &t + } + return nil + }, + } + tc.Validate() +} diff --git a/pkg/dbsql/database.go b/pkg/dbsql/database.go index 549832f..df0434d 100644 --- a/pkg/dbsql/database.go +++ b/pkg/dbsql/database.go @@ -34,6 +34,8 @@ import ( _ "github.com/golang-migrate/migrate/v4/source/file" ) +type QueryModifier = func(sq.SelectBuilder) (sq.SelectBuilder, error) + type Database struct { db *sql.DB provider Provider @@ -42,8 +44,6 @@ type Database struct { sequenceColumn string } -type QueryModifier = func(sq.SelectBuilder) (sq.SelectBuilder, error) - // PreCommitAccumulator is a structure that can accumulate state during // the transaction, then has a function that is called just before commit. type PreCommitAccumulator interface { @@ -202,12 +202,17 @@ func (s *Database) QueryTx(ctx context.Context, table string, tx *TXWrapper, q s // in the read operations (read after insert for example). tx = GetTXFromContext(ctx) } + return s.RunAsQueryTx(ctx, table, tx, q.PlaceholderFormat(s.features.PlaceholderFormat)) +} + +func (s *Database) RunAsQueryTx(ctx context.Context, table string, tx *TXWrapper, q sq.Sqlizer) (*sql.Rows, *TXWrapper, error) { l := log.L(ctx) - sqlQuery, args, err := q.PlaceholderFormat(s.features.PlaceholderFormat).ToSql() + sqlQuery, args, err := q.ToSql() if err != nil { return nil, tx, i18n.WrapError(ctx, err, i18n.MsgDBQueryBuildFailed) } + before := time.Now() l.Tracef(`SQL-> query: %s (args: %+v)`, sqlQuery, args) var rows *sql.Rows diff --git a/pkg/dbsql/filter_sql.go b/pkg/dbsql/filter_sql.go index ca80d50..1730675 100644 --- a/pkg/dbsql/filter_sql.go +++ b/pkg/dbsql/filter_sql.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Kaleido, Inc. +// Copyright © 2024 Kaleido, Inc. // // SPDX-License-Identifier: Apache-2.0 // @@ -206,7 +206,7 @@ func (s *Database) mapFieldName(tableName, fieldName string, tm map[string]strin field = mf } } - if tableName != "" { + if tableName != "" && !strings.Contains(field, ".") { field = fmt.Sprintf("%s.%s", tableName, field) } return field diff --git a/pkg/dbsql/mock_provider.go b/pkg/dbsql/mock_provider.go index b01c75a..1aeacdd 100644 --- a/pkg/dbsql/mock_provider.go +++ b/pkg/dbsql/mock_provider.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Kaleido, Inc. +// Copyright © 2024 Kaleido, Inc. // // SPDX-License-Identifier: Apache-2.0 // @@ -41,11 +41,12 @@ type MockProvider struct { } type MockProviderConfig struct { - FakePSQLInsert bool - OpenError error - GetMigrationDriverError error - IndividualSort bool - MultiRowInsert bool + FakePSQLInsert bool + OpenError error + GetMigrationDriverError error + IndividualSort bool + MultiRowInsert bool + FakePSQLUpsertOptimization bool } func NewMockProvider() *MockProvider { @@ -87,6 +88,9 @@ func (mp *MockProvider) Features() SQLFeatures { return fmt.Sprintf(``, lockName) } features.MultiRowInsert = mp.MultiRowInsert + if mp.FakePSQLUpsertOptimization { + features.DBOptimizedUpsertBuilder = BuildPostgreSQLOptimizedUpsert + } return features } diff --git a/pkg/dbsql/postgres_helpers.go b/pkg/dbsql/postgres_helpers.go new file mode 100644 index 0000000..e9085ee --- /dev/null +++ b/pkg/dbsql/postgres_helpers.go @@ -0,0 +1,47 @@ +// Copyright © 2024 Kaleido, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dbsql + +import ( + "context" + "database/sql/driver" + "fmt" + "strings" + + sq "github.com/Masterminds/squirrel" + "github.com/hyperledger/firefly-common/pkg/i18n" +) + +// PostgreSQL helper to avoid implementing this lots of times in child packages +func BuildPostgreSQLOptimizedUpsert(ctx context.Context, table string, idColumn string, insertCols, updateCols []string, returnCol string, values map[string]driver.Value) (insert sq.InsertBuilder, err error) { + insertValues := make([]interface{}, 0, len(insertCols)) + for _, c := range insertCols { + insertValues = append(insertValues, values[c]) + } + insert = sq.Insert(table).Columns(insertCols...).Values(insertValues...) + update := sq.Update("REMOVED_BEFORE_RUNNING" /* cheat to avoid table name */) + for _, c := range updateCols { + update = update.Set(c, values[c]) + } + updateSQL, updateValues, err := update.ToSql() + updateSQL, ok := strings.CutPrefix(updateSQL, "UPDATE REMOVED_BEFORE_RUNNING ") + if err != nil || !ok { + return insert, i18n.NewError(ctx, i18n.MsgDBErrorBuildingStatement, err) + } + return insert.Suffix(fmt.Sprintf("ON CONFLICT (%s) DO UPDATE", idColumn)).SuffixExpr(sq.Expr(updateSQL, updateValues...)).Suffix("RETURNING " + returnCol), nil + +} diff --git a/pkg/dbsql/postgres_helpers_test.go b/pkg/dbsql/postgres_helpers_test.go new file mode 100644 index 0000000..0e84c15 --- /dev/null +++ b/pkg/dbsql/postgres_helpers_test.go @@ -0,0 +1,62 @@ +// Copyright © 2024 Kaleido, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dbsql + +import ( + "context" + "database/sql/driver" + "testing" + + "github.com/hyperledger/firefly-common/pkg/fftypes" + "github.com/stretchr/testify/assert" +) + +func TestBuildPostgreSQLOptimizedUpsert(t *testing.T) { + + now := fftypes.Now() + q, err := BuildPostgreSQLOptimizedUpsert(context.Background(), "table1", "id", []string{ + "created", + "updated", + "mutable_col", + "immutable_col", + }, []string{ + "updated", + "mutable_col", + }, "created", map[string]driver.Value{ + "created": now, + "updated": now, + "mutable_col": "value1", + "immutable_col": "value2", + }) + assert.NoError(t, err) + + queryStr, values, err := q.ToSql() + assert.NoError(t, err) + assert.Equal(t, "INSERT INTO table1 (created,updated,mutable_col,immutable_col) VALUES (?,?,?,?) ON CONFLICT (id) DO UPDATE SET updated = ?, mutable_col = ? RETURNING created", queryStr) + assert.Equal(t, []interface{}{ + now, now, "value1", "value2", + now, "value1", + }, values) + +} + +func TestBuildPostgreSQLOptimizedUpsertFail(t *testing.T) { + + _, err := BuildPostgreSQLOptimizedUpsert(context.Background(), "", "", []string{}, []string{}, "", map[string]driver.Value{}) + assert.Regexp(t, "FF00247", err) + +} diff --git a/pkg/dbsql/provider.go b/pkg/dbsql/provider.go index 758b175..3598f8f 100644 --- a/pkg/dbsql/provider.go +++ b/pkg/dbsql/provider.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Kaleido, Inc. +// Copyright © 2024 Kaleido, Inc. // // SPDX-License-Identifier: Apache-2.0 // @@ -17,7 +17,9 @@ package dbsql import ( + "context" "database/sql" + "database/sql/driver" sq "github.com/Masterminds/squirrel" migratedb "github.com/golang-migrate/migrate/v4/database" @@ -28,6 +30,9 @@ type SQLFeatures struct { MultiRowInsert bool PlaceholderFormat sq.PlaceholderFormat AcquireLock func(lockName string) string + // DB specific query builder for RDBMS-side optimized upsert, returning the requested column from the query + // (the CRUD layer will request the create time column to detect if the record was new or not) + DBOptimizedUpsertBuilder func(ctx context.Context, table string, idColumn string, insertCols, updateCols []string, returnCol string, values map[string]driver.Value) (sq.InsertBuilder, error) } func DefaultSQLProviderFeatures() SQLFeatures { diff --git a/pkg/dbsql/provider_sqlitego.go b/pkg/dbsql/provider_sqlitego.go index 19deec0..4f91381 100644 --- a/pkg/dbsql/provider_sqlitego.go +++ b/pkg/dbsql/provider_sqlitego.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Kaleido, Inc. +// Copyright © 2024 Kaleido, Inc. // // SPDX-License-Identifier: Apache-2.0 // diff --git a/pkg/eventstreams/activestream.go b/pkg/eventstreams/activestream.go index 3bedad3..060078b 100644 --- a/pkg/eventstreams/activestream.go +++ b/pkg/eventstreams/activestream.go @@ -72,7 +72,7 @@ func (as *activeStream[CT, DT]) runEventLoop() { checkpointSequenceID, err := as.loadCheckpoint() if err == nil { // Run the inner source read loop until it exits - err = as.retry.Do(as.ctx, "source run loop", func(attempt int) (retry bool, err error) { + err = as.retry.Do(as.ctx, "source run loop", func(_ int) (retry bool, err error) { if err = as.runSourceLoop(checkpointSequenceID); err != nil { log.L(as.ctx).Errorf("source loop error: %s", err) return true, err @@ -88,7 +88,7 @@ func (as *activeStream[CT, DT]) runEventLoop() { } func (as *activeStream[CT, DT]) loadCheckpoint() (sequencedID string, err error) { - err = as.retry.Do(as.ctx, "load checkpoint", func(attempt int) (retry bool, err error) { + err = as.retry.Do(as.ctx, "load checkpoint", func(_ int) (retry bool, err error) { log.L(as.ctx).Debugf("Loading checkpoint: %s", as.spec.GetID()) cp, err := as.persistence.Checkpoints().GetByID(as.ctx, as.spec.GetID()) if err != nil { @@ -238,7 +238,7 @@ func (as *activeStream[CT, DT]) checkpointRoutine() { if checkpointSequenceID == "" { return // We're done } - err := as.retry.Do(as.ctx, "checkpoint", func(attempt int) (retry bool, err error) { + err := as.retry.Do(as.ctx, "checkpoint", func(_ int) (retry bool, err error) { log.L(as.bgCtx).Debugf("Writing checkpoint id=%s sequenceID=%s", as.spec.GetID(), checkpointSequenceID) _, err = as.esm.persistence.Checkpoints().Upsert(as.ctx, &EventStreamCheckpoint{ ID: ptrTo(as.spec.GetID()), // the ID of the stream is the ID of the checkpoint diff --git a/pkg/ffapi/apiserver.go b/pkg/ffapi/apiserver.go index 2d43f02..d1d639a 100644 --- a/pkg/ffapi/apiserver.go +++ b/pkg/ffapi/apiserver.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Kaleido, Inc. +// Copyright © 2024 Kaleido, Inc. // // SPDX-License-Identifier: Apache-2.0 // @@ -19,6 +19,7 @@ package ffapi import ( "context" "fmt" + "io" "net" "net/http" "time" @@ -86,6 +87,7 @@ type APIServerOptions[T any] struct { type APIServerRouteExt[T any] struct { JSONHandler func(*APIRequest, T) (output interface{}, err error) UploadHandler func(*APIRequest, T) (output interface{}, err error) + StreamHandler func(*APIRequest, T) (output io.ReadCloser, err error) } // NewAPIServer makes a new server, with the specified configuration, and @@ -201,13 +203,25 @@ func (as *apiServer[T]) routeHandler(hf *HandlerFactory, route *Route) http.Hand // We extend the base ffapi functionality, with standardized DB filter support for all core resources. // We also pass the Orchestrator context through ext := route.Extensions.(*APIServerRouteExt[T]) - route.JSONHandler = func(r *APIRequest) (output interface{}, err error) { - er, err := as.EnrichRequest(r) - if err != nil { - return nil, err + switch { + case ext.StreamHandler != nil: + route.StreamHandler = func(r *APIRequest) (output io.ReadCloser, err error) { + er, err := as.EnrichRequest(r) + if err != nil { + return nil, err + } + return ext.StreamHandler(r, er) + } + case ext.JSONHandler != nil: + route.JSONHandler = func(r *APIRequest) (output interface{}, err error) { + er, err := as.EnrichRequest(r) + if err != nil { + return nil, err + } + return ext.JSONHandler(r, er) } - return ext.JSONHandler(r, er) } + return hf.RouteHandler(route) } @@ -247,7 +261,7 @@ func (as *apiServer[T]) createMuxRouter(ctx context.Context) *mux.Router { return ce.UploadHandler(r, er) } } - if ce.JSONHandler != nil || ce.UploadHandler != nil { + if ce.JSONHandler != nil || ce.UploadHandler != nil || ce.StreamHandler != nil { r.HandleFunc(fmt.Sprintf("/api/v1/%s", route.Path), as.routeHandler(hf, route)). Methods(route.Method) } diff --git a/pkg/ffapi/apiserver_test.go b/pkg/ffapi/apiserver_test.go index 8abad5d..f91b5bd 100644 --- a/pkg/ffapi/apiserver_test.go +++ b/pkg/ffapi/apiserver_test.go @@ -19,6 +19,7 @@ package ffapi import ( "context" "fmt" + "github.com/getkin/kin-openapi/openapi3" "io" "net/http" "strings" @@ -38,6 +39,7 @@ type utManager struct { mockEnrichErr error calledJSONHandler string calledUploadHandler string + calledStreamHandler string } type sampleInput struct { @@ -80,6 +82,35 @@ var utAPIRoute1 = &Route{ }, } +var utAPIRoute2 = &Route{ + Name: "utAPIRoute2", + Path: "ut/utresource/{resourceid}/getit", + Method: http.MethodGet, + Description: "random GET stream route for testing", + PathParams: []*PathParam{ + {Name: "resourceid", Description: "My resource"}, + }, + FormParams: nil, + JSONInputValue: nil, + JSONOutputValue: nil, + JSONOutputCodes: nil, + CustomResponseRefs: map[string]*openapi3.ResponseRef{ + "200": { + Value: &openapi3.Response{ + Content: openapi3.Content{ + "application/octet-stream": {}, + }, + }, + }, + }, + Extensions: &APIServerRouteExt[*utManager]{ + StreamHandler: func(r *APIRequest, um *utManager) (output io.ReadCloser, err error) { + um.calledStreamHandler = r.PP["resourceid"] + return io.NopCloser(strings.NewReader("a stream!")), nil + }, + }, +} + func initUTConfig() (config.Section, config.Section, config.Section) { config.RootConfigReset() apiConfig := config.RootSection("ut.api") @@ -97,7 +128,7 @@ func newTestAPIServer(t *testing.T, start bool) (*utManager, *apiServer[*utManag um := &utManager{t: t} as := NewAPIServer(ctx, APIServerOptions[*utManager]{ MetricsRegistry: metric.NewPrometheusMetricsRegistry("ut"), - Routes: []*Route{utAPIRoute1}, + Routes: []*Route{utAPIRoute1, utAPIRoute2}, EnrichRequest: func(r *APIRequest) (*utManager, error) { // This could be some dynamic object based on extra processing in the request, // but the most common case is you just have a "manager" that you inject into each @@ -125,6 +156,24 @@ func newTestAPIServer(t *testing.T, start bool) (*utManager, *apiServer[*utManag } } +func TestAPIServerInvokeAPIRouteStream(t *testing.T) { + um, as, done := newTestAPIServer(t, true) + defer done() + + <-as.Started() + + var o sampleOutput + res, err := resty.New().R(). + SetBody(nil). + SetResult(&o). + Get(fmt.Sprintf("%s/api/v1/ut/utresource/id12345/getit", as.APIPublicURL())) + assert.NoError(t, err) + assert.Equal(t, 200, res.StatusCode()) + assert.Equal(t, "application/octet-stream", res.Header().Get("Content-Type")) + assert.Equal(t, "id12345", um.calledStreamHandler) + assert.Equal(t, "a stream!", string(res.Body())) +} + func TestAPIServerInvokeAPIRouteJSON(t *testing.T) { um, as, done := newTestAPIServer(t, true) defer done() diff --git a/pkg/ffapi/handler.go b/pkg/ffapi/handler.go index 4a7229b..40e5487 100644 --- a/pkg/ffapi/handler.go +++ b/pkg/ffapi/handler.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Kaleido, Inc. +// Copyright © 2024 Kaleido, Inc. // // SPDX-License-Identifier: Apache-2.0 // @@ -189,7 +189,7 @@ func (hs *HandlerFactory) RouteHandler(route *Route) http.HandlerFunc { } } - var status = 400 // if fail parsing input + status := 400 // if fail parsing input var output interface{} if err == nil { queryParams, pathParams, queryArrayParams = hs.getParams(req, route) @@ -202,24 +202,29 @@ func (hs *HandlerFactory) RouteHandler(route *Route) http.HandlerFunc { if err == nil { r := &APIRequest{ - Req: req, - PP: pathParams, - QP: queryParams, - QAP: queryArrayParams, - Filter: filter, - Input: jsonInput, - SuccessStatus: http.StatusOK, + Req: req, + PP: pathParams, + QP: queryParams, + QAP: queryArrayParams, + Filter: filter, + Input: jsonInput, + SuccessStatus: http.StatusOK, + AlwaysPaginate: hs.AlwaysPaginate, + + // res.Header() returns a map which is a ref type so handler header edits are persisted ResponseHeaders: res.Header(), - AlwaysPaginate: hs.AlwaysPaginate, } if len(route.JSONOutputCodes) > 0 { r.SuccessStatus = route.JSONOutputCodes[0] } - if multipart != nil { + switch { + case multipart != nil: r.FP = multipart.formParams r.Part = multipart.part output, err = route.FormUploadHandler(r) - } else { + case route.StreamHandler != nil: + output, err = route.StreamHandler(r) + default: output, err = route.JSONHandler(r) } status = r.SuccessStatus // Can be updated by the route @@ -259,7 +264,9 @@ func (hs *HandlerFactory) handleOutput(ctx context.Context, res http.ResponseWri res.WriteHeader(204) case reader != nil: defer reader.Close() - res.Header().Add("Content-Type", "application/octet-stream") + if res.Header().Get("Content-Type") == "" { + res.Header().Add("Content-Type", "application/octet-stream") + } res.WriteHeader(status) _, marshalErr = io.Copy(res, reader) default: diff --git a/pkg/ffapi/handler_test.go b/pkg/ffapi/handler_test.go index 3f8375b..b8ffd21 100644 --- a/pkg/ffapi/handler_test.go +++ b/pkg/ffapi/handler_test.go @@ -21,6 +21,8 @@ import ( "context" "encoding/json" "fmt" + "github.com/getkin/kin-openapi/openapi3" + "github.com/stretchr/testify/require" "io" "mime/multipart" "net/http" @@ -156,6 +158,72 @@ func TestJSONHTTPNilResponseNon204(t *testing.T) { assert.Regexp(t, "FF00164", resJSON["error"]) } +func TestStreamHttpResponsePlainText200(t *testing.T) { + text := ` +some stream +of +text +!!! +` + s, _, done := newTestServer(t, []*Route{{ + Name: "testRoute", + Path: "/test", + Method: "GET", + CustomResponseRefs: map[string]*openapi3.ResponseRef{ + "200": { + Value: &openapi3.Response{ + Content: openapi3.Content{ + "text/plain": {}, + }, + }, + }, + }, + StreamHandler: func(r *APIRequest) (output io.ReadCloser, err error) { + r.ResponseHeaders.Add("Content-Type", "text/plain") + return io.NopCloser(strings.NewReader(text)), nil + }, + }}, "", nil) + defer done() + + res, err := http.Get(fmt.Sprintf("http://%s/test", s.Addr())) + require.NoError(t, err) + assert.Equal(t, 200, res.StatusCode) + assert.Equal(t, "text/plain", res.Header.Get("Content-Type")) + b, err := io.ReadAll(res.Body) + require.NoError(t, err) + assert.Equal(t, text, string(b)) +} + +func TestStreamHttpResponseBinary200(t *testing.T) { + randomBytes := []byte{3, 255, 192, 201, 33, 50} + s, _, done := newTestServer(t, []*Route{{ + Name: "testRoute", + Path: "/test", + Method: "GET", + CustomResponseRefs: map[string]*openapi3.ResponseRef{ + "200": { + Value: &openapi3.Response{ + Content: openapi3.Content{ + "application/octet-stream": &openapi3.MediaType{}, + }, + }, + }, + }, + StreamHandler: func(r *APIRequest) (output io.ReadCloser, err error) { + return io.NopCloser(bytes.NewReader(randomBytes)), nil + }, + }}, "", nil) + defer done() + + res, err := http.Get(fmt.Sprintf("http://%s/test", s.Addr())) + require.NoError(t, err) + assert.Equal(t, 200, res.StatusCode) + assert.Equal(t, "application/octet-stream", res.Header.Get("Content-Type")) + b, err := io.ReadAll(res.Body) + require.NoError(t, err) + assert.Equal(t, randomBytes, b) +} + func TestJSONHTTPDefault500Error(t *testing.T) { s, _, done := newTestServer(t, []*Route{{ Name: "testRoute", diff --git a/pkg/ffapi/openapi3.go b/pkg/ffapi/openapi3.go index c67868b..262d179 100644 --- a/pkg/ffapi/openapi3.go +++ b/pkg/ffapi/openapi3.go @@ -97,13 +97,13 @@ func (sg *SwaggerGen) Generate(ctx context.Context, routes []*Route) *openapi3.T Schemas: make(openapi3.Schemas), }, } - opIds := make(map[string]bool) + opIDs := make(map[string]bool) for _, route := range routes { - if route.Name == "" || opIds[route.Name] { + if route.Name == "" || opIDs[route.Name] { log.Panicf("Duplicate/invalid name (used as operation ID in swagger): %s", route.Name) } sg.addRoute(ctx, doc, route) - opIds[route.Name] = true + opIDs[route.Name] = true } return doc } @@ -313,6 +313,12 @@ func (sg *SwaggerGen) addOutput(ctx context.Context, doc *openapi3.T, route *Rou }, }) } + for code, res := range route.CustomResponseRefs { + if res.Value != nil && res.Value.Description == nil { + res.Value.Description = &s + } + op.Responses.Set(code, res) + } } func (sg *SwaggerGen) AddParam(ctx context.Context, op *openapi3.Operation, in, name, def, example string, description i18n.MessageKey, deprecated bool, msgArgs ...interface{}) { diff --git a/pkg/ffapi/openapi3_test.go b/pkg/ffapi/openapi3_test.go index ee1eb98..f44911a 100644 --- a/pkg/ffapi/openapi3_test.go +++ b/pkg/ffapi/openapi3_test.go @@ -19,6 +19,7 @@ package ffapi import ( "context" "fmt" + "github.com/stretchr/testify/require" "net/http" "testing" @@ -298,6 +299,36 @@ func TestFFExcludeTag(t *testing.T) { assert.Regexp(t, "no schema", err) } +func TestCustomResponseRefs(t *testing.T) { + routes := []*Route{ + { + Name: "CustomResponseRefTest", + Path: "/test", + Method: http.MethodGet, + CustomResponseRefs: map[string]*openapi3.ResponseRef{ + "200": { + Value: &openapi3.Response{ + Content: openapi3.Content{ + "text/plain": &openapi3.MediaType{}, + }, + }, + }, + }, + }, + } + swagger := NewSwaggerGen(&SwaggerGenOptions{ + Title: "UnitTest", + Version: "1.0", + BaseURL: "http://localhost:12345/api/v1", + }).Generate(context.Background(), routes) + assert.Nil(t, swagger.Paths.Find("/test").Get.RequestBody) + require.NotEmpty(t, swagger.Paths.Find("/test").Get.Responses) + require.NotNil(t, swagger.Paths.Find("/test").Get.Responses.Value("200")) + require.NotNil(t, swagger.Paths.Find("/test").Get.Responses.Value("200").Value) + assert.NotNil(t, swagger.Paths.Find("/test").Get.Responses.Value("200").Value.Content.Get("text/plain")) + assert.Nil(t, swagger.Paths.Find("/test").Get.Responses.Value("201")) +} + func TestPanicOnMissingDescription(t *testing.T) { routes := []*Route{ { diff --git a/pkg/ffapi/routes.go b/pkg/ffapi/routes.go index 98b9203..12578ef 100644 --- a/pkg/ffapi/routes.go +++ b/pkg/ffapi/routes.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Kaleido, Inc. +// Copyright © 2024 Kaleido, Inc. // // SPDX-License-Identifier: Apache-2.0 // @@ -18,6 +18,7 @@ package ffapi import ( "context" + "io" "github.com/getkin/kin-openapi/openapi3" "github.com/hyperledger/firefly-common/pkg/config" @@ -61,12 +62,16 @@ type Route struct { JSONOutputSchema func(ctx context.Context, schemaGen SchemaGenerator) (*openapi3.SchemaRef, error) // JSONOutputValue is a function that returns a pointer to a structure to take JSON output JSONOutputValue func() interface{} - // JSONOutputCodes is the success response code + // JSONOutputCodes is the success response codes that could be returned by the API. Error codes are explicitly not supported by the framework since they could be subject to change by the errors thrown or how errors are handled. JSONOutputCodes []int - // JSONHandler is a function for handling JSON content type input. Input/Ouptut objects are returned by JSONInputValue/JSONOutputValue funcs + // JSONHandler is a function for handling JSON content type input. Input/Output objects are returned by JSONInputValue/JSONOutputValue funcs JSONHandler func(r *APIRequest) (output interface{}, err error) // FormUploadHandler takes a single file upload, and returns a JSON object FormUploadHandler func(r *APIRequest) (output interface{}, err error) + // StreamHandler allows for custom request handling with explicit stream (io.ReadCloser) responses + StreamHandler func(r *APIRequest) (output io.ReadCloser, err error) + // CustomResponseRefs allows for specifying custom responses for a route + CustomResponseRefs map[string]*openapi3.ResponseRef // Deprecated whether this route is deprecated Deprecated bool // Tag a category identifier for this route in the generated OpenAPI spec diff --git a/pkg/ffresty/ffresty.go b/pkg/ffresty/ffresty.go index 656f889..db717e4 100644 --- a/pkg/ffresty/ffresty.go +++ b/pkg/ffresty/ffresty.go @@ -33,6 +33,7 @@ import ( "github.com/hyperledger/firefly-common/pkg/ffapi" "github.com/hyperledger/firefly-common/pkg/fftypes" "github.com/hyperledger/firefly-common/pkg/i18n" + "github.com/hyperledger/firefly-common/pkg/metric" "github.com/hyperledger/firefly-common/pkg/log" "github.com/sirupsen/logrus" ) @@ -50,6 +51,13 @@ type Config struct { HTTPConfig } +var ( + metricsManager metric.MetricsManager + onErrorHooks []func(*resty.Request, error) + onSuccessHooks []func(*resty.Client, *resty.Response) +) + + // HTTPConfig is all the optional configuration separate to the URL you wish to invoke. // This is JSON serializable with docs, so you can embed it into API objects. type HTTPConfig struct { @@ -77,6 +85,43 @@ type HTTPConfig struct { OnBeforeRequest func(req *resty.Request) error `json:"-"` // called before each request, even retry } +func EnableClientMetrics(ctx context.Context, metricsRegistry metric.MetricsRegistry) error { + //create a metrics manager + mm, err := metricsRegistry.NewMetricsManagerForSubsystem(ctx, "ffresty") + if err != nil { + return err + } + + metricsManager := mm + metricsManager.NewCounterMetricWithLabels(ctx, "http_response", "HTTP response", []string{"status","error"}, false) + metricsManager.NewCounterMetricWithLabels(ctx, "network_error", "Network error", []string{}, false) + + //create hooks + onErrorMetricsHook := func(req *resty.Request, err error){ + if v, ok := err.(*resty.ResponseError); ok { + code := v.Response.StatusCode + metricsManager.IncCounterMetricWithLabels(ctx, "http_response", map[string]string{"status": fmt.Sprintf("%d",code), "error": "true"}, nil) + } + metricsManager.IncCounterMetricWithLabels(ctx, "network_error", map[string]string{}, nil) + } + RegisterGlobalOnError(onErrorMetricsHook) + + onSuccessMetricsHook := func(c *resty.Client, resp *resty.Response){ + code := resp.StatusCode + metricsManager.IncCounterMetricWithLabels(ctx, "http_response", map[string]string{"status": fmt.Sprintf("%d",code), "error": "false"}, nil) + } + RegisterGlobalOnSuccess(onSuccessMetricsHook) + return nil +} + +func RegisterGlobalOnError(onError func(req *resty.Request, err error)) { + onErrorHooks = append(onErrorHooks, onError) +} + +func RegisterGlobalOnSuccess(onSuccess func(c *resty.Client, resp *resty.Response)) { + onSuccessHooks = append(onSuccessHooks, onSuccess) +} + // OnAfterResponse when using SetDoNotParseResponse(true) for streaming binary replies, // the caller should invoke ffresty.OnAfterResponse on the response manually. // The middleware is disabled on this path :-( @@ -96,6 +141,19 @@ func OnAfterResponse(c *resty.Client, resp *resty.Response) { log.L(rCtx).Logf(level, "<== %s %s [%d] (%.2fms)", resp.Request.Method, resp.Request.URL, status, elapsed) } + +func OnError(req *resty.Request, err error) { + for _, hook := range onErrorHooks { + hook(req,err) + } +} + +func OnSuccess(c *resty.Client, resp *resty.Response) { + for _, hook := range onSuccessHooks { + hook(c,resp) + } +} + // New creates a new Resty client, using static configuration (from the config file) // from a given section in the static configuration // @@ -160,7 +218,7 @@ func NewWithConfig(ctx context.Context, ffrestyConfig Config) (client *resty.Cli client.SetTimeout(time.Duration(ffrestyConfig.HTTPRequestTimeout)) - client.OnBeforeRequest(func(c *resty.Client, req *resty.Request) error { + client.OnBeforeRequest(func(_ *resty.Client, req *resty.Request) error { rCtx := req.Context() rc := rCtx.Value(retryCtxKey{}) if rc == nil { @@ -171,8 +229,7 @@ func NewWithConfig(ctx context.Context, ffrestyConfig Config) (client *resty.Cli } rCtx = context.WithValue(rCtx, retryCtxKey{}, r) // Create a request logger from the root logger passed into the client - l := log.L(ctx).WithField("breq", r.id) - rCtx = log.WithLogger(rCtx, l) + rCtx = log.WithLogField(rCtx, "breq", r.id) req.SetContext(rCtx) } @@ -205,9 +262,11 @@ func NewWithConfig(ctx context.Context, ffrestyConfig Config) (client *resty.Cli }) // Note that callers using SetNotParseResponse will need to invoke this themselves - client.OnAfterResponse(func(c *resty.Client, r *resty.Response) error { OnAfterResponse(c, r); return nil }) + client.OnError( func(req *resty.Request, e error) { OnError(req, e); return }) + client.OnSuccess(func(c *resty.Client, r *resty.Response) { OnSuccess(c, r); return }) + for k, v := range ffrestyConfig.HTTPHeaders { if vs, ok := v.(string); ok { client.SetHeader(k, vs) diff --git a/pkg/fftls/config.go b/pkg/fftls/config.go index dc143ae..9e0dbdb 100644 --- a/pkg/fftls/config.go +++ b/pkg/fftls/config.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Kaleido, Inc. +// Copyright © 2024 Kaleido, Inc. // // SPDX-License-Identifier: Apache-2.0 // @@ -23,14 +23,21 @@ import ( const ( // HTTPConfTLSCAFile the TLS certificate authority file for the HTTP server HTTPConfTLSCAFile = "caFile" + // HTTPConfTLSCA the TLS certificate authority in PEM format, this option is ignored if HTTPConfTLSCAFile is also set + HTTPConfTLSCA = "ca" // HTTPConfTLSCertFile the TLS certificate file for the HTTP server HTTPConfTLSCertFile = "certFile" + // HTTPConfTLSCert the TLS certificate in PEM format, this option is ignored if HTTPConfTLSCertFile is also set + HTTPConfTLSCert = "cert" // HTTPConfTLSClientAuth whether the HTTP server requires a mutual TLS connection HTTPConfTLSClientAuth = "clientAuth" // HTTPConfTLSEnabled whether TLS is enabled for the HTTP server HTTPConfTLSEnabled = "enabled" // HTTPConfTLSKeyFile the private key file for TLS on the server HTTPConfTLSKeyFile = "keyFile" + // HTTPConfTLSKey the TLS certificate key in PEM format, this option is ignored if HTTPConfTLSKeyFile is also set + HTTPConfTLSKey = "key" + // HTTPConfTLSInsecureSkipHostVerify disables host verification - insecure (for dev only) HTTPConfTLSInsecureSkipHostVerify = "insecureSkipHostVerify" @@ -44,8 +51,11 @@ type Config struct { Enabled bool `ffstruct:"tlsconfig" json:"enabled"` ClientAuth bool `ffstruct:"tlsconfig" json:"clientAuth,omitempty"` CAFile string `ffstruct:"tlsconfig" json:"caFile,omitempty"` + CA string `ffstruct:"tlsconfig" json:"ca,omitempty"` CertFile string `ffstruct:"tlsconfig" json:"certFile,omitempty"` + Cert string `ffstruct:"tlsconfig" json:"cert,omitempty"` KeyFile string `ffstruct:"tlsconfig" json:"keyFile,omitempty"` + Key string `ffstruct:"tlsconfig" json:"key,omitempty"` InsecureSkipHostVerify bool `ffstruct:"tlsconfig" json:"insecureSkipHostVerify"` RequiredDNAttributes map[string]interface{} `ffstruct:"tlsconfig" json:"requiredDNAttributes,omitempty"` } @@ -53,9 +63,12 @@ type Config struct { func InitTLSConfig(conf config.Section) { conf.AddKnownKey(HTTPConfTLSEnabled, defaultHTTPTLSEnabled) conf.AddKnownKey(HTTPConfTLSCAFile) + conf.AddKnownKey(HTTPConfTLSCA) conf.AddKnownKey(HTTPConfTLSClientAuth) conf.AddKnownKey(HTTPConfTLSCertFile) + conf.AddKnownKey(HTTPConfTLSCert) conf.AddKnownKey(HTTPConfTLSKeyFile) + conf.AddKnownKey(HTTPConfTLSKey) conf.AddKnownKey(HTTPConfTLSRequiredDNAttributes) conf.AddKnownKey(HTTPConfTLSInsecureSkipHostVerify) } @@ -65,8 +78,11 @@ func GenerateConfig(conf config.Section) *Config { Enabled: conf.GetBool(HTTPConfTLSEnabled), ClientAuth: conf.GetBool(HTTPConfTLSClientAuth), CAFile: conf.GetString(HTTPConfTLSCAFile), + CA: conf.GetString(HTTPConfTLSCA), CertFile: conf.GetString(HTTPConfTLSCertFile), + Cert: conf.GetString(HTTPConfTLSCert), KeyFile: conf.GetString(HTTPConfTLSKeyFile), + Key: conf.GetString(HTTPConfTLSKey), InsecureSkipHostVerify: conf.GetBool(HTTPConfTLSInsecureSkipHostVerify), RequiredDNAttributes: conf.GetObject(HTTPConfTLSRequiredDNAttributes), } diff --git a/pkg/fftls/fftls.go b/pkg/fftls/fftls.go index 61280d6..70f9bc1 100644 --- a/pkg/fftls/fftls.go +++ b/pkg/fftls/fftls.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Kaleido, Inc. +// Copyright © 2024 Kaleido, Inc. // // SPDX-License-Identifier: Apache-2.0 // @@ -48,7 +48,7 @@ func NewTLSConfig(ctx context.Context, config *Config, tlsType TLSType) (*tls.Co tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS12, - VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + VerifyPeerCertificate: func(_ [][]byte, verifiedChains [][]*x509.Certificate) error { if len(verifiedChains) > 0 && len(verifiedChains[0]) > 0 { cert := verifiedChains[0][0] log.L(ctx).Debugf("Client certificate provided Subject=%s Issuer=%s Expiry=%s", cert.Subject, cert.Issuer, cert.NotAfter) @@ -62,7 +62,8 @@ func NewTLSConfig(ctx context.Context, config *Config, tlsType TLSType) (*tls.Co var err error // Support custom CA file var rootCAs *x509.CertPool - if config.CAFile != "" { + switch { + case config.CAFile != "": rootCAs = x509.NewCertPool() var caBytes []byte caBytes, err = os.ReadFile(config.CAFile) @@ -72,7 +73,13 @@ func NewTLSConfig(ctx context.Context, config *Config, tlsType TLSType) (*tls.Co err = i18n.NewError(ctx, i18n.MsgInvalidCAFile) } } - } else { + case config.CA != "": + rootCAs = x509.NewCertPool() + ok := rootCAs.AppendCertsFromPEM([]byte(config.CA)) + if !ok { + err = i18n.NewError(ctx, i18n.MsgInvalidCAFile) + } + default: rootCAs, err = x509.SystemCertPool() } @@ -89,7 +96,12 @@ func NewTLSConfig(ctx context.Context, config *Config, tlsType TLSType) (*tls.Co if err != nil { return nil, i18n.WrapError(ctx, err, i18n.MsgInvalidKeyPairFiles) } - + tlsConfig.Certificates = []tls.Certificate{cert} + } else if config.Cert != "" && config.Key != "" { + cert, err := tls.X509KeyPair([]byte(config.Cert), []byte(config.Key)) + if err != nil { + return nil, i18n.WrapError(ctx, err, i18n.MsgInvalidKeyPairFiles) + } tlsConfig.Certificates = []tls.Certificate{cert} } @@ -174,7 +186,7 @@ func buildDNValidator(ctx context.Context, requiredDNAttributes map[string]inter } validators[attr] = validator } - return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + return func(_ [][]byte, verifiedChains [][]*x509.Certificate) error { if len(verifiedChains) == 0 { log.L(ctx).Errorf("Failed TLS DN check: Nil cert chain") return i18n.NewError(ctx, i18n.MsgInvalidTLSDnChain) diff --git a/pkg/fftls/fftls_test.go b/pkg/fftls/fftls_test.go index 5d41dc5..c218bca 100644 --- a/pkg/fftls/fftls_test.go +++ b/pkg/fftls/fftls_test.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Kaleido, Inc. +// Copyright © 2024 Kaleido, Inc. // // SPDX-License-Identifier: Apache-2.0 // @@ -24,9 +24,11 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" + "github.com/stretchr/testify/require" "math/big" "net" "os" + "strings" "testing" "time" @@ -35,6 +37,34 @@ import ( ) func buildSelfSignedTLSKeyPair(t *testing.T, subject pkix.Name) (string, string) { + // Create an X509 certificate pair + privatekey, _ := rsa.GenerateKey(rand.Reader, 2048) + publickey := &privatekey.PublicKey + var privateKeyBytes []byte = x509.MarshalPKCS1PrivateKey(privatekey) + privateKeyBlock := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: privateKeyBytes} + privateKeyPEM := &strings.Builder{} + err := pem.Encode(privateKeyPEM, privateKeyBlock) + require.NoError(t, err) + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + x509Template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: subject, + NotBefore: time.Now(), + NotAfter: time.Now().Add(100 * time.Second), + KeyUsage: x509.KeyUsageDigitalSignature, + BasicConstraintsValid: true, + IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)}, + } + require.NoError(t, err) + derBytes, err := x509.CreateCertificate(rand.Reader, x509Template, x509Template, publickey, privatekey) + require.NoError(t, err) + publicKeyPEM := &strings.Builder{} + err = pem.Encode(publicKeyPEM, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + require.NoError(t, err) + return publicKeyPEM.String(), privateKeyPEM.String() +} + +func buildSelfSignedTLSKeyPairFiles(t *testing.T, subject pkix.Name) (string, string) { // Create an X509 certificate pair privatekey, _ := rsa.GenerateKey(rand.Reader, 2048) publickey := &privatekey.PublicKey @@ -129,7 +159,7 @@ func TestTLSDefault(t *testing.T) { func TestErrInvalidCAFile(t *testing.T) { config.RootConfigReset() - _, notTheCAFileTheKey := buildSelfSignedTLSKeyPair(t, pkix.Name{ + _, notTheCAFileTheKey := buildSelfSignedTLSKeyPairFiles(t, pkix.Name{ CommonName: "server.example.com", }) @@ -140,13 +170,28 @@ func TestErrInvalidCAFile(t *testing.T) { _, err := ConstructTLSConfig(context.Background(), conf, ClientType) assert.Regexp(t, "FF00152", err) +} + +func TestErrInvalidCA(t *testing.T) { + + config.RootConfigReset() + _, notTheCATheKey := buildSelfSignedTLSKeyPair(t, pkix.Name{ + CommonName: "server.example.com", + }) + + conf := config.RootSection("fftls_server") + InitTLSConfig(conf) + conf.Set(HTTPConfTLSEnabled, true) + conf.Set(HTTPConfTLSCA, notTheCATheKey) + _, err := ConstructTLSConfig(context.Background(), conf, ClientType) + assert.Regexp(t, "FF00152", err) } func TestErrInvalidKeyPairFile(t *testing.T) { config.RootConfigReset() - notTheKeyFile, notTheCertFile := buildSelfSignedTLSKeyPair(t, pkix.Name{ + notTheKeyFile, notTheCertFile := buildSelfSignedTLSKeyPairFiles(t, pkix.Name{ CommonName: "server.example.com", }) @@ -161,12 +206,29 @@ func TestErrInvalidKeyPairFile(t *testing.T) { } -func TestMTLSOk(t *testing.T) { +func TestErrInvalidKeyPair(t *testing.T) { - serverPublicKeyFile, serverKeyFile := buildSelfSignedTLSKeyPair(t, pkix.Name{ + config.RootConfigReset() + notTheKey, notTheCert := buildSelfSignedTLSKeyPair(t, pkix.Name{ CommonName: "server.example.com", }) - clientPublicKeyFile, clientKeyFile := buildSelfSignedTLSKeyPair(t, pkix.Name{ + + conf := config.RootSection("fftls_server") + InitTLSConfig(conf) + conf.Set(HTTPConfTLSEnabled, true) + conf.Set(HTTPConfTLSKey, notTheKey) + conf.Set(HTTPConfTLSCert, notTheCert) + + _, err := ConstructTLSConfig(context.Background(), conf, ClientType) + assert.Regexp(t, "FF00206", err) + +} + +func TestMTLSOk(t *testing.T) { + serverPublicKey, serverKey := buildSelfSignedTLSKeyPair(t, pkix.Name{ + CommonName: "server.example.com", + }) + clientPublicKey, clientKey := buildSelfSignedTLSKeyPair(t, pkix.Name{ CommonName: "client.example.com", }) @@ -175,9 +237,9 @@ func TestMTLSOk(t *testing.T) { serverConf := config.RootSection("fftls_server") InitTLSConfig(serverConf) serverConf.Set(HTTPConfTLSEnabled, true) - serverConf.Set(HTTPConfTLSCAFile, clientPublicKeyFile) - serverConf.Set(HTTPConfTLSCertFile, serverPublicKeyFile) - serverConf.Set(HTTPConfTLSKeyFile, serverKeyFile) + serverConf.Set(HTTPConfTLSCA, clientPublicKey) + serverConf.Set(HTTPConfTLSCert, serverPublicKey) + serverConf.Set(HTTPConfTLSKey, serverKey) serverConf.Set(HTTPConfTLSClientAuth, true) addr, done := buildTLSListener(t, serverConf, ServerType) @@ -186,9 +248,9 @@ func TestMTLSOk(t *testing.T) { clientConf := config.RootSection("fftls_client") InitTLSConfig(clientConf) clientConf.Set(HTTPConfTLSEnabled, true) - clientConf.Set(HTTPConfTLSCAFile, serverPublicKeyFile) - clientConf.Set(HTTPConfTLSCertFile, clientPublicKeyFile) - clientConf.Set(HTTPConfTLSKeyFile, clientKeyFile) + clientConf.Set(HTTPConfTLSCA, serverPublicKey) + clientConf.Set(HTTPConfTLSCert, clientPublicKey) + clientConf.Set(HTTPConfTLSKey, clientKey) tlsConfig, err := ConstructTLSConfig(context.Background(), clientConf, ClientType) assert.NoError(t, err) @@ -208,7 +270,7 @@ func TestMTLSOk(t *testing.T) { func TestMTLSMissingClientCert(t *testing.T) { - serverPublicKeyFile, serverKeyFile := buildSelfSignedTLSKeyPair(t, pkix.Name{ + serverPublicKeyFile, serverKeyFile := buildSelfSignedTLSKeyPairFiles(t, pkix.Name{ CommonName: "server.example.com", }) @@ -243,10 +305,10 @@ func TestMTLSMissingClientCert(t *testing.T) { func TestMTLSMatchFullSubject(t *testing.T) { - serverPublicKeyFile, serverKeyFile := buildSelfSignedTLSKeyPair(t, pkix.Name{ + serverPublicKeyFile, serverKeyFile := buildSelfSignedTLSKeyPairFiles(t, pkix.Name{ CommonName: "server.example.com", }) - clientPublicKeyFile, clientKeyFile := buildSelfSignedTLSKeyPair(t, pkix.Name{ + clientPublicKeyFile, clientKeyFile := buildSelfSignedTLSKeyPairFiles(t, pkix.Name{ CommonName: "client.example.com", Country: []string{"GB"}, Organization: []string{"hyperledger"}, @@ -306,10 +368,10 @@ func TestMTLSMatchFullSubject(t *testing.T) { func TestMTLSMismatchSubject(t *testing.T) { - serverPublicKeyFile, serverKeyFile := buildSelfSignedTLSKeyPair(t, pkix.Name{ + serverPublicKeyFile, serverKeyFile := buildSelfSignedTLSKeyPairFiles(t, pkix.Name{ CommonName: "server.example.com", }) - clientPublicKeyFile, clientKeyFile := buildSelfSignedTLSKeyPair(t, pkix.Name{ + clientPublicKeyFile, clientKeyFile := buildSelfSignedTLSKeyPairFiles(t, pkix.Name{ CommonName: "wrong.example.com", }) @@ -429,7 +491,7 @@ func TestMTLSDNValidatorEmptyChain(t *testing.T) { func TestConnectSkipVerification(t *testing.T) { - serverPublicKeyFile, serverKeyFile := buildSelfSignedTLSKeyPair(t, pkix.Name{ + serverPublicKeyFile, serverKeyFile := buildSelfSignedTLSKeyPairFiles(t, pkix.Name{ CommonName: "server.example.com", }) diff --git a/pkg/fftypes/jsonobject.go b/pkg/fftypes/jsonobject.go index 682da86..4598dbf 100644 --- a/pkg/fftypes/jsonobject.go +++ b/pkg/fftypes/jsonobject.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Kaleido, Inc. +// Copyright © 2024 Kaleido, Inc. // // SPDX-License-Identifier: Apache-2.0 // @@ -133,8 +133,8 @@ func (jd JSONObject) GetObject(key string) JSONObject { } func (jd JSONObject) GetObjectOk(key string) (JSONObject, bool) { - vInterace, ok := jd[key] - if ok && vInterace != nil { + vInterface, ok := jd[key] + if ok && vInterface != nil { vInterface := jd[key] switch vMap := vInterface.(type) { case map[string]interface{}: @@ -142,7 +142,7 @@ func (jd JSONObject) GetObjectOk(key string) (JSONObject, bool) { case JSONObject: return vMap, true default: - log.L(context.Background()).Errorf("Invalid object value '%+v' for key '%s'", vInterace, key) + log.L(context.Background()).Errorf("Invalid object value '%+v' for key '%s'", vInterface, key) return JSONObject{}, false // Ensures a non-nil return } } @@ -187,11 +187,10 @@ func (jd JSONObject) GetObjectArray(key string) JSONObjectArray { } func (jd JSONObject) GetObjectArrayOk(key string) (JSONObjectArray, bool) { - vInterace, ok := jd[key] - if ok && vInterace != nil { - return ToJSONObjectArray(vInterace) + vInterface, ok := jd[key] + if ok && vInterface != nil { + return ToJSONObjectArray(vInterface) } - log.L(context.Background()).Errorf("Invalid object value '%+v' for key '%s'", vInterace, key) return JSONObjectArray{}, false // Ensures a non-nil return } @@ -201,11 +200,11 @@ func (jd JSONObject) GetStringArray(key string) []string { } func (jd JSONObject) GetStringArrayOk(key string) ([]string, bool) { - vInterace, ok := jd[key] - if ok && vInterace != nil { - return ToStringArray(vInterace) + vInterface, ok := jd[key] + if ok && vInterface != nil { + return ToStringArray(vInterface) } - log.L(context.Background()).Errorf("Invalid string array value '%+v' for key '%s'", vInterace, key) + log.L(context.Background()).Errorf("Invalid string array value '%+v' for key '%s'", vInterface, key) return []string{}, false // Ensures a non-nil return } diff --git a/pkg/i18n/en_base_config_descriptions.go b/pkg/i18n/en_base_config_descriptions.go index c4b2e43..770ee70 100644 --- a/pkg/i18n/en_base_config_descriptions.go +++ b/pkg/i18n/en_base_config_descriptions.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Kaleido, Inc. +// Copyright © 2024 Kaleido, Inc. // // SPDX-License-Identifier: Apache-2.0 // @@ -67,10 +67,13 @@ var ( ConfigGlobalWsWriteBufferSize = ffc("config.global.ws.writeBufferSize", "The size in bytes of the write buffer for the WebSocket connection", ByteSizeType) ConfigGlobalWsURL = ffc("config.global.ws.url", "URL to use for WebSocket - overrides url one level up (in the HTTP config)", StringType) + ConfigGlobalTLSCA = ffc("config.global.tls.ca", "The TLS certificate authority in PEM format (this option is ignored if caFile is also set)", StringType) ConfigGlobalTLSCaFile = ffc("config.global.tls.caFile", "The path to the CA file for TLS on this API", StringType) + ConfigGlobalTLSCert = ffc("config.global.tls.cert", "The TLS certificate in PEM format (this option is ignored if certFile is also set)", StringType) ConfigGlobalTLSCertFile = ffc("config.global.tls.certFile", "The path to the certificate file for TLS on this API", StringType) ConfigGlobalTLSClientAuth = ffc("config.global.tls.clientAuth", "Enables or disables client auth for TLS on this API", StringType) ConfigGlobalTLSEnabled = ffc("config.global.tls.enabled", "Enables or disables TLS on this API", BooleanType) + ConfigGlobalTLSKey = ffc("config.global.tls.key", "The TLS certificate key in PEM format (this option is ignored if keyFile is also set)", StringType) ConfigGlobalTLSKeyFile = ffc("config.global.tls.keyFile", "The path to the private key file for TLS on this API", StringType) ConfigGlobalTLSRequiredDNAttributes = ffc("config.global.tls.requiredDNAttributes", "A set of required subject DN attributes. Each entry is a regular expression, and the subject certificate must have a matching attribute of the specified type (CN, C, O, OU, ST, L, STREET, POSTALCODE, SERIALNUMBER are valid attributes)", MapStringStringType) ConfigGlobalTLSInsecureSkipHostVerify = ffc("config.global.tls.insecureSkipHostVerify", "When to true in unit test development environments to disable TLS verification. Use with extreme caution", BooleanType) diff --git a/pkg/i18n/en_base_error_messages.go b/pkg/i18n/en_base_error_messages.go index 2749cfc..6a5219f 100644 --- a/pkg/i18n/en_base_error_messages.go +++ b/pkg/i18n/en_base_error_messages.go @@ -181,4 +181,6 @@ var ( MsgWebSocketBatchInflight = ffe("FF00243", "Stream '%s' already has batch '%d' inflight on websocket connection '%s'") MsgWebSocketRoundTripTimeout = ffe("FF00244", "Timed out or cancelled waiting for acknowledgement") MsgDBExecFailed = ffe("FF00245", "Database update failed") + MsgDBErrorBuildingStatement = ffe("FF00247", "Error building statement: %s") + MsgDBReadInsertTSFailed = ffe("FF00248", "Failed to read timestamp from database optimized upsert: %s") ) diff --git a/pkg/version/version.go b/pkg/version/version.go index 000e4e3..8027706 100644 --- a/pkg/version/version.go +++ b/pkg/version/version.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Kaleido, Inc. +// Copyright © 2024 Kaleido, Inc. // // SPDX-License-Identifier: Apache-2.0 // @@ -66,7 +66,7 @@ func NewInfo(buildDate, buildCommit, buildVersionOverride, license string) *Info Use: "version", Short: "Prints the version info", Long: "Prints the version info in plain, JSON or YAML formats", - RunE: func(cmd *cobra.Command, args []string) error { + RunE: func(_ *cobra.Command, _ []string) error { if shortened { fmt.Println(info.Version) } else {