Skip to content

Commit

Permalink
refactoring for azure arc
Browse files Browse the repository at this point in the history
  • Loading branch information
AndyOHart committed Oct 22, 2024
1 parent e881f85 commit 6c92a9e
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 76 deletions.
142 changes: 88 additions & 54 deletions apps/managedidentity/managedidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,34 +281,87 @@ func acquireIMDS(ctx context.Context, client Client, resource string, fakeAuthPa
}

func acquireAzureArc(ctx context.Context, client Client, resource string, fakeAuthParams authority.AuthParams) (base.AuthResult, error) {
req, err := createAzureArcAuthRequest(ctx, resource)
req, err := createAzureArcAuthRequest(ctx, resource, "")
if err != nil {
return base.AuthResult{}, err
}

tokenResponse, err := client.getTokenForRequest(req)
response, err := client.httpClient.Do(req)
if err != nil {
return handleAzureArcExpectedError(ctx, client, resource, fakeAuthParams, err)
return base.AuthResult{}, err
}
defer response.Body.Close()

return authResultFromToken(fakeAuthParams, tokenResponse)
}
if response.StatusCode != http.StatusUnauthorized {
return base.AuthResult{}, fmt.Errorf("expected a 401 response, received %d", response.StatusCode)
}

func handleAzureArcExpectedError(ctx context.Context, client Client, resource string, fakeAuthParams authority.AuthParams, err error) (base.AuthResult, error) {
var newCallErr errors.CallErr
secret, err := client.getAzureArcSecretKey(ctx, response, resource, runtime.GOOS)
if err != nil {
return base.AuthResult{}, err
}

if errors.As(err, &newCallErr) {
response, err := client.handleAzureArcResponse(ctx, newCallErr.Resp, resource, runtime.GOOS)
if err != nil {
return base.AuthResult{}, err
}
secondRequest, err := createAzureArcAuthRequest(ctx, resource, string(secret))
if err != nil {
return base.AuthResult{}, err
}

return authResultFromToken(fakeAuthParams, response)
secondResponse, err := client.httpClient.Do(secondRequest)
if err != nil {
return base.AuthResult{}, err
}
defer secondResponse.Body.Close()

if err != nil {
return base.AuthResult{}, err
}

responseBytes, err := io.ReadAll(secondResponse.Body)
if err != nil {
return base.AuthResult{}, fmt.Errorf("failed to read second azure arc response body: %w", err)
}

var r accesstokens.TokenResponse
err = json.Unmarshal(responseBytes, &r)
if err != nil {
return base.AuthResult{}, fmt.Errorf("failed to unmarshal second response body: %w", err)
}

return base.AuthResult{}, err
r.GrantedScopes.Slice = append(r.GrantedScopes.Slice, secondRequest.URL.Query().Get(resourceQueryParameterName))

return authResultFromToken(client.authParams, r)
// tokenResponse, err := client.getTokenForRequest(req)
// if err == nil {
// return base.AuthResult{}, fmt.Errorf("expected a 401 error response")
// }

// the endpoint is expected to return a 401 with the WWW-Authenticate header set to the location
// of the secret key file. Any other status code indicates an error in the request.
// var newCallErr errors.CallErr
// if errors.As(err, &newCallErr) {
// if newCallErr.Resp.StatusCode != http.StatusUnauthorized {
// return base.AuthResult{}, fmt.Errorf("expected a 401 response, received %d", newCallErr.Resp.StatusCode)
// }
// }

// if errors.As(err, &newCallErr) {
// response, err := client.handleAzureArcResponse(ctx, newCallErr.Resp, resource, runtime.GOOS)
// if err != nil {
// return base.AuthResult{}, err
// }

// return authResultFromToken(fakeAuthParams, response)
// }

// return base.AuthResult{}, err

// return authResultFromToken(fakeAuthParams, tokenResponse)
}

// func handleAzureArcExpectedError(ctx context.Context, client Client, resource string, fakeAuthParams authority.AuthParams, err error) (base.AuthResult, error) {

// }

func createFakeAuthParams(client Client) (authority.AuthParams, error) {
fakeAuthInfo, err := authority.NewInfoFromAuthorityURI("https://login.microsoftonline.com/managed_identity", false, true)
if err != nil {
Expand Down Expand Up @@ -407,7 +460,7 @@ func createIMDSAuthRequest(ctx context.Context, id ID, resource string) (*http.R
return req, nil
}

func createAzureArcAuthRequest(ctx context.Context, resource string) (*http.Request, error) {
func createAzureArcAuthRequest(ctx context.Context, resource string, key string) (*http.Request, error) {
identityEndpoint := azureArcEndpoint
msiEndpoint, parseErr := url.Parse(identityEndpoint)

Expand All @@ -426,6 +479,10 @@ func createAzureArcAuthRequest(ctx context.Context, resource string) (*http.Requ
return nil, fmt.Errorf("error creating http request %s", err)
}
req.Header.Set(metaHTTPHeaderName, "true")

if condition := key != ""; condition {
req.Header.Set("Authorization", fmt.Sprintf("Basic %s", key))
}
return req, nil
}

Expand All @@ -445,77 +502,54 @@ func isAzureArcEnvironment(identityEndpoint, imdsEndpoint string, platform strin
return false
}

func (c *Client) handleAzureArcResponse(ctx context.Context, response *http.Response, resource string, platform string) (accesstokens.TokenResponse, error) {
if response.StatusCode == http.StatusUnauthorized {
wwwAuthenticateHeader := response.Header.Get(wwwAuthenticateHeaderName)

if len(wwwAuthenticateHeader) == 0 {
return accesstokens.TokenResponse{}, errors.New("response has no www-authenticate header")
}

// check if the platform is supported
expectedSecretFilePath := getAzureArcPlatformPath(platform)
if expectedSecretFilePath == "" {
return accesstokens.TokenResponse{}, fmt.Errorf("platform not supported, expected linux or windows, got %s", platform)
}

secret, err := handleSecretFile(wwwAuthenticateHeader, expectedSecretFilePath)
if err != nil {
return accesstokens.TokenResponse{}, err
}

authHeaderValue := fmt.Sprintf("Basic %s", string(secret))
func (c *Client) getAzureArcSecretKey(ctx context.Context, response *http.Response, resource string, platform string) (string, error) {
wwwAuthenticateHeader := response.Header.Get(wwwAuthenticateHeaderName)

req, err := createAzureArcAuthRequest(ctx, resource)
if err != nil {
return accesstokens.TokenResponse{}, err
}

req.Header.Set("Authorization", authHeaderValue)

return c.getTokenForRequest(req)
if len(wwwAuthenticateHeader) == 0 {
return "", errors.New("response has no www-authenticate header")
}

return accesstokens.TokenResponse{}, fmt.Errorf("managed identity error: %d", response.StatusCode)
}
// check if the platform is supported
expectedSecretFilePath := getAzureArcPlatformPath(platform)
if expectedSecretFilePath == "" {
return "", fmt.Errorf("platform not supported, expected linux or windows, got %s", platform)
}

func handleSecretFile(wwwAuthenticateHeader, expectedSecretFilePath string) ([]byte, error) {
// split the header to get the secret file path
parts := strings.Split(wwwAuthenticateHeader, "Basic realm=")
if len(parts) < 2 {
return nil, fmt.Errorf("basic realm= not found in the string, instead found: %s", wwwAuthenticateHeader)
return "", fmt.Errorf("basic realm= not found in the string, instead found: %s", wwwAuthenticateHeader)
}

secretFilePath := parts

// check that the file in the file path is a .key file
fileName := filepath.Base(secretFilePath[1])
if !strings.HasSuffix(fileName, azureArcFileExtension) {
return nil, fmt.Errorf("invalid file extension, expected %s, got %s", azureArcFileExtension, filepath.Ext(fileName))
return "", fmt.Errorf("invalid file extension, expected %s, got %s", azureArcFileExtension, filepath.Ext(fileName))
}

// check that file path from header matches the expected file path for the platform
if expectedSecretFilePath != filepath.Dir(secretFilePath[1]) {
return nil, fmt.Errorf("invalid file path, expected %s, got %s", expectedSecretFilePath, filepath.Dir(secretFilePath[1]))
return "", fmt.Errorf("invalid file path, expected %s, got %s", expectedSecretFilePath, filepath.Dir(secretFilePath[1]))
}

fileInfo, err := os.Stat(secretFilePath[1])
if err != nil {
return nil, fmt.Errorf("failed to get metadata for %s due to error: %s", secretFilePath[1], err)
return "", fmt.Errorf("failed to get metadata for %s due to error: %s", secretFilePath[1], err)
}

secretFileSize := fileInfo.Size()

// Throw an error if the secret file's size is greater than 4096 bytes
if s := fileInfo.Size(); s > azureArcMaxFileSizeBytes {
return nil, fmt.Errorf("invalid secret file size, expected %d, file size was %d", azureArcMaxFileSizeBytes, secretFileSize)
return "", fmt.Errorf("invalid secret file size, expected %d, file size was %d", azureArcMaxFileSizeBytes, secretFileSize)
}

// Attempt to read the contents of the secret file
secret, err := os.ReadFile(secretFilePath[1])
if err != nil {
return nil, fmt.Errorf("failed to read %q due to error: %s", secretFilePath[1], err)
return "", fmt.Errorf("failed to read %q due to error: %s", secretFilePath[1], err)
}

return secret, nil
return string(secret), nil
}
46 changes: 24 additions & 22 deletions apps/managedidentity/managedidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ func createMockFile(t *testing.T, path string, size int64) {
t.Fatalf("failed to truncate file: %v", err)
}
}

// Write the content to the file
if _, err := f.WriteString("secret file data"); err != nil {
t.Fatalf("failed to write to file: %v", err)
}
}

func getMockFilePath(t *testing.T) (string, error) {
Expand Down Expand Up @@ -339,16 +344,15 @@ func TestAzureArcAcquireTokenReturnsTokenSuccess(t *testing.T) {
}

testCases := []struct {
source Source
endpoint string
resource string
miType ID
apiVersion string
failFirstResponse bool
source Source
endpoint string
resource string
miType ID
apiVersion string
}{
{source: AzureArc, endpoint: azureArcEndpoint, resource: resource, miType: SystemAssigned(), apiVersion: azureArcAPIVersion, failFirstResponse: false},
{source: AzureArc, endpoint: azureArcEndpoint, resource: resourceDefaultSuffix, miType: SystemAssigned(), apiVersion: azureArcAPIVersion, failFirstResponse: false},
{source: AzureArc, endpoint: azureArcEndpoint, resource: resource, miType: SystemAssigned(), apiVersion: azureArcAPIVersion, failFirstResponse: true},
{source: AzureArc, endpoint: azureArcEndpoint, resource: resource, miType: SystemAssigned(), apiVersion: azureArcAPIVersion},
{source: AzureArc, endpoint: azureArcEndpoint, resource: resourceDefaultSuffix, miType: SystemAssigned(), apiVersion: azureArcAPIVersion},
{source: AzureArc, endpoint: azureArcEndpoint, resource: resource, miType: SystemAssigned(), apiVersion: azureArcAPIVersion},
}

for _, testCase := range testCases {
Expand All @@ -364,27 +368,25 @@ func TestAzureArcAcquireTokenReturnsTokenSuccess(t *testing.T) {
t.Fatalf(errorFormingJsonResponse, err.Error())
}

if testCase.failFirstResponse {
mockFilePath := filepath.Join(testCaseFilePath, secretKey)
setCustomAzureArcPlatformPath(t, testCaseFilePath)
createMockFile(t, mockFilePath, 0)
mockFilePath := filepath.Join(testCaseFilePath, secretKey)
setCustomAzureArcPlatformPath(t, testCaseFilePath)
createMockFile(t, mockFilePath, 0)

defer os.Remove(mockFilePath)
defer os.Remove(mockFilePath)

headers := http.Header{}
headers.Add(wwwAuthenticateHeaderName, basicRealm+filepath.Join(testCaseFilePath, secretKey))
headers := http.Header{}
headers.Add(wwwAuthenticateHeaderName, basicRealm+filepath.Join(testCaseFilePath, secretKey))

mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusUnauthorized),
mock.WithHTTPHeader(headers),
mock.WithCallback(func(r *http.Request) { localUrl = r.URL }))
}
mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusUnauthorized),
mock.WithHTTPHeader(headers),
mock.WithCallback(func(r *http.Request) { localUrl = r.URL }))

responseBody, err := getSuccessfulResponse(resource)
if err != nil {
t.Fatalf(errorFormingJsonResponse, err.Error())
}

mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK), mock.WithBody(responseBody), mock.WithCallback(func(r *http.Request) {
mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusUnauthorized), mock.WithBody(responseBody), mock.WithCallback(func(r *http.Request) {
localUrl = r.URL
}))

Expand Down Expand Up @@ -830,7 +832,7 @@ func TestHandleAzureArcResponse(t *testing.T) {
tc.context = nil
}

_, err := client.handleAzureArcResponse(tc.context, response, "", tc.platform)
_, err := client.getAzureArcSecretKey(tc.context, response, "", tc.platform)

if err == nil || err.Error() != tc.expectedError {
t.Fatalf("expected error: \"%v\"\ngot error: \"%v\"", tc.expectedError, err)
Expand Down

0 comments on commit 6c92a9e

Please sign in to comment.