diff --git a/dburl.go b/dburl.go index 31d6c10..44903b2 100644 --- a/dburl.go +++ b/dburl.go @@ -11,6 +11,7 @@ package dburl import ( "database/sql" + "fmt" "io/fs" "net/url" "os" @@ -42,6 +43,17 @@ func Open(urlstr string) (*sql.DB, error) { return sql.Open(driver, u.DSN) } +// OpenMap takes a map of URL components and opens a standard [sql.DB] connection. +// +// See [BuildURL] for information on the recognized map components. +func OpenMap(components map[string]interface{}) (*sql.DB, error) { + urlstr, err := BuildURL(components) + if err != nil { + return nil, err + } + return Open(urlstr) +} + // URL wraps the standard [net/url.URL] type, adding OriginalScheme, Transport, // Driver, Unaliased, and DSN strings. type URL struct { @@ -162,6 +174,30 @@ func Parse(urlstr string) (*URL, error) { return u, nil } +// FromMap creates a [URL] using the mapped components. +// +// Recognized components are: +// +// protocol, proto, scheme +// transport +// username, user +// password, pass +// hostname, host +// port +// path, file, opaque +// database, dbname, db +// instance +// parameters, params, options, opts, query, q +// +// See [BuildURL] for more information. +func FromMap(components map[string]interface{}) (*URL, error) { + urlstr, err := BuildURL(components) + if err != nil { + return nil, err + } + return Parse(urlstr) +} + // String satisfies the [fmt.Stringer] interface. func (u *URL) String() string { p := &url.URL{ @@ -334,6 +370,8 @@ const ( ErrMissingPath Error = "missing path" // ErrMissingUser is the missing user error. ErrMissingUser Error = "missing user" + // ErrInvalidQuery is the invalid query error. + ErrInvalidQuery Error = "invalid query" ) // Stat is the default stat func. @@ -355,6 +393,88 @@ var OpenFile = func(name string) (fs.File, error) { return f, nil } +// BuildURL creates a dsn using the mapped components. +// +// Recognized components are: +// +// protocol, proto, scheme +// transport +// username, user +// password, pass +// hostname, host +// port +// path, file, opaque +// database, dbname, db +// instance +// parameters, params, options, opts, query, q +// +// See [BuildURL] for more information. +func BuildURL(components map[string]interface{}) (string, error) { + if components == nil { + return "", ErrInvalidDatabaseScheme + } + var urlstr string + if proto, ok := getComponent(components, "protocol", "proto", "scheme"); ok { + if transport, ok := getComponent(components, "transport"); ok { + proto += "+" + transport + } + urlstr = proto + ":" + } + if host, ok := getComponent(components, "hostname", "host"); ok { + hostinfo := url.QueryEscape(host) + if port, ok := getComponent(components, "port"); ok { + hostinfo += ":" + port + } + var userinfo string + if user, ok := getComponent(components, "username", "user"); ok { + userinfo += url.QueryEscape(user) + if pass, ok := getComponent(components, "password", "pass"); ok { + userinfo += ":" + url.QueryEscape(pass) + } + hostinfo = userinfo + "@" + hostinfo + } + urlstr += "//" + hostinfo + } + if pathstr, ok := getComponent(components, "path", "file", "opaque"); ok { + if urlstr == "" { + urlstr += "file:" + } + urlstr += pathstr + } else { + var v []string + if instance, ok := getComponent(components, "instance"); ok { + v = append(v, url.PathEscape(instance)) + } + if dbname, ok := getComponent(components, "database", "dbname", "db"); ok { + v = append(v, url.PathEscape(dbname)) + } + if len(v) != 0 { + if s := path.Join(v...); s != "" { + urlstr += "/" + s + } + } + } + if v, ok := getFirst(components, "parameters", "params", "options", "opts", "query", "q"); ok { + switch z := v.(type) { + case string: + if z != "" { + urlstr += "?" + z + } + case map[string]interface{}: + q := url.Values{} + for k, v := range z { + q.Set(k, fmt.Sprintf("%v", v)) + } + if s := q.Encode(); s != "" { + urlstr += "?" + s + } + default: + return "", ErrInvalidQuery + } + } + return urlstr, nil +} + // resolveType tries to resolve a path to a Unix domain socket or directory. func resolveType(s string) (string, bool) { if i := strings.LastIndex(s, "?"); i != -1 { @@ -430,3 +550,23 @@ func mode(s string) os.FileMode { } return 0 } + +// getComponent returns the first defined component in the map. +func getComponent(m map[string]interface{}, v ...string) (string, bool) { + if z, ok := getFirst(m, v...); ok { + str := fmt.Sprintf("%v", z) + return str, str != "" + + } + return "", false +} + +// getFirst returns the first value in the map. +func getFirst(m map[string]interface{}, v ...string) (interface{}, bool) { + for _, s := range v { + if z, ok := m[s]; ok { + return z, ok + } + } + return nil, false +} diff --git a/dburl_test.go b/dburl_test.go index 81e3ec2..0eb7d32 100644 --- a/dburl_test.go +++ b/dburl_test.go @@ -983,6 +983,87 @@ func testParse(t *testing.T, s, d, exp, path string) { } } +func TestBuildURL(t *testing.T) { + tests := []struct { + m map[string]interface{} + exp string + err error + }{ + {nil, "", ErrInvalidDatabaseScheme}, + { + map[string]interface{}{ + "proto": "mysql", + "transport": "tcp", + "host": "localhost", + "port": 999, + "q": map[string]interface{}{ + "foo": "bar", + "opt1": "b", + }, + }, + "mysql+tcp://localhost:999?foo=bar&opt1=b", nil, + }, + { + map[string]interface{}{ + "proto": "sqlserver", + "host": "localhost", + "port": "5555", + "instance": "instance", + "database": "dbname", + "q": map[string]interface{}{ + "foo": "bar", + "opt1": "b", + }, + }, + "sqlserver://localhost:5555/instance/dbname?foo=bar&opt1=b", nil, + }, + { + map[string]interface{}{ + "proto": "pg", + "host": "host name", + "user": "user name", + "password": "P!!!@@@@ 👀", + "database": "my awesome db", + "q": map[string]interface{}{ + "foo": "bar is cool", + "opt1": "b zzzz@@@:/", + }, + }, + "pg://user+name:P%21%21%21%40%40%40%40+%F0%9F%91%80@host+name/my%20awesome%20db?foo=bar+is+cool&opt1=b+zzzz%40%40%40%3A%2F", nil, + }, + { + map[string]interface{}{ + "file": "fake.sqlite3", + "q": map[string]interface{}{ + "foo": "bar", + "opt1": "b", + }, + }, + "file:fake.sqlite3?foo=bar&opt1=b", nil, + }, + } + for i, test := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + switch s, err := BuildURL(test.m); { + case err != nil && !errors.Is(err, test.err): + t.Fatalf("expected error %v, got: %v", test.err, err) + case err != nil && test.err == nil: + t.Fatalf("expected no error, got: %v", err) + case s != test.exp: + t.Errorf("expected %q, got: %q", test.exp, s) + default: + t.Logf("dsn: %q", s) + } + switch u, err := FromMap(test.m); { + case err != nil: + t.Logf("parse error: %v", err) + default: + t.Logf("url: %q", u.String()) + } + }) + } +} + func init() { statFile, openFile := Stat, OpenFile Stat = func(name string) (fs.FileInfo, error) {