Skip to content

Commit

Permalink
SQLite: allow enabling foreign keys in GetConnectionString (#3253)
Browse files Browse the repository at this point in the history
Signed-off-by: ItalyPaleAle <[email protected]>
  • Loading branch information
ItalyPaleAle authored Dec 4, 2023
1 parent ca00355 commit 79adc56
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 19 deletions.
21 changes: 17 additions & 4 deletions common/authentication/sqlite/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,14 @@ func (m SqliteAuthMetadata) IsInMemoryDB() bool {
return strings.HasPrefix(lc, ":memory:") || strings.HasPrefix(lc, "file::memory:")
}

// GetConnectionStringOpts contains options for GetConnectionString
type GetConnectionStringOpts struct {
// Enabled foreign keys
EnableForeignKeys bool
}

// GetConnectionString returns the parsed connection string.
func (m *SqliteAuthMetadata) GetConnectionString(log logger.Logger) (string, error) {
func (m *SqliteAuthMetadata) GetConnectionString(log logger.Logger, opts GetConnectionStringOpts) (string, error) {
// Check if we're using the in-memory database
isMemoryDB := m.IsInMemoryDB()

Expand Down Expand Up @@ -126,16 +132,20 @@ func (m *SqliteAuthMetadata) GetConnectionString(log logger.Logger) (string, err

// Add pragma values
if len(qs["_pragma"]) == 0 {
qs["_pragma"] = make([]string, 0, 2)
qs["_pragma"] = make([]string, 0, 3)
} else {
for _, p := range qs["_pragma"] {
p = strings.ToLower(p)
if strings.HasPrefix(p, "busy_timeout") {
switch {
case strings.HasPrefix(p, "busy_timeout"):
log.Error("Cannot set `_pragma=busy_timeout` option in the connection string; please use the `busyTimeout` metadata property instead")
return "", errors.New("found forbidden option '_pragma=busy_timeout' in the connection string")
} else if strings.HasPrefix(p, "journal_mode") {
case strings.HasPrefix(p, "journal_mode"):
log.Error("Cannot set `_pragma=journal_mode` option in the connection string; please use the `disableWAL` metadata property instead")
return "", errors.New("found forbidden option '_pragma=journal_mode' in the connection string")
case strings.HasPrefix(p, "foreign_keys"):
log.Error("Cannot set `_pragma=foreign_keys` option in the connection string")
return "", errors.New("found forbidden option '_pragma=foreign_keys' in the connection string")
}
}
}
Expand All @@ -152,6 +162,9 @@ func (m *SqliteAuthMetadata) GetConnectionString(log logger.Logger) (string, err
// Enable WAL
qs["_pragma"] = append(qs["_pragma"], "journal_mode(WAL)")
}
if opts.EnableForeignKeys {
qs["_pragma"] = append(qs["_pragma"], "foreign_keys(1)")
}

// Build the final connection string
connString := m.ConnectionString
Expand Down
3 changes: 2 additions & 1 deletion nameresolution/sqlite/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"

"github.com/dapr/components-contrib/common/authentication/sqlite"
commonsql "github.com/dapr/components-contrib/common/component/sql"
"github.com/dapr/components-contrib/nameresolution"
"github.com/dapr/kit/logger"
Expand Down Expand Up @@ -67,7 +68,7 @@ func (s *resolver) Init(ctx context.Context, md nameresolution.Metadata) error {
return err
}

connString, err := s.metadata.GetConnectionString(s.logger)
connString, err := s.metadata.GetConnectionString(s.logger, sqlite.GetConnectionStringOpts{})
if err != nil {
// Already logged
return err
Expand Down
3 changes: 2 additions & 1 deletion state/sqlite/sqlite_dbaccess.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
// Blank import for the underlying SQLite Driver.
_ "modernc.org/sqlite"

"github.com/dapr/components-contrib/common/authentication/sqlite"
commonsql "github.com/dapr/components-contrib/common/component/sql"
"github.com/dapr/components-contrib/state"
stateutils "github.com/dapr/components-contrib/state/utils"
Expand Down Expand Up @@ -77,7 +78,7 @@ func (a *sqliteDBAccess) Init(ctx context.Context, md state.Metadata) error {
return err
}

connString, err := a.metadata.GetConnectionString(a.logger)
connString, err := a.metadata.GetConnectionString(a.logger, sqlite.GetConnectionStringOpts{})
if err != nil {
// Already logged
return err
Expand Down
27 changes: 14 additions & 13 deletions state/sqlite/sqlite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/dapr/components-contrib/common/authentication/sqlite"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/state"
"github.com/dapr/kit/logger"
Expand All @@ -48,7 +49,7 @@ func TestGetConnectionString(t *testing.T) {
db.metadata.reset()
db.metadata.ConnectionString = "file:test.db"

connString, err := db.metadata.GetConnectionString(log)
connString, err := db.metadata.GetConnectionString(log, sqlite.GetConnectionStringOpts{})
require.NoError(t, err)

values := url.Values{
Expand All @@ -64,7 +65,7 @@ func TestGetConnectionString(t *testing.T) {
db.metadata.reset()
db.metadata.ConnectionString = "test.db"

connString, err := db.metadata.GetConnectionString(log)
connString, err := db.metadata.GetConnectionString(log, sqlite.GetConnectionStringOpts{})
require.NoError(t, err)

values := url.Values{
Expand All @@ -82,7 +83,7 @@ func TestGetConnectionString(t *testing.T) {
db.metadata.reset()
db.metadata.ConnectionString = ":memory:"

connString, err := db.metadata.GetConnectionString(log)
connString, err := db.metadata.GetConnectionString(log, sqlite.GetConnectionStringOpts{})
require.NoError(t, err)

values := url.Values{
Expand All @@ -103,7 +104,7 @@ func TestGetConnectionString(t *testing.T) {
db.metadata.reset()
db.metadata.ConnectionString = "file:test.db?_txlock=immediate"

connString, err := db.metadata.GetConnectionString(log)
connString, err := db.metadata.GetConnectionString(log, sqlite.GetConnectionStringOpts{})
require.NoError(t, err)

values := url.Values{
Expand All @@ -121,7 +122,7 @@ func TestGetConnectionString(t *testing.T) {
db.metadata.reset()
db.metadata.ConnectionString = "file:test.db?_txlock=deferred"

connString, err := db.metadata.GetConnectionString(log)
connString, err := db.metadata.GetConnectionString(log, sqlite.GetConnectionStringOpts{})
require.NoError(t, err)

values := url.Values{
Expand All @@ -141,7 +142,7 @@ func TestGetConnectionString(t *testing.T) {
db.metadata.reset()
db.metadata.ConnectionString = "file:test.db?_pragma=busy_timeout(50)"

_, err := db.metadata.GetConnectionString(log)
_, err := db.metadata.GetConnectionString(log, sqlite.GetConnectionStringOpts{})
require.Error(t, err)
require.ErrorContains(t, err, "found forbidden option '_pragma=busy_timeout' in the connection string")
})
Expand All @@ -150,7 +151,7 @@ func TestGetConnectionString(t *testing.T) {
db.metadata.reset()
db.metadata.ConnectionString = "file:test.db?_pragma=journal_mode(WAL)"

_, err := db.metadata.GetConnectionString(log)
_, err := db.metadata.GetConnectionString(log, sqlite.GetConnectionStringOpts{})
require.Error(t, err)
require.ErrorContains(t, err, "found forbidden option '_pragma=journal_mode' in the connection string")
})
Expand All @@ -162,7 +163,7 @@ func TestGetConnectionString(t *testing.T) {
db.metadata.ConnectionString = "file:test.db"
db.metadata.BusyTimeout = time.Second

connString, err := db.metadata.GetConnectionString(log)
connString, err := db.metadata.GetConnectionString(log, sqlite.GetConnectionStringOpts{})
require.NoError(t, err)

values := url.Values{
Expand All @@ -179,7 +180,7 @@ func TestGetConnectionString(t *testing.T) {
db.metadata.ConnectionString = "file:test.db"
db.metadata.DisableWAL = false

connString, err := db.metadata.GetConnectionString(log)
connString, err := db.metadata.GetConnectionString(log, sqlite.GetConnectionStringOpts{})
require.NoError(t, err)

values := url.Values{
Expand All @@ -195,7 +196,7 @@ func TestGetConnectionString(t *testing.T) {
db.metadata.ConnectionString = "file:test.db"
db.metadata.DisableWAL = true

connString, err := db.metadata.GetConnectionString(log)
connString, err := db.metadata.GetConnectionString(log, sqlite.GetConnectionStringOpts{})
require.NoError(t, err)

values := url.Values{
Expand All @@ -210,7 +211,7 @@ func TestGetConnectionString(t *testing.T) {
db.metadata.reset()
db.metadata.ConnectionString = "file::memory:"

connString, err := db.metadata.GetConnectionString(log)
connString, err := db.metadata.GetConnectionString(log, sqlite.GetConnectionStringOpts{})
require.NoError(t, err)

values := url.Values{
Expand All @@ -226,7 +227,7 @@ func TestGetConnectionString(t *testing.T) {
db.metadata.reset()
db.metadata.ConnectionString = "file:test.db?mode=ro"

connString, err := db.metadata.GetConnectionString(log)
connString, err := db.metadata.GetConnectionString(log, sqlite.GetConnectionStringOpts{})
require.NoError(t, err)

values := url.Values{
Expand All @@ -242,7 +243,7 @@ func TestGetConnectionString(t *testing.T) {
db.metadata.reset()
db.metadata.ConnectionString = "file:test.db?immutable=1"

connString, err := db.metadata.GetConnectionString(log)
connString, err := db.metadata.GetConnectionString(log, sqlite.GetConnectionStringOpts{})
require.NoError(t, err)

values := url.Values{
Expand Down

0 comments on commit 79adc56

Please sign in to comment.