Skip to content
This repository has been archived by the owner on Oct 12, 2023. It is now read-only.

Commit

Permalink
Adds support for MIC to authenticate with azure using system assigned…
Browse files Browse the repository at this point in the history
… or user assigned MSI.

Resolves the item in #261.
This PR adds the capability for MIC to look at azure.json or environment variables
to determine whether the system assigned or user assigned MSI has to be used for accessing
azure resources. The MIC requests for token based on MSI. Also contains changes in NMI to determine
if the request is originating from an MIC replicaset. If so, NMI directly generates the tokens
instead of looking up the azure assigned identity for the pod-binding match.
  • Loading branch information
kkmsft committed Jun 25, 2019
1 parent 8515541 commit fbf2a47
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 31 deletions.
22 changes: 22 additions & 0 deletions pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,28 @@ const (
activeDirectoryEndpoint = "https://login.microsoftonline.com/"
)

// GetServicePrincipalTokenFromMSI return the token for the assigned user
func GetServicePrincipalTokenFromMSI(resource string) (*adal.Token, error) {
// Get the MSI endpoint accoriding with the OS (Linux/Windows)
msiEndpoint, err := adal.GetMSIVMEndpoint()
if err != nil {
return nil, fmt.Errorf("Failed to get the MSI endpoint. Error: %v", err)
}
// Set up the configuration of the service principal
spt, err := adal.NewServicePrincipalTokenFromMSI(msiEndpoint, resource)
if err != nil {
return nil, fmt.Errorf("Failed to acquire a token for MSI. Error: %v", err)
}
// Evectively acqurie the token
err = spt.Refresh()
if err != nil {
return nil, err
}
token := spt.Token()

return &token, nil
}

// GetServicePrincipalTokenFromMSIWithUserAssignedID return the token for the assigned user
func GetServicePrincipalTokenFromMSIWithUserAssignedID(clientID, resource string) (*adal.Token, error) {
// Get the MSI endpoint accoriding with the OS (Linux/Windows)
Expand Down
46 changes: 37 additions & 9 deletions pkg/cloudprovider/cloudprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"os"
"path"
"regexp"
"strings"
"time"

config "github.com/Azure/aad-pod-identity/pkg/config"
Expand Down Expand Up @@ -55,6 +56,8 @@ func NewCloudProvider(configFile string) (c *Client, e error) {
azureConfig.SubscriptionID = os.Getenv("SUBSCRIPTION_ID")
azureConfig.ResourceGroupName = os.Getenv("RESOURCE_GROUP")
azureConfig.VMType = os.Getenv("VM_TYPE")
azureConfig.UseManagedIdentityExtension = strings.EqualFold(os.Getenv("USE_MSI"), "True")
azureConfig.UserAssignedIdentityID = os.Getenv("USER_ASSIGNED_MSI_CLIENTID")
}

azureEnv, err := azure.EnvironmentFromName(azureConfig.Cloud)
Expand All @@ -67,15 +70,40 @@ func NewCloudProvider(configFile string) (c *Client, e error) {
glog.Errorf("Create OAuth config error: %+v", err)
return nil, err
}
spt, err := adal.NewServicePrincipalToken(
*oauthConfig,
azureConfig.ClientID,
azureConfig.ClientSecret,
azureEnv.ResourceManagerEndpoint,
)
if err != nil {
glog.Errorf("Get service principle token error: %+v", err)
return nil, err

var spt *adal.ServicePrincipalToken
if azureConfig.UseManagedIdentityExtension {
// MSI endpoing is required for both types of MSI - system assigned and user assigned.
msiEndpoint, err := adal.GetMSIVMEndpoint()
if err != nil {
glog.Errorf("Failed to get msiEndpoint: %+v", err)
return nil, err
}
// UserAssignedIdentityID is empty, so we are going to use system assigned MSI
if azureConfig.UserAssignedIdentityID == "" {
spt, err = adal.NewServicePrincipalTokenFromMSI(msiEndpoint, azureEnv.ResourceManagerEndpoint)
if err != nil {
glog.Errorf("Get token from system assigned MSI error: %+v", err)
return nil, err
}
} else { // User assigned identity usage.
spt, err = adal.NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, azureEnv.ResourceManagerEndpoint, azureConfig.UserAssignedIdentityID)
if err != nil {
glog.Errorf("Get token from user assigned MSI error: %+v", err)
return nil, err
}
}
} else { // This is the default scenario - use service principal to get the token.
spt, err = adal.NewServicePrincipalToken(
*oauthConfig,
azureConfig.ClientID,
azureConfig.ClientSecret,
azureEnv.ResourceManagerEndpoint,
)
if err != nil {
glog.Errorf("Get service principle token error: %+v", err)
return nil, err
}
}

extClient := compute.NewVirtualMachineExtensionsClient(azureConfig.SubscriptionID)
Expand Down
18 changes: 10 additions & 8 deletions pkg/config/azureconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package config

// AzureConfig is representing /etc/kubernetes/azure.json
type AzureConfig struct {
Cloud string `json:"cloud" yaml:"cloud"`
TenantID string `json:"tenantId" yaml:"tenantId"`
ClientID string `json:"aadClientId" yaml:"aadClientId"`
ClientSecret string `json:"aadClientSecret" yaml:"aadClientSecret"`
SubscriptionID string `json:"subscriptionId" yaml:"subscriptionId"`
ResourceGroupName string `json:"resourceGroup" yaml:"resourceGroup"`
SecurityGroupName string `json:"securityGroupName" yaml:"securityGroupName"`
VMType string `json:"vmType" yaml:"vmType"`
Cloud string `json:"cloud" yaml:"cloud"`
TenantID string `json:"tenantId" yaml:"tenantId"`
ClientID string `json:"aadClientId" yaml:"aadClientId"`
ClientSecret string `json:"aadClientSecret" yaml:"aadClientSecret"`
SubscriptionID string `json:"subscriptionId" yaml:"subscriptionId"`
ResourceGroupName string `json:"resourceGroup" yaml:"resourceGroup"`
SecurityGroupName string `json:"securityGroupName" yaml:"securityGroupName"`
VMType string `json:"vmType" yaml:"vmType"`
UseManagedIdentityExtension bool `json:"useManagedIdentityExtension,omitempty" yaml:"useManagedIdentityExtension,omitempty"`
UserAssignedIdentityID string `json:"userAssignedIdentityID,omitempty" yaml:"userAssignedIdentityID,omitempty"`
}
21 changes: 13 additions & 8 deletions pkg/k8s/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ func getPodPhaseFilter() string {

// Client api client
type Client interface {
// GetPodName return the matching azure identity or nil
GetPodName(podip string) (podns, podname string, err error)
// GetPodInfo returns the pod name, namespace & deployment for a given pod ip
GetPodInfo(podip string) (podns, podname, deployment string, err error)
// ListPodIds pod matching azure identity or nil
ListPodIds(podns, podname string) (*[]aadpodid.AzureIdentity, error)
// GetSecret returns secret the secretRef represents
Expand Down Expand Up @@ -81,23 +81,28 @@ func NewKubeClient() (Client, error) {
return kubeClient, nil
}

// GetPodName get pod ns,name from apiserver
func (c *KubeClient) GetPodName(podip string) (podns, poddname string, err error) {
// GetPodInfo get pod ns,name from apiserver
func (c *KubeClient) GetPodInfo(podip string) (podns, poddname, deployment string, err error) {
if podip == "" {
return "", "", fmt.Errorf("podip is empty")
return "", "", "", fmt.Errorf("podip is empty")
}

podList, err := c.getPodListWithTries(podip, getPodListTries, getPodListSleepTimeMilliseconds)

if err != nil {
return "", "", err
return "", "", "", err
}
numMatching := len(podList.Items)
if numMatching == 1 {
return podList.Items[0].Namespace, podList.Items[0].Name, nil
// TODO: Filter out the deployment owner references.
deployment := ""
if podList.Items[0].OwnerReferences[0].Kind == "ReplicaSet" {
deployment = podList.Items[0].OwnerReferences[0].Name
}
return podList.Items[0].Namespace, podList.Items[0].Name, deployment, nil
}

return "", "", fmt.Errorf("match failed, ip:%s matching pods:%v", podip, podList)
return "", "", "", fmt.Errorf("match failed, ip:%s matching pods:%v", podip, podList)
}

func (c *KubeClient) getPodListWithTries(podip string, tries int, sleeptime time.Duration) (*v1.PodList, error) {
Expand Down
8 changes: 4 additions & 4 deletions pkg/k8s/fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package k8s

import (
aadpodid "github.com/Azure/aad-pod-identity/pkg/apis/aadpodidentity/v1"
"k8s.io/api/core/v1"
v1 "k8s.io/api/core/v1"
)

// FakeClient implements Interface
Expand All @@ -17,9 +17,9 @@ func NewFakeClient() (Client, error) {
return fakeClient, nil
}

// GetPodName returns fake pod name
func (c *FakeClient) GetPodName(podip string) (podns, podname string, err error) {
return "ns", "podname", nil
// GetPodInfo returns fake pod name
func (c *FakeClient) GetPodInfo(podip string) (podns, podname, deployment string, err error) {
return "ns", "podname", "deployment", nil
}

// ListPodIds for pod
Expand Down
38 changes: 36 additions & 2 deletions pkg/nmi/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/http"
"os"
"os/signal"
"regexp"
"runtime"
"runtime/debug"
"strings"
Expand Down Expand Up @@ -226,20 +227,53 @@ func (s *Server) msiHandler(logger *log.Entry, w http.ResponseWriter, r *http.Re
http.Error(w, msg, http.StatusInternalServerError)
return
}
podns, podname, err := s.KubeClient.GetPodName(podIP)
podns, podname, deployment, err := s.KubeClient.GetPodInfo(podIP)
if err != nil {
logger.Errorf("missing podname for podip:%s, %+v", podIP, err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

// TODO: make it generic for whitelisting of applications
micRegEx := regexp.MustCompile(`^mic-*`)
micMatch := micRegEx.MatchString(deployment)

// Request id and request resource extraction are common steps required for
// requests from mic as well as other applications.
rqClientID, rqResource := parseRequestClientIDAndResource(r)

// If its mic, then just directly get the token and pass back.
if micMatch {
var token *adal.Token
// UserAssignedIdentity clientID is empty, so we are going to use system assigned MSI
if rqClientID == "" {
token, err = auth.GetServicePrincipalTokenFromMSI(rqResource)
} else { // User assigned identity usage.
token, err = auth.GetServicePrincipalTokenFromMSIWithUserAssignedID(rqClientID, rqResource)
}
if err != nil {
logger.Errorf("failed to get service principal token for pod:%s/%s, %+v", podns, podname, err)
http.Error(w, err.Error(), http.StatusForbidden)
return
}
response, err := json.Marshal(*token)
if err != nil {
logger.Errorf("failed to marshal service principal token for pod:%s/%s, %+v", podns, podname, err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Write(response)
return
}

podIDs, err := s.KubeClient.ListPodIds(podns, podname)
if err != nil || len(*podIDs) == 0 {
msg := fmt.Sprintf("no AzureAssignedIdentity found for pod:%s/%s", podns, podname)
logger.Errorf("%s, %+v", msg, err)
http.Error(w, msg, http.StatusForbidden)
return
}
rqClientID, rqResource := parseRequestClientIDAndResource(r)

token, _, err := getTokenForMatchingID(s.KubeClient, logger, rqClientID, rqResource, podIDs)
if err != nil {
logger.Errorf("failed to get service principal token for pod:%s/%s, %+v", podns, podname, err)
Expand Down

0 comments on commit fbf2a47

Please sign in to comment.