diff --git a/config/config.go b/config/config.go index 8656d4a83fb..732791fea90 100644 --- a/config/config.go +++ b/config/config.go @@ -850,20 +850,6 @@ func SetupViper(v *viper.Viper, filename string, bidderInfos BidderInfos) { v.SetDefault("stored_requests.filesystem.enabled", false) v.SetDefault("stored_requests.filesystem.directorypath", "./stored_requests/data/by_id") v.SetDefault("stored_requests.directorypath", "./stored_requests/data/by_id") - v.SetDefault("stored_requests.postgres.connection.dbname", "") - v.SetDefault("stored_requests.postgres.connection.host", "") - v.SetDefault("stored_requests.postgres.connection.port", 0) - v.SetDefault("stored_requests.postgres.connection.user", "") - v.SetDefault("stored_requests.postgres.connection.password", "") - v.SetDefault("stored_requests.postgres.fetcher.query", "") - v.SetDefault("stored_requests.postgres.fetcher.amp_query", "") - v.SetDefault("stored_requests.postgres.initialize_caches.timeout_ms", 0) - v.SetDefault("stored_requests.postgres.initialize_caches.query", "") - v.SetDefault("stored_requests.postgres.initialize_caches.amp_query", "") - v.SetDefault("stored_requests.postgres.poll_for_updates.refresh_rate_seconds", 0) - v.SetDefault("stored_requests.postgres.poll_for_updates.timeout_ms", 0) - v.SetDefault("stored_requests.postgres.poll_for_updates.query", "") - v.SetDefault("stored_requests.postgres.poll_for_updates.amp_query", "") v.SetDefault("stored_requests.http.endpoint", "") v.SetDefault("stored_requests.http.amp_endpoint", "") v.SetDefault("stored_requests.in_memory_cache.type", "none") @@ -880,17 +866,6 @@ func SetupViper(v *viper.Viper, filename string, bidderInfos BidderInfos) { // PBS is not in the business of storing video content beyond the normal prebid cache system. v.SetDefault("stored_video_req.filesystem.enabled", false) v.SetDefault("stored_video_req.filesystem.directorypath", "") - v.SetDefault("stored_video_req.postgres.connection.dbname", "") - v.SetDefault("stored_video_req.postgres.connection.host", "") - v.SetDefault("stored_video_req.postgres.connection.port", 0) - v.SetDefault("stored_video_req.postgres.connection.user", "") - v.SetDefault("stored_video_req.postgres.connection.password", "") - v.SetDefault("stored_video_req.postgres.fetcher.query", "") - v.SetDefault("stored_video_req.postgres.initialize_caches.timeout_ms", 0) - v.SetDefault("stored_video_req.postgres.initialize_caches.query", "") - v.SetDefault("stored_video_req.postgres.poll_for_updates.refresh_rate_seconds", 0) - v.SetDefault("stored_video_req.postgres.poll_for_updates.timeout_ms", 0) - v.SetDefault("stored_video_req.postgres.poll_for_updates.query", "") v.SetDefault("stored_video_req.http.endpoint", "") v.SetDefault("stored_video_req.in_memory_cache.type", "none") v.SetDefault("stored_video_req.in_memory_cache.ttl_seconds", 0) @@ -904,17 +879,6 @@ func SetupViper(v *viper.Viper, filename string, bidderInfos BidderInfos) { v.SetDefault("stored_video_req.http_events.timeout_ms", 0) v.SetDefault("stored_responses.filesystem.enabled", false) v.SetDefault("stored_responses.filesystem.directorypath", "") - v.SetDefault("stored_responses.postgres.connection.dbname", "") - v.SetDefault("stored_responses.postgres.connection.host", "") - v.SetDefault("stored_responses.postgres.connection.port", 0) - v.SetDefault("stored_responses.postgres.connection.user", "") - v.SetDefault("stored_responses.postgres.connection.password", "") - v.SetDefault("stored_responses.postgres.fetcher.query", "") - v.SetDefault("stored_responses.postgres.initialize_caches.timeout_ms", 0) - v.SetDefault("stored_responses.postgres.initialize_caches.query", "") - v.SetDefault("stored_responses.postgres.poll_for_updates.refresh_rate_seconds", 0) - v.SetDefault("stored_responses.postgres.poll_for_updates.timeout_ms", 0) - v.SetDefault("stored_responses.postgres.poll_for_updates.query", "") v.SetDefault("stored_responses.http.endpoint", "") v.SetDefault("stored_responses.in_memory_cache.type", "none") v.SetDefault("stored_responses.in_memory_cache.ttl_seconds", 0) @@ -1039,6 +1003,7 @@ func SetupViper(v *viper.Viper, filename string, bidderInfos BidderInfos) { migrateConfigPurposeOneTreatment(v) migrateConfigSpecialFeature1(v) migrateConfigTCF2PurposeFlags(v) + migrateConfigDatabaseConnection(v) // These defaults must be set after the migrate functions because those functions look for the presence of these // config fields and there isn't a way to detect presence of a config field using the viper package if a default @@ -1190,6 +1155,170 @@ func migrateConfigTCF2PurposeEnabledFlags(v *viper.Viper) { } } +func migrateConfigDatabaseConnection(v *viper.Viper) { + + type QueryParamMigration struct { + old string + new string + } + + type QueryMigration struct { + name string + params []QueryParamMigration + } + + type Migration struct { + old string + new string + fields []string + queryMigrations []QueryMigration + } + + queryParamMigrations := struct { + RequestIdList QueryParamMigration + ImpIdList QueryParamMigration + IdList QueryParamMigration + LastUpdated QueryParamMigration + }{ + RequestIdList: QueryParamMigration{ + old: "%REQUEST_ID_LIST%", + new: "$REQUEST_ID_LIST", + }, + ImpIdList: QueryParamMigration{ + old: "%IMP_ID_LIST%", + new: "$IMP_ID_LIST", + }, + IdList: QueryParamMigration{ + old: "%ID_LIST%", + new: "$ID_LIST", + }, + LastUpdated: QueryParamMigration{ + old: "$1", + new: "$LAST_UPDATED", + }, + } + + queryMigrations := []QueryMigration{ + { + name: "fetcher.query", + params: []QueryParamMigration{queryParamMigrations.RequestIdList, queryParamMigrations.ImpIdList, queryParamMigrations.IdList}, + }, + { + name: "fetcher.amp_query", + params: []QueryParamMigration{queryParamMigrations.RequestIdList, queryParamMigrations.ImpIdList, queryParamMigrations.IdList}, + }, + { + name: "poll_for_updates.query", + params: []QueryParamMigration{queryParamMigrations.LastUpdated}, + }, + { + name: "poll_for_updates.amp_query", + params: []QueryParamMigration{queryParamMigrations.LastUpdated}, + }, + } + + migrations := []Migration{ + { + old: "stored_requests.postgres", + new: "stored_requests.database", + fields: []string{ + "connection.dbname", + "connection.host", + "connection.port", + "connection.user", + "connection.password", + "fetcher.query", + "fetcher.amp_query", + "initialize_caches.timeout_ms", + "initialize_caches.query", + "initialize_caches.amp_query", + "poll_for_updates.refresh_rate_seconds", + "poll_for_updates.timeout_ms", + "poll_for_updates.query", + "poll_for_updates.amp_query", + }, + queryMigrations: queryMigrations, + }, + { + old: "stored_video_req.postgres", + new: "stored_video_req.database", + fields: []string{ + "connection.dbname", + "connection.host", + "connection.port", + "connection.user", + "connection.password", + "fetcher.query", + "initialize_caches.timeout_ms", + "initialize_caches.query", + "poll_for_updates.refresh_rate_seconds", + "poll_for_updates.timeout_ms", + "poll_for_updates.query", + }, + queryMigrations: queryMigrations, + }, + { + old: "stored_responses.postgres", + new: "stored_responses.database", + fields: []string{ + "connection.dbname", + "connection.host", + "connection.port", + "connection.user", + "connection.password", + "fetcher.query", + "initialize_caches.timeout_ms", + "initialize_caches.query", + "poll_for_updates.refresh_rate_seconds", + "poll_for_updates.timeout_ms", + "poll_for_updates.query", + }, + queryMigrations: queryMigrations, + }, + } + + for _, migration := range migrations { + driverField := migration.new + ".connection.driver" + if !v.IsSet(migration.new) && v.IsSet(migration.old) { + glog.Warning(fmt.Sprintf("%s is deprecated and should be changed to %s", migration.old, migration.new)) + glog.Warning(fmt.Sprintf("%s is not set, using default (postgres)", driverField)) + v.Set(driverField, "postgres") + + for _, field := range migration.fields { + oldField := migration.old + "." + field + newField := migration.new + "." + field + if v.IsSet(oldField) { + glog.Warning(fmt.Sprintf("%s is deprecated and should be changed to %s", oldField, newField)) + v.Set(newField, v.Get(oldField)) + } + } + + for _, queryMigration := range migration.queryMigrations { + oldQueryField := migration.old + "." + queryMigration.name + newQueryField := migration.new + "." + queryMigration.name + queryString := v.GetString(oldQueryField) + for _, queryParam := range queryMigration.params { + if strings.Contains(queryString, queryParam.old) { + glog.Warning(fmt.Sprintf("Query param %s for %s is deprecated and should be changed to %s", queryParam.old, oldQueryField, queryParam.new)) + queryString = strings.ReplaceAll(queryString, queryParam.old, queryParam.new) + v.Set(newQueryField, queryString) + } + } + } + } else if v.IsSet(migration.new) && v.IsSet(migration.old) { + glog.Warning(fmt.Sprintf("using %s and ignoring deprecated %s", migration.new, migration.old)) + + for _, field := range migration.fields { + oldField := migration.old + "." + field + newField := migration.new + "." + field + if v.IsSet(oldField) { + glog.Warning(fmt.Sprintf("using %s and ignoring deprecated %s", newField, oldField)) + } + } + } + } +} + func setBidderDefaults(v *viper.Viper, bidder string) { adapterCfgPrefix := "adapters." + bidder v.BindEnv(adapterCfgPrefix+".disabled", "") diff --git a/config/config_test.go b/config/config_test.go index e2c53c631a9..b854a8ed502 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -1611,6 +1611,691 @@ func TestMigrateConfigTCF2EnforcePurposeFlags(t *testing.T) { } } +func TestMigrateConfigDatabaseConnection(t *testing.T) { + type configs struct { + old []byte + new []byte + both []byte + } + + // Stored Requests Config Migration + storedReqestsConfigs := configs{ + old: []byte(` + stored_requests: + postgres: + connection: + dbname: "old_connection_dbname" + host: "old_connection_host" + port: 1000 + user: "old_connection_user" + password: "old_connection_password" + fetcher: + query: "old_fetcher_query" + amp_query: "old_fetcher_amp_query" + initialize_caches: + timeout_ms: 1000 + query: "old_initialize_caches_query" + amp_query: "old_initialize_caches_amp_query" + poll_for_updates: + refresh_rate_seconds: 1000 + timeout_ms: 1000 + query: "old_poll_for_updates_query" + amp_query: "old_poll_for_updates_amp_query" + `), + new: []byte(` + stored_requests: + database: + connection: + dbname: "new_connection_dbname" + host: "new_connection_host" + port: 2000 + user: "new_connection_user" + password: "new_connection_password" + fetcher: + query: "new_fetcher_query" + amp_query: "new_fetcher_amp_query" + initialize_caches: + timeout_ms: 2000 + query: "new_initialize_caches_query" + amp_query: "new_initialize_caches_amp_query" + poll_for_updates: + refresh_rate_seconds: 2000 + timeout_ms: 2000 + query: "new_poll_for_updates_query" + amp_query: "new_poll_for_updates_amp_query" + `), + both: []byte(` + stored_requests: + postgres: + connection: + dbname: "old_connection_dbname" + host: "old_connection_host" + port: 1000 + user: "old_connection_user" + password: "old_connection_password" + fetcher: + query: "old_fetcher_query" + amp_query: "old_fetcher_amp_query" + initialize_caches: + timeout_ms: 1000 + query: "old_initialize_caches_query" + amp_query: "old_initialize_caches_amp_query" + poll_for_updates: + refresh_rate_seconds: 1000 + timeout_ms: 1000 + query: "old_poll_for_updates_query" + amp_query: "old_poll_for_updates_amp_query" + database: + connection: + dbname: "new_connection_dbname" + host: "new_connection_host" + port: 2000 + user: "new_connection_user" + password: "new_connection_password" + fetcher: + query: "new_fetcher_query" + amp_query: "new_fetcher_amp_query" + initialize_caches: + timeout_ms: 2000 + query: "new_initialize_caches_query" + amp_query: "new_initialize_caches_amp_query" + poll_for_updates: + refresh_rate_seconds: 2000 + timeout_ms: 2000 + query: "new_poll_for_updates_query" + amp_query: "new_poll_for_updates_amp_query" + `), + } + + storedRequestsTests := []struct { + description string + config []byte + + want_connection_dbname string + want_connection_host string + want_connection_port int + want_connection_user string + want_connection_password string + want_fetcher_query string + want_fetcher_amp_query string + want_initialize_caches_timeout_ms int + want_initialize_caches_query string + want_initialize_caches_amp_query string + want_poll_for_updates_refresh_rate_seconds int + want_poll_for_updates_timeout_ms int + want_poll_for_updates_query string + want_poll_for_updates_amp_query string + }{ + { + description: "New config and old config not set", + config: []byte{}, + }, + { + description: "New config not set, old config set", + config: storedReqestsConfigs.old, + + want_connection_dbname: "old_connection_dbname", + want_connection_host: "old_connection_host", + want_connection_port: 1000, + want_connection_user: "old_connection_user", + want_connection_password: "old_connection_password", + want_fetcher_query: "old_fetcher_query", + want_fetcher_amp_query: "old_fetcher_amp_query", + want_initialize_caches_timeout_ms: 1000, + want_initialize_caches_query: "old_initialize_caches_query", + want_initialize_caches_amp_query: "old_initialize_caches_amp_query", + want_poll_for_updates_refresh_rate_seconds: 1000, + want_poll_for_updates_timeout_ms: 1000, + want_poll_for_updates_query: "old_poll_for_updates_query", + want_poll_for_updates_amp_query: "old_poll_for_updates_amp_query", + }, + { + description: "New config set, old config not set", + config: storedReqestsConfigs.new, + + want_connection_dbname: "new_connection_dbname", + want_connection_host: "new_connection_host", + want_connection_port: 2000, + want_connection_user: "new_connection_user", + want_connection_password: "new_connection_password", + want_fetcher_query: "new_fetcher_query", + want_fetcher_amp_query: "new_fetcher_amp_query", + want_initialize_caches_timeout_ms: 2000, + want_initialize_caches_query: "new_initialize_caches_query", + want_initialize_caches_amp_query: "new_initialize_caches_amp_query", + want_poll_for_updates_refresh_rate_seconds: 2000, + want_poll_for_updates_timeout_ms: 2000, + want_poll_for_updates_query: "new_poll_for_updates_query", + want_poll_for_updates_amp_query: "new_poll_for_updates_amp_query", + }, + { + description: "New config and old config set", + config: storedReqestsConfigs.both, + + want_connection_dbname: "new_connection_dbname", + want_connection_host: "new_connection_host", + want_connection_port: 2000, + want_connection_user: "new_connection_user", + want_connection_password: "new_connection_password", + want_fetcher_query: "new_fetcher_query", + want_fetcher_amp_query: "new_fetcher_amp_query", + want_initialize_caches_timeout_ms: 2000, + want_initialize_caches_query: "new_initialize_caches_query", + want_initialize_caches_amp_query: "new_initialize_caches_amp_query", + want_poll_for_updates_refresh_rate_seconds: 2000, + want_poll_for_updates_timeout_ms: 2000, + want_poll_for_updates_query: "new_poll_for_updates_query", + want_poll_for_updates_amp_query: "new_poll_for_updates_amp_query", + }, + } + + for _, tt := range storedRequestsTests { + v := viper.New() + v.SetConfigType("yaml") + v.ReadConfig(bytes.NewBuffer(tt.config)) + + migrateConfigDatabaseConnection(v) + + if len(tt.config) > 0 { + assert.Equal(t, tt.want_connection_dbname, v.GetString("stored_requests.database.connection.dbname"), tt.description) + assert.Equal(t, tt.want_connection_host, v.GetString("stored_requests.database.connection.host"), tt.description) + assert.Equal(t, tt.want_connection_port, v.GetInt("stored_requests.database.connection.port"), tt.description) + assert.Equal(t, tt.want_connection_user, v.GetString("stored_requests.database.connection.user"), tt.description) + assert.Equal(t, tt.want_connection_password, v.GetString("stored_requests.database.connection.password"), tt.description) + assert.Equal(t, tt.want_fetcher_query, v.GetString("stored_requests.database.fetcher.query"), tt.description) + assert.Equal(t, tt.want_fetcher_amp_query, v.GetString("stored_requests.database.fetcher.amp_query"), tt.description) + assert.Equal(t, tt.want_initialize_caches_timeout_ms, v.GetInt("stored_requests.database.initialize_caches.timeout_ms"), tt.description) + assert.Equal(t, tt.want_initialize_caches_query, v.GetString("stored_requests.database.initialize_caches.query"), tt.description) + assert.Equal(t, tt.want_initialize_caches_amp_query, v.GetString("stored_requests.database.initialize_caches.amp_query"), tt.description) + assert.Equal(t, tt.want_poll_for_updates_refresh_rate_seconds, v.GetInt("stored_requests.database.poll_for_updates.refresh_rate_seconds"), tt.description) + assert.Equal(t, tt.want_poll_for_updates_timeout_ms, v.GetInt("stored_requests.database.poll_for_updates.timeout_ms"), tt.description) + assert.Equal(t, tt.want_poll_for_updates_query, v.GetString("stored_requests.database.poll_for_updates.query"), tt.description) + assert.Equal(t, tt.want_poll_for_updates_amp_query, v.GetString("stored_requests.database.poll_for_updates.amp_query"), tt.description) + } else { + assert.Nil(t, v.Get("stored_requests.database.connection.dbname"), tt.description) + assert.Nil(t, v.Get("stored_requests.database.connection.host"), tt.description) + assert.Nil(t, v.Get("stored_requests.database.connection.port"), tt.description) + assert.Nil(t, v.Get("stored_requests.database.connection.user"), tt.description) + assert.Nil(t, v.Get("stored_requests.database.connection.password"), tt.description) + assert.Nil(t, v.Get("stored_requests.database.fetcher.query"), tt.description) + assert.Nil(t, v.Get("stored_requests.database.fetcher.amp_query"), tt.description) + assert.Nil(t, v.Get("stored_requests.database.initialize_caches.timeout_ms"), tt.description) + assert.Nil(t, v.Get("stored_requests.database.initialize_caches.query"), tt.description) + assert.Nil(t, v.Get("stored_requests.database.initialize_caches.amp_query"), tt.description) + assert.Nil(t, v.Get("stored_requests.database.poll_for_updates.refresh_rate_seconds"), tt.description) + assert.Nil(t, v.Get("stored_requests.database.poll_for_updates.timeout_ms"), tt.description) + assert.Nil(t, v.Get("stored_requests.database.poll_for_updates.query"), tt.description) + assert.Nil(t, v.Get("stored_requests.database.poll_for_updates.amp_query"), tt.description) + } + } + + // Stored Video Reqs Config Migration + storedVideoReqsConfigs := configs{ + old: []byte(` + stored_video_req: + postgres: + connection: + dbname: "old_connection_dbname" + host: "old_connection_host" + port: 1000 + user: "old_connection_user" + password: "old_connection_password" + fetcher: + query: "old_fetcher_query" + initialize_caches: + timeout_ms: 1000 + query: "old_initialize_caches_query" + poll_for_updates: + refresh_rate_seconds: 1000 + timeout_ms: 1000 + query: "old_poll_for_updates_query" + `), + new: []byte(` + stored_video_req: + database: + connection: + dbname: "new_connection_dbname" + host: "new_connection_host" + port: 2000 + user: "new_connection_user" + password: "new_connection_password" + fetcher: + query: "new_fetcher_query" + initialize_caches: + timeout_ms: 2000 + query: "new_initialize_caches_query" + poll_for_updates: + refresh_rate_seconds: 2000 + timeout_ms: 2000 + query: "new_poll_for_updates_query" + `), + both: []byte(` + stored_video_req: + postgres: + connection: + dbname: "old_connection_dbname" + host: "old_connection_host" + port: 1000 + user: "old_connection_user" + password: "old_connection_password" + fetcher: + query: "old_fetcher_query" + initialize_caches: + timeout_ms: 1000 + query: "old_initialize_caches_query" + poll_for_updates: + refresh_rate_seconds: 1000 + timeout_ms: 1000 + query: "old_poll_for_updates_query" + database: + connection: + dbname: "new_connection_dbname" + host: "new_connection_host" + port: 2000 + user: "new_connection_user" + password: "new_connection_password" + fetcher: + query: "new_fetcher_query" + initialize_caches: + timeout_ms: 2000 + query: "new_initialize_caches_query" + poll_for_updates: + refresh_rate_seconds: 2000 + timeout_ms: 2000 + query: "new_poll_for_updates_query" + `), + } + + storedVideoReqsTests := []struct { + description string + config []byte + + want_connection_dbname string + want_connection_host string + want_connection_port int + want_connection_user string + want_connection_password string + want_fetcher_query string + want_initialize_caches_timeout_ms int + want_initialize_caches_query string + want_poll_for_updates_refresh_rate_seconds int + want_poll_for_updates_timeout_ms int + want_poll_for_updates_query string + }{ + { + description: "New config and old config not set", + config: []byte{}, + }, + { + description: "New config not set, old config set", + config: storedVideoReqsConfigs.old, + + want_connection_dbname: "old_connection_dbname", + want_connection_host: "old_connection_host", + want_connection_port: 1000, + want_connection_user: "old_connection_user", + want_connection_password: "old_connection_password", + want_fetcher_query: "old_fetcher_query", + want_initialize_caches_timeout_ms: 1000, + want_initialize_caches_query: "old_initialize_caches_query", + want_poll_for_updates_refresh_rate_seconds: 1000, + want_poll_for_updates_timeout_ms: 1000, + want_poll_for_updates_query: "old_poll_for_updates_query", + }, + { + description: "New config set, old config not set", + config: storedVideoReqsConfigs.new, + + want_connection_dbname: "new_connection_dbname", + want_connection_host: "new_connection_host", + want_connection_port: 2000, + want_connection_user: "new_connection_user", + want_connection_password: "new_connection_password", + want_fetcher_query: "new_fetcher_query", + want_initialize_caches_timeout_ms: 2000, + want_initialize_caches_query: "new_initialize_caches_query", + want_poll_for_updates_refresh_rate_seconds: 2000, + want_poll_for_updates_timeout_ms: 2000, + want_poll_for_updates_query: "new_poll_for_updates_query", + }, + { + description: "New config and old config set", + config: storedVideoReqsConfigs.both, + + want_connection_dbname: "new_connection_dbname", + want_connection_host: "new_connection_host", + want_connection_port: 2000, + want_connection_user: "new_connection_user", + want_connection_password: "new_connection_password", + want_fetcher_query: "new_fetcher_query", + want_initialize_caches_timeout_ms: 2000, + want_initialize_caches_query: "new_initialize_caches_query", + want_poll_for_updates_refresh_rate_seconds: 2000, + want_poll_for_updates_timeout_ms: 2000, + want_poll_for_updates_query: "new_poll_for_updates_query", + }, + } + + for _, tt := range storedVideoReqsTests { + v := viper.New() + v.SetConfigType("yaml") + v.ReadConfig(bytes.NewBuffer(tt.config)) + + migrateConfigDatabaseConnection(v) + + if len(tt.config) > 0 { + assert.Equal(t, tt.want_connection_dbname, v.Get("stored_video_req.database.connection.dbname").(string), tt.description) + assert.Equal(t, tt.want_connection_host, v.Get("stored_video_req.database.connection.host").(string), tt.description) + assert.Equal(t, tt.want_connection_port, v.Get("stored_video_req.database.connection.port").(int), tt.description) + assert.Equal(t, tt.want_connection_user, v.Get("stored_video_req.database.connection.user").(string), tt.description) + assert.Equal(t, tt.want_connection_password, v.Get("stored_video_req.database.connection.password").(string), tt.description) + assert.Equal(t, tt.want_fetcher_query, v.Get("stored_video_req.database.fetcher.query").(string), tt.description) + assert.Equal(t, tt.want_initialize_caches_timeout_ms, v.Get("stored_video_req.database.initialize_caches.timeout_ms").(int), tt.description) + assert.Equal(t, tt.want_initialize_caches_query, v.Get("stored_video_req.database.initialize_caches.query").(string), tt.description) + assert.Equal(t, tt.want_poll_for_updates_refresh_rate_seconds, v.Get("stored_video_req.database.poll_for_updates.refresh_rate_seconds").(int), tt.description) + assert.Equal(t, tt.want_poll_for_updates_timeout_ms, v.Get("stored_video_req.database.poll_for_updates.timeout_ms").(int), tt.description) + assert.Equal(t, tt.want_poll_for_updates_query, v.Get("stored_video_req.database.poll_for_updates.query").(string), tt.description) + } else { + assert.Nil(t, v.Get("stored_video_req.database.connection.dbname"), tt.description) + assert.Nil(t, v.Get("stored_video_req.database.connection.host"), tt.description) + assert.Nil(t, v.Get("stored_video_req.database.connection.port"), tt.description) + assert.Nil(t, v.Get("stored_video_req.database.connection.user"), tt.description) + assert.Nil(t, v.Get("stored_video_req.database.connection.password"), tt.description) + assert.Nil(t, v.Get("stored_video_req.database.fetcher.query"), tt.description) + assert.Nil(t, v.Get("stored_video_req.database.initialize_caches.timeout_ms"), tt.description) + assert.Nil(t, v.Get("stored_video_req.database.initialize_caches.query"), tt.description) + assert.Nil(t, v.Get("stored_video_req.database.poll_for_updates.refresh_rate_seconds"), tt.description) + assert.Nil(t, v.Get("stored_video_req.database.poll_for_updates.timeout_ms"), tt.description) + assert.Nil(t, v.Get("stored_video_req.database.poll_for_updates.query"), tt.description) + } + } + + // Stored Responses Config Migration + storedResponsesConfigs := configs{ + old: []byte(` + stored_responses: + postgres: + connection: + dbname: "old_connection_dbname" + host: "old_connection_host" + port: 1000 + user: "old_connection_user" + password: "old_connection_password" + fetcher: + query: "old_fetcher_query" + initialize_caches: + timeout_ms: 1000 + query: "old_initialize_caches_query" + poll_for_updates: + refresh_rate_seconds: 1000 + timeout_ms: 1000 + query: "old_poll_for_updates_query" + `), + new: []byte(` + stored_responses: + database: + connection: + dbname: "new_connection_dbname" + host: "new_connection_host" + port: 2000 + user: "new_connection_user" + password: "new_connection_password" + fetcher: + query: "new_fetcher_query" + initialize_caches: + timeout_ms: 2000 + query: "new_initialize_caches_query" + poll_for_updates: + refresh_rate_seconds: 2000 + timeout_ms: 2000 + query: "new_poll_for_updates_query" + `), + both: []byte(` + stored_responses: + postgres: + connection: + dbname: "old_connection_dbname" + host: "old_connection_host" + port: 1000 + user: "old_connection_user" + password: "old_connection_password" + fetcher: + query: "old_fetcher_query" + initialize_caches: + timeout_ms: 1000 + query: "old_initialize_caches_query" + poll_for_updates: + refresh_rate_seconds: 1000 + timeout_ms: 1000 + query: "old_poll_for_updates_query" + database: + connection: + dbname: "new_connection_dbname" + host: "new_connection_host" + port: 2000 + user: "new_connection_user" + password: "new_connection_password" + fetcher: + query: "new_fetcher_query" + initialize_caches: + timeout_ms: 2000 + query: "new_initialize_caches_query" + poll_for_updates: + refresh_rate_seconds: 2000 + timeout_ms: 2000 + query: "new_poll_for_updates_query" + `), + } + + storedResponsesTests := []struct { + description string + config []byte + + want_connection_dbname string + want_connection_host string + want_connection_port int + want_connection_user string + want_connection_password string + want_fetcher_query string + want_initialize_caches_timeout_ms int + want_initialize_caches_query string + want_poll_for_updates_refresh_rate_seconds int + want_poll_for_updates_timeout_ms int + want_poll_for_updates_query string + }{ + { + description: "New config and old config not set", + config: []byte{}, + }, + { + description: "New config not set, old config set", + config: storedResponsesConfigs.old, + + want_connection_dbname: "old_connection_dbname", + want_connection_host: "old_connection_host", + want_connection_port: 1000, + want_connection_user: "old_connection_user", + want_connection_password: "old_connection_password", + want_fetcher_query: "old_fetcher_query", + want_initialize_caches_timeout_ms: 1000, + want_initialize_caches_query: "old_initialize_caches_query", + want_poll_for_updates_refresh_rate_seconds: 1000, + want_poll_for_updates_timeout_ms: 1000, + want_poll_for_updates_query: "old_poll_for_updates_query", + }, + { + description: "New config set, old config not set", + config: storedResponsesConfigs.new, + + want_connection_dbname: "new_connection_dbname", + want_connection_host: "new_connection_host", + want_connection_port: 2000, + want_connection_user: "new_connection_user", + want_connection_password: "new_connection_password", + want_fetcher_query: "new_fetcher_query", + want_initialize_caches_timeout_ms: 2000, + want_initialize_caches_query: "new_initialize_caches_query", + want_poll_for_updates_refresh_rate_seconds: 2000, + want_poll_for_updates_timeout_ms: 2000, + want_poll_for_updates_query: "new_poll_for_updates_query", + }, + { + description: "New config and old config set", + config: storedResponsesConfigs.both, + + want_connection_dbname: "new_connection_dbname", + want_connection_host: "new_connection_host", + want_connection_port: 2000, + want_connection_user: "new_connection_user", + want_connection_password: "new_connection_password", + want_fetcher_query: "new_fetcher_query", + want_initialize_caches_timeout_ms: 2000, + want_initialize_caches_query: "new_initialize_caches_query", + want_poll_for_updates_refresh_rate_seconds: 2000, + want_poll_for_updates_timeout_ms: 2000, + want_poll_for_updates_query: "new_poll_for_updates_query", + }, + } + + for _, tt := range storedResponsesTests { + v := viper.New() + v.SetConfigType("yaml") + v.ReadConfig(bytes.NewBuffer(tt.config)) + + migrateConfigDatabaseConnection(v) + + if len(tt.config) > 0 { + assert.Equal(t, tt.want_connection_dbname, v.Get("stored_responses.database.connection.dbname").(string), tt.description) + assert.Equal(t, tt.want_connection_host, v.Get("stored_responses.database.connection.host").(string), tt.description) + assert.Equal(t, tt.want_connection_port, v.Get("stored_responses.database.connection.port").(int), tt.description) + assert.Equal(t, tt.want_connection_user, v.Get("stored_responses.database.connection.user").(string), tt.description) + assert.Equal(t, tt.want_connection_password, v.Get("stored_responses.database.connection.password").(string), tt.description) + assert.Equal(t, tt.want_fetcher_query, v.Get("stored_responses.database.fetcher.query").(string), tt.description) + assert.Equal(t, tt.want_initialize_caches_timeout_ms, v.Get("stored_responses.database.initialize_caches.timeout_ms").(int), tt.description) + assert.Equal(t, tt.want_initialize_caches_query, v.Get("stored_responses.database.initialize_caches.query").(string), tt.description) + assert.Equal(t, tt.want_poll_for_updates_refresh_rate_seconds, v.Get("stored_responses.database.poll_for_updates.refresh_rate_seconds").(int), tt.description) + assert.Equal(t, tt.want_poll_for_updates_timeout_ms, v.Get("stored_responses.database.poll_for_updates.timeout_ms").(int), tt.description) + assert.Equal(t, tt.want_poll_for_updates_query, v.Get("stored_responses.database.poll_for_updates.query").(string), tt.description) + } else { + assert.Nil(t, v.Get("stored_responses.database.connection.dbname"), tt.description) + assert.Nil(t, v.Get("stored_responses.database.connection.host"), tt.description) + assert.Nil(t, v.Get("stored_responses.database.connection.port"), tt.description) + assert.Nil(t, v.Get("stored_responses.database.connection.user"), tt.description) + assert.Nil(t, v.Get("stored_responses.database.connection.password"), tt.description) + assert.Nil(t, v.Get("stored_responses.database.fetcher.query"), tt.description) + assert.Nil(t, v.Get("stored_responses.database.initialize_caches.timeout_ms"), tt.description) + assert.Nil(t, v.Get("stored_responses.database.initialize_caches.query"), tt.description) + assert.Nil(t, v.Get("stored_responses.database.poll_for_updates.refresh_rate_seconds"), tt.description) + assert.Nil(t, v.Get("stored_responses.database.poll_for_updates.timeout_ms"), tt.description) + assert.Nil(t, v.Get("stored_responses.database.poll_for_updates.query"), tt.description) + } + } +} + +func TestMigrateConfigDatabaseQueryParams(t *testing.T) { + + config := []byte(` + stored_requests: + postgres: + fetcher: + query: + SELECT * FROM Table1 WHERE id in (%REQUEST_ID_LIST%) + UNION ALL + SELECT * FROM Table2 WHERE id in (%IMP_ID_LIST%) + UNION ALL + SELECT * FROM Table3 WHERE id in (%ID_LIST%) + amp_query: + SELECT * FROM Table1 WHERE id in (%REQUEST_ID_LIST%) + UNION ALL + SELECT * FROM Table2 WHERE id in (%IMP_ID_LIST%) + UNION ALL + SELECT * FROM Table3 WHERE id in (%ID_LIST%) + poll_for_updates: + query: "SELECT * FROM Table1 WHERE last_updated > $1 UNION ALL SELECT * FROM Table2 WHERE last_updated > $1" + amp_query: "SELECT * FROM Table1 WHERE last_updated > $1 UNION ALL SELECT * FROM Table2 WHERE last_updated > $1" + stored_video_req: + postgres: + fetcher: + query: + SELECT * FROM Table1 WHERE id in (%REQUEST_ID_LIST%) + UNION ALL + SELECT * FROM Table2 WHERE id in (%IMP_ID_LIST%) + UNION ALL + SELECT * FROM Table3 WHERE id in (%ID_LIST%) + amp_query: + SELECT * FROM Table1 WHERE id in (%REQUEST_ID_LIST%) + UNION ALL + SELECT * FROM Table2 WHERE id in (%IMP_ID_LIST%) + UNION ALL + SELECT * FROM Table3 WHERE id in (%ID_LIST%) + poll_for_updates: + query: "SELECT * FROM Table1 WHERE last_updated > $1 UNION ALL SELECT * FROM Table2 WHERE last_updated > $1" + amp_query: "SELECT * FROM Table1 WHERE last_updated > $1 UNION ALL SELECT * FROM Table2 WHERE last_updated > $1" + stored_responses: + postgres: + fetcher: + query: + SELECT * FROM Table1 WHERE id in (%REQUEST_ID_LIST%) + UNION ALL + SELECT * FROM Table2 WHERE id in (%IMP_ID_LIST%) + UNION ALL + SELECT * FROM Table3 WHERE id in (%ID_LIST%) + amp_query: + SELECT * FROM Table1 WHERE id in (%REQUEST_ID_LIST%) + UNION ALL + SELECT * FROM Table2 WHERE id in (%IMP_ID_LIST%) + UNION ALL + SELECT * FROM Table3 WHERE id in (%ID_LIST%) + poll_for_updates: + query: "SELECT * FROM Table1 WHERE last_updated > $1 UNION ALL SELECT * FROM Table2 WHERE last_updated > $1" + amp_query: "SELECT * FROM Table1 WHERE last_updated > $1 UNION ALL SELECT * FROM Table2 WHERE last_updated > $1" + `) + + want_queries := struct { + fetcher_query string + fetcher_amp_query string + poll_for_updates_query string + poll_for_updates_amp_query string + }{ + fetcher_query: "SELECT * FROM Table1 WHERE id in ($REQUEST_ID_LIST) " + + "UNION ALL " + + "SELECT * FROM Table2 WHERE id in ($IMP_ID_LIST) " + + "UNION ALL " + + "SELECT * FROM Table3 WHERE id in ($ID_LIST)", + fetcher_amp_query: "SELECT * FROM Table1 WHERE id in ($REQUEST_ID_LIST) " + + "UNION ALL " + + "SELECT * FROM Table2 WHERE id in ($IMP_ID_LIST) " + + "UNION ALL " + + "SELECT * FROM Table3 WHERE id in ($ID_LIST)", + poll_for_updates_query: "SELECT * FROM Table1 WHERE last_updated > $LAST_UPDATED UNION ALL SELECT * FROM Table2 WHERE last_updated > $LAST_UPDATED", + poll_for_updates_amp_query: "SELECT * FROM Table1 WHERE last_updated > $LAST_UPDATED UNION ALL SELECT * FROM Table2 WHERE last_updated > $LAST_UPDATED", + } + + v := viper.New() + v.SetConfigType("yaml") + v.ReadConfig(bytes.NewBuffer(config)) + + migrateConfigDatabaseConnection(v) + + // stored_requests queries + assert.Equal(t, want_queries.fetcher_query, v.GetString("stored_requests.database.fetcher.query")) + assert.Equal(t, want_queries.fetcher_amp_query, v.GetString("stored_requests.database.fetcher.amp_query")) + assert.Equal(t, want_queries.poll_for_updates_query, v.GetString("stored_requests.database.poll_for_updates.query")) + assert.Equal(t, want_queries.poll_for_updates_amp_query, v.GetString("stored_requests.database.poll_for_updates.amp_query")) + + // stored_video_req queries + assert.Equal(t, want_queries.fetcher_query, v.GetString("stored_video_req.database.fetcher.query")) + assert.Equal(t, want_queries.fetcher_amp_query, v.GetString("stored_video_req.database.fetcher.amp_query")) + assert.Equal(t, want_queries.poll_for_updates_query, v.GetString("stored_video_req.database.poll_for_updates.query")) + assert.Equal(t, want_queries.poll_for_updates_amp_query, v.GetString("stored_video_req.database.poll_for_updates.amp_query")) + + // stored_responses queries + assert.Equal(t, want_queries.fetcher_query, v.GetString("stored_responses.database.fetcher.query")) + assert.Equal(t, want_queries.fetcher_amp_query, v.GetString("stored_responses.database.fetcher.amp_query")) + assert.Equal(t, want_queries.poll_for_updates_query, v.GetString("stored_responses.database.poll_for_updates.query")) + assert.Equal(t, want_queries.poll_for_updates_amp_query, v.GetString("stored_responses.database.poll_for_updates.amp_query")) +} + func TestNegativeRequestSize(t *testing.T) { cfg, v := newDefaultConfig(t) cfg.MaxRequestSize = -1 @@ -1804,11 +2489,11 @@ func TestValidateAccountsConfigRestrictions(t *testing.T) { cfg, v := newDefaultConfig(t) cfg.Accounts.Files.Enabled = true cfg.Accounts.HTTP.Endpoint = "http://localhost" - cfg.Accounts.Postgres.ConnectionInfo.Database = "accounts" + cfg.Accounts.Database.ConnectionInfo.Database = "accounts" errs := cfg.validate(v) assert.Len(t, errs, 1) - assert.Contains(t, errs, errors.New("accounts.postgres: retrieving accounts via postgres not available, use accounts.files")) + assert.Contains(t, errs, errors.New("accounts.database: retrieving accounts via database not available, use accounts.files")) } func newDefaultConfig(t *testing.T) (*Configuration, *viper.Viper) { diff --git a/config/stored_requests.go b/config/stored_requests.go index c8fa5c28761..6bd8a44f411 100644 --- a/config/stored_requests.go +++ b/config/stored_requests.go @@ -1,9 +1,7 @@ package config import ( - "bytes" "fmt" - "strconv" "strings" "time" @@ -57,10 +55,10 @@ type StoredRequests struct { // Files should be used if Stored Requests should be loaded from the filesystem. // Fetchers are in stored_requests/backends/file_system/fetcher.go Files FileFetcherConfig `mapstructure:"filesystem"` - // Postgres configures Fetchers and EventProducers which read from a Postgres DB. - // Fetchers are in stored_requests/backends/db_fetcher/postgres.go - // EventProducers are in stored_requests/events/postgres - Postgres PostgresConfig `mapstructure:"postgres"` + // Database configures Fetchers and EventProducers which read from a Database. + // Fetchers are in stored_requests/backends/db_fetcher/fetcher.go + // EventProducers are in stored_requests/events/database + Database DatabaseConfig `mapstructure:"database"` // HTTP configures an instance of stored_requests/backends/http/http_fetcher.go. // If non-nil, Stored Requests will be fetched from the endpoint described there. HTTP HTTPFetcherConfig `mapstructure:"http"` @@ -122,9 +120,9 @@ func resolvedStoredRequestsConfig(cfg *Configuration) { // Amp uses the same config but some fields get replaced by Amp* version of similar fields cfg.StoredRequestsAMP = cfg.StoredRequests - amp.Postgres.FetcherQueries.QueryTemplate = sr.Postgres.FetcherQueries.AmpQueryTemplate - amp.Postgres.CacheInitialization.Query = sr.Postgres.CacheInitialization.AmpQuery - amp.Postgres.PollUpdates.Query = sr.Postgres.PollUpdates.AmpQuery + amp.Database.FetcherQueries.QueryTemplate = sr.Database.FetcherQueries.AmpQueryTemplate + amp.Database.CacheInitialization.Query = sr.Database.CacheInitialization.AmpQuery + amp.Database.PollUpdates.Query = sr.Database.PollUpdates.AmpQuery amp.HTTP.Endpoint = sr.HTTP.AmpEndpoint amp.CacheEvents.Endpoint = "/storedrequests/amp" amp.HTTPEvents.Endpoint = sr.HTTPEvents.AmpEndpoint @@ -139,10 +137,10 @@ func resolvedStoredRequestsConfig(cfg *Configuration) { } func (cfg *StoredRequests) validate(errs []error) []error { - if cfg.DataType() == AccountDataType && cfg.Postgres.ConnectionInfo.Database != "" { - errs = append(errs, fmt.Errorf("%s.postgres: retrieving accounts via postgres not available, use accounts.files", cfg.Section())) + if cfg.DataType() == AccountDataType && cfg.Database.ConnectionInfo.Database != "" { + errs = append(errs, fmt.Errorf("%s.database: retrieving accounts via database not available, use accounts.files", cfg.Section())) } else { - errs = cfg.Postgres.validate(cfg.DataType(), errs) + errs = cfg.Database.validate(cfg.DataType(), errs) } // Categories do not use cache so none of the following checks apply @@ -159,27 +157,27 @@ func (cfg *StoredRequests) validate(errs []error) []error { errs = append(errs, fmt.Errorf("%s: http_events.refresh_rate_seconds must be 0 if in_memory_cache=none", cfg.Section())) } - if cfg.Postgres.PollUpdates.Query != "" { - errs = append(errs, fmt.Errorf("%s: postgres.poll_for_updates.query must be empty if in_memory_cache=none", cfg.Section())) + if cfg.Database.PollUpdates.Query != "" { + errs = append(errs, fmt.Errorf("%s: database.poll_for_updates.query must be empty if in_memory_cache=none", cfg.Section())) } - if cfg.Postgres.CacheInitialization.Query != "" { - errs = append(errs, fmt.Errorf("%s: postgres.initialize_caches.query must be empty if in_memory_cache=none", cfg.Section())) + if cfg.Database.CacheInitialization.Query != "" { + errs = append(errs, fmt.Errorf("%s: database.initialize_caches.query must be empty if in_memory_cache=none", cfg.Section())) } } errs = cfg.InMemoryCache.validate(cfg.DataType(), errs) return errs } -// PostgresConfig configures the Stored Request ecosystem to use Postgres. This must include a Fetcher, +// DatabaseConfig configures the Stored Request ecosystem to use Database. This must include a Fetcher, // and may optionally include some EventProducers to populate and refresh the caches. -type PostgresConfig struct { - ConnectionInfo PostgresConnection `mapstructure:"connection"` - FetcherQueries PostgresFetcherQueries `mapstructure:"fetcher"` - CacheInitialization PostgresCacheInitializer `mapstructure:"initialize_caches"` - PollUpdates PostgresUpdatePolling `mapstructure:"poll_for_updates"` +type DatabaseConfig struct { + ConnectionInfo DatabaseConnection `mapstructure:"connection"` + FetcherQueries DatabaseFetcherQueries `mapstructure:"fetcher"` + CacheInitialization DatabaseCacheInitializer `mapstructure:"initialize_caches"` + PollUpdates DatabaseUpdatePolling `mapstructure:"poll_for_updates"` } -func (cfg *PostgresConfig) validate(dataType DataType, errs []error) []error { +func (cfg *DatabaseConfig) validate(dataType DataType, errs []error) []error { if cfg.ConnectionInfo.Database == "" { return errs } @@ -189,9 +187,10 @@ func (cfg *PostgresConfig) validate(dataType DataType, errs []error) []error { return errs } -// PostgresConnection has options which put types to the Postgres Connection string. See: +// DatabaseConnection has options which put types to the Database Connection string. See: // https://godoc.org/github.com/lib/pq#hdr-Connection_String_Parameters -type PostgresConnection struct { +type DatabaseConnection struct { + Driver string `mapstructure:"driver"` Database string `mapstructure:"dbname"` Host string `mapstructure:"host"` Port int `mapstructure:"port"` @@ -199,55 +198,18 @@ type PostgresConnection struct { Password string `mapstructure:"password"` } -func (cfg *PostgresConnection) ConnString() string { - buffer := bytes.NewBuffer(nil) - - if cfg.Host != "" { - buffer.WriteString("host=") - buffer.WriteString(cfg.Host) - buffer.WriteString(" ") - } - - if cfg.Port > 0 { - buffer.WriteString("port=") - buffer.WriteString(strconv.Itoa(cfg.Port)) - buffer.WriteString(" ") - } - - if cfg.Username != "" { - buffer.WriteString("user=") - buffer.WriteString(cfg.Username) - buffer.WriteString(" ") - } - - if cfg.Password != "" { - buffer.WriteString("password=") - buffer.WriteString(cfg.Password) - buffer.WriteString(" ") - } - - if cfg.Database != "" { - buffer.WriteString("dbname=") - buffer.WriteString(cfg.Database) - buffer.WriteString(" ") - } - - buffer.WriteString("sslmode=disable") - return buffer.String() -} - -type PostgresFetcherQueries struct { - // QueryTemplate is the Postgres Query which can be used to fetch configs from the database. +type DatabaseFetcherQueries struct { + // QueryTemplate is the Database Query which can be used to fetch configs from the database. // It is a Template, rather than a full Query, because a single HTTP request may reference multiple Stored Requests. // // In the simplest case, this could be something like: // SELECT id, requestData, 'request' as type // FROM stored_requests - // WHERE id in %REQUEST_ID_LIST% + // WHERE id in $REQUEST_ID_LIST // UNION ALL // SELECT id, impData, 'imp' as type // FROM stored_imps - // WHERE id in %IMP_ID_LIST% + // WHERE id in $IMP_ID_LIST // // The MakeQuery function will transform this query into: // SELECT id, requestData, 'request' as type @@ -265,7 +227,7 @@ type PostgresFetcherQueries struct { AmpQueryTemplate string `mapstructure:"amp_query"` } -type PostgresCacheInitializer struct { +type DatabaseCacheInitializer struct { Timeout int `mapstructure:"timeout_ms"` // Query should be something like: // @@ -275,27 +237,27 @@ type PostgresCacheInitializer struct { // // This query will be run once on startup to fetch _all_ known Stored Request data from the database. // - // For more details on the expected format of requestData and impData, see stored_requests/events/postgres/polling.go + // For more details on the expected format of requestData and impData, see stored_requests/events/database/database.go Query string `mapstructure:"query"` // AmpQuery is just like Query, but for AMP Stored Requests AmpQuery string `mapstructure:"amp_query"` } -func (cfg *PostgresCacheInitializer) validate(dataType DataType, errs []error) []error { +func (cfg *DatabaseCacheInitializer) validate(dataType DataType, errs []error) []error { section := dataType.Section() if cfg.Query == "" { return errs } if cfg.Timeout <= 0 { - errs = append(errs, fmt.Errorf("%s: postgres.initialize_caches.timeout_ms must be positive", section)) + errs = append(errs, fmt.Errorf("%s: database.initialize_caches.timeout_ms must be positive", section)) } if strings.Contains(cfg.Query, "$") { - errs = append(errs, fmt.Errorf("%s: postgres.initialize_caches.query should not contain any wildcards (e.g. $1)", section)) + errs = append(errs, fmt.Errorf("%s: database.initialize_caches.query should not contain any wildcards denoted by $ (e.g. $LAST_UPDATED)", section)) } return errs } -type PostgresUpdatePolling struct { +type DatabaseUpdatePolling struct { // RefreshRate determines how frequently the Query and AmpQuery are run. RefreshRate int `mapstructure:"refresh_rate_seconds"` @@ -306,11 +268,11 @@ type PostgresUpdatePolling struct { // // SELECT id, requestData, 'request' AS type // FROM stored_requests - // WHERE last_updated > $1 + // WHERE last_updated > $LAST_UPDATED // UNION ALL - // SELECT id, requestData, 'imp' AS type + // SELECT id, impData, 'imp' AS type // FROM stored_imps - // WHERE last_updated > $1 + // WHERE last_updated > $LAST_UPDATED // // The code will be run periodically to fetch updates from the database. Query string `mapstructure:"query"` @@ -318,89 +280,24 @@ type PostgresUpdatePolling struct { AmpQuery string `mapstructure:"amp_query"` } -func (cfg *PostgresUpdatePolling) validate(dataType DataType, errs []error) []error { +func (cfg *DatabaseUpdatePolling) validate(dataType DataType, errs []error) []error { section := dataType.Section() if cfg.Query == "" { return errs } if cfg.RefreshRate <= 0 { - errs = append(errs, fmt.Errorf("%s: postgres.poll_for_updates.refresh_rate_seconds must be > 0", section)) + errs = append(errs, fmt.Errorf("%s: database.poll_for_updates.refresh_rate_seconds must be > 0", section)) } if cfg.Timeout <= 0 { - errs = append(errs, fmt.Errorf("%s: postgres.poll_for_updates.timeout_ms must be > 0", section)) + errs = append(errs, fmt.Errorf("%s: database.poll_for_updates.timeout_ms must be > 0", section)) } - - if !strings.Contains(cfg.Query, "$1") || strings.Contains(cfg.Query, "$2") { - errs = append(errs, fmt.Errorf("%s: postgres.poll_for_updates.query must contain exactly one wildcard", section)) + if !strings.Contains(cfg.Query, "$LAST_UPDATED") { + errs = append(errs, fmt.Errorf("%s: database.poll_for_updates.query must contain $LAST_UPDATED parameter", section)) } - return errs -} - -// MakeQuery builds a query which can fetch numReqs Stored Requests and numImps Stored Imps. -// See the docs on PostgresConfig.QueryTemplate for a description of how it works. -func (cfg *PostgresFetcherQueries) MakeQuery(numReqs int, numImps int) (query string) { - return resolve(cfg.QueryTemplate, numReqs, numImps) -} -func (cfg *PostgresFetcherQueries) MakeQueryResponses(numIds int) (query string) { - return resolveQueryResponses(cfg.QueryTemplate, numIds) -} - -func resolve(template string, numReqs int, numImps int) (query string) { - numReqs = ensureNonNegative("Request", numReqs) - numImps = ensureNonNegative("Imp", numImps) - - query = strings.Replace(template, "%REQUEST_ID_LIST%", makeIdList(0, numReqs), -1) - query = strings.Replace(query, "%IMP_ID_LIST%", makeIdList(numReqs, numImps), -1) - return -} - -func resolveQueryResponses(template string, numIds int) (query string) { - numIds = ensureNonNegative("Response", numIds) - - query = strings.Replace(template, "%ID_LIST%", makeIdList(0, numIds), -1) - return -} - -func ensureNonNegative(storedThing string, num int) int { - if num < 0 { - glog.Errorf("Can't build a SQL query for %d Stored %ss.", num, storedThing) - return 0 - } - return num -} - -func makeIdList(numSoFar int, numArgs int) string { - // Any empty list like "()" is illegal in Postgres. A (NULL) is the next best thing, - // though, since `id IN (NULL)` is valid for all "id" column types, and evaluates to an empty set. - // - // The query plan also suggests that it's basically free: - // - // explain SELECT id, requestData FROM stored_requests WHERE id in %ID_LIST%; - // - // QUERY PLAN - // ------------------------------------------- - // Result (cost=0.00..0.00 rows=0 width=16) - // One-Time Filter: false - // (2 rows) - if numArgs == 0 { - return "(NULL)" - } - - final := bytes.NewBuffer(make([]byte, 0, 2+4*numArgs)) - final.WriteString("(") - for i := numSoFar + 1; i < numSoFar+numArgs; i++ { - final.WriteString("$") - final.WriteString(strconv.Itoa(i)) - final.WriteString(", ") - } - final.WriteString("$") - final.WriteString(strconv.Itoa(numSoFar + numArgs)) - final.WriteString(")") - - return final.String() + return errs } type InMemoryCache struct { diff --git a/config/stored_requests_test.go b/config/stored_requests_test.go index ac468d2a32b..6184d19bd38 100644 --- a/config/stored_requests_test.go +++ b/config/stored_requests_test.go @@ -2,119 +2,11 @@ package config import ( "errors" - "strconv" - "strings" "testing" "github.com/stretchr/testify/assert" ) -const sampleQueryTemplate = "SELECT id, requestData, 'request' as type FROM stored_requests WHERE id in %REQUEST_ID_LIST% UNION ALL SELECT id, impData, 'imp' as type FROM stored_requests WHERE id in %IMP_ID_LIST%" -const sampleResponsesQueryTemplate = "SELECT id, responseData, 'response' as type FROM stored_responses WHERE id in %ID_LIST%" - -func TestNormalQueryMaker(t *testing.T) { - madeQuery := buildQuery(sampleQueryTemplate, 1, 3) - assertStringsEqual(t, madeQuery, "SELECT id, requestData, 'request' as type FROM stored_requests WHERE id in ($1) UNION ALL SELECT id, impData, 'imp' as type FROM stored_requests WHERE id in ($2, $3, $4)") -} - -func TestQueryMakerManyImps(t *testing.T) { - madeQuery := buildQuery(sampleQueryTemplate, 1, 11) - assertStringsEqual(t, madeQuery, "SELECT id, requestData, 'request' as type FROM stored_requests WHERE id in ($1) UNION ALL SELECT id, impData, 'imp' as type FROM stored_requests WHERE id in ($2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)") -} - -func TestQueryMakerNoRequests(t *testing.T) { - madeQuery := buildQuery(sampleQueryTemplate, 0, 3) - assertStringsEqual(t, madeQuery, "SELECT id, requestData, 'request' as type FROM stored_requests WHERE id in (NULL) UNION ALL SELECT id, impData, 'imp' as type FROM stored_requests WHERE id in ($1, $2, $3)") -} - -func TestQueryMakerNoImps(t *testing.T) { - madeQuery := buildQuery(sampleQueryTemplate, 1, 0) - assertStringsEqual(t, madeQuery, "SELECT id, requestData, 'request' as type FROM stored_requests WHERE id in ($1) UNION ALL SELECT id, impData, 'imp' as type FROM stored_requests WHERE id in (NULL)") -} - -func TestQueryMakerMultilists(t *testing.T) { - madeQuery := buildQuery("SELECT id, config FROM table WHERE id in %IMP_ID_LIST% UNION ALL SELECT id, config FROM other_table WHERE id in %IMP_ID_LIST%", 0, 3) - assertStringsEqual(t, madeQuery, "SELECT id, config FROM table WHERE id in ($1, $2, $3) UNION ALL SELECT id, config FROM other_table WHERE id in ($1, $2, $3)") -} - -func TestQueryMakerNegative(t *testing.T) { - query := buildQuery(sampleQueryTemplate, -1, -2) - expected := buildQuery(sampleQueryTemplate, 0, 0) - assertStringsEqual(t, query, expected) -} - -func TestResponseQueryMaker(t *testing.T) { - testCases := []struct { - description string - inputRespNumber int - expectedQuery string - }{ - { - description: "single response query maker", - inputRespNumber: 1, - expectedQuery: "SELECT id, responseData, 'response' as type FROM stored_responses WHERE id in ($1)", - }, - { - description: "many responses query maker", - inputRespNumber: 11, - expectedQuery: "SELECT id, responseData, 'response' as type FROM stored_responses WHERE id in ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)", - }, - { - description: "no responses query maker", - inputRespNumber: 0, - expectedQuery: "SELECT id, responseData, 'response' as type FROM stored_responses WHERE id in (NULL)", - }, - { - description: "no responses query maker", - inputRespNumber: -2, - expectedQuery: "SELECT id, responseData, 'response' as type FROM stored_responses WHERE id in (NULL)", - }, - } - - for _, test := range testCases { - cfg := PostgresFetcherQueries{QueryTemplate: sampleResponsesQueryTemplate} - query := cfg.MakeQueryResponses(test.inputRespNumber) - assertStringsEqual(t, query, test.expectedQuery) - } -} - -func TestPostgressConnString(t *testing.T) { - db := "TestDB" - host := "somehost.com" - port := 20 - username := "someuser" - password := "somepassword" - - cfg := PostgresConnection{ - Database: db, - Host: host, - Port: port, - Username: username, - Password: password, - } - - dataSourceName := cfg.ConnString() - paramList := strings.Split(dataSourceName, " ") - params := make(map[string]string, len(paramList)) - for _, param := range paramList { - keyVals := strings.Split(param, "=") - if len(keyVals) != 2 { - t.Fatalf(`param "%s" must only have one equals sign`, param) - } - if _, ok := params[keyVals[0]]; ok { - t.Fatalf("found duplicate param at key %s", keyVals[0]) - } - params[keyVals[0]] = keyVals[1] - } - - assertHasValue(t, params, "dbname", db) - assertHasValue(t, params, "host", host) - assertHasValue(t, params, "port", strconv.Itoa(port)) - assertHasValue(t, params, "user", username) - assertHasValue(t, params, "password", password) - assertHasValue(t, params, "sslmode", "disable") -} - func TestInMemoryCacheValidationStoredRequests(t *testing.T) { assertNoErrs(t, (&InMemoryCache{ Type: "unbounded", @@ -211,7 +103,7 @@ func TestInMemoryCacheValidationSingleCache(t *testing.T) { }).validate(AccountDataType, nil)) } -func TestPostgresConfigValidation(t *testing.T) { +func TestDatabaseConfigValidation(t *testing.T) { tests := []struct { description string connectionStr string @@ -247,21 +139,21 @@ func TestPostgresConfigValidation(t *testing.T) { { description: "Invalid cache init query contains wildcard", connectionStr: "some-connection-string", - cacheInitQuery: "SELECT * FROM table WHERE $1", + cacheInitQuery: "SELECT * FROM table WHERE $LAST_UPDATED", cacheInitTimeout: 1, wantErrorCount: 1, }, { description: "Valid cache update query with non-zero timeout and refresh rate", connectionStr: "some-connection-string", - cacheUpdateQuery: "SELECT * FROM table WHERE $1", + cacheUpdateQuery: "SELECT * FROM table WHERE $LAST_UPDATED", cacheUpdateRefreshRate: 1, cacheUpdateTimeout: 1, }, { description: "Valid cache update query with zero timeout and non-zero refresh rate", connectionStr: "some-connection-string", - cacheUpdateQuery: "SELECT * FROM table WHERE $1", + cacheUpdateQuery: "SELECT * FROM table WHERE $LAST_UPDATED", cacheUpdateRefreshRate: 1, cacheUpdateTimeout: 0, wantErrorCount: 1, @@ -269,7 +161,7 @@ func TestPostgresConfigValidation(t *testing.T) { { description: "Valid cache update query with non-zero timeout and zero refresh rate", connectionStr: "some-connection-string", - cacheUpdateQuery: "SELECT * FROM table WHERE $1", + cacheUpdateQuery: "SELECT * FROM table WHERE $LAST_UPDATED", cacheUpdateRefreshRate: 0, cacheUpdateTimeout: 1, wantErrorCount: 1, @@ -286,29 +178,29 @@ func TestPostgresConfigValidation(t *testing.T) { description: "Multiple errors: valid queries missing timeouts and refresh rates plus existing error", connectionStr: "some-connection-string", cacheInitQuery: "SELECT * FROM table;", - cacheUpdateQuery: "SELECT * FROM table WHERE $1", + cacheUpdateQuery: "SELECT * FROM table WHERE $LAST_UPDATED", existingErrors: []error{errors.New("existing error before calling validate")}, wantErrorCount: 4, }, } for _, tt := range tests { - pgConfig := &PostgresConfig{ - ConnectionInfo: PostgresConnection{ + dbConfig := &DatabaseConfig{ + ConnectionInfo: DatabaseConnection{ Database: tt.connectionStr, }, - CacheInitialization: PostgresCacheInitializer{ + CacheInitialization: DatabaseCacheInitializer{ Query: tt.cacheInitQuery, Timeout: tt.cacheInitTimeout, }, - PollUpdates: PostgresUpdatePolling{ + PollUpdates: DatabaseUpdatePolling{ Query: tt.cacheUpdateQuery, RefreshRate: tt.cacheUpdateRefreshRate, Timeout: tt.cacheUpdateTimeout, }, } - errs := pgConfig.validate(RequestDataType, tt.existingErrors) + errs := dbConfig.validate(RequestDataType, tt.existingErrors) assert.Equal(t, tt.wantErrorCount, len(errs), tt.description) } } @@ -327,24 +219,6 @@ func assertNoErrs(t *testing.T, err []error) { } } -func assertHasValue(t *testing.T, m map[string]string, key string, val string) { - t.Helper() - realVal, ok := m[key] - if !ok { - t.Errorf("Map missing required key: %s", key) - } - if val != realVal { - t.Errorf("Unexpected value at key %s. Expected %s, Got %s", key, val, realVal) - } -} - -func buildQuery(template string, numReqs int, numImps int) string { - cfg := PostgresFetcherQueries{} - cfg.QueryTemplate = template - - return cfg.MakeQuery(numReqs, numImps) -} - func assertStringsEqual(t *testing.T, actual string, expected string) { if actual != expected { t.Errorf("Queries did not match.\n\"%s\" -- expected\n\"%s\" -- actual", expected, actual) @@ -358,21 +232,22 @@ func TestResolveConfig(t *testing.T) { Files: FileFetcherConfig{ Enabled: true, Path: "/test-path"}, - Postgres: PostgresConfig{ - ConnectionInfo: PostgresConnection{ + Database: DatabaseConfig{ + ConnectionInfo: DatabaseConnection{ + Driver: "postgres", Database: "db", Host: "pghost", Port: 5, Username: "user", Password: "pass", }, - FetcherQueries: PostgresFetcherQueries{ + FetcherQueries: DatabaseFetcherQueries{ AmpQueryTemplate: "amp-fetcher-query", }, - CacheInitialization: PostgresCacheInitializer{ + CacheInitialization: DatabaseCacheInitializer{ AmpQuery: "amp-cache-init-query", }, - PollUpdates: PostgresUpdatePolling{ + PollUpdates: DatabaseUpdatePolling{ AmpQuery: "amp-poll-query", }, }, @@ -394,9 +269,9 @@ func TestResolveConfig(t *testing.T) { }, } - cfg.StoredRequests.Postgres.FetcherQueries.QueryTemplate = "auc-fetcher-query" - cfg.StoredRequests.Postgres.CacheInitialization.Query = "auc-cache-init-query" - cfg.StoredRequests.Postgres.PollUpdates.Query = "auc-poll-query" + cfg.StoredRequests.Database.FetcherQueries.QueryTemplate = "auc-fetcher-query" + cfg.StoredRequests.Database.CacheInitialization.Query = "auc-cache-init-query" + cfg.StoredRequests.Database.PollUpdates.Query = "auc-poll-query" cfg.StoredRequests.HTTP.Endpoint = "auc-http-fetcher-endpoint" cfg.StoredRequests.HTTPEvents.Endpoint = "auc-http-events-endpoint" @@ -408,9 +283,9 @@ func TestResolveConfig(t *testing.T) { assertStringsEqual(t, auc.CacheEvents.Endpoint, "/storedrequests/openrtb2") // Amp should have the amp values in it - assertStringsEqual(t, amp.Postgres.FetcherQueries.QueryTemplate, cfg.StoredRequests.Postgres.FetcherQueries.AmpQueryTemplate) - assertStringsEqual(t, amp.Postgres.CacheInitialization.Query, cfg.StoredRequests.Postgres.CacheInitialization.AmpQuery) - assertStringsEqual(t, amp.Postgres.PollUpdates.Query, cfg.StoredRequests.Postgres.PollUpdates.AmpQuery) + assertStringsEqual(t, amp.Database.FetcherQueries.QueryTemplate, cfg.StoredRequests.Database.FetcherQueries.AmpQueryTemplate) + assertStringsEqual(t, amp.Database.CacheInitialization.Query, cfg.StoredRequests.Database.CacheInitialization.AmpQuery) + assertStringsEqual(t, amp.Database.PollUpdates.Query, cfg.StoredRequests.Database.PollUpdates.AmpQuery) assertStringsEqual(t, amp.HTTP.Endpoint, cfg.StoredRequests.HTTP.AmpEndpoint) assertStringsEqual(t, amp.HTTPEvents.Endpoint, cfg.StoredRequests.HTTPEvents.AmpEndpoint) assertStringsEqual(t, amp.CacheEvents.Endpoint, "/storedrequests/amp") diff --git a/docs/developers/stored-requests.md b/docs/developers/stored-requests.md index 9adf4ed1309..0f24391c04d 100644 --- a/docs/developers/stored-requests.md +++ b/docs/developers/stored-requests.md @@ -194,14 +194,32 @@ with different [configuration options](configuration.md). For example: ```yaml stored_requests: - postgres: - host: localhost - port: 5432 - user: db-username - dbname: database-name - query: SELECT id, requestData, 'request' as type FROM stored_requests WHERE id in %REQUEST_ID_LIST% UNION ALL SELECT id, impData, 'imp' as type FROM stored_imps WHERE id in %IMP_ID_LIST%; + database: + connection: + driver: postgres + host: localhost + port: 5432 + user: db-username + dbname: database-name + fetcher: + query: SELECT id, requestData, 'request' as type FROM stored_requests WHERE id in $REQUEST_ID_LIST UNION ALL SELECT id, impData, 'imp' as type FROM stored_imps WHERE id in $IMP_ID_LIST; ``` +### Supported Databases +- postgres +- mysql + +### Query Syntax +All database queries should be expressed using the native SQL syntax of your supported database of choice with one caveat. + +For all supported database drivers, wherever you need to specify a query parameter, you must not use the native syntax (e.g. `$1`, `%%`, `?`, etc.), but rather a PBS-specific syntax to represent the parameter which is of the format `$VARIABLE_NAME`. PBS currently supports just four query parameters, each of which pertains to particular config queries, and here is how they should be specified in your queries: +- last updated at timestamp --> `$LAST_UPDATED` +- stored request ID list --> `$REQUEST_ID_LIST` +- stored imp ID list --> `$IMP_ID_LIST` +- stored response ID list --> `$ID_LIST` + +See the query defined at `stored_requests.database.connection.fetcher.query` in the yaml config above as an example of how to mix these variables in with native SQL syntax. + ```yaml stored_requests: http: @@ -234,17 +252,20 @@ Any concrete Fetcher in the project will be composed with any Cache(s) to create EventProducer events are used to Save or Invalidate values from the Cache(s). Saves and invalidates will propagate to all Cache layers. -Here is an example `pbs.yaml` file which looks for Stored Requests first from Postgres, and then from an HTTP endpoint. +Here is an example `pbs.yaml` file which looks for Stored Requests first from Database (i.e. Postgres), and then from an HTTP endpoint. It will use an in-memory LRU cache to store data locally, and poll another HTTP endpoint to listen for updates. ```yaml stored_requests: - postgres: - host: localhost - port: 5432 - user: db-username - dbname: database-name - query: SELECT id, requestData, 'request' as type FROM stored_requests WHERE id in %REQUEST_ID_LIST% UNION ALL SELECT id, impData, 'imp' as type FROM stored_imps WHERE id in %IMP_ID_LIST%; + database: + connection: + driver: postgres + host: localhost + port: 5432 + user: db-username + dbname: database-name + fetcher: + query: SELECT id, requestData, 'request' as type FROM stored_requests WHERE id in $REQUEST_ID_LIST UNION ALL SELECT id, impData, 'imp' as type FROM stored_imps WHERE id in $IMP_ID_LIST; http: endpoint: http://stored-requests.prebid.com amp_endpoint: http://stored-requests.prebid.com?amp=true diff --git a/go.mod b/go.mod index be5ea481437..2beada03894 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/chasex/glog v0.0.0-20160217080310-c62392af379c github.com/coocood/freecache v1.2.1 github.com/docker/go-units v0.4.0 + github.com/go-sql-driver/mysql v1.6.0 // indirect github.com/gofrs/uuid v4.2.0+incompatible github.com/golang/glog v1.0.0 github.com/julienschmidt/httprouter v1.3.0 diff --git a/go.sum b/go.sum index 5cce1e88dae..383623149b0 100644 --- a/go.sum +++ b/go.sum @@ -149,6 +149,8 @@ github.com/go-ldap/ldap v3.0.2+incompatible/go.mod h1:qfd9rJvER9Q0/D/Sqn1DfHRoBp github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= +github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/go-test/deep v1.0.2-0.20181118220953-042da051cf31/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA= diff --git a/router/router.go b/router/router.go index fc92a2f8eb5..3b863f68a22 100644 --- a/router/router.go +++ b/router/router.go @@ -35,6 +35,7 @@ import ( "github.com/prebid/prebid-server/util/uuidutil" "github.com/prebid/prebid-server/version" + _ "github.com/go-sql-driver/mysql" "github.com/golang/glog" "github.com/julienschmidt/httprouter" _ "github.com/lib/pq" diff --git a/stored_requests/backends/db_fetcher/fetcher.go b/stored_requests/backends/db_fetcher/fetcher.go index 1ad64a3ca3f..7963751bcf3 100644 --- a/stored_requests/backends/db_fetcher/fetcher.go +++ b/stored_requests/backends/db_fetcher/fetcher.go @@ -2,36 +2,42 @@ package db_fetcher import ( "context" - "database/sql" "encoding/json" + "github.com/lib/pq" "github.com/golang/glog" "github.com/prebid/prebid-server/stored_requests" + "github.com/prebid/prebid-server/stored_requests/backends/db_provider" ) -func NewFetcher(db *sql.DB, queryMaker func(int, int) string, responseQueryMaker func(int) string) stored_requests.AllFetcher { - if db == nil { - glog.Fatalf("The Postgres Stored Request Fetcher requires a database connection. Please report this as a bug.") +func NewFetcher( + provider db_provider.DbProvider, + queryTemplate string, + responseQueryTemplate string, +) stored_requests.AllFetcher { + + if provider == nil { + glog.Fatalf("The Database Stored Request Fetcher requires a database connection. Please report this as a bug.") } - if queryMaker == nil { - glog.Fatalf("The Postgres Stored Request Fetcher requires a queryMaker function. Please report this as a bug.") + if queryTemplate == "" { + glog.Fatalf("The Database Stored Request Fetcher requires a queryTemplate. Please report this as a bug.") } - if responseQueryMaker == nil { - glog.Fatalf("The Postgres Stored Response Fetcher requires a responseQueryMaker function. Please report this as a bug.") + if responseQueryTemplate == "" { + glog.Fatalf("The Database Stored Response Fetcher requires a responseQueryTemplate. Please report this as a bug.") } return &dbFetcher{ - db: db, - queryMaker: queryMaker, - responseQueryMaker: responseQueryMaker, + provider: provider, + queryTemplate: queryTemplate, + responseQueryTemplate: responseQueryTemplate, } } // dbFetcher fetches Stored Requests from a database. This should be instantiated through the NewFetcher() function. type dbFetcher struct { - db *sql.DB - queryMaker func(numReqs int, numImps int) (query string) - responseQueryMaker func(numIds int) (query string) + provider db_provider.DbProvider + queryTemplate string + responseQueryTemplate string } func (fetcher *dbFetcher) FetchRequests(ctx context.Context, requestIDs []string, impIDs []string) (map[string]json.RawMessage, map[string]json.RawMessage, []error) { @@ -39,16 +45,21 @@ func (fetcher *dbFetcher) FetchRequests(ctx context.Context, requestIDs []string return nil, nil, nil } - query := fetcher.queryMaker(len(requestIDs), len(impIDs)) - idInterfaces := make([]interface{}, len(requestIDs)+len(impIDs)) + requestIDsParam := make([]interface{}, len(requestIDs)) for i := 0; i < len(requestIDs); i++ { - idInterfaces[i] = requestIDs[i] + requestIDsParam[i] = requestIDs[i] } + impIDsParam := make([]interface{}, len(impIDs)) for i := 0; i < len(impIDs); i++ { - idInterfaces[i+len(requestIDs)] = impIDs[i] + impIDsParam[i] = impIDs[i] } - rows, err := fetcher.db.QueryContext(ctx, query, idInterfaces...) + params := []db_provider.QueryParam{ + {Name: "REQUEST_ID_LIST", Value: requestIDsParam}, + {Name: "IMP_ID_LIST", Value: impIDsParam}, + } + + rows, err := fetcher.provider.QueryContext(ctx, fetcher.queryTemplate, params...) if err != nil { if err != context.DeadlineExceeded && !isBadInput(err) { glog.Errorf("Error reading from Stored Request DB: %s", err.Error()) @@ -82,7 +93,7 @@ func (fetcher *dbFetcher) FetchRequests(ctx context.Context, requestIDs []string case "imp": storedImpData[id] = data default: - glog.Errorf("Postgres result set with id=%s has invalid type: %s. This will be ignored.", id, dataType) + glog.Errorf("Database result set with id=%s has invalid type: %s. This will be ignored.", id, dataType) } } @@ -102,13 +113,15 @@ func (fetcher *dbFetcher) FetchResponses(ctx context.Context, ids []string) (dat return nil, nil } - query := fetcher.responseQueryMaker(len(ids)) idInterfaces := make([]interface{}, len(ids)) for i := 0; i < len(ids); i++ { idInterfaces[i] = ids[i] } + params := []db_provider.QueryParam{ + {Name: "ID_LIST", Value: idInterfaces}, + } - rows, err := fetcher.db.QueryContext(ctx, query, idInterfaces...) + rows, err := fetcher.provider.QueryContext(ctx, fetcher.responseQueryTemplate, params...) if err != nil { return nil, []error{err} } diff --git a/stored_requests/backends/db_fetcher/fetcher_test.go b/stored_requests/backends/db_fetcher/fetcher_test.go index 8959736dbbb..04753fb8af5 100644 --- a/stored_requests/backends/db_fetcher/fetcher_test.go +++ b/stored_requests/backends/db_fetcher/fetcher_test.go @@ -11,20 +11,21 @@ import ( "time" "github.com/DATA-DOG/go-sqlmock" + "github.com/prebid/prebid-server/stored_requests/backends/db_provider" "github.com/stretchr/testify/assert" ) func TestEmptyQuery(t *testing.T) { - db, _, err := sqlmock.New() + provider, _, err := db_provider.NewDbProviderMock() if err != nil { t.Fatalf("Unexpected error stubbing DB: %v", err) } - defer db.Close() + defer provider.Close() fetcher := dbFetcher{ - db: db, - queryMaker: successfulQueryMaker(""), - responseQueryMaker: successfulResponseQueryMaker(""), + provider: provider, + queryTemplate: "", + responseQueryTemplate: "", } storedReqs, storedImps, errs := fetcher.FetchRequests(context.Background(), nil, nil) assertErrorCount(t, 0, errs) @@ -45,7 +46,7 @@ func TestGoodResponse(t *testing.T) { AddRow("imp-id-2", `{"imp":true,"value":2}`, "imp") mock, fetcher := newFetcher(t, mockReturn, mockQuery, "request-id") - defer fetcher.db.Close() + defer fetcher.provider.Close() storedReqs, storedImps, errs := fetcher.FetchRequests(context.Background(), []string{"request-id"}, nil) @@ -98,7 +99,7 @@ func TestFetchResponses(t *testing.T) { for _, test := range testCases { mock, fetcher := newFetcher(t, test.mockReturn, test.mockQuery, test.arguments...) - defer fetcher.db.Close() + defer fetcher.provider.Close() storedResponses, errs := fetcher.FetchResponses(context.Background(), test.respIds) @@ -120,7 +121,7 @@ func TestPartialResponse(t *testing.T) { AddRow("stored-req-id", "{}", "request") mock, fetcher := newFetcher(t, mockReturn, mockQuery, "stored-req-id", "stored-req-id-2") - defer fetcher.db.Close() + defer fetcher.provider.Close() storedReqs, storedImps, errs := fetcher.FetchRequests(context.Background(), []string{"stored-req-id", "stored-req-id-2"}, nil) @@ -137,7 +138,7 @@ func TestEmptyResponse(t *testing.T) { mockReturn := sqlmock.NewRows([]string{"id", "data", "dataType"}) mock, fetcher := newFetcher(t, mockReturn, mockQuery, "stored-req-id", "stored-req-id-2", "stored-imp-id") - defer fetcher.db.Close() + defer fetcher.provider.Close() storedReqs, storedImps, errs := fetcher.FetchRequests(context.Background(), []string{"stored-req-id", "stored-req-id-2"}, []string{"stored-imp-id"}) @@ -149,7 +150,7 @@ func TestEmptyResponse(t *testing.T) { // TestDatabaseError makes sure we exit with an error if the DB query fails. func TestDatabaseError(t *testing.T) { - db, mock, err := sqlmock.New() + provider, mock, err := db_provider.NewDbProviderMock() if err != nil { t.Fatalf("Failed to create mock: %v", err) } @@ -157,8 +158,8 @@ func TestDatabaseError(t *testing.T) { mock.ExpectQuery(".*").WillReturnError(errors.New("Invalid query.")) fetcher := &dbFetcher{ - db: db, - queryMaker: successfulQueryMaker("SELECT id, data, dataType FROM my_table WHERE id IN (?, ?)"), + provider: provider, + queryTemplate: "SELECT id, data, dataType FROM my_table WHERE id IN (?, ?)", } storedReqs, storedImps, errs := fetcher.FetchRequests(context.Background(), []string{"stored-req-id"}, nil) @@ -169,7 +170,7 @@ func TestDatabaseError(t *testing.T) { // TestContextDeadlines makes sure a hung query returns when the timeout expires. func TestContextDeadlines(t *testing.T) { - db, mock, err := sqlmock.New() + provider, mock, err := db_provider.NewDbProviderMock() if err != nil { t.Fatalf("Failed to create mock: %v", err) } @@ -177,9 +178,9 @@ func TestContextDeadlines(t *testing.T) { mock.ExpectQuery(".*").WillDelayFor(2 * time.Minute) fetcher := &dbFetcher{ - db: db, - queryMaker: successfulQueryMaker("SELECT id, requestData FROM my_table WHERE id IN (?, ?)"), - responseQueryMaker: successfulResponseQueryMaker("SELECT id, responseData FROM my_table WHERE id IN (?, ?)"), + provider: provider, + queryTemplate: "SELECT id, requestData FROM my_table WHERE id IN (?, ?)", + responseQueryTemplate: "SELECT id, responseData FROM my_table WHERE id IN (?, ?)", } ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) @@ -197,7 +198,7 @@ func TestContextDeadlines(t *testing.T) { // TestContextCancelled makes sure a hung query returns when the context is cancelled. func TestContextCancelled(t *testing.T) { - db, mock, err := sqlmock.New() + provider, mock, err := db_provider.NewDbProviderMock() if err != nil { t.Fatalf("Failed to create mock: %v", err) } @@ -205,9 +206,9 @@ func TestContextCancelled(t *testing.T) { mock.ExpectQuery(".*").WillDelayFor(2 * time.Minute) fetcher := &dbFetcher{ - db: db, - queryMaker: successfulQueryMaker("SELECT id, requestData FROM my_table WHERE id IN (?, ?)"), - responseQueryMaker: successfulResponseQueryMaker("SELECT id, responseData FROM my_table WHERE id IN (?, ?)"), + provider: provider, + queryTemplate: "SELECT id, requestData FROM my_table WHERE id IN (?, ?)", + responseQueryTemplate: "SELECT id, responseData FROM my_table WHERE id IN (?, ?)", } ctx, cancel := context.WithCancel(context.Background()) @@ -224,7 +225,7 @@ func TestContextCancelled(t *testing.T) { // Prevents #338 func TestRowErrors(t *testing.T) { - db, mock, err := sqlmock.New() + provider, mock, err := db_provider.NewDbProviderMock() if err != nil { t.Fatalf("Failed to create mock: %v", err) } @@ -234,8 +235,8 @@ func TestRowErrors(t *testing.T) { rows.RowError(1, errors.New("Error reading from row 1")) mock.ExpectQuery(".*").WillReturnRows(rows) fetcher := &dbFetcher{ - db: db, - queryMaker: successfulQueryMaker("SELECT id, data, dataType FROM my_table WHERE id IN (?)"), + provider: provider, + queryTemplate: "SELECT id, data, dataType FROM my_table WHERE id IN (?)", } data, _, errs := fetcher.FetchRequests(context.Background(), []string{"foo", "bar"}, nil) assertErrorCount(t, 1, errs) @@ -246,7 +247,7 @@ func TestRowErrors(t *testing.T) { } func TestRowErrorsFetchResponses(t *testing.T) { - db, mock, err := sqlmock.New() + provider, mock, err := db_provider.NewDbProviderMock() if err != nil { t.Fatalf("Failed to create mock: %v", err) } @@ -256,9 +257,9 @@ func TestRowErrorsFetchResponses(t *testing.T) { rows.RowError(1, errors.New("Error reading from row 1")) mock.ExpectQuery(".*").WillReturnRows(rows) fetcher := &dbFetcher{ - db: db, - queryMaker: successfulQueryMaker("SELECT id, data, dataType FROM my_table WHERE id IN (?)"), - responseQueryMaker: successfulResponseQueryMaker("SELECT id, data, dataType FROM my_table WHERE id IN (?)"), + provider: provider, + queryTemplate: "SELECT id, data, dataType FROM my_table WHERE id IN (?)", + responseQueryTemplate: "SELECT id, data, dataType FROM my_table WHERE id IN (?)", } data, errs := fetcher.FetchResponses(context.Background(), []string{"foo", "bar"}) assertErrorCount(t, 1, errs) @@ -269,7 +270,7 @@ func TestRowErrorsFetchResponses(t *testing.T) { } func newFetcher(t *testing.T, rows *sqlmock.Rows, query string, args ...driver.Value) (sqlmock.Sqlmock, *dbFetcher) { - db, mock, err := sqlmock.New() + provider, mock, err := db_provider.NewDbProviderMock() if err != nil { t.Fatalf("Failed to create mock: %v", err) return nil, nil @@ -278,9 +279,9 @@ func newFetcher(t *testing.T, rows *sqlmock.Rows, query string, args ...driver.V queryRegex := fmt.Sprintf("^%s$", regexp.QuoteMeta(query)) mock.ExpectQuery(queryRegex).WithArgs(args...).WillReturnRows(rows) fetcher := &dbFetcher{ - db: db, - queryMaker: successfulQueryMaker(query), - responseQueryMaker: successfulResponseQueryMaker(query), + provider: provider, + queryTemplate: query, + responseQueryTemplate: query, } return mock, fetcher @@ -317,15 +318,3 @@ func assertErrorCount(t *testing.T, num int, errs []error) { t.Errorf("Wrong number of errors. Expected %d. Got %d. Errors are %v", num, len(errs), errs) } } - -func successfulQueryMaker(response string) func(int, int) string { - return func(numReqs int, numImps int) string { - return response - } -} - -func successfulResponseQueryMaker(response string) func(int) string { - return func(numIds int) string { - return response - } -} diff --git a/stored_requests/backends/db_provider/db_provider.go b/stored_requests/backends/db_provider/db_provider.go new file mode 100644 index 00000000000..df6ae81e8e6 --- /dev/null +++ b/stored_requests/backends/db_provider/db_provider.go @@ -0,0 +1,51 @@ +package db_provider + +import ( + "context" + "database/sql" + + "github.com/golang/glog" + "github.com/prebid/prebid-server/config" +) + +type DbProvider interface { + Config() config.DatabaseConnection + ConnString() string + Open() error + Close() error + Ping() error + PrepareQuery(template string, params ...QueryParam) (query string, args []interface{}) + QueryContext(ctx context.Context, template string, params ...QueryParam) (*sql.Rows, error) +} + +func NewDbProvider(dataType config.DataType, cfg config.DatabaseConnection) DbProvider { + var provider DbProvider + + switch cfg.Driver { + case "mysql": + provider = &MySqlDbProvider{ + cfg: cfg, + } + case "postgres": + provider = &PostgresDbProvider{ + cfg: cfg, + } + default: + glog.Fatalf("Unsupported database driver %s", cfg.Driver) + return nil + } + + if err := provider.Open(); err != nil { + glog.Fatalf("Failed to open %s database connection: %v", dataType, err) + } + if err := provider.Ping(); err != nil { + glog.Fatalf("Failed to ping %s database: %v", dataType, err) + } + + return provider +} + +type QueryParam struct { + Name string + Value interface{} +} diff --git a/stored_requests/backends/db_provider/db_provider_mock.go b/stored_requests/backends/db_provider/db_provider_mock.go new file mode 100644 index 00000000000..3432f62e713 --- /dev/null +++ b/stored_requests/backends/db_provider/db_provider_mock.go @@ -0,0 +1,67 @@ +package db_provider + +import ( + "context" + "database/sql" + "reflect" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/prebid/prebid-server/config" +) + +func NewDbProviderMock() (*DbProviderMock, sqlmock.Sqlmock, error) { + db, mock, err := sqlmock.New() + if err != nil { + return nil, nil, err + } + + provider := &DbProviderMock{ + db: db, + mock: mock, + } + + return provider, mock, nil +} + +type DbProviderMock struct { + db *sql.DB + mock sqlmock.Sqlmock +} + +func (provider DbProviderMock) Config() config.DatabaseConnection { + return config.DatabaseConnection{} +} + +func (provider DbProviderMock) ConnString() string { + return "" +} + +func (provider DbProviderMock) Open() error { + return nil +} + +func (provider DbProviderMock) Close() error { + return nil +} + +func (provider DbProviderMock) Ping() error { + return nil +} + +func (provider DbProviderMock) PrepareQuery(template string, params ...QueryParam) (query string, args []interface{}) { + for _, param := range params { + if reflect.TypeOf(param.Value).Kind() == reflect.Slice { + idList := param.Value.([]interface{}) + args = append(args, idList...) + } else { + args = append(args, param.Value) + } + } + return template, args +} + +func (provider DbProviderMock) QueryContext(ctx context.Context, template string, params ...QueryParam) (*sql.Rows, error) { + query, args := provider.PrepareQuery(template, params...) + + return provider.db.QueryContext(ctx, query, args...) +} diff --git a/stored_requests/backends/db_provider/db_provider_test.go b/stored_requests/backends/db_provider/db_provider_test.go new file mode 100644 index 00000000000..1bb70419f3f --- /dev/null +++ b/stored_requests/backends/db_provider/db_provider_test.go @@ -0,0 +1,154 @@ +package db_provider + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPrepareQuery(t *testing.T) { + tests := []struct { + description string + + template string + params []QueryParam + mySqlQuery string + mySqlArgs []interface{} + postgresQuery string + postgresArgs []interface{} + }{ + { + description: "Np parameters", + template: "SELECT * FROM table", + params: []QueryParam{}, + mySqlQuery: "SELECT * FROM table", + mySqlArgs: []interface{}{}, + postgresQuery: "SELECT * FROM table", + postgresArgs: []interface{}{}, + }, + { + description: "One simple parameter", + template: "SELECT * FROM table WHERE id = $ID", + params: []QueryParam{{Name: "ID", Value: "1001"}}, + mySqlQuery: "SELECT * FROM table WHERE id = ?", + mySqlArgs: []interface{}{"1001"}, + postgresQuery: "SELECT * FROM table WHERE id = $1", + postgresArgs: []interface{}{"1001"}, + }, + { + description: "Two simple parameters", + template: "SELECT * FROM table WHERE id = $ID AND name = $NAME", + params: []QueryParam{ + {Name: "ID", Value: "1001"}, + {Name: "NAME", Value: "Alice"}, + }, + mySqlQuery: "SELECT * FROM table WHERE id = ? AND name = ?", + mySqlArgs: []interface{}{"1001", "Alice"}, + postgresQuery: "SELECT * FROM table WHERE id = $1 AND name = $2", + postgresArgs: []interface{}{"1001", "Alice"}, + }, + { + description: "Two simple parameters, used several times", + template: "SELECT $ID, $NAME, * FROM table WHERE id = $ID AND name = $NAME", + params: []QueryParam{ + {Name: "ID", Value: "1001"}, + {Name: "NAME", Value: "Alice"}, + }, + mySqlQuery: "SELECT ?, ?, * FROM table WHERE id = ? AND name = ?", + mySqlArgs: []interface{}{"1001", "Alice", "1001", "Alice"}, + postgresQuery: "SELECT $1, $2, * FROM table WHERE id = $1 AND name = $2", + postgresArgs: []interface{}{"1001", "Alice"}, + }, + { + description: "Empty list parameter", + template: "SELECT * FROM table WHERE id IN $IDS", + params: []QueryParam{{Name: "IDS", Value: []interface{}{}}}, + mySqlQuery: "SELECT * FROM table WHERE id IN (NULL)", + mySqlArgs: []interface{}{}, + postgresQuery: "SELECT * FROM table WHERE id IN (NULL)", + postgresArgs: []interface{}{}, + }, + { + description: "One list parameter", + template: "SELECT * FROM table WHERE id IN $IDS", + params: []QueryParam{{Name: "IDS", Value: []interface{}{"1001", "1002"}}}, + mySqlQuery: "SELECT * FROM table WHERE id IN (?, ?)", + mySqlArgs: []interface{}{"1001", "1002"}, + postgresQuery: "SELECT * FROM table WHERE id IN ($1, $2)", + postgresArgs: []interface{}{"1001", "1002"}, + }, + { + description: "Two list parameters", + template: "SELECT * FROM table WHERE id IN $IDS OR name in $NAMES", + params: []QueryParam{ + {Name: "IDS", Value: []interface{}{"1001"}}, + {Name: "NAMES", Value: []interface{}{"Bob", "Nancy"}}, + }, + mySqlQuery: "SELECT * FROM table WHERE id IN (?) OR name in (?, ?)", + mySqlArgs: []interface{}{"1001", "Bob", "Nancy"}, + postgresQuery: "SELECT * FROM table WHERE id IN ($1) OR name in ($2, $3)", + postgresArgs: []interface{}{"1001", "Bob", "Nancy"}, + }, + { + description: "Mix of simple and list parameters", + template: ` + SELECT * FROM table1 + WHERE last_updated > $LAST_UPDATED + AND (id IN $IDS OR name in $NAMES) + UNION ALL + SELECT * FROM table1 + WHERE last_updated > $LAST_UPDATED + AND (id IN $IDS OR name in $NAMES) + `, + params: []QueryParam{ + {Name: "LAST_UPDATED", Value: "1970-01-01"}, + {Name: "IDS", Value: []interface{}{"1001"}}, + {Name: "NAMES", Value: []interface{}{"Bob", "Nancy"}}, + }, + mySqlQuery: ` + SELECT * FROM table1 + WHERE last_updated > ? + AND (id IN (?) OR name in (?, ?)) + UNION ALL + SELECT * FROM table1 + WHERE last_updated > ? + AND (id IN (?) OR name in (?, ?)) + `, + mySqlArgs: []interface{}{ + "1970-01-01", + "1001", + "Bob", "Nancy", + "1970-01-01", + "1001", + "Bob", "Nancy", + }, + postgresQuery: ` + SELECT * FROM table1 + WHERE last_updated > $1 + AND (id IN ($2) OR name in ($3, $4)) + UNION ALL + SELECT * FROM table1 + WHERE last_updated > $1 + AND (id IN ($2) OR name in ($3, $4)) + `, + postgresArgs: []interface{}{ + "1970-01-01", + "1001", + "Bob", "Nancy", + }, + }, + } + + for _, tt := range tests { + mySqlDbProvider := MySqlDbProvider{} + mySqlQuery, mySqlArgs := mySqlDbProvider.PrepareQuery(tt.template, tt.params...) + assert.Equal(t, tt.mySqlQuery, mySqlQuery, fmt.Sprintf("MySql: %s", tt.description)) + assert.Equal(t, tt.mySqlArgs, mySqlArgs, fmt.Sprintf("MySql: %s", tt.description)) + + postgresDbProvider := PostgresDbProvider{} + postgresQuery, postgresArgs := postgresDbProvider.PrepareQuery(tt.template, tt.params...) + assert.Equal(t, tt.postgresQuery, postgresQuery, fmt.Sprintf("Postgres: %s", tt.description)) + assert.Equal(t, tt.postgresArgs, postgresArgs, fmt.Sprintf("Postgres: %s", tt.description)) + } +} diff --git a/stored_requests/backends/db_provider/mysql_dbprovider.go b/stored_requests/backends/db_provider/mysql_dbprovider.go new file mode 100644 index 00000000000..91c37f04910 --- /dev/null +++ b/stored_requests/backends/db_provider/mysql_dbprovider.go @@ -0,0 +1,151 @@ +package db_provider + +import ( + "bytes" + "context" + "database/sql" + "regexp" + "sort" + "strconv" + "strings" + + "github.com/prebid/prebid-server/config" +) + +type MySqlDbProvider struct { + cfg config.DatabaseConnection + db *sql.DB +} + +func (provider *MySqlDbProvider) Config() config.DatabaseConnection { + return provider.cfg +} + +func (provider *MySqlDbProvider) Open() error { + db, err := sql.Open(provider.cfg.Driver, provider.ConnString()) + + if err != nil { + return err + } + + provider.db = db + return nil +} + +func (provider *MySqlDbProvider) Close() error { + if provider.db != nil { + db := provider.db + provider.db = nil + return db.Close() + } + + return nil +} + +func (provider *MySqlDbProvider) Ping() error { + return provider.db.Ping() +} + +func (provider *MySqlDbProvider) ConnString() string { + buffer := bytes.NewBuffer(nil) + + if provider.cfg.Username != "" { + buffer.WriteString(provider.cfg.Username) + if provider.cfg.Password != "" { + buffer.WriteString(":") + buffer.WriteString(provider.cfg.Password) + } + buffer.WriteString("@") + } + + buffer.WriteString("tcp(") + if provider.cfg.Host != "" { + buffer.WriteString(provider.cfg.Host) + } + + if provider.cfg.Port > 0 { + buffer.WriteString(":") + buffer.WriteString(strconv.Itoa(provider.cfg.Port)) + } + buffer.WriteString(")") + + buffer.WriteString("/") + + if provider.cfg.Database != "" { + buffer.WriteString(provider.cfg.Database) + } + + return buffer.String() +} + +func (provider *MySqlDbProvider) PrepareQuery(template string, params ...QueryParam) (query string, args []interface{}) { + query = template + args = []interface{}{} + + type occurrence struct { + startIndex int + param QueryParam + } + occurrences := []occurrence{} + + for _, param := range params { + re := regexp.MustCompile("\\$" + param.Name) + matches := re.FindAllIndex([]byte(query), -1) + for _, match := range matches { + occurrences = append(occurrences, + occurrence{ + startIndex: match[0], + param: param, + }) + } + } + sort.Slice(occurrences, func(i, j int) bool { + return occurrences[i].startIndex < occurrences[j].startIndex + }) + + for _, occurrence := range occurrences { + switch occurrence.param.Value.(type) { + case []interface{}: + idList := occurrence.param.Value.([]interface{}) + args = append(args, idList...) + default: + args = append(args, occurrence.param.Value) + } + } + + for _, param := range params { + switch param.Value.(type) { + case []interface{}: + len := len(param.Value.([]interface{})) + idList := provider.createIdList(len) + query = strings.Replace(query, "$"+param.Name, idList, -1) + default: + query = strings.Replace(query, "$"+param.Name, "?", -1) + } + } + return +} + +func (provider *MySqlDbProvider) QueryContext(ctx context.Context, template string, params ...QueryParam) (*sql.Rows, error) { + query, args := provider.PrepareQuery(template, params...) + return provider.db.QueryContext(ctx, query, args...) +} + +func (provider *MySqlDbProvider) createIdList(numArgs int) string { + // Any empty list like "()" is illegal in MySql. A (NULL) is the next best thing, + // though, since `id IN (NULL)` is valid for all "id" column types, and evaluates to an empty set. + if numArgs == 0 { + return "(NULL)" + } + + result := bytes.NewBuffer(make([]byte, 0, 2+3*numArgs)) + result.WriteString("(") + for i := 1; i < numArgs; i++ { + result.WriteString("?") + result.WriteString(", ") + } + result.WriteString("?") + result.WriteString(")") + + return result.String() +} diff --git a/stored_requests/backends/db_provider/mysql_dbprovider_test.go b/stored_requests/backends/db_provider/mysql_dbprovider_test.go new file mode 100644 index 00000000000..e0796d00e95 --- /dev/null +++ b/stored_requests/backends/db_provider/mysql_dbprovider_test.go @@ -0,0 +1,89 @@ +package db_provider + +import ( + "testing" + + "github.com/prebid/prebid-server/config" + "github.com/stretchr/testify/assert" +) + +func TestConnStringMySql(t *testing.T) { + type Params struct { + db string + host string + port int + username string + password string + } + + tests := []struct { + name string + params Params + connString string + }{ + { + params: Params{ + db: "", + }, + connString: "tcp()/", + }, + { + params: Params{ + db: "TestDB", + }, + connString: "tcp()/TestDB", + }, + { + params: Params{ + host: "example.com", + }, + connString: "tcp(example.com)/", + }, + { + params: Params{ + port: 20, + }, + connString: "tcp(:20)/", + }, + { + params: Params{ + username: "someuser", + }, + connString: "someuser@tcp()/", + }, + { + params: Params{ + username: "someuser", + password: "somepassword", + }, + connString: "someuser:somepassword@tcp()/", + }, + { + params: Params{ + db: "TestDB", + host: "example.com", + port: 20, + username: "someuser", + password: "somepassword", + }, + connString: "someuser:somepassword@tcp(example.com:20)/TestDB", + }, + } + + for _, test := range tests { + cfg := config.DatabaseConnection{ + Database: test.params.db, + Host: test.params.host, + Port: test.params.port, + Username: test.params.username, + Password: test.params.password, + } + + provider := MySqlDbProvider{ + cfg: cfg, + } + + connString := provider.ConnString() + assert.Equal(t, test.connString, connString, "Strings did not match") + } +} diff --git a/stored_requests/backends/db_provider/postgres_dbprovider.go b/stored_requests/backends/db_provider/postgres_dbprovider.go new file mode 100644 index 00000000000..cbd8d9d0913 --- /dev/null +++ b/stored_requests/backends/db_provider/postgres_dbprovider.go @@ -0,0 +1,138 @@ +package db_provider + +import ( + "bytes" + "context" + "database/sql" + "fmt" + "strconv" + "strings" + + "github.com/prebid/prebid-server/config" +) + +type PostgresDbProvider struct { + cfg config.DatabaseConnection + db *sql.DB +} + +func (provider *PostgresDbProvider) Config() config.DatabaseConnection { + return provider.cfg +} + +func (provider *PostgresDbProvider) Open() error { + db, err := sql.Open(provider.cfg.Driver, provider.ConnString()) + + if err != nil { + return err + } + + provider.db = db + return nil +} + +func (provider *PostgresDbProvider) Close() error { + if provider.db != nil { + db := provider.db + provider.db = nil + return db.Close() + } + + return nil +} + +func (provider *PostgresDbProvider) Ping() error { + return provider.db.Ping() +} + +func (provider *PostgresDbProvider) ConnString() string { + buffer := bytes.NewBuffer(nil) + + if provider.cfg.Host != "" { + buffer.WriteString("host=") + buffer.WriteString(provider.cfg.Host) + buffer.WriteString(" ") + } + + if provider.cfg.Port > 0 { + buffer.WriteString("port=") + buffer.WriteString(strconv.Itoa(provider.cfg.Port)) + buffer.WriteString(" ") + } + + if provider.cfg.Username != "" { + buffer.WriteString("user=") + buffer.WriteString(provider.cfg.Username) + buffer.WriteString(" ") + } + + if provider.cfg.Password != "" { + buffer.WriteString("password=") + buffer.WriteString(provider.cfg.Password) + buffer.WriteString(" ") + } + + if provider.cfg.Database != "" { + buffer.WriteString("dbname=") + buffer.WriteString(provider.cfg.Database) + buffer.WriteString(" ") + } + + buffer.WriteString("sslmode=disable") + return buffer.String() +} + +func (provider *PostgresDbProvider) PrepareQuery(template string, params ...QueryParam) (query string, args []interface{}) { + query = template + args = []interface{}{} + + for _, param := range params { + switch v := param.Value.(type) { + case []interface{}: + idList := v + idListStr := provider.createIdList(len(args), len(idList)) + args = append(args, idList...) + query = strings.Replace(query, "$"+param.Name, idListStr, -1) + default: + args = append(args, param.Value) + query = strings.Replace(query, "$"+param.Name, fmt.Sprintf("$%d", len(args)), -1) + } + } + return +} + +func (provider *PostgresDbProvider) QueryContext(ctx context.Context, template string, params ...QueryParam) (*sql.Rows, error) { + query, args := provider.PrepareQuery(template, params...) + return provider.db.QueryContext(ctx, query, args...) +} + +func (provider *PostgresDbProvider) createIdList(numSoFar int, numArgs int) string { + // Any empty list like "()" is illegal in Postgres. A (NULL) is the next best thing, + // though, since `id IN (NULL)` is valid for all "id" column types, and evaluates to an empty set. + // + // The query plan also suggests that it's basically free: + // + // explain SELECT id, requestData FROM stored_requests WHERE id in $ID_LIST; + // + // QUERY PLAN + // ------------------------------------------- + // Result (cost=0.00..0.00 rows=0 width=16) + // One-Time Filter: false + // (2 rows) + if numArgs == 0 { + return "(NULL)" + } + + final := bytes.NewBuffer(make([]byte, 0, 2+4*numArgs)) + final.WriteString("(") + for i := numSoFar + 1; i < numSoFar+numArgs; i++ { + final.WriteString("$") + final.WriteString(strconv.Itoa(i)) + final.WriteString(", ") + } + final.WriteString("$") + final.WriteString(strconv.Itoa(numSoFar + numArgs)) + final.WriteString(")") + + return final.String() +} diff --git a/stored_requests/backends/db_provider/postgres_dbprovider_test.go b/stored_requests/backends/db_provider/postgres_dbprovider_test.go new file mode 100644 index 00000000000..50d825dfdfb --- /dev/null +++ b/stored_requests/backends/db_provider/postgres_dbprovider_test.go @@ -0,0 +1,89 @@ +package db_provider + +import ( + "testing" + + "github.com/prebid/prebid-server/config" + "github.com/stretchr/testify/assert" +) + +func TestConnStringPostgres(t *testing.T) { + type Params struct { + db string + host string + port int + username string + password string + } + + tests := []struct { + name string + params Params + connString string + }{ + { + params: Params{ + db: "", + }, + connString: "sslmode=disable", + }, + { + params: Params{ + db: "TestDB", + }, + connString: "dbname=TestDB sslmode=disable", + }, + { + params: Params{ + host: "example.com", + }, + connString: "host=example.com sslmode=disable", + }, + { + params: Params{ + port: 20, + }, + connString: "port=20 sslmode=disable", + }, + { + params: Params{ + username: "someuser", + }, + connString: "user=someuser sslmode=disable", + }, + { + params: Params{ + username: "someuser", + password: "somepassword", + }, + connString: "user=someuser password=somepassword sslmode=disable", + }, + { + params: Params{ + db: "TestDB", + host: "example.com", + port: 20, + username: "someuser", + password: "somepassword", + }, + connString: "host=example.com port=20 user=someuser password=somepassword dbname=TestDB sslmode=disable", + }, + } + + for _, test := range tests { + cfg := config.DatabaseConnection{ + Database: test.params.db, + Host: test.params.host, + Port: test.params.port, + Username: test.params.username, + Password: test.params.password, + } + + provider := PostgresDbProvider{ + cfg: cfg, + } + + connString := provider.ConnString() + assert.Equal(t, test.connString, connString, "Strings did not match") + } +} diff --git a/stored_requests/config/config.go b/stored_requests/config/config.go index f091126800a..9cb349d1f72 100644 --- a/stored_requests/config/config.go +++ b/stored_requests/config/config.go @@ -2,7 +2,6 @@ package config import ( "context" - "database/sql" "net/http" "time" @@ -13,6 +12,7 @@ import ( "github.com/prebid/prebid-server/config" "github.com/prebid/prebid-server/stored_requests" "github.com/prebid/prebid-server/stored_requests/backends/db_fetcher" + "github.com/prebid/prebid-server/stored_requests/backends/db_provider" "github.com/prebid/prebid-server/stored_requests/backends/empty_fetcher" "github.com/prebid/prebid-server/stored_requests/backends/file_fetcher" "github.com/prebid/prebid-server/stored_requests/backends/http_fetcher" @@ -20,18 +20,11 @@ import ( "github.com/prebid/prebid-server/stored_requests/caches/nil_cache" "github.com/prebid/prebid-server/stored_requests/events" apiEvents "github.com/prebid/prebid-server/stored_requests/events/api" + databaseEvents "github.com/prebid/prebid-server/stored_requests/events/database" httpEvents "github.com/prebid/prebid-server/stored_requests/events/http" - postgresEvents "github.com/prebid/prebid-server/stored_requests/events/postgres" "github.com/prebid/prebid-server/util/task" ) -// This gets set to the connection string used when a database connection is made. We only support a single -// database currently, so all fetchers need to share the same db connection for now. -type dbConnection struct { - conn string - db *sql.DB -} - // CreateStoredRequests returns three things: // // 1. A Fetcher which can be used to get Stored Requests @@ -42,31 +35,28 @@ type dbConnection struct { // // As a side-effect, it will add some endpoints to the router if the config calls for it. // In the future we should look for ways to simplify this so that it's not doing two things. -func CreateStoredRequests(cfg *config.StoredRequests, metricsEngine metrics.MetricsEngine, client *http.Client, router *httprouter.Router, dbc *dbConnection) (fetcher stored_requests.AllFetcher, shutdown func()) { +func CreateStoredRequests(cfg *config.StoredRequests, metricsEngine metrics.MetricsEngine, client *http.Client, router *httprouter.Router, provider db_provider.DbProvider) (fetcher stored_requests.AllFetcher, shutdown func()) { // Create database connection if given options for one - if cfg.Postgres.ConnectionInfo.Database != "" { - conn := cfg.Postgres.ConnectionInfo.ConnString() - - if dbc.conn == "" { - glog.Infof("Connecting to Postgres for Stored %s. DB=%s, host=%s, port=%d, user=%s", + if cfg.Database.ConnectionInfo.Database != "" { + if provider == nil { + glog.Infof("Connecting to Database for Stored %s. Driver=%s, DB=%s, host=%s, port=%d, user=%s", cfg.DataType(), - cfg.Postgres.ConnectionInfo.Database, - cfg.Postgres.ConnectionInfo.Host, - cfg.Postgres.ConnectionInfo.Port, - cfg.Postgres.ConnectionInfo.Username) - db := newPostgresDB(cfg.DataType(), cfg.Postgres.ConnectionInfo) - dbc.conn = conn - dbc.db = db + cfg.Database.ConnectionInfo.Driver, + cfg.Database.ConnectionInfo.Database, + cfg.Database.ConnectionInfo.Host, + cfg.Database.ConnectionInfo.Port, + cfg.Database.ConnectionInfo.Username) + provider = db_provider.NewDbProvider(cfg.DataType(), cfg.Database.ConnectionInfo) } // Error out if config is trying to use multiple database connections for different stored requests (not supported yet) - if conn != dbc.conn { + if provider.Config() != cfg.Database.ConnectionInfo { glog.Fatal("Multiple database connection settings found in config, only a single database connection is currently supported.") } } - eventProducers := newEventProducers(cfg, client, dbc.db, metricsEngine, router) - fetcher = newFetcher(cfg, client, dbc.db) + eventProducers := newEventProducers(cfg, client, provider, metricsEngine, router) + fetcher = newFetcher(cfg, client, provider) var shutdown1 func() @@ -80,13 +70,13 @@ func CreateStoredRequests(cfg *config.StoredRequests, metricsEngine metrics.Metr if shutdown1 != nil { shutdown1() } - if dbc.db != nil { - db := dbc.db - dbc.db = nil - dbc.conn = "" - if err := db.Close(); err != nil { - glog.Errorf("Error closing DB connection: %v", err) - } + + if provider == nil { + return + } + + if err := provider.Close(); err != nil { + glog.Errorf("Error closing DB connection: %v", err) } } @@ -115,14 +105,14 @@ func NewStoredRequests(cfg *config.Configuration, metricsEngine metrics.MetricsE videoFetcher stored_requests.Fetcher, storedRespFetcher stored_requests.Fetcher) { - var dbc dbConnection + var provider db_provider.DbProvider - fetcher1, shutdown1 := CreateStoredRequests(&cfg.StoredRequests, metricsEngine, client, router, &dbc) - fetcher2, shutdown2 := CreateStoredRequests(&cfg.StoredRequestsAMP, metricsEngine, client, router, &dbc) - fetcher3, shutdown3 := CreateStoredRequests(&cfg.CategoryMapping, metricsEngine, client, router, &dbc) - fetcher4, shutdown4 := CreateStoredRequests(&cfg.StoredVideo, metricsEngine, client, router, &dbc) - fetcher5, shutdown5 := CreateStoredRequests(&cfg.Accounts, metricsEngine, client, router, &dbc) - fetcher6, shutdown6 := CreateStoredRequests(&cfg.StoredResponses, metricsEngine, client, router, &dbc) + fetcher1, shutdown1 := CreateStoredRequests(&cfg.StoredRequests, metricsEngine, client, router, provider) + fetcher2, shutdown2 := CreateStoredRequests(&cfg.StoredRequestsAMP, metricsEngine, client, router, provider) + fetcher3, shutdown3 := CreateStoredRequests(&cfg.CategoryMapping, metricsEngine, client, router, provider) + fetcher4, shutdown4 := CreateStoredRequests(&cfg.StoredVideo, metricsEngine, client, router, provider) + fetcher5, shutdown5 := CreateStoredRequests(&cfg.Accounts, metricsEngine, client, router, provider) + fetcher6, shutdown6 := CreateStoredRequests(&cfg.StoredResponses, metricsEngine, client, router, provider) fetcher = fetcher1.(stored_requests.Fetcher) ampFetcher = fetcher2.(stored_requests.Fetcher) @@ -159,17 +149,18 @@ func addListeners(cache stored_requests.Cache, eventProducers []events.EventProd } } -func newFetcher(cfg *config.StoredRequests, client *http.Client, db *sql.DB) (fetcher stored_requests.AllFetcher) { +func newFetcher(cfg *config.StoredRequests, client *http.Client, provider db_provider.DbProvider) (fetcher stored_requests.AllFetcher) { idList := make(stored_requests.MultiFetcher, 0, 3) if cfg.Files.Enabled { fFetcher := newFilesystem(cfg.DataType(), cfg.Files.Path) idList = append(idList, fFetcher) } - if cfg.Postgres.FetcherQueries.QueryTemplate != "" { - glog.Infof("Loading Stored %s data via Postgres.\nQuery: %s", cfg.DataType(), cfg.Postgres.FetcherQueries.QueryTemplate) - idList = append(idList, db_fetcher.NewFetcher(db, cfg.Postgres.FetcherQueries.MakeQuery, cfg.Postgres.FetcherQueries.MakeQueryResponses)) - } else if cfg.Postgres.CacheInitialization.Query != "" && cfg.Postgres.PollUpdates.Query != "" { + if cfg.Database.FetcherQueries.QueryTemplate != "" { + glog.Infof("Loading Stored %s data via Database.\nQuery: %s", cfg.DataType(), cfg.Database.FetcherQueries.QueryTemplate) + idList = append(idList, db_fetcher.NewFetcher(provider, + cfg.Database.FetcherQueries.QueryTemplate, cfg.Database.FetcherQueries.QueryTemplate)) + } else if cfg.Database.CacheInitialization.Query != "" && cfg.Database.PollUpdates.Query != "" { //in this case data will be loaded to cache via poll for updates event idList = append(idList, empty_fetcher.EmptyFetcher{}) } @@ -202,28 +193,28 @@ func newCache(cfg *config.StoredRequests) stored_requests.Cache { return cache } -func newEventProducers(cfg *config.StoredRequests, client *http.Client, db *sql.DB, metricsEngine metrics.MetricsEngine, router *httprouter.Router) (eventProducers []events.EventProducer) { +func newEventProducers(cfg *config.StoredRequests, client *http.Client, provider db_provider.DbProvider, metricsEngine metrics.MetricsEngine, router *httprouter.Router) (eventProducers []events.EventProducer) { if cfg.CacheEvents.Enabled { eventProducers = append(eventProducers, newEventsAPI(router, cfg.CacheEvents.Endpoint)) } if cfg.HTTPEvents.RefreshRate != 0 && cfg.HTTPEvents.Endpoint != "" { eventProducers = append(eventProducers, newHttpEvents(client, cfg.HTTPEvents.TimeoutDuration(), cfg.HTTPEvents.RefreshRateDuration(), cfg.HTTPEvents.Endpoint)) } - if cfg.Postgres.CacheInitialization.Query != "" { - pgEventCfg := postgresEvents.PostgresEventProducerConfig{ - DB: db, + if cfg.Database.CacheInitialization.Query != "" { + dbEventCfg := databaseEvents.DatabaseEventProducerConfig{ + Provider: provider, RequestType: cfg.DataType(), - CacheInitQuery: cfg.Postgres.CacheInitialization.Query, - CacheInitTimeout: time.Duration(cfg.Postgres.CacheInitialization.Timeout) * time.Millisecond, - CacheUpdateQuery: cfg.Postgres.PollUpdates.Query, - CacheUpdateTimeout: time.Duration(cfg.Postgres.PollUpdates.Timeout) * time.Millisecond, + CacheInitQuery: cfg.Database.CacheInitialization.Query, + CacheInitTimeout: time.Duration(cfg.Database.CacheInitialization.Timeout) * time.Millisecond, + CacheUpdateQuery: cfg.Database.PollUpdates.Query, + CacheUpdateTimeout: time.Duration(cfg.Database.PollUpdates.Timeout) * time.Millisecond, MetricsEngine: metricsEngine, } - pgEventProducer := postgresEvents.NewPostgresEventProducer(pgEventCfg) - fetchInterval := time.Duration(cfg.Postgres.PollUpdates.RefreshRate) * time.Second - pgEventTickerTask := task.NewTickerTask(fetchInterval, pgEventProducer) - pgEventTickerTask.Start() - eventProducers = append(eventProducers, pgEventProducer) + dbEventProducer := databaseEvents.NewDatabaseEventProducer(dbEventCfg) + fetchInterval := time.Duration(cfg.Database.PollUpdates.RefreshRate) * time.Second + dbEventTickerTask := task.NewTickerTask(fetchInterval, dbEventProducer) + dbEventTickerTask.Start() + eventProducers = append(eventProducers, dbEventProducer) } return } @@ -251,19 +242,6 @@ func newFilesystem(dataType config.DataType, configPath string) stored_requests. return fetcher } -func newPostgresDB(dataType config.DataType, cfg config.PostgresConnection) *sql.DB { - db, err := sql.Open("postgres", cfg.ConnString()) - if err != nil { - glog.Fatalf("Failed to open %s postgres connection: %v", dataType, err) - } - - if err := db.Ping(); err != nil { - glog.Fatalf("Failed to ping %s postgres: %v", dataType, err) - } - - return db -} - // consolidate returns a single Fetcher from an array of fetchers of any size. func consolidate(dataType config.DataType, fetchers []stored_requests.AllFetcher) stored_requests.AllFetcher { if len(fetchers) == 0 { diff --git a/stored_requests/config/config_test.go b/stored_requests/config/config_test.go index 7cf6c38af0c..b06feea7d31 100644 --- a/stored_requests/config/config_test.go +++ b/stored_requests/config/config_test.go @@ -2,7 +2,6 @@ package config import ( "context" - "database/sql" "encoding/json" "errors" "net/http" @@ -17,6 +16,7 @@ import ( "github.com/prebid/prebid-server/config" "github.com/prebid/prebid-server/metrics" "github.com/prebid/prebid-server/stored_requests" + "github.com/prebid/prebid-server/stored_requests/backends/db_provider" "github.com/prebid/prebid-server/stored_requests/backends/empty_fetcher" "github.com/prebid/prebid-server/stored_requests/backends/http_fetcher" "github.com/prebid/prebid-server/stored_requests/events" @@ -56,59 +56,68 @@ func TestNewEmptyFetcher(t *testing.T) { }, { config: &config.StoredRequests{ - Postgres: config.PostgresConfig{ - CacheInitialization: config.PostgresCacheInitializer{ + Database: config.DatabaseConfig{ + ConnectionInfo: config.DatabaseConnection{ + Driver: "postgres", + }, + CacheInitialization: config.DatabaseCacheInitializer{ Query: "test query", }, - PollUpdates: config.PostgresUpdatePolling{ + PollUpdates: config.DatabaseUpdatePolling{ Query: "test poll query", }, - FetcherQueries: config.PostgresFetcherQueries{ + FetcherQueries: config.DatabaseFetcherQueries{ QueryTemplate: "", }, }, }, emptyFetcher: true, - description: "If Postgres fetcher query is not defined, but Postgres Cache init query and Postgres update polling query are defined EmptyFetcher should be returned", + description: "If Database fetcher query is not defined, but Database Cache init query and Database update polling query are defined EmptyFetcher should be returned", }, { config: &config.StoredRequests{ - Postgres: config.PostgresConfig{ - CacheInitialization: config.PostgresCacheInitializer{ + Database: config.DatabaseConfig{ + ConnectionInfo: config.DatabaseConnection{ + Driver: "postgres", + }, + CacheInitialization: config.DatabaseCacheInitializer{ Query: "", }, - PollUpdates: config.PostgresUpdatePolling{ + PollUpdates: config.DatabaseUpdatePolling{ Query: "", }, - FetcherQueries: config.PostgresFetcherQueries{ + FetcherQueries: config.DatabaseFetcherQueries{ QueryTemplate: "test fetcher query", }, }, }, emptyFetcher: false, - description: "If Postgres fetcher query is defined, but Postgres Cache init query and Postgres update polling query are not defined not EmptyFetcher (DBFetcher) should be returned", + description: "If Database fetcher query is defined, but Database Cache init query and Database update polling query are not defined not EmptyFetcher (DBFetcher) should be returned", }, { config: &config.StoredRequests{ - Postgres: config.PostgresConfig{ - CacheInitialization: config.PostgresCacheInitializer{ + Database: config.DatabaseConfig{ + ConnectionInfo: config.DatabaseConnection{ + Driver: "postgres", + }, + CacheInitialization: config.DatabaseCacheInitializer{ Query: "test cache query", }, - PollUpdates: config.PostgresUpdatePolling{ + PollUpdates: config.DatabaseUpdatePolling{ Query: "test poll query", }, - FetcherQueries: config.PostgresFetcherQueries{ + FetcherQueries: config.DatabaseFetcherQueries{ QueryTemplate: "test fetcher query", }, }, }, emptyFetcher: false, - description: "If Postgres fetcher query is defined and Postgres Cache init query and Postgres update polling query are defined not EmptyFetcher (DBFetcher) should be returned", + description: "If Database fetcher query is defined and Database Cache init query and Database update polling query are defined not EmptyFetcher (DBFetcher) should be returned", }, } for _, test := range testCases { - fetcher := newFetcher(test.config, nil, &sql.DB{}) + fetcher := newFetcher(test.config, nil, db_provider.DbProviderMock{}) assert.NotNil(t, fetcher, "The fetcher should be non-nil.") if test.emptyFetcher { assert.Equal(t, empty_fetcher.EmptyFetcher{}, fetcher, "Empty fetcher should be returned") @@ -190,18 +199,18 @@ func TestNewInMemoryAccountCache(t *testing.T) { assert.True(t, isEmptyCacheType(cache.Responses), "The newCache method should return an empty Responses cache for Accounts config") } -func TestNewPostgresEventProducers(t *testing.T) { +func TestNewDatabaseEventProducers(t *testing.T) { metricsMock := &metrics.MetricsEngineMock{} metricsMock.Mock.On("RecordStoredDataFetchTime", mock.Anything, mock.Anything).Return() metricsMock.Mock.On("RecordStoredDataError", mock.Anything).Return() cfg := &config.StoredRequests{ - Postgres: config.PostgresConfig{ - CacheInitialization: config.PostgresCacheInitializer{ + Database: config.DatabaseConfig{ + CacheInitialization: config.DatabaseCacheInitializer{ Timeout: 50, Query: "SELECT id, requestData, type FROM stored_data", }, - PollUpdates: config.PostgresUpdatePolling{ + PollUpdates: config.DatabaseUpdatePolling{ RefreshRate: 20, Timeout: 50, Query: "SELECT id, requestData, type FROM stored_data WHERE last_updated > $1", @@ -209,13 +218,13 @@ func TestNewPostgresEventProducers(t *testing.T) { }, } client := &http.Client{} - db, mock, err := sqlmock.New() + provider, mock, err := db_provider.NewDbProviderMock() if err != nil { t.Fatalf("Failed to create mock: %v", err) } - mock.ExpectQuery("^" + regexp.QuoteMeta(cfg.Postgres.CacheInitialization.Query) + "$").WillReturnError(errors.New("Query failed")) + mock.ExpectQuery("^" + regexp.QuoteMeta(cfg.Database.CacheInitialization.Query) + "$").WillReturnError(errors.New("Query failed")) - evProducers := newEventProducers(cfg, client, db, metricsMock, nil) + evProducers := newEventProducers(cfg, client, provider, metricsMock, nil) assertProducerLength(t, evProducers, 1) assertExpectationsMet(t, mock) diff --git a/stored_requests/events/postgres/database.go b/stored_requests/events/database/database.go similarity index 82% rename from stored_requests/events/postgres/database.go rename to stored_requests/events/database/database.go index 9d69e84b164..24eddf214eb 100644 --- a/stored_requests/events/postgres/database.go +++ b/stored_requests/events/database/database.go @@ -1,4 +1,4 @@ -package postgres +package database import ( "bytes" @@ -11,6 +11,7 @@ import ( "github.com/golang/glog" "github.com/prebid/prebid-server/config" "github.com/prebid/prebid-server/metrics" + "github.com/prebid/prebid-server/stored_requests/backends/db_provider" "github.com/prebid/prebid-server/stored_requests/events" "github.com/prebid/prebid-server/util/timeutil" ) @@ -28,8 +29,8 @@ var storedDataTypeMetricMap = map[config.DataType]metrics.StoredDataType{ config.ResponseDataType: metrics.ResponseDataType, } -type PostgresEventProducerConfig struct { - DB *sql.DB +type DatabaseEventProducerConfig struct { + Provider db_provider.DbProvider RequestType config.DataType CacheInitQuery string CacheInitTimeout time.Duration @@ -38,20 +39,20 @@ type PostgresEventProducerConfig struct { MetricsEngine metrics.MetricsEngine } -type PostgresEventProducer struct { - cfg PostgresEventProducerConfig +type DatabaseEventProducer struct { + cfg DatabaseEventProducerConfig lastUpdate time.Time invalidations chan events.Invalidation saves chan events.Save time timeutil.Time } -func NewPostgresEventProducer(cfg PostgresEventProducerConfig) (eventProducer *PostgresEventProducer) { - if cfg.DB == nil { - glog.Fatalf("The Postgres Stored %s Loader needs a database connection to work.", cfg.RequestType) +func NewDatabaseEventProducer(cfg DatabaseEventProducerConfig) (eventProducer *DatabaseEventProducer) { + if cfg.Provider == nil { + glog.Fatalf("The Database Stored %s Loader needs a database connection to work.", cfg.RequestType) } - return &PostgresEventProducer{ + return &DatabaseEventProducer{ cfg: cfg, lastUpdate: time.Time{}, saves: make(chan events.Save, 1), @@ -60,7 +61,7 @@ func NewPostgresEventProducer(cfg PostgresEventProducerConfig) (eventProducer *P } } -func (e *PostgresEventProducer) Run() error { +func (e *DatabaseEventProducer) Run() error { if e.lastUpdate.IsZero() { return e.fetchAll() } @@ -68,21 +69,21 @@ func (e *PostgresEventProducer) Run() error { return e.fetchDelta() } -func (e *PostgresEventProducer) Saves() <-chan events.Save { +func (e *DatabaseEventProducer) Saves() <-chan events.Save { return e.saves } -func (e *PostgresEventProducer) Invalidations() <-chan events.Invalidation { +func (e *DatabaseEventProducer) Invalidations() <-chan events.Invalidation { return e.invalidations } -func (e *PostgresEventProducer) fetchAll() (fetchErr error) { +func (e *DatabaseEventProducer) fetchAll() (fetchErr error) { timeout := e.cfg.CacheInitTimeout * time.Millisecond ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() startTime := e.time.Now().UTC() - rows, err := e.cfg.DB.QueryContext(ctx, e.cfg.CacheInitQuery) + rows, err := e.cfg.Provider.QueryContext(ctx, e.cfg.CacheInitQuery) elapsedTime := time.Since(startTime) e.recordFetchTime(elapsedTime, metrics.FetchAll) @@ -113,13 +114,18 @@ func (e *PostgresEventProducer) fetchAll() (fetchErr error) { return nil } -func (e *PostgresEventProducer) fetchDelta() (fetchErr error) { +func (e *DatabaseEventProducer) fetchDelta() (fetchErr error) { timeout := e.cfg.CacheUpdateTimeout * time.Millisecond ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() startTime := e.time.Now().UTC() - rows, err := e.cfg.DB.QueryContext(ctx, e.cfg.CacheUpdateQuery, e.lastUpdate) + + params := []db_provider.QueryParam{ + {Name: "LAST_UPDATED", Value: e.lastUpdate}, + } + + rows, err := e.cfg.Provider.QueryContext(ctx, e.cfg.CacheUpdateQuery, params...) elapsedTime := time.Since(startTime) e.recordFetchTime(elapsedTime, metrics.FetchDelta) @@ -150,7 +156,7 @@ func (e *PostgresEventProducer) fetchDelta() (fetchErr error) { return nil } -func (e *PostgresEventProducer) recordFetchTime(elapsedTime time.Duration, fetchType metrics.StoredDataFetchType) { +func (e *DatabaseEventProducer) recordFetchTime(elapsedTime time.Duration, fetchType metrics.StoredDataFetchType) { e.cfg.MetricsEngine.RecordStoredDataFetchTime( metrics.StoredDataLabels{ DataType: storedDataTypeMetricMap[e.cfg.RequestType], @@ -158,7 +164,7 @@ func (e *PostgresEventProducer) recordFetchTime(elapsedTime time.Duration, fetch }, elapsedTime) } -func (e *PostgresEventProducer) recordError(errorType metrics.StoredDataError) { +func (e *DatabaseEventProducer) recordError(errorType metrics.StoredDataError) { e.cfg.MetricsEngine.RecordStoredDataError( metrics.StoredDataLabels{ DataType: storedDataTypeMetricMap[e.cfg.RequestType], @@ -168,7 +174,7 @@ func (e *PostgresEventProducer) recordError(errorType metrics.StoredDataError) { // sendEvents reads the rows and sends notifications into the channel for any updates. // If it returns an error, then callers can be certain that no events were sent to the channels. -func (e *PostgresEventProducer) sendEvents(rows *sql.Rows) (err error) { +func (e *DatabaseEventProducer) sendEvents(rows *sql.Rows) (err error) { storedRequestData := make(map[string]json.RawMessage) storedImpData := make(map[string]json.RawMessage) storedRespData := make(map[string]json.RawMessage) diff --git a/stored_requests/events/postgres/database_test.go b/stored_requests/events/database/database_test.go similarity index 96% rename from stored_requests/events/postgres/database_test.go rename to stored_requests/events/database/database_test.go index c3dfa0ae70e..8ce21bfde95 100644 --- a/stored_requests/events/postgres/database_test.go +++ b/stored_requests/events/database/database_test.go @@ -1,4 +1,4 @@ -package postgres +package database import ( "encoding/json" @@ -9,6 +9,7 @@ import ( "github.com/prebid/prebid-server/config" "github.com/prebid/prebid-server/metrics" + "github.com/prebid/prebid-server/stored_requests/backends/db_provider" "github.com/prebid/prebid-server/stored_requests/events" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -113,7 +114,7 @@ func TestFetchAllSuccess(t *testing.T) { } for _, tt := range tests { - db, dbMock, _ := sqlmock.New() + provider, dbMock, _ := db_provider.NewDbProviderMock() dbMock.ExpectQuery(fakeQueryRegex()).WillReturnRows(tt.giveMockRows) metricsMock := &metrics.MetricsEngineMock{} @@ -122,8 +123,8 @@ func TestFetchAllSuccess(t *testing.T) { DataFetchType: metrics.FetchAll, }, mock.Anything).Return() - eventProducer := NewPostgresEventProducer(PostgresEventProducerConfig{ - DB: db, + eventProducer := NewDatabaseEventProducer(DatabaseEventProducerConfig{ + Provider: provider, RequestType: config.RequestDataType, CacheInitTimeout: 100 * time.Millisecond, CacheInitQuery: fakeQuery, @@ -196,7 +197,7 @@ func TestFetchAllErrors(t *testing.T) { } for _, tt := range tests { - db, dbMock, _ := sqlmock.New() + provider, dbMock, _ := db_provider.NewDbProviderMock() if tt.giveMockRows == nil { dbMock.ExpectQuery(fakeQueryRegex()).WillReturnError(errors.New("Query failed.")) } else { @@ -213,8 +214,8 @@ func TestFetchAllErrors(t *testing.T) { Error: tt.wantRecordedError, }).Return() - eventProducer := NewPostgresEventProducer(PostgresEventProducerConfig{ - DB: db, + eventProducer := NewDatabaseEventProducer(DatabaseEventProducerConfig{ + Provider: provider, RequestType: config.RequestDataType, CacheInitTimeout: time.Duration(tt.giveTimeoutMS) * time.Millisecond, CacheInitQuery: fakeQuery, @@ -349,7 +350,7 @@ func TestFetchDeltaSuccess(t *testing.T) { } for _, tt := range tests { - db, dbMock, _ := sqlmock.New() + provider, dbMock, _ := db_provider.NewDbProviderMock() dbMock.ExpectQuery(fakeQueryRegex()).WillReturnRows(tt.giveMockRows) metricsMock := &metrics.MetricsEngineMock{} @@ -358,8 +359,8 @@ func TestFetchDeltaSuccess(t *testing.T) { DataFetchType: metrics.FetchDelta, }, mock.Anything).Return() - eventProducer := NewPostgresEventProducer(PostgresEventProducerConfig{ - DB: db, + eventProducer := NewDatabaseEventProducer(DatabaseEventProducerConfig{ + Provider: provider, RequestType: config.RequestDataType, CacheUpdateTimeout: 100 * time.Millisecond, CacheUpdateQuery: fakeQuery, @@ -437,7 +438,7 @@ func TestFetchDeltaErrors(t *testing.T) { } for _, tt := range tests { - db, dbMock, _ := sqlmock.New() + provider, dbMock, _ := db_provider.NewDbProviderMock() if tt.giveMockRows == nil { dbMock.ExpectQuery(fakeQueryRegex()).WillReturnError(errors.New("Query failed.")) } else { @@ -454,8 +455,8 @@ func TestFetchDeltaErrors(t *testing.T) { Error: tt.wantRecordedError, }).Return() - eventProducer := NewPostgresEventProducer(PostgresEventProducerConfig{ - DB: db, + eventProducer := NewDatabaseEventProducer(DatabaseEventProducerConfig{ + Provider: provider, RequestType: config.RequestDataType, CacheUpdateTimeout: time.Duration(tt.giveTimeoutMS) * time.Millisecond, CacheUpdateQuery: fakeQuery,