Skip to content

Commit

Permalink
improve unit test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
synfinatic committed Jun 29, 2024
1 parent 6ae7292 commit b196c7c
Show file tree
Hide file tree
Showing 11 changed files with 211 additions and 24 deletions.
15 changes: 1 addition & 14 deletions cmd/aws-sso/ecs_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ package main

import (
"context"
"crypto/x509"
"encoding/pem"
"fmt"
"net"
"os"
Expand Down Expand Up @@ -122,21 +120,10 @@ func (cc *EcsCertCmd) Run(ctx *RunContext) error {
return fmt.Errorf("failed to read private key file: %w", err)
}

block, _ := pem.Decode(privateKey)
if _, err := x509.ParsePKCS8PrivateKey(block.Bytes); err != nil {
return fmt.Errorf("private key file is not a valid private key: %s", err)
}
certChain, err := os.ReadFile(ctx.Cli.Ecs.Cert.CertChain)
if err != nil {
return fmt.Errorf("failed to read certificate chain file: %w", err)
}
block, _ = pem.Decode(certChain)
if err != nil {
return fmt.Errorf("failed to decode certificate chain file: %w", err)
}
if _, err := x509.ParseCertificate(block.Bytes); err != nil {
return fmt.Errorf("certificate chain file is not a valid certificate: %w", err)
}

return ctx.Store.SaveEcsSslKeyPair(string(privateKey), string(certChain))
return ctx.Store.SaveEcsSslKeyPair(privateKey, certChain)
}
3 changes: 1 addition & 2 deletions docs/remote-ssh.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ encrypted over ssh.

**Note:** The root user or anyone with [CAP_NET_RAW or CAP_NET_ADMIN](https://man7.org/linux/man-pages/man7/capabilities.7.html)
will be able to intercept the HTTP traffic on either endpoint and obtain the bearer token
and/or any IAM Credentials stored in the ECS Server. As of this time, `aws-sso` does
[not support HTTPS](https://github.com/synfinatic/aws-sso-cli/issues/518) for full end-to-end encryption.
and/or any IAM Credentials stored in the ECS Server if you have not [enabled SSL](ecs-server.md#enable-ssl).

## On your local system

Expand Down
43 changes: 42 additions & 1 deletion internal/ecs/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -54,6 +55,11 @@ func TestNewECSClient(t *testing.T) {
assert.NotEmpty(t, c.loadSlotUrl)
assert.NotEmpty(t, c.profileUrl)
assert.NotEmpty(t, c.listUrl)

certChain, err := os.ReadFile("../server/testdata/localhost.crt")
assert.NoError(t, err)
c = NewECSClient(4144, "token", string(certChain))
assert.NotNil(t, c)
}

func TestECSClientLoadUrl(t *testing.T) {
Expand Down Expand Up @@ -206,8 +212,8 @@ func TestECSGetProfile(t *testing.T) {
defer ts.Close()

c := NewECSClient(4144, "token", "")
c.profileUrl = ts.URL
assert.NotNil(t, c)
c.profileUrl = ts.URL

lprResp, err := c.GetProfile()
assert.NoError(t, err)
Expand All @@ -231,6 +237,30 @@ func TestECSGetProfile(t *testing.T) {
assert.Error(t, err)
}

func TestECSAuthFailures(t *testing.T) {
t.Parallel()

// create mocked http server
ts := httptest.NewServer(
http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
ecs.WriteMessage(w, "Invalid authorization token", http.StatusForbidden)
},
),
)
defer ts.Close()

c := NewECSClient(4144, "token", "")
assert.NotNil(t, c)
c.profileUrl = ts.URL

_, err := c.GetProfile()
assert.Error(t, err)

_, err = c.ListProfiles()
assert.Error(t, err)
}

func TestECSListProfiles(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -300,3 +330,14 @@ func TestECSDelete(t *testing.T) {
err = c.Delete("foo")
assert.Error(t, err)
}

func TestNewHTTPClient(t *testing.T) {
t.Parallel()

cert, err := os.ReadFile("../server/testdata/localhost.crt")
assert.NoError(t, err)

c, err := NewHTTPClient(string(cert))
assert.NoError(t, err)
assert.NotNil(t, c)
}
6 changes: 6 additions & 0 deletions internal/ecs/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,12 @@ func TestBaseURL(t *testing.T) {

str := es.BaseURL()
assert.Regexp(t, regexp.MustCompile(`^http://`), str)

// check ssl
es.privateKey = "test"
es.certChain = "test"
str = es.BaseURL()
assert.Regexp(t, regexp.MustCompile(`^https://`), str)
}

func TestAuthToken(t *testing.T) {
Expand Down
13 changes: 10 additions & 3 deletions internal/storage/json_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,16 @@ func (jc *JsonStore) DeleteEcsBearerToken() error {
}

// SaveEcsSslKeyPair stores the SSL private key and certificate chain in the json file
func (jc *JsonStore) SaveEcsSslKeyPair(privateKey, certChain string) error {
jc.EcsPrivateKey = privateKey
jc.EcsCertChain = certChain
func (jc *JsonStore) SaveEcsSslKeyPair(privateKey, certChain []byte) error {
if err := ValidateSSLCertificate(certChain); err != nil {
return err
}
jc.EcsCertChain = string(certChain)

if err := ValidateSSLPrivateKey(privateKey); err != nil {
return err

Check warning on line 219 in internal/storage/json_store.go

View check run for this annotation

Codecov / codecov/patch

internal/storage/json_store.go#L219

Added line #L219 was not covered by tests
}
jc.EcsPrivateKey = string(privateKey)
return jc.save()
}

Expand Down
43 changes: 43 additions & 0 deletions internal/storage/json_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,46 @@ func (s *JsonStoreTestSuite) TestEcsBearerToken() {
assert.NoError(t, err)
assert.Empty(t, token)
}

func (s *JsonStoreTestSuite) TestEcsSslKeyPair() { // nolint: dupl
t := s.T()

cert, err := s.json.GetEcsSslCert()
assert.NoError(t, err)
assert.Empty(t, cert)

key, err := s.json.GetEcsSslKey()
assert.NoError(t, err)
assert.Empty(t, key)

certBytes, err := os.ReadFile("../ecs/server/testdata/localhost.crt")
assert.NoError(t, err)
keyBytes, err := os.ReadFile("../ecs/server/testdata/localhost.key")
assert.NoError(t, err)
err = s.json.SaveEcsSslKeyPair(keyBytes, certBytes)
assert.NoError(t, err)

err = s.json.SaveEcsSslKeyPair(certBytes, keyBytes)
assert.Error(t, err)
err = s.json.SaveEcsSslKeyPair(keyBytes, keyBytes)
assert.Error(t, err)

cert, err = s.json.GetEcsSslCert()
assert.NoError(t, err)
assert.Equal(t, string(certBytes), cert)

key, err = s.json.GetEcsSslKey()
assert.NoError(t, err)
assert.Equal(t, string(keyBytes), key)

err = s.json.DeleteEcsSslKeyPair()
assert.NoError(t, err)

cert, err = s.json.GetEcsSslCert()
assert.NoError(t, err)
assert.Empty(t, cert)

key, err = s.json.GetEcsSslKey()
assert.NoError(t, err)
assert.Empty(t, key)
}
13 changes: 10 additions & 3 deletions internal/storage/keyring.go
Original file line number Diff line number Diff line change
Expand Up @@ -454,9 +454,16 @@ func (kr *KeyringStore) DeleteEcsBearerToken() error {
}

// SaveEcsSslKeyPair stores the private key and certificate chain in the keyring
func (kr *KeyringStore) SaveEcsSslKeyPair(privateKey, certChain string) error {
kr.cache.EcsCertChain = certChain
kr.cache.EcsPrivateKey = privateKey
func (kr *KeyringStore) SaveEcsSslKeyPair(privateKey, certChain []byte) error {
if err := ValidateSSLCertificate(certChain); err != nil {
return err
}
kr.cache.EcsCertChain = string(certChain)

if err := ValidateSSLPrivateKey(privateKey); err != nil {
return err

Check warning on line 464 in internal/storage/keyring.go

View check run for this annotation

Codecov / codecov/patch

internal/storage/keyring.go#L464

Added line #L464 was not covered by tests
}
kr.cache.EcsPrivateKey = string(privateKey)
return kr.saveStorageData()
}

Expand Down
44 changes: 44 additions & 0 deletions internal/storage/keyring_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,50 @@ func (suite *KeyringSuite) TestEcsBearerToken() {
assert.Empty(t, token)
}

func (suite *KeyringSuite) TestEcsSslKeyPair() { // nolint: dupl
t := suite.T()

cert, err := suite.store.GetEcsSslCert()
assert.NoError(t, err)
assert.Empty(t, cert)

key, err := suite.store.GetEcsSslKey()
assert.NoError(t, err)
assert.Empty(t, key)

certBytes, err := os.ReadFile("../ecs/server/testdata/localhost.crt")
assert.NoError(t, err)
keyBytes, err := os.ReadFile("../ecs/server/testdata/localhost.key")
assert.NoError(t, err)
err = suite.store.SaveEcsSslKeyPair(keyBytes, certBytes)
assert.NoError(t, err)

err = suite.store.SaveEcsSslKeyPair(certBytes, keyBytes)
assert.Error(t, err)

err = suite.store.SaveEcsSslKeyPair(keyBytes, keyBytes)
assert.Error(t, err)

cert, err = suite.store.GetEcsSslCert()
assert.NoError(t, err)
assert.Equal(t, string(certBytes), cert)

key, err = suite.store.GetEcsSslKey()
assert.NoError(t, err)
assert.Equal(t, string(keyBytes), key)

err = suite.store.DeleteEcsSslKeyPair()
assert.NoError(t, err)

cert, err = suite.store.GetEcsSslCert()
assert.NoError(t, err)
assert.Empty(t, cert)

key, err = suite.store.GetEcsSslKey()
assert.NoError(t, err)
assert.Empty(t, key)
}

func (suite *KeyringSuite) TestErrorReadKeyring() {
t := suite.T()
// Read non existent key
Expand Down
2 changes: 1 addition & 1 deletion internal/storage/secure_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ type SecureStorage interface {
DeleteEcsBearerToken() error

// ECS Server SSL Cert
SaveEcsSslKeyPair(string, string) error
SaveEcsSslKeyPair([]byte, []byte) error
DeleteEcsSslKeyPair() error
GetEcsSslCert() (string, error)
GetEcsSslKey() (string, error)
Expand Down
22 changes: 22 additions & 0 deletions internal/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package storage
*/

import (
"crypto/x509"
"encoding/pem"
"fmt"
"reflect"
"time"
Expand Down Expand Up @@ -162,3 +164,23 @@ func (sc *StaticCredentials) AccountIdStr() string {
}
return s
}

// ValidateSSLCertificate ensures we have a valid SSL certificate
func ValidateSSLCertificate(certChain []byte) error {
block, _ := pem.Decode(certChain)

if _, err := x509.ParseCertificate(block.Bytes); err != nil {
return fmt.Errorf("certificate chain file is not a valid certificate: %w", err)
}
return nil
}

// ValidateSSLPrivateKey ensures we have a valid SSL private key
func ValidateSSLPrivateKey(privateKey []byte) error {
block, _ := pem.Decode(privateKey)

if _, err := x509.ParsePKCS8PrivateKey(block.Bytes); err != nil {
return fmt.Errorf("private key file is not a valid private key: %s", err)
}
return nil
}
31 changes: 31 additions & 0 deletions internal/storage/storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package storage
*/

import (
"os"
"testing"
"time"

Expand Down Expand Up @@ -169,3 +170,33 @@ func TestRoleCredentialsValidate(t *testing.T) {
k.Expiration = 0
assert.ErrorContains(t, (&k).Validate(), "expiration")
}

func TestValidateSSLCertificate(t *testing.T) {
t.Parallel()
cert, err := os.ReadFile("../ecs/server/testdata/localhost.crt")
assert.NoError(t, err)

err = ValidateSSLCertificate(cert)
assert.NoError(t, err)

cert, err = os.ReadFile("../ecs/server/testdata/localhost.key")
assert.NoError(t, err)

err = ValidateSSLCertificate(cert)
assert.Error(t, err)
}

func TestValidateSSLPrivateKey(t *testing.T) {
t.Parallel()
key, err := os.ReadFile("../ecs/server/testdata/localhost.key")
assert.NoError(t, err)

err = ValidateSSLPrivateKey(key)
assert.NoError(t, err)

key, err = os.ReadFile("../ecs/server/testdata/localhost.crt")
assert.NoError(t, err)

err = ValidateSSLPrivateKey(key)
assert.Error(t, err)
}

0 comments on commit b196c7c

Please sign in to comment.