diff --git a/dbx/pg/config.go b/dbx/pg/config.go index 250ec6d..757edce 100644 --- a/dbx/pg/config.go +++ b/dbx/pg/config.go @@ -1,6 +1,7 @@ package pg import ( + "context" "fmt" "strconv" "strings" @@ -8,6 +9,24 @@ import ( "github.com/zeiss/pkg/utilx" ) +type contextKey int + +const ( + configKey contextKey = iota +) + +// Context returns a new Context that carries the provided Config. +func (cfg Config) Context(ctx context.Context) context.Context { + return context.WithValue(ctx, configKey, cfg) +} + +// FromContext will return the Config carried in the provided Context. +// +// It panics if config is not available on the current context. +func FromContext(ctx context.Context) Config { + return ctx.Value(configKey).(Config) +} + // Config represents configuration for PostgreSQL connection type Config struct { Database string `envconfig:"PG_DB_NAME"` diff --git a/dbx/pg/config_test.go b/dbx/pg/config_test.go index 80ed6c0..4a156ba 100644 --- a/dbx/pg/config_test.go +++ b/dbx/pg/config_test.go @@ -1,6 +1,7 @@ package pg_test import ( + "context" "testing" "github.com/zeiss/pkg/dbx/pg" @@ -33,3 +34,19 @@ func TestFormatDSN(t *testing.T) { dsn := config.FormatDSN() assert.Equal(t, "dbname=test_db user=test_user password=password host=localhost port=5432 sslmode=disable", dsn) } + +func TestContext(t *testing.T) { + t.Parallel() + + config := pg.Config{ + Database: "test_db", + Host: "localhost", + Password: "password", + Port: 5432, + SslMode: "disable", + User: "test_user", + } + + ctx := config.Context(context.Background()) + assert.Equal(t, config, pg.FromContext(ctx)) +}