From 65b7a66bcaab3e0372f277756a0eb0600f32df30 Mon Sep 17 00:00:00 2001 From: Tarun Koyalwar <45962551+tarunKoyalwar@users.noreply.github.com> Date: Tue, 23 Jan 2024 00:50:20 +0530 Subject: [PATCH] validate invalid/unsupported schemes (#311) * validate invalid/unsupported schemes * make apiserver env optional --- auth/pdcp/creds.go | 2 +- url/parsers.go | 5 ++++- url/url_test.go | 20 ++++++++++++++++++++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/auth/pdcp/creds.go b/auth/pdcp/creds.go index e3049c0..f6f3c03 100644 --- a/auth/pdcp/creds.go +++ b/auth/pdcp/creds.go @@ -74,7 +74,7 @@ func (p *PDCPCredHandler) GetCreds() (*PDCPCredentials, error) { // if not or incomplete credentials are found it return nil func (p *PDCPCredHandler) getCredsFromEnv() *PDCPCredentials { apiKey := env.GetEnvOrDefault(apiKeyEnv, "") - apiServer := env.GetEnvOrDefault(apiServerEnv, "") + apiServer := env.GetEnvOrDefault(apiServerEnv, DefaultApiServer) if apiKey == "" || apiServer == "" { return nil } diff --git a/url/parsers.go b/url/parsers.go index fa575e7..711fbd9 100644 --- a/url/parsers.go +++ b/url/parsers.go @@ -146,7 +146,10 @@ func absoluteURLParser(u *URL) (*URL, error) { FTP + SchemeSeparator, "//", } - if stringsutil.HasPrefixAny(u.Original, allowedSchemes...) { + if strings.Contains(u.Original, SchemeSeparator) || strings.HasPrefix(u.Original, "//") { + if !strings.HasPrefix(u.Original, "//") && !stringsutil.HasPrefixAny(u.Original, allowedSchemes...) { + return nil, errorutil.NewWithTag("urlutil", "failed to parse url got invalid scheme input=%v", u.Original) + } u.IsRelative = false urlparse, parseErr := url.Parse(u.Original) if parseErr != nil { diff --git a/url/url_test.go b/url/url_test.go index 8c001f9..99bbe63 100644 --- a/url/url_test.go +++ b/url/url_test.go @@ -204,3 +204,23 @@ func TestUnicodeEscapeWithUnsafe(t *testing.T) { require.Equal(t, v.expected, urlx.String()) } } + +func TestInvalidScheme(t *testing.T) { + testcases := []struct { + input string + expectedErr bool + }{ + {"//:foo", true}, + {"://foo", true}, + } + for _, v := range testcases { + urlx, err := ParseAbsoluteURL(v.input, true) + if v.expectedErr { + require.NotNil(t, err) + require.Nil(t, urlx) + } else { + require.Nil(t, err) + require.NotNil(t, urlx) + } + } +}