Skip to content

Commit

Permalink
Update BuildAuthToken to validate endpoint contains a port (#1837)
Browse files Browse the repository at this point in the history
* validated that the right side of the colon has to be an string representation of an integer
* fixed linter error
* Add changelog description

Co-authored-by: Sean McGrail <[email protected]>
  • Loading branch information
RanVaknin and skmcgrail authored Sep 13, 2022
1 parent b011f04 commit 63566f0
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 10 deletions.
8 changes: 8 additions & 0 deletions .changelog/aaba2642a8f64293a3ad7dd5bc0e9ef7.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "aaba2642-a8f6-4293-a3ad-7dd5bc0e9ef7",
"type": "feature",
"description": "Updated `BuildAuthToken` to validate the provided endpoint contains a port.",
"modules": [
"feature/rds/auth"
]
}
29 changes: 29 additions & 0 deletions feature/rds/auth/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/http"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -44,6 +45,11 @@ type BuildAuthTokenOptions struct{}
// See http://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html
// for more information on using IAM database authentication with RDS.
func BuildAuthToken(ctx context.Context, endpoint, region, dbUser string, creds aws.CredentialsProvider, optFns ...func(options *BuildAuthTokenOptions)) (string, error) {
_, port := validateURL(endpoint)
if port == "" {
return "", fmt.Errorf("the provided endpoint is missing a port, or the provided port is invalid")
}

o := BuildAuthTokenOptions{}

for _, fn := range optFns {
Expand Down Expand Up @@ -94,3 +100,26 @@ func BuildAuthToken(ctx context.Context, endpoint, region, dbUser string, creds

return url, nil
}

func validateURL(hostPort string) (host, port string) {
colon := strings.LastIndexByte(hostPort, ':')
if colon != -1 {
host, port = hostPort[:colon], hostPort[colon+1:]
}
if !validatePort(port) {
port = ""
return
}
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
host = host[1 : len(host)-1]
}

return
}

func validatePort(port string) bool {
if _, err := strconv.Atoi(port); err == nil {
return true
}
return false
}
45 changes: 35 additions & 10 deletions feature/rds/auth/connect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package auth_test
import (
"context"
"regexp"
"strings"
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
Expand All @@ -15,27 +16,51 @@ func TestBuildAuthToken(t *testing.T) {
region string
user string
expectedRegex string
expectedError string
}{
{
"https://prod-instance.us-east-1.rds.amazonaws.com:3306",
"us-west-2",
"mysqlUser",
`^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306\?Action=connect.*?DBUser=mysqlUser.*`,
endpoint: "https://prod-instance.us-east-1.rds.amazonaws.com:3306",
region: "us-west-2",
user: "mysqlUser",
expectedRegex: `^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306\?Action=connect.*?DBUser=mysqlUser.*`,
},
{
"prod-instance.us-east-1.rds.amazonaws.com:3306",
"us-west-2",
"mysqlUser",
`^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306\?Action=connect.*?DBUser=mysqlUser.*`,
endpoint: "prod-instance.us-east-1.rds.amazonaws.com:3306",
region: "us-west-2",
user: "mysqlUser",
expectedRegex: `^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306\?Action=connect.*?DBUser=mysqlUser.*`,
},
{
endpoint: "prod-instance.us-east-1.rds.amazonaws.com",
region: "us-west-2",
user: "mysqlUser",
expectedError: "port",
},
{
endpoint: "prod-instance.us-east-1.rds.amazonaws.com:kakasdkasd",
region: "us-west-2",
user: "mysqlUser",
expectedError: "port",
},
}

for _, c := range cases {
creds := &staticCredentials{AccessKey: "AKID", SecretKey: "SECRET", Session: "SESSION"}
url, err := auth.BuildAuthToken(context.Background(), c.endpoint, c.region, c.user, creds)
if err != nil {
t.Errorf("expect no error, got %v", err)
if len(c.expectedError) > 0 {
if err != nil {
if !strings.Contains(err.Error(), c.expectedError) {
t.Fatalf("expect err: %v, actual err: %v", c.expectedError, err)
} else {
continue
}
} else {
t.Fatalf("expect err: %v, actual err: %v", c.expectedError, err)
}
} else if err != nil {
t.Fatalf("expect no err, got: %v", err)
}

if re, a := regexp.MustCompile(c.expectedRegex), url; !re.MatchString(a) {
t.Errorf("expect %s to match %s", re, a)
}
Expand Down

0 comments on commit 63566f0

Please sign in to comment.