-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding MSI Login Example #241
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,12 +6,14 @@ import ( | |
"encoding/base64" | ||
"encoding/binary" | ||
"fmt" | ||
"github.com/Azure/go-autorest/autorest/adal" | ||
"io" | ||
"log" | ||
"net" | ||
"net/http" | ||
"net/url" | ||
"os" | ||
"reflect" | ||
"strings" | ||
"time" | ||
|
||
|
@@ -1302,3 +1304,244 @@ func ExampleListBlobsHierarchy() { | |
} | ||
} | ||
} | ||
|
||
//// =========================================================================================== | ||
//type HTTPResponseExtension struct { | ||
// *http.Response | ||
//} | ||
// | ||
//// IsSuccessStatusCode checks if response's status code is contained in specified success status codes. | ||
//func (r HTTPResponseExtension) IsSuccessStatusCode(successStatusCodes ...int) bool { | ||
// if r.Response == nil { | ||
// return false | ||
// } | ||
// for _, i := range successStatusCodes { | ||
// if i == r.StatusCode { | ||
// return true | ||
// } | ||
// } | ||
// return false | ||
//} | ||
// | ||
//type ByteSlice []byte | ||
//type ByteSliceExtension struct { | ||
// ByteSlice | ||
//} | ||
// | ||
//// RemoveBOM removes any BOM from the byte slice | ||
//func (bs ByteSliceExtension) RemoveBOM() []byte { | ||
// if bs.ByteSlice == nil { | ||
// return nil | ||
// } | ||
// // UTF8 | ||
// return bytes.TrimPrefix(bs.ByteSlice, []byte("\xef\xbb\xbf")) | ||
//} | ||
// | ||
//// Resource used in azure storage OAuth authentication | ||
//const ( | ||
// Resource = "https://storage.azure.com" | ||
// DefaultTenantID = "common" | ||
// DefaultActiveDirectoryEndpoint = "https://login.microsoftonline.com" | ||
// IMDSAPIVersion = "2018-02-01" | ||
// MSIEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token" | ||
//) | ||
// | ||
//func goSDKHTTPClient() *http.Client { | ||
// return &http.Client{ | ||
// Transport: &http.Transport{ | ||
// Proxy: nil, | ||
// // We use Dial instead of DialContext as DialContext has been reported to cause slower performance. | ||
// Dial /*Context*/ : (&net.Dialer{ | ||
// Timeout: 30 * time.Second, | ||
// KeepAlive: 30 * time.Second, | ||
// DualStack: true, | ||
// }).Dial, /*Context*/ | ||
// MaxIdleConns: 0, // No limit | ||
// MaxIdleConnsPerHost: 1000, | ||
// IdleConnTimeout: 180 * time.Second, | ||
// TLSHandshakeTimeout: 10 * time.Second, | ||
// ExpectContinueTimeout: 1 * time.Second, | ||
// DisableKeepAlives: false, | ||
// DisableCompression: true, | ||
// MaxResponseHeaderBytes: 0, | ||
// }, | ||
// } | ||
//} | ||
// | ||
//func getNewTokenFromMSI(ctx context.Context, clientID, objectID, resourceID string) (*adal.Token, error) { | ||
// // Prepare request to get token from Azure Instance Metadata Service identity endpoint. | ||
// req, err := http.NewRequest("GET", MSIEndpoint, nil) | ||
// if err != nil { | ||
// return nil, fmt.Errorf("failed to create request, %v", err) | ||
// } | ||
// params := req.URL.Query() | ||
// params.Set("resource", Resource) | ||
// params.Set("api-version", IMDSAPIVersion) | ||
// if clientID != "" { | ||
// params.Set("client_id", clientID) | ||
// } | ||
// if objectID != "" { | ||
// params.Set("object_id", objectID) | ||
// } | ||
// if resourceID != "" { | ||
// params.Set("msi_res_id", resourceID) | ||
// } | ||
// req.URL.RawQuery = params.Encode() | ||
// req.Header.Set("Metadata", "true") | ||
// // Set context. | ||
// req.WithContext(ctx) | ||
// | ||
// // Send request | ||
// var msiTokenHTTPClient = goSDKHTTPClient() | ||
// resp, err := msiTokenHTTPClient.Do(req) | ||
// if err != nil { | ||
// return nil, fmt.Errorf("please check whether MSI is enabled on this PC, to enable MSI please refer to https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/qs-configure-portal-windows-vm#enable-system-assigned-identity-on-an-existing-vm. (Error details: %v)", err) | ||
// } | ||
// defer func() { // resp and Body should not be nil | ||
// io.Copy(ioutil.Discard, resp.Body) | ||
// resp.Body.Close() | ||
// }() | ||
// | ||
// // Check if the status code indicates success | ||
// // The request returns 200 currently, add 201 and 202 as well for possible extension. | ||
// if !(HTTPResponseExtension{Response: resp}).IsSuccessStatusCode(http.StatusOK, http.StatusCreated, http.StatusAccepted) { | ||
// return nil, fmt.Errorf("failed to get token from msi, status code: %v", resp.StatusCode) | ||
// } | ||
// | ||
// b, err := ioutil.ReadAll(resp.Body) | ||
// if err != nil { | ||
// return nil, err | ||
// } | ||
// | ||
// result := &adal.Token{} | ||
// if len(b) > 0 { | ||
// b = ByteSliceExtension{ByteSlice: b}.RemoveBOM() | ||
// if err := json.Unmarshal(b, result); err != nil { | ||
// return nil, fmt.Errorf("failed to unmarshal response body, %v", err) | ||
// } | ||
// } else { | ||
// return nil, errors.New("failed to get token from msi") | ||
// } | ||
// | ||
// return result, nil | ||
//} | ||
// //================================================================================================================================== | ||
|
||
func fetchMSIToken(applicationID string, identityResourceID string, resource string, callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) { | ||
// Both application id and identityResourceId cannot be present at the same time. | ||
if applicationID != "" && identityResourceID != "" { | ||
return nil, fmt.Errorf("didn't expect applicationID and identityResourceID at same time") | ||
} | ||
|
||
// msiEndpoint is the well known endpoint for getting MSI authentications tokens | ||
// msiEndpoint := "http://169.254.169.254/metadata/identity/oauth2/token" for production Jobs | ||
msiEndpoint, _ := adal.GetMSIVMEndpoint() | ||
|
||
var spt *adal.ServicePrincipalToken | ||
var err error | ||
|
||
// both can be empty, systemAssignedMSI scenario | ||
if applicationID == "" && identityResourceID == "" { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. code cleanliness: else ifs would be preferable here |
||
spt, err = adal.NewServicePrincipalTokenFromMSI(msiEndpoint, resource, callbacks...) | ||
} | ||
|
||
// msi login with clientID | ||
if applicationID != "" { | ||
spt, err = adal.NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource, applicationID, callbacks...) | ||
} | ||
|
||
// msi login with resourceID | ||
if identityResourceID != "" { | ||
spt, err = adal.NewServicePrincipalTokenFromMSIWithIdentityResourceID(msiEndpoint, resource, identityResourceID, callbacks...) | ||
} | ||
|
||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return spt, spt.Refresh() | ||
} | ||
|
||
func getOAuthToken(applicationID, identityResourceID, resource string, callbacks ...adal.TokenRefreshCallback) (*TokenCredential, error) { | ||
spt, err := fetchMSIToken(applicationID, identityResourceID, resource, callbacks...) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
// Refresh obtains a fresh token for the Service Principal. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo: Service Principal |
||
err = spt.Refresh() | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
tc := NewTokenCredential(spt.Token().AccessToken, func(tc TokenCredential) time.Duration { | ||
_ = spt.Refresh() | ||
return time.Until(spt.Token().Expires()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps it's best to deduct a slight buffer from the expiration so that the token credential is always valid. |
||
}) | ||
|
||
return &tc, nil | ||
} | ||
|
||
func ExampleMSILogin() { | ||
var accountName string | ||
// Use the azure resource id of user assigned identity when creating the token. | ||
// identityResourceID := "/subscriptions/{subscriptionID}/resourceGroups/testGroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/test-identity" | ||
// resource := "https://resource" | ||
var applicationID, identityResourceID, resource string | ||
var err error | ||
|
||
callbacks := func(token adal.Token) error { return nil } | ||
|
||
tokenCredentials, err := getOAuthToken(applicationID, identityResourceID, resource, callbacks) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
// Create pipeline to handle requests | ||
p := NewPipeline(*tokenCredentials, PipelineOptions{}) | ||
blobPrimaryURL, _ := url.Parse("https://" + accountName + ".blob.core.windows.net/") | ||
// Generate a blob service URL | ||
bsu := NewServiceURL(*blobPrimaryURL, p) | ||
|
||
// Create container & upload sample data | ||
containerName := generateContainerName() | ||
containerURL := bsu.NewContainerURL(containerName) | ||
_, err = containerURL.Create(ctx, Metadata{}, PublicAccessNone) | ||
defer containerURL.Delete(ctx, ContainerAccessConditions{}) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
// Inside the container, create a test blob with random data. | ||
blobName := generateBlobName() | ||
blobURL := containerURL.NewBlockBlobURL(blobName) | ||
data := "Hello World!" | ||
uploadResp, err := blobURL.Upload(ctx, strings.NewReader(data), BlobHTTPHeaders{ContentType: "text/plain"}, Metadata{}, BlobAccessConditions{}, DefaultAccessTier, nil, ClientProvidedKeyOptions{}) | ||
if err != nil || uploadResp.StatusCode() != 201 { | ||
log.Fatal(err) | ||
} | ||
|
||
// Download data via User Delegation SAS URL; must succeed | ||
downloadResp, err := blobURL.Download(ctx, 0, 0, BlobAccessConditions{}, false, ClientProvidedKeyOptions{}) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
downloadedData := &bytes.Buffer{} | ||
reader := downloadResp.Body(RetryReaderOptions{}) | ||
_, err = downloadedData.ReadFrom(reader) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
err = reader.Close() | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
// Verify the content | ||
reflect.DeepEqual(data, downloadedData) | ||
|
||
// Delete the item using the User Delegation SAS URL; must succeed | ||
_, err = blobURL.Delete(ctx, DeleteSnapshotsOptionInclude, BlobAccessConditions{}) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This commented code should be removed, right?