diff --git a/internal/authentication/sqlite/metadata.go b/internal/authentication/sqlite/metadata.go new file mode 100644 index 0000000000..8628cf728b --- /dev/null +++ b/internal/authentication/sqlite/metadata.go @@ -0,0 +1,180 @@ +/* +Copyright 2023 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sqlite + +import ( + "errors" + "fmt" + "net/url" + "strings" + "time" + + "github.com/dapr/kit/logger" +) + +const ( + DefaultTimeout = 20 * time.Second // Default timeout for database requests, in seconds + DefaultBusyTimeout = 2 * time.Second +) + +// SqliteAuthMetadata contains the auth metadata for a SQLite component. +type SqliteAuthMetadata struct { + ConnectionString string `mapstructure:"connectionString" mapstructurealiases:"url"` + Timeout time.Duration `mapstructure:"timeout" mapstructurealiases:"timeoutInSeconds"` + BusyTimeout time.Duration `mapstructure:"busyTimeout"` + DisableWAL bool `mapstructure:"disableWAL"` // Disable WAL journaling. You should not use WAL if the database is stored on a network filesystem (or data corruption may happen). This is ignored if the database is in-memory. +} + +// Reset the object +func (m *SqliteAuthMetadata) Reset() { + m.ConnectionString = "" + m.Timeout = DefaultTimeout + m.BusyTimeout = DefaultBusyTimeout + m.DisableWAL = false +} + +func (m *SqliteAuthMetadata) Validate() error { + // Validate and sanitize input + if m.ConnectionString == "" { + return errors.New("missing connection string") + } + if m.Timeout < time.Second { + return errors.New("invalid value for 'timeout': must be greater than 1s") + } + + // Busy timeout + // Truncate values to milliseconds. Values <= 0 do not set any timeout + m.BusyTimeout = m.BusyTimeout.Truncate(time.Millisecond) + + return nil +} + +func (m *SqliteAuthMetadata) GetConnectionString(log logger.Logger) (string, error) { + // Check if we're using the in-memory database + lc := strings.ToLower(m.ConnectionString) + isMemoryDB := strings.HasPrefix(lc, ":memory:") || strings.HasPrefix(lc, "file::memory:") + + // Get the "query string" from the connection string if present + idx := strings.IndexRune(m.ConnectionString, '?') + var qs url.Values + if idx > 0 { + qs, _ = url.ParseQuery(m.ConnectionString[(idx + 1):]) + } + if len(qs) == 0 { + qs = make(url.Values, 2) + } + + // If the database is in-memory, we must ensure that cache=shared is set + if isMemoryDB { + qs["cache"] = []string{"shared"} + } + + // Check if the database is read-only or immutable + isReadOnly := false + if len(qs["mode"]) > 0 { + // Keep the first value only + qs["mode"] = []string{ + qs["mode"][0], + } + if qs["mode"][0] == "ro" { + isReadOnly = true + } + } + if len(qs["immutable"]) > 0 { + // Keep the first value only + qs["immutable"] = []string{ + qs["immutable"][0], + } + if qs["immutable"][0] == "1" { + isReadOnly = true + } + } + + // We do not want to override a _txlock if set, but we'll show a warning if it's not "immediate" + if len(qs["_txlock"]) > 0 { + // Keep the first value only + qs["_txlock"] = []string{ + strings.ToLower(qs["_txlock"][0]), + } + if qs["_txlock"][0] != "immediate" { + log.Warn("Database connection is being created with a _txlock different from the recommended value 'immediate'") + } + } else { + qs["_txlock"] = []string{"immediate"} + } + + // Add pragma values + if len(qs["_pragma"]) == 0 { + qs["_pragma"] = make([]string, 0, 2) + } else { + for _, p := range qs["_pragma"] { + p = strings.ToLower(p) + if strings.HasPrefix(p, "busy_timeout") { + log.Error("Cannot set `_pragma=busy_timeout` option in the connection string; please use the `busyTimeout` metadata property instead") + return "", errors.New("found forbidden option '_pragma=busy_timeout' in the connection string") + } else if strings.HasPrefix(p, "journal_mode") { + log.Error("Cannot set `_pragma=journal_mode` option in the connection string; please use the `disableWAL` metadata property instead") + return "", errors.New("found forbidden option '_pragma=journal_mode' in the connection string") + } + } + } + if m.BusyTimeout > 0 { + qs["_pragma"] = append(qs["_pragma"], fmt.Sprintf("busy_timeout(%d)", m.BusyTimeout.Milliseconds())) + } + if isMemoryDB { + // For in-memory databases, set the journal to MEMORY, the only allowed option besides OFF (which would make transactions ineffective) + qs["_pragma"] = append(qs["_pragma"], "journal_mode(MEMORY)") + } else if m.DisableWAL || isReadOnly { + // Set the journaling mode to "DELETE" (the default) if WAL is disabled or if the database is read-only + qs["_pragma"] = append(qs["_pragma"], "journal_mode(DELETE)") + } else { + // Enable WAL + qs["_pragma"] = append(qs["_pragma"], "journal_mode(WAL)") + } + + // Build the final connection string + connString := m.ConnectionString + if idx > 0 { + connString = connString[:idx] + } + connString += "?" + qs.Encode() + + // If the connection string doesn't begin with "file:", add the prefix + if !strings.HasPrefix(lc, "file:") { + log.Debug("prefix 'file:' added to the connection string") + connString = "file:" + connString + } + + return connString, nil +} + +// Validates an identifier, such as table or DB name. +func ValidIdentifier(v string) bool { + if v == "" { + return false + } + + // Loop through the string as byte slice as we only care about ASCII characters + b := []byte(v) + for i := 0; i < len(b); i++ { + if (b[i] >= '0' && b[i] <= '9') || + (b[i] >= 'a' && b[i] <= 'z') || + (b[i] >= 'A' && b[i] <= 'Z') || + b[i] == '_' { + continue + } + return false + } + return true +} diff --git a/internal/authentication/sqlite/metadata_test.go b/internal/authentication/sqlite/metadata_test.go new file mode 100644 index 0000000000..290801a9a8 --- /dev/null +++ b/internal/authentication/sqlite/metadata_test.go @@ -0,0 +1,95 @@ +/* +Copyright 2023 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sqlite + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/dapr/components-contrib/metadata" + "github.com/dapr/components-contrib/state" +) + +func TestSqliteMetadata(t *testing.T) { + stateMetadata := func(props map[string]string) state.Metadata { + return state.Metadata{Base: metadata.Base{Properties: props}} + } + + t.Run("default options", func(t *testing.T) { + md := &SqliteAuthMetadata{} + md.Reset() + + err := metadata.DecodeMetadata(stateMetadata(map[string]string{ + "connectionString": "file:data.db", + }), &md) + require.NoError(t, err) + + err = md.Validate() + + require.NoError(t, err) + assert.Equal(t, "file:data.db", md.ConnectionString) + assert.Equal(t, DefaultTimeout, md.Timeout) + assert.Equal(t, DefaultBusyTimeout, md.BusyTimeout) + assert.False(t, md.DisableWAL) + }) + + t.Run("empty connection string", func(t *testing.T) { + md := &SqliteAuthMetadata{} + md.Reset() + + err := metadata.DecodeMetadata(stateMetadata(map[string]string{}), &md) + require.NoError(t, err) + + err = md.Validate() + + require.Error(t, err) + assert.ErrorContains(t, err, "missing connection string") + }) + + t.Run("invalid timeout", func(t *testing.T) { + md := &SqliteAuthMetadata{} + md.Reset() + + err := metadata.DecodeMetadata(stateMetadata(map[string]string{ + "connectionString": "file:data.db", + "timeout": "500ms", + }), &md) + require.NoError(t, err) + + err = md.Validate() + + require.Error(t, err) + assert.ErrorContains(t, err, "timeout") + }) + + t.Run("aliases", func(t *testing.T) { + md := &SqliteAuthMetadata{} + md.Reset() + + err := metadata.DecodeMetadata(stateMetadata(map[string]string{ + "url": "file:data.db", + "timeoutinseconds": "1200", + }), &md) + require.NoError(t, err) + + err = md.Validate() + + require.NoError(t, err) + assert.Equal(t, "file:data.db", md.ConnectionString) + assert.Equal(t, 20*time.Minute, md.Timeout) + }) +} diff --git a/metadata/utils.go b/metadata/utils.go index 73ffa15d96..3ac7945497 100644 --- a/metadata/utils.go +++ b/metadata/utils.go @@ -158,7 +158,7 @@ func DecodeMetadata(input any, result any) error { } // Handle aliases - err = resolveAliases(inputMap, result) + err = resolveAliases(inputMap, reflect.TypeOf(result)) if err != nil { return fmt.Errorf("failed to resolve aliases: %w", err) } @@ -183,7 +183,7 @@ func DecodeMetadata(input any, result any) error { return err } -func resolveAliases(md map[string]string, result any) error { +func resolveAliases(md map[string]string, t reflect.Type) error { // Get the list of all keys in the map keys := make(map[string]string, len(md)) for k := range md { @@ -199,7 +199,6 @@ func resolveAliases(md map[string]string, result any) error { } // Error if result is not pointer to struct, or pointer to pointer to struct - t := reflect.TypeOf(result) if t.Kind() != reflect.Pointer { return fmt.Errorf("not a pointer: %s", t.Kind().String()) } @@ -211,7 +210,14 @@ func resolveAliases(md map[string]string, result any) error { return fmt.Errorf("not a struct: %s", t.Kind().String()) } - // Iterate through all the properties of result to see if anyone has the "mapstructurealiases" property + // Iterate through all the properties, possibly recursively + resolveAliasesInType(md, keys, t) + + return nil +} + +func resolveAliasesInType(md map[string]string, keys map[string]string, t reflect.Type) { + // Iterate through all the properties of the type to see if anyone has the "mapstructurealiases" property for i := 0; i < t.NumField(); i++ { currentField := t.Field(i) @@ -221,6 +227,12 @@ func resolveAliases(md map[string]string, result any) error { continue } + // Check if this is an embedded struct + if mapstructureTag == ",squash" { + resolveAliasesInType(md, keys, currentField.Type) + continue + } + // If the current property has a value in the metadata, then we don't need to handle aliases _, ok := keys[strings.ToLower(mapstructureTag)] if ok { @@ -246,8 +258,6 @@ func resolveAliases(md map[string]string, result any) error { break } } - - return nil } func toTruthyBoolHookFunc() mapstructure.DecodeHookFunc { diff --git a/metadata/utils_test.go b/metadata/utils_test.go index 2ca80db4c4..9cbfd9eca8 100644 --- a/metadata/utils_test.go +++ b/metadata/utils_test.go @@ -98,7 +98,13 @@ func TestTryGetContentType(t *testing.T) { func TestMetadataDecode(t *testing.T) { t.Run("Test metadata decoding", func(t *testing.T) { + type TestEmbedded struct { + MyEmbedded string `mapstructure:"embedded"` + MyEmbeddedAliased string `mapstructure:"embalias" mapstructurealiases:"embalias2"` + } type testMetadata struct { + TestEmbedded `mapstructure:",squash"` + Mystring string `mapstructure:"mystring"` Myduration Duration `mapstructure:"myduration"` Myinteger int `mapstructure:"myinteger"` @@ -139,6 +145,8 @@ func TestMetadataDecode(t *testing.T) { "aliasA2": "hello", "aliasB1": "ciao", "aliasB2": "bonjour", + "embedded": "hi", + "embalias2": "ciao", } err := DecodeMetadata(testData, &m) @@ -159,6 +167,8 @@ func TestMetadataDecode(t *testing.T) { assert.Equal(t, []time.Duration{}, *m.MyDurationArrayPointerEmpty) assert.Equal(t, "hello", m.AliasedFieldA) assert.Equal(t, "ciao", m.AliasedFieldB) + assert.Equal(t, "hi", m.TestEmbedded.MyEmbedded) + assert.Equal(t, "ciao", m.TestEmbedded.MyEmbeddedAliased) }) t.Run("Test metadata decode hook for truthy values", func(t *testing.T) { @@ -346,6 +356,10 @@ func TestMetadataStructToStringMap(t *testing.T) { } func TestResolveAliases(t *testing.T) { + type Embedded struct { + Hello string `mapstructure:"hello" mapstructurealiases:"ciao"` + } + tests := []struct { name string md map[string]string @@ -497,11 +511,27 @@ func TestResolveAliases(t *testing.T) { "bonjour": "monde", }, }, + { + name: "aliases in embedded struct", + md: map[string]string{ + "ciao": "mondo", + "bonjour": "monde", + }, + result: &struct { + Embedded `mapstructure:",squash"` + Bonjour string `mapstructure:"bonjour"` + }{}, + wantMd: map[string]string{ + "bonjour": "monde", + "ciao": "mondo", + "hello": "mondo", + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { md := maps.Clone(tt.md) - err := resolveAliases(md, tt.result) + err := resolveAliases(md, reflect.TypeOf(tt.result)) if tt.wantErr { require.Error(t, err) diff --git a/state/sqlite/sqlite_dbaccess.go b/state/sqlite/sqlite_dbaccess.go index e05b5c5032..d8a2785dfb 100644 --- a/state/sqlite/sqlite_dbaccess.go +++ b/state/sqlite/sqlite_dbaccess.go @@ -20,7 +20,6 @@ import ( "encoding/json" "errors" "fmt" - "net/url" "strconv" "strings" "time" @@ -78,7 +77,7 @@ func (a *sqliteDBAccess) Init(ctx context.Context, md state.Metadata) error { return err } - connString, err := a.getConnectionString() + connString, err := a.metadata.GetConnectionString(a.logger) if err != nil { // Already logged return err @@ -137,107 +136,8 @@ func (a *sqliteDBAccess) CleanupExpired() error { return a.gc.CleanupExpired() } -func (a *sqliteDBAccess) getConnectionString() (string, error) { - // Check if we're using the in-memory database - lc := strings.ToLower(a.metadata.ConnectionString) - isMemoryDB := strings.HasPrefix(lc, ":memory:") || strings.HasPrefix(lc, "file::memory:") - - // Get the "query string" from the connection string if present - idx := strings.IndexRune(a.metadata.ConnectionString, '?') - var qs url.Values - if idx > 0 { - qs, _ = url.ParseQuery(a.metadata.ConnectionString[(idx + 1):]) - } - if len(qs) == 0 { - qs = make(url.Values, 2) - } - - // If the database is in-memory, we must ensure that cache=shared is set - if isMemoryDB { - qs["cache"] = []string{"shared"} - } - - // Check if the database is read-only or immutable - isReadOnly := false - if len(qs["mode"]) > 0 { - // Keep the first value only - qs["mode"] = []string{ - qs["mode"][0], - } - if qs["mode"][0] == "ro" { - isReadOnly = true - } - } - if len(qs["immutable"]) > 0 { - // Keep the first value only - qs["immutable"] = []string{ - qs["immutable"][0], - } - if qs["immutable"][0] == "1" { - isReadOnly = true - } - } - - // We do not want to override a _txlock if set, but we'll show a warning if it's not "immediate" - if len(qs["_txlock"]) > 0 { - // Keep the first value only - qs["_txlock"] = []string{ - strings.ToLower(qs["_txlock"][0]), - } - if qs["_txlock"][0] != "immediate" { - a.logger.Warn("Database connection is being created with a _txlock different from the recommended value 'immediate'") - } - } else { - qs["_txlock"] = []string{"immediate"} - } - - // Add pragma values - if len(qs["_pragma"]) == 0 { - qs["_pragma"] = make([]string, 0, 2) - } else { - for _, p := range qs["_pragma"] { - p = strings.ToLower(p) - if strings.HasPrefix(p, "busy_timeout") { - a.logger.Error("Cannot set `_pragma=busy_timeout` option in the connection string; please use the `busyTimeout` metadata property instead") - return "", errors.New("found forbidden option '_pragma=busy_timeout' in the connection string") - } else if strings.HasPrefix(p, "journal_mode") { - a.logger.Error("Cannot set `_pragma=journal_mode` option in the connection string; please use the `disableWAL` metadata property instead") - return "", errors.New("found forbidden option '_pragma=journal_mode' in the connection string") - } - } - } - if a.metadata.BusyTimeout > 0 { - qs["_pragma"] = append(qs["_pragma"], fmt.Sprintf("busy_timeout(%d)", a.metadata.BusyTimeout.Milliseconds())) - } - if isMemoryDB { - // For in-memory databases, set the journal to MEMORY, the only allowed option besides OFF (which would make transactions ineffective) - qs["_pragma"] = append(qs["_pragma"], "journal_mode(MEMORY)") - } else if a.metadata.DisableWAL || isReadOnly { - // Set the journaling mode to "DELETE" (the default) if WAL is disabled or if the database is read-only - qs["_pragma"] = append(qs["_pragma"], "journal_mode(DELETE)") - } else { - // Enable WAL - qs["_pragma"] = append(qs["_pragma"], "journal_mode(WAL)") - } - - // Build the final connection string - connString := a.metadata.ConnectionString - if idx > 0 { - connString = connString[:idx] - } - connString += "?" + qs.Encode() - - // If the connection string doesn't begin with "file:", add the prefix - if !strings.HasPrefix(lc, "file:") { - a.logger.Debug("prefix 'file:' added to the connection string") - connString = "file:" + connString - } - - return connString, nil -} - func (a *sqliteDBAccess) Ping(parentCtx context.Context) error { - ctx, cancel := context.WithTimeout(parentCtx, a.metadata.timeout) + ctx, cancel := context.WithTimeout(parentCtx, a.metadata.Timeout) err := a.db.PingContext(ctx) cancel() return err @@ -253,7 +153,7 @@ func (a *sqliteDBAccess) Get(parentCtx context.Context, req *state.GetRequest) ( WHERE key = ? AND (expiration_time IS NULL OR expiration_time > CURRENT_TIMESTAMP)` - ctx, cancel := context.WithTimeout(parentCtx, a.metadata.timeout) + ctx, cancel := context.WithTimeout(parentCtx, a.metadata.Timeout) defer cancel() row := a.db.QueryRowContext(ctx, stmt, req.Key) _, value, etag, expireTime, err := readRow(row) @@ -296,7 +196,7 @@ func (a *sqliteDBAccess) BulkGet(parentCtx context.Context, req []state.GetReque WHERE key IN (` + inClause + `) AND (expiration_time IS NULL OR expiration_time > CURRENT_TIMESTAMP)` - ctx, cancel := context.WithTimeout(parentCtx, a.metadata.timeout) + ctx, cancel := context.WithTimeout(parentCtx, a.metadata.Timeout) defer cancel() rows, err := a.db.QueryContext(ctx, stmt, params...) if err != nil { @@ -475,7 +375,7 @@ func (a *sqliteDBAccess) doSet(parentCtx context.Context, db querier, req *state stmt = "INSERT OR REPLACE INTO " + a.metadata.TableName + ` (key, value, is_binary, etag, update_time, expiration_time) VALUES(?, ?, ?, ?, CURRENT_TIMESTAMP, ` + expiration + `)` - ctx, cancel := context.WithTimeout(context.Background(), a.metadata.timeout) + ctx, cancel := context.WithTimeout(context.Background(), a.metadata.Timeout) defer cancel() res, err = db.ExecContext(ctx, stmt, req.Key, requestValue, isBinary, newEtag, req.Key) } else { @@ -489,7 +389,7 @@ func (a *sqliteDBAccess) doSet(parentCtx context.Context, db querier, req *state key = ? AND etag = ? AND (expiration_time IS NULL OR expiration_time > CURRENT_TIMESTAMP)` - ctx, cancel := context.WithTimeout(context.Background(), a.metadata.timeout) + ctx, cancel := context.WithTimeout(context.Background(), a.metadata.Timeout) defer cancel() res, err = db.ExecContext(ctx, stmt, requestValue, newEtag, isBinary, req.Key, *req.ETag) } @@ -573,7 +473,7 @@ func (a *sqliteDBAccess) doDelete(parentCtx context.Context, db querier, req *st return fmt.Errorf("missing key in delete operation") } - ctx, cancel := context.WithTimeout(parentCtx, a.metadata.timeout) + ctx, cancel := context.WithTimeout(parentCtx, a.metadata.Timeout) defer cancel() var result sql.Result if !req.HasETag() { diff --git a/state/sqlite/sqlite_integration_test.go b/state/sqlite/sqlite_integration_test.go index 4400ca4cfd..c14415a954 100644 --- a/state/sqlite/sqlite_integration_test.go +++ b/state/sqlite/sqlite_integration_test.go @@ -678,7 +678,7 @@ func testInitConfiguration(t *testing.T) { { name: "Empty", props: map[string]string{}, - expectedErr: errMissingConnectionString, + expectedErr: "missing connection string", }, { name: "Valid connection string", @@ -703,10 +703,10 @@ func testInitConfiguration(t *testing.T) { err := p.Init(context.Background(), metadata) if tt.expectedErr == "" { - assert.NoError(t, err) + require.NoError(t, err) } else { - assert.Error(t, err) - assert.Equal(t, err.Error(), tt.expectedErr) + require.Error(t, err) + assert.ErrorContains(t, err, tt.expectedErr) } }) } diff --git a/state/sqlite/sqlite_metadata.go b/state/sqlite/sqlite_metadata.go index 857e5f996d..bbddb6be96 100644 --- a/state/sqlite/sqlite_metadata.go +++ b/state/sqlite/sqlite_metadata.go @@ -14,11 +14,10 @@ limitations under the License. package sqlite import ( - "errors" "fmt" - "strconv" "time" + authSqlite "github.com/dapr/components-contrib/internal/authentication/sqlite" "github.com/dapr/components-contrib/metadata" "github.com/dapr/components-contrib/state" ) @@ -27,30 +26,18 @@ const ( defaultTableName = "state" defaultMetadataTableName = "metadata" defaultCleanupInternal = time.Duration(0) // Disabled by default - defaultTimeout = 20 * time.Second // Default timeout for database requests, in seconds - defaultBusyTimeout = 2 * time.Second - - errMissingConnectionString = "missing connection string" - errInvalidIdentifier = "invalid identifier: %s" // specify identifier type, e.g. "table name" ) type sqliteMetadataStruct struct { - ConnectionString string `json:"connectionString" mapstructure:"connectionString"` - TableName string `json:"tableName" mapstructure:"tableName"` - MetadataTableName string `json:"metadataTableName" mapstructure:"metadataTableName"` - TimeoutInSeconds string `json:"timeoutInSeconds" mapstructure:"timeoutInSeconds"` - CleanupInterval time.Duration `json:"cleanupInterval" mapstructure:"cleanupInterval"` - BusyTimeout time.Duration `json:"busyTimeout" mapstructure:"busyTimeout"` - DisableWAL bool `json:"disableWAL" mapstructure:"disableWAL"` // Disable WAL journaling. You should not use WAL if the database is stored on a network filesystem (or data corruption may happen). This is ignored if the database is in-memory. - - // Deprecated properties, maintained for backwards-compatibility - CleanupIntervalInSeconds string `json:"cleanupIntervalInSeconds" mapstructure:"cleanupIntervalInSeconds"` + authSqlite.SqliteAuthMetadata `mapstructure:",squash"` - // Internal properties - timeout time.Duration + TableName string `mapstructure:"tableName"` + MetadataTableName string `mapstructure:"metadataTableName"` + CleanupInterval time.Duration `mapstructure:"cleanupInterval" mapstructurealiases:"cleanupIntervalInSeconds"` } func (m *sqliteMetadataStruct) InitWithMetadata(meta state.Metadata) error { + // Reset the object m.reset() // Decode the metadata @@ -60,78 +47,25 @@ func (m *sqliteMetadataStruct) InitWithMetadata(meta state.Metadata) error { } // Validate and sanitize input - if m.ConnectionString == "" { - return errors.New(errMissingConnectionString) + err = m.SqliteAuthMetadata.Validate() + if err != nil { + return err } - if !validIdentifier(m.TableName) { - return fmt.Errorf(errInvalidIdentifier, m.TableName) + if !authSqlite.ValidIdentifier(m.TableName) { + return fmt.Errorf("invalid identifier: %s", m.TableName) } - - // Timeout - if m.TimeoutInSeconds != "" { - timeoutInSec, err := strconv.ParseInt(m.TimeoutInSeconds, 10, 0) - if err != nil { - return fmt.Errorf("invalid value for 'timeoutInSeconds': %s", m.TimeoutInSeconds) - } - if timeoutInSec < 1 { - return errors.New("invalid value for 'timeoutInSeconds': must be greater than 0") - } - - m.timeout = time.Duration(timeoutInSec) * time.Second - } - - // Legacy "CleanupIntervalInSeconds" property - // Non-positive duration means never clean up expired data - if v := meta.Properties["cleanupInterval"]; v == "" && m.CleanupIntervalInSeconds != "" { - cleanupIntervalInSec, err := strconv.ParseInt(m.CleanupIntervalInSeconds, 10, 0) - if err != nil { - return fmt.Errorf("invalid value for 'cleanupIntervalInSeconds': %s", m.CleanupIntervalInSeconds) - } - - // Non-positive value from meta means disable auto cleanup. - if cleanupIntervalInSec > 0 { - m.CleanupInterval = time.Duration(cleanupIntervalInSec) * time.Second - } + if !authSqlite.ValidIdentifier(m.MetadataTableName) { + return fmt.Errorf("invalid identifier: %s", m.MetadataTableName) } - // Busy timeout - // Truncate values to milliseconds. Values <= 0 do not set any timeout - m.BusyTimeout = m.BusyTimeout.Truncate(time.Millisecond) - return nil } // Reset the object func (m *sqliteMetadataStruct) reset() { - m.ConnectionString = "" + m.SqliteAuthMetadata.Reset() + m.TableName = defaultTableName m.MetadataTableName = defaultMetadataTableName - m.TimeoutInSeconds = "" m.CleanupInterval = defaultCleanupInternal - m.BusyTimeout = defaultBusyTimeout - m.DisableWAL = false - - m.CleanupIntervalInSeconds = "" - - m.timeout = defaultTimeout -} - -// Validates an identifier, such as table or DB name. -func validIdentifier(v string) bool { - if v == "" { - return false - } - - // Loop through the string as byte slice as we only care about ASCII characters - b := []byte(v) - for i := 0; i < len(b); i++ { - if (b[i] >= '0' && b[i] <= '9') || - (b[i] >= 'a' && b[i] <= 'z') || - (b[i] >= 'A' && b[i] <= 'Z') || - b[i] == '_' { - continue - } - return false - } - return true } diff --git a/state/sqlite/sqlite_metadata_test.go b/state/sqlite/sqlite_metadata_test.go new file mode 100644 index 0000000000..10923ce0c0 --- /dev/null +++ b/state/sqlite/sqlite_metadata_test.go @@ -0,0 +1,103 @@ +/* +Copyright 2023 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sqlite + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + authSqlite "github.com/dapr/components-contrib/internal/authentication/sqlite" + "github.com/dapr/components-contrib/metadata" + "github.com/dapr/components-contrib/state" +) + +func TestSqliteMetadata(t *testing.T) { + stateMetadata := func(props map[string]string) state.Metadata { + return state.Metadata{Base: metadata.Base{Properties: props}} + } + + t.Run("default options", func(t *testing.T) { + md := &sqliteMetadataStruct{} + err := md.InitWithMetadata(stateMetadata(map[string]string{ + "connectionString": "file:data.db", + })) + + require.NoError(t, err) + assert.Equal(t, "file:data.db", md.ConnectionString) + assert.Equal(t, defaultTableName, md.TableName) + assert.Equal(t, defaultMetadataTableName, md.MetadataTableName) + assert.Equal(t, authSqlite.DefaultTimeout, md.Timeout) + assert.Equal(t, defaultCleanupInternal, md.CleanupInterval) + assert.Equal(t, authSqlite.DefaultBusyTimeout, md.BusyTimeout) + assert.False(t, md.DisableWAL) + }) + + t.Run("empty connection string", func(t *testing.T) { + md := &sqliteMetadataStruct{} + err := md.InitWithMetadata(stateMetadata(map[string]string{})) + + require.Error(t, err) + assert.ErrorContains(t, err, "missing connection string") + }) + + t.Run("invalid state table name", func(t *testing.T) { + md := &sqliteMetadataStruct{} + err := md.InitWithMetadata(stateMetadata(map[string]string{ + "connectionstring": "file:data.db", + "tablename": "not.valid", + })) + + require.Error(t, err) + assert.ErrorContains(t, err, "invalid identifier") + }) + + t.Run("invalid metadata table name", func(t *testing.T) { + md := &sqliteMetadataStruct{} + err := md.InitWithMetadata(stateMetadata(map[string]string{ + "connectionstring": "file:data.db", + "metadatatablename": "not.valid", + })) + + require.Error(t, err) + assert.ErrorContains(t, err, "invalid identifier") + }) + + t.Run("invalid timeout", func(t *testing.T) { + md := &sqliteMetadataStruct{} + err := md.InitWithMetadata(stateMetadata(map[string]string{ + "connectionString": "file:data.db", + "timeout": "500ms", + })) + + require.Error(t, err) + assert.ErrorContains(t, err, "timeout") + }) + + t.Run("aliases", func(t *testing.T) { + md := &sqliteMetadataStruct{} + err := md.InitWithMetadata(stateMetadata(map[string]string{ + "url": "file:data.db", + "timeoutinseconds": "1200", + "cleanupintervalinseconds": "22", + })) + + require.NoError(t, err) + assert.Equal(t, "file:data.db", md.ConnectionString) + assert.Equal(t, 20*time.Minute, md.Timeout) + assert.Equal(t, 22*time.Second, md.CleanupInterval) + }) +} diff --git a/state/sqlite/sqlite_test.go b/state/sqlite/sqlite_test.go index 0af7c4c2c5..13b048e1a7 100644 --- a/state/sqlite/sqlite_test.go +++ b/state/sqlite/sqlite_test.go @@ -48,7 +48,7 @@ func TestGetConnectionString(t *testing.T) { db.metadata.reset() db.metadata.ConnectionString = "file:test.db" - connString, err := db.getConnectionString() + connString, err := db.metadata.GetConnectionString(log) require.NoError(t, err) values := url.Values{ @@ -64,7 +64,7 @@ func TestGetConnectionString(t *testing.T) { db.metadata.reset() db.metadata.ConnectionString = "test.db" - connString, err := db.getConnectionString() + connString, err := db.metadata.GetConnectionString(log) require.NoError(t, err) values := url.Values{ @@ -82,7 +82,7 @@ func TestGetConnectionString(t *testing.T) { db.metadata.reset() db.metadata.ConnectionString = ":memory:" - connString, err := db.getConnectionString() + connString, err := db.metadata.GetConnectionString(log) require.NoError(t, err) values := url.Values{ @@ -103,7 +103,7 @@ func TestGetConnectionString(t *testing.T) { db.metadata.reset() db.metadata.ConnectionString = "file:test.db?_txlock=immediate" - connString, err := db.getConnectionString() + connString, err := db.metadata.GetConnectionString(log) require.NoError(t, err) values := url.Values{ @@ -121,7 +121,7 @@ func TestGetConnectionString(t *testing.T) { db.metadata.reset() db.metadata.ConnectionString = "file:test.db?_txlock=deferred" - connString, err := db.getConnectionString() + connString, err := db.metadata.GetConnectionString(log) require.NoError(t, err) values := url.Values{ @@ -141,7 +141,7 @@ func TestGetConnectionString(t *testing.T) { db.metadata.reset() db.metadata.ConnectionString = "file:test.db?_pragma=busy_timeout(50)" - _, err := db.getConnectionString() + _, err := db.metadata.GetConnectionString(log) require.Error(t, err) assert.ErrorContains(t, err, "found forbidden option '_pragma=busy_timeout' in the connection string") }) @@ -150,7 +150,7 @@ func TestGetConnectionString(t *testing.T) { db.metadata.reset() db.metadata.ConnectionString = "file:test.db?_pragma=journal_mode(WAL)" - _, err := db.getConnectionString() + _, err := db.metadata.GetConnectionString(log) require.Error(t, err) assert.ErrorContains(t, err, "found forbidden option '_pragma=journal_mode' in the connection string") }) @@ -162,7 +162,7 @@ func TestGetConnectionString(t *testing.T) { db.metadata.ConnectionString = "file:test.db" db.metadata.BusyTimeout = time.Second - connString, err := db.getConnectionString() + connString, err := db.metadata.GetConnectionString(log) require.NoError(t, err) values := url.Values{ @@ -179,7 +179,7 @@ func TestGetConnectionString(t *testing.T) { db.metadata.ConnectionString = "file:test.db" db.metadata.DisableWAL = false - connString, err := db.getConnectionString() + connString, err := db.metadata.GetConnectionString(log) require.NoError(t, err) values := url.Values{ @@ -195,7 +195,7 @@ func TestGetConnectionString(t *testing.T) { db.metadata.ConnectionString = "file:test.db" db.metadata.DisableWAL = true - connString, err := db.getConnectionString() + connString, err := db.metadata.GetConnectionString(log) require.NoError(t, err) values := url.Values{ @@ -210,7 +210,7 @@ func TestGetConnectionString(t *testing.T) { db.metadata.reset() db.metadata.ConnectionString = "file::memory:" - connString, err := db.getConnectionString() + connString, err := db.metadata.GetConnectionString(log) require.NoError(t, err) values := url.Values{ @@ -226,7 +226,7 @@ func TestGetConnectionString(t *testing.T) { db.metadata.reset() db.metadata.ConnectionString = "file:test.db?mode=ro" - connString, err := db.getConnectionString() + connString, err := db.metadata.GetConnectionString(log) require.NoError(t, err) values := url.Values{ @@ -242,7 +242,7 @@ func TestGetConnectionString(t *testing.T) { db.metadata.reset() db.metadata.ConnectionString = "file:test.db?immutable=1" - connString, err := db.getConnectionString() + connString, err := db.metadata.GetConnectionString(log) require.NoError(t, err) values := url.Values{