Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding MSI Login Example #241

Merged
merged 2 commits into from
Mar 23, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 243 additions & 0 deletions azblob/zt_examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ import (
"encoding/base64"
"encoding/binary"
"fmt"
"github.com/Azure/go-autorest/autorest/adal"
"io"
"log"
"net"
"net/http"
"net/url"
"os"
"reflect"
"strings"
"time"

Expand Down Expand Up @@ -1302,3 +1304,244 @@ func ExampleListBlobsHierarchy() {
}
}
}

//// ===========================================================================================
//type HTTPResponseExtension struct {
// *http.Response
//}
//
//// IsSuccessStatusCode checks if response's status code is contained in specified success status codes.
//func (r HTTPResponseExtension) IsSuccessStatusCode(successStatusCodes ...int) bool {
// if r.Response == nil {
// return false
// }
// for _, i := range successStatusCodes {
// if i == r.StatusCode {
// return true
// }
// }
// return false
//}
//
//type ByteSlice []byte
//type ByteSliceExtension struct {
// ByteSlice
//}
//
//// RemoveBOM removes any BOM from the byte slice
//func (bs ByteSliceExtension) RemoveBOM() []byte {
// if bs.ByteSlice == nil {
// return nil
// }
// // UTF8
// return bytes.TrimPrefix(bs.ByteSlice, []byte("\xef\xbb\xbf"))
//}
//
//// Resource used in azure storage OAuth authentication
//const (
// Resource = "https://storage.azure.com"
// DefaultTenantID = "common"
// DefaultActiveDirectoryEndpoint = "https://login.microsoftonline.com"
// IMDSAPIVersion = "2018-02-01"
// MSIEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
//)
//
//func goSDKHTTPClient() *http.Client {
// return &http.Client{
// Transport: &http.Transport{
// Proxy: nil,
// // We use Dial instead of DialContext as DialContext has been reported to cause slower performance.
// Dial /*Context*/ : (&net.Dialer{
// Timeout: 30 * time.Second,
// KeepAlive: 30 * time.Second,
// DualStack: true,
// }).Dial, /*Context*/
// MaxIdleConns: 0, // No limit
// MaxIdleConnsPerHost: 1000,
// IdleConnTimeout: 180 * time.Second,
// TLSHandshakeTimeout: 10 * time.Second,
// ExpectContinueTimeout: 1 * time.Second,
// DisableKeepAlives: false,
// DisableCompression: true,
// MaxResponseHeaderBytes: 0,
// },
// }
//}
//
//func getNewTokenFromMSI(ctx context.Context, clientID, objectID, resourceID string) (*adal.Token, error) {
// // Prepare request to get token from Azure Instance Metadata Service identity endpoint.
// req, err := http.NewRequest("GET", MSIEndpoint, nil)
// if err != nil {
// return nil, fmt.Errorf("failed to create request, %v", err)
// }
// params := req.URL.Query()
// params.Set("resource", Resource)
// params.Set("api-version", IMDSAPIVersion)
// if clientID != "" {
// params.Set("client_id", clientID)
// }
// if objectID != "" {
// params.Set("object_id", objectID)
// }
// if resourceID != "" {
// params.Set("msi_res_id", resourceID)
// }
// req.URL.RawQuery = params.Encode()
// req.Header.Set("Metadata", "true")
// // Set context.
// req.WithContext(ctx)
//
// // Send request
// var msiTokenHTTPClient = goSDKHTTPClient()
// resp, err := msiTokenHTTPClient.Do(req)
// if err != nil {
// return nil, fmt.Errorf("please check whether MSI is enabled on this PC, to enable MSI please refer to https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/qs-configure-portal-windows-vm#enable-system-assigned-identity-on-an-existing-vm. (Error details: %v)", err)
// }
// defer func() { // resp and Body should not be nil
// io.Copy(ioutil.Discard, resp.Body)
// resp.Body.Close()
// }()
//
// // Check if the status code indicates success
// // The request returns 200 currently, add 201 and 202 as well for possible extension.
// if !(HTTPResponseExtension{Response: resp}).IsSuccessStatusCode(http.StatusOK, http.StatusCreated, http.StatusAccepted) {
// return nil, fmt.Errorf("failed to get token from msi, status code: %v", resp.StatusCode)
// }
//
// b, err := ioutil.ReadAll(resp.Body)
// if err != nil {
// return nil, err
// }
//
// result := &adal.Token{}
// if len(b) > 0 {
// b = ByteSliceExtension{ByteSlice: b}.RemoveBOM()
// if err := json.Unmarshal(b, result); err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commented code should be removed, right?

// return nil, fmt.Errorf("failed to unmarshal response body, %v", err)
// }
// } else {
// return nil, errors.New("failed to get token from msi")
// }
//
// return result, nil
//}
// //==================================================================================================================================

func fetchMSIToken(applicationID string, identityResourceID string, resource string, callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
// Both application id and identityResourceId cannot be present at the same time.
if applicationID != "" && identityResourceID != "" {
return nil, fmt.Errorf("didn't expect applicationID and identityResourceID at same time")
}

// msiEndpoint is the well known endpoint for getting MSI authentications tokens
// msiEndpoint := "http://169.254.169.254/metadata/identity/oauth2/token" for production Jobs
msiEndpoint, _ := adal.GetMSIVMEndpoint()

var spt *adal.ServicePrincipalToken
var err error

// both can be empty, systemAssignedMSI scenario
if applicationID == "" && identityResourceID == "" {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

code cleanliness: else ifs would be preferable here

spt, err = adal.NewServicePrincipalTokenFromMSI(msiEndpoint, resource, callbacks...)
}

// msi login with clientID
if applicationID != "" {
spt, err = adal.NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource, applicationID, callbacks...)
}

// msi login with resourceID
if identityResourceID != "" {
spt, err = adal.NewServicePrincipalTokenFromMSIWithIdentityResourceID(msiEndpoint, resource, identityResourceID, callbacks...)
}

if err != nil {
return nil, err
}

return spt, spt.Refresh()
}

func getOAuthToken(applicationID, identityResourceID, resource string, callbacks ...adal.TokenRefreshCallback) (*TokenCredential, error) {
spt, err := fetchMSIToken(applicationID, identityResourceID, resource, callbacks...)
if err != nil {
log.Fatal(err)
}

// Refresh obtains a fresh token for the Service Principal.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: Service Principal

err = spt.Refresh()
if err != nil {
log.Fatal(err)
}

tc := NewTokenCredential(spt.Token().AccessToken, func(tc TokenCredential) time.Duration {
_ = spt.Refresh()
return time.Until(spt.Token().Expires())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it's best to deduct a slight buffer from the expiration so that the token credential is always valid.

})

return &tc, nil
}

func ExampleMSILogin() {
var accountName string
// Use the azure resource id of user assigned identity when creating the token.
// identityResourceID := "/subscriptions/{subscriptionID}/resourceGroups/testGroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/test-identity"
// resource := "https://resource"
var applicationID, identityResourceID, resource string
var err error

callbacks := func(token adal.Token) error { return nil }

tokenCredentials, err := getOAuthToken(applicationID, identityResourceID, resource, callbacks)
if err != nil {
log.Fatal(err)
}
// Create pipeline to handle requests
p := NewPipeline(*tokenCredentials, PipelineOptions{})
blobPrimaryURL, _ := url.Parse("https://" + accountName + ".blob.core.windows.net/")
// Generate a blob service URL
bsu := NewServiceURL(*blobPrimaryURL, p)

// Create container & upload sample data
containerName := generateContainerName()
containerURL := bsu.NewContainerURL(containerName)
_, err = containerURL.Create(ctx, Metadata{}, PublicAccessNone)
defer containerURL.Delete(ctx, ContainerAccessConditions{})
if err != nil {
log.Fatal(err)
}

// Inside the container, create a test blob with random data.
blobName := generateBlobName()
blobURL := containerURL.NewBlockBlobURL(blobName)
data := "Hello World!"
uploadResp, err := blobURL.Upload(ctx, strings.NewReader(data), BlobHTTPHeaders{ContentType: "text/plain"}, Metadata{}, BlobAccessConditions{}, DefaultAccessTier, nil, ClientProvidedKeyOptions{})
if err != nil || uploadResp.StatusCode() != 201 {
log.Fatal(err)
}

// Download data via User Delegation SAS URL; must succeed
downloadResp, err := blobURL.Download(ctx, 0, 0, BlobAccessConditions{}, false, ClientProvidedKeyOptions{})
if err != nil {
log.Fatal(err)
}
downloadedData := &bytes.Buffer{}
reader := downloadResp.Body(RetryReaderOptions{})
_, err = downloadedData.ReadFrom(reader)
if err != nil {
log.Fatal(err)
}
err = reader.Close()
if err != nil {
log.Fatal(err)
}

// Verify the content
reflect.DeepEqual(data, downloadedData)

// Delete the item using the User Delegation SAS URL; must succeed
_, err = blobURL.Delete(ctx, DeleteSnapshotsOptionInclude, BlobAccessConditions{})
if err != nil {
log.Fatal(err)
}
}