Skip to content

Commit

Permalink
Adding BuildDSN, OpenMap, and FromMap funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
kenshaw committed Apr 4, 2024
1 parent d86b6db commit 95c81b1
Show file tree
Hide file tree
Showing 2 changed files with 221 additions and 0 deletions.
140 changes: 140 additions & 0 deletions dburl.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package dburl

import (
"database/sql"
"fmt"
"io/fs"
"net/url"
"os"
Expand Down Expand Up @@ -42,6 +43,17 @@ func Open(urlstr string) (*sql.DB, error) {
return sql.Open(driver, u.DSN)
}

// OpenMap takes a map of URL components and opens a standard [sql.DB] connection.
//
// See [BuildURL] for information on the recognized map components.
func OpenMap(components map[string]interface{}) (*sql.DB, error) {
urlstr, err := BuildURL(components)
if err != nil {
return nil, err
}
return Open(urlstr)
}

// URL wraps the standard [net/url.URL] type, adding OriginalScheme, Transport,
// Driver, Unaliased, and DSN strings.
type URL struct {
Expand Down Expand Up @@ -162,6 +174,30 @@ func Parse(urlstr string) (*URL, error) {
return u, nil
}

// FromMap creates a [URL] using the mapped components.
//
// Recognized components are:
//
// protocol, proto, scheme
// transport
// username, user
// password, pass
// hostname, host
// port
// path, file, opaque
// database, dbname, db
// instance
// parameters, params, options, opts, query, q
//
// See [BuildURL] for more information.
func FromMap(components map[string]interface{}) (*URL, error) {
urlstr, err := BuildURL(components)
if err != nil {
return nil, err
}
return Parse(urlstr)
}

// String satisfies the [fmt.Stringer] interface.
func (u *URL) String() string {
p := &url.URL{
Expand Down Expand Up @@ -334,6 +370,8 @@ const (
ErrMissingPath Error = "missing path"
// ErrMissingUser is the missing user error.
ErrMissingUser Error = "missing user"
// ErrInvalidQuery is the invalid query error.
ErrInvalidQuery Error = "invalid query"
)

// Stat is the default stat func.
Expand All @@ -355,6 +393,88 @@ var OpenFile = func(name string) (fs.File, error) {
return f, nil
}

// BuildURL creates a dsn using the mapped components.
//
// Recognized components are:
//
// protocol, proto, scheme
// transport
// username, user
// password, pass
// hostname, host
// port
// path, file, opaque
// database, dbname, db
// instance
// parameters, params, options, opts, query, q
//
// See [BuildURL] for more information.
func BuildURL(components map[string]interface{}) (string, error) {
if components == nil {
return "", ErrInvalidDatabaseScheme
}
var urlstr string
if proto, ok := getComponent(components, "protocol", "proto", "scheme"); ok {
if transport, ok := getComponent(components, "transport"); ok {
proto += "+" + transport
}
urlstr = proto + ":"
}
if host, ok := getComponent(components, "hostname", "host"); ok {
hostinfo := url.QueryEscape(host)
if port, ok := getComponent(components, "port"); ok {
hostinfo += ":" + port
}
var userinfo string
if user, ok := getComponent(components, "username", "user"); ok {
userinfo += url.QueryEscape(user)
if pass, ok := getComponent(components, "password", "pass"); ok {
userinfo += ":" + url.QueryEscape(pass)
}
hostinfo = userinfo + "@" + hostinfo
}
urlstr += "//" + hostinfo
}
if pathstr, ok := getComponent(components, "path", "file", "opaque"); ok {
if urlstr == "" {
urlstr += "file:"
}
urlstr += pathstr
} else {
var v []string
if instance, ok := getComponent(components, "instance"); ok {
v = append(v, url.PathEscape(instance))
}
if dbname, ok := getComponent(components, "database", "dbname", "db"); ok {
v = append(v, url.PathEscape(dbname))
}
if len(v) != 0 {
if s := path.Join(v...); s != "" {
urlstr += "/" + s
}
}
}
if v, ok := getFirst(components, "parameters", "params", "options", "opts", "query", "q"); ok {
switch z := v.(type) {
case string:
if z != "" {
urlstr += "?" + z
}
case map[string]interface{}:
q := url.Values{}
for k, v := range z {
q.Set(k, fmt.Sprintf("%v", v))
}
if s := q.Encode(); s != "" {
urlstr += "?" + s
}
default:
return "", ErrInvalidQuery
}
}
return urlstr, nil
}

// resolveType tries to resolve a path to a Unix domain socket or directory.
func resolveType(s string) (string, bool) {
if i := strings.LastIndex(s, "?"); i != -1 {
Expand Down Expand Up @@ -430,3 +550,23 @@ func mode(s string) os.FileMode {
}
return 0
}

// getComponent returns the first defined component in the map.
func getComponent(m map[string]interface{}, v ...string) (string, bool) {
if z, ok := getFirst(m, v...); ok {
str := fmt.Sprintf("%v", z)
return str, str != ""

}
return "", false
}

// getFirst returns the first value in the map.
func getFirst(m map[string]interface{}, v ...string) (interface{}, bool) {
for _, s := range v {
if z, ok := m[s]; ok {
return z, ok
}
}
return nil, false
}
81 changes: 81 additions & 0 deletions dburl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,87 @@ func testParse(t *testing.T, s, d, exp, path string) {
}
}

func TestBuildURL(t *testing.T) {
tests := []struct {
m map[string]interface{}
exp string
err error
}{
{nil, "", ErrInvalidDatabaseScheme},
{
map[string]interface{}{
"proto": "mysql",
"transport": "tcp",
"host": "localhost",
"port": 999,
"q": map[string]interface{}{
"foo": "bar",
"opt1": "b",
},
},
"mysql+tcp://localhost:999?foo=bar&opt1=b", nil,
},
{
map[string]interface{}{
"proto": "sqlserver",
"host": "localhost",
"port": "5555",
"instance": "instance",
"database": "dbname",
"q": map[string]interface{}{
"foo": "bar",
"opt1": "b",
},
},
"sqlserver://localhost:5555/instance/dbname?foo=bar&opt1=b", nil,
},
{
map[string]interface{}{
"proto": "pg",
"host": "host name",
"user": "user name",
"password": "P!!!@@@@ 👀",
"database": "my awesome db",
"q": map[string]interface{}{
"foo": "bar is cool",
"opt1": "b zzzz@@@:/",
},
},
"pg://user+name:P%21%21%21%40%40%40%40+%F0%9F%91%80@host+name/my%20awesome%20db?foo=bar+is+cool&opt1=b+zzzz%40%40%40%3A%2F", nil,
},
{
map[string]interface{}{
"file": "fake.sqlite3",
"q": map[string]interface{}{
"foo": "bar",
"opt1": "b",
},
},
"file:fake.sqlite3?foo=bar&opt1=b", nil,
},
}
for i, test := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
switch s, err := BuildURL(test.m); {
case err != nil && !errors.Is(err, test.err):
t.Fatalf("expected error %v, got: %v", test.err, err)
case err != nil && test.err == nil:
t.Fatalf("expected no error, got: %v", err)
case s != test.exp:
t.Errorf("expected %q, got: %q", test.exp, s)
default:
t.Logf("dsn: %q", s)
}
switch u, err := FromMap(test.m); {
case err != nil:
t.Logf("parse error: %v", err)
default:
t.Logf("url: %q", u.String())
}
})
}
}

func init() {
statFile, openFile := Stat, OpenFile
Stat = func(name string) (fs.FileInfo, error) {
Expand Down

0 comments on commit 95c81b1

Please sign in to comment.