Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #8168 - AWS secrets should not be exposed while running tests #8169

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pkg/repository/config/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ import (
"github.com/pkg/errors"
)

// getS3CredentialsFunc is used to make testing more convenient
var getS3CredentialsFunc = GetS3Credentials

const (
// AWS specific environment variable
awsProfileEnvVar = "AWS_PROFILE"
Expand Down Expand Up @@ -63,7 +66,7 @@ func GetS3ResticEnvVars(config map[string]string) (map[string]string, error) {
// GetS3ResticEnvVars reads the AWS config, from files and envs
// if needed assumes the role and returns the session credentials
// setting these variables emulates what would happen for example when using kube2iam
if creds, err := GetS3Credentials(config); err == nil && creds != nil {
if creds, err := getS3CredentialsFunc(config); err == nil && creds != nil {
result[awsKeyIDEnvVar] = creds.AccessKeyID
result[awsSecretKeyEnvVar] = creds.SecretAccessKey
result[awsSessTokenEnvVar] = creds.SessionToken
Expand Down
45 changes: 39 additions & 6 deletions pkg/repository/config/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,18 @@ import (

func TestGetS3ResticEnvVars(t *testing.T) {
testCases := []struct {
name string
config map[string]string
expected map[string]string
name string
config map[string]string
expected map[string]string
getS3Credentials func(config map[string]string) (*aws.Credentials, error)
}{
{
name: "when config is empty, no env vars are returned",
config: map[string]string{},
expected: map[string]string{},
getS3Credentials: func(config map[string]string) (*aws.Credentials, error) {
return nil, nil
},
},
{
name: "when config contains profile key, profile env var is set with profile value",
Expand All @@ -53,16 +57,39 @@ func TestGetS3ResticEnvVars(t *testing.T) {
expected: map[string]string{
"AWS_SHARED_CREDENTIALS_FILE": "/tmp/credentials/path/to/secret",
},
getS3Credentials: func(config map[string]string) (*aws.Credentials, error) {
return nil, nil
},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Mock GetS3Credentials
if tc.getS3Credentials != nil {
getS3CredentialsFunc = tc.getS3Credentials
} else {
getS3CredentialsFunc = GetS3Credentials
}

actual, err := GetS3ResticEnvVars(tc.config)

require.NoError(t, err)

require.Equal(t, tc.expected, actual)
// Avoid direct comparison of expected and actual to prevent exposing secrets.
// This may occur if the test doesn't set getS3Credentials func correctly.
if !reflect.DeepEqual(tc.expected, actual) {
t.Errorf("Expected and actual results do not match for test case %q", tc.name)
for key, value := range actual {
if expVal, err := tc.expected[key]; !err || expVal != value {
if actualVal, ok := actual[key]; !ok {
t.Errorf("Key %q is missing in actual result", key)
} else if expVal != actualVal {
t.Errorf("Key %q: expected value %q", key, expVal)
}
}
}
}
})
}
}
Expand Down Expand Up @@ -117,6 +144,11 @@ func TestGetS3CredentialsCorrectlyUseProfile(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Ensure env variables do not set AWS config entries
t.Setenv("AWS_ACCESS_KEY_ID", "")
t.Setenv("AWS_SECRET_ACCESS_KEY", "")
t.Setenv("AWS_SHARED_CREDENTIALS_FILE", "")

tmpFile, err := os.CreateTemp("", "velero-test-aws-credentials")
defer os.Remove(tmpFile.Name())
if err != nil {
Expand All @@ -129,17 +161,18 @@ func TestGetS3CredentialsCorrectlyUseProfile(t *testing.T) {
t.Errorf("GetS3Credentials() error = %v", err)
return
}

tt.args.config["credentialsFile"] = tmpFile.Name()
got, err := GetS3Credentials(tt.args.config)
if (err != nil) != tt.wantErr {
t.Errorf("GetS3Credentials() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got.AccessKeyID, tt.want.AccessKeyID) {
t.Errorf("GetS3Credentials() got = %v, want %v", got.AccessKeyID, tt.want.AccessKeyID)
t.Errorf("GetS3Credentials() want %v", tt.want.AccessKeyID)
}
if !reflect.DeepEqual(got.SecretAccessKey, tt.want.SecretAccessKey) {
t.Errorf("GetS3Credentials() got = %v, want %v", got.SecretAccessKey, tt.want.SecretAccessKey)
t.Errorf("GetS3Credentials() want %v", tt.want.SecretAccessKey)
}
})
}
Expand Down
Loading