Skip to content

Commit

Permalink
feat: Adding support for custom AMIs (#2014)
Browse files Browse the repository at this point in the history
* Adding support for custom AMIs
  • Loading branch information
suket22 authored Jul 18, 2022
1 parent 7a818e0 commit 3928e7c
Show file tree
Hide file tree
Showing 12 changed files with 463 additions and 30 deletions.
5 changes: 5 additions & 0 deletions charts/karpenter/crds/karpenter.k8s.aws_awsnodetemplates.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ spec:
amiFamily:
description: AMIFamily is the AMI family that instances use.
type: string
amiSelector:
additionalProperties:
type: string
description: AMISelector discovers AMIs to be used by Amazon EC2 tags.
type: object
apiVersion:
description: 'APIVersion defines the versioned schema of this representation
of an object. Servers should convert recognized schemas to the latest
Expand Down
3 changes: 3 additions & 0 deletions pkg/apis/awsnodetemplate/v1alpha1/awsnodetemplate.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ type AWSNodeTemplateSpec struct {
// +optional
UserData *string `json:"userData,omitempty"`
v1alpha1.AWS `json:",inline"`
// AMISelector discovers AMIs to be used by Amazon EC2 tags.
// +optional
AMISelector map[string]string `json:"amiSelector,omitempty"`
}

// AWSNodeTemplate is the Schema for the AWSNodeTemplate API
Expand Down
7 changes: 7 additions & 0 deletions pkg/apis/awsnodetemplate/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

142 changes: 137 additions & 5 deletions pkg/cloudprovider/aws/amifamily/ami.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,162 @@ import (
"fmt"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
"github.com/aws/aws-sdk-go/service/ssm"
"github.com/aws/aws-sdk-go/service/ssm/ssmiface"
"github.com/mitchellh/hashstructure/v2"
"github.com/patrickmn/go-cache"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/types"
"knative.dev/pkg/logging"
"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/samber/lo"

"github.com/aws/karpenter/pkg/apis/awsnodetemplate/v1alpha1"
"github.com/aws/karpenter/pkg/apis/provisioning/v1alpha5"
"github.com/aws/karpenter/pkg/cloudprovider"
awsv1alpha1 "github.com/aws/karpenter/pkg/cloudprovider/aws/apis/v1alpha1"
"github.com/aws/karpenter/pkg/scheduling"
"github.com/aws/karpenter/pkg/utils/functional"
"github.com/aws/karpenter/pkg/utils/sets"
)

type AMIProvider struct {
cache *cache.Cache
ssm ssmiface.SSMAPI
ssmCache *cache.Cache
ec2Cache *cache.Cache
ssm ssmiface.SSMAPI
kubeClient client.Client
ec2api ec2iface.EC2API
}

// Get returns a set of AMIIDs and corresponding instance types. AMI may vary due to architecture, accelerator, etc
func (p *AMIProvider) Get(ctx context.Context, instanceType cloudprovider.InstanceType, ssmQuery string) (string, error) {
if id, ok := p.cache.Get(ssmQuery); ok {
// If AMI overrides are specified in the AWSNodeTemplate, then only those AMIs will be chosen.
func (p *AMIProvider) Get(ctx context.Context, provider *awsv1alpha1.AWS, nodeRequest *cloudprovider.NodeRequest, options *Options, amiFamily AMIFamily) (map[string][]cloudprovider.InstanceType, error) {
amiIDs := map[string][]cloudprovider.InstanceType{}
amiRequirements, err := p.getAMIRequirements(ctx, nodeRequest.Template.ProviderRef)
if err != nil {
return nil, err
}
if len(amiRequirements) > 0 {
for _, instanceType := range nodeRequest.InstanceTypeOptions {
for amiID, requirements := range amiRequirements {
if err := instanceType.Requirements().Compatible(requirements); err == nil {
amiIDs[amiID] = append(amiIDs[amiID], instanceType)
}
}
}
if len(amiIDs) == 0 {
return nil, fmt.Errorf("no instance types satisfy requirements of amis %v,", lo.Keys(amiRequirements))
}
} else {
for _, instanceType := range nodeRequest.InstanceTypeOptions {
amiID, err := p.getDefaultAMIFromSSM(ctx, instanceType, amiFamily.SSMAlias(options.KubernetesVersion, instanceType))
if err != nil {
return nil, err
}
amiIDs[amiID] = append(amiIDs[amiID], instanceType)
}
}
return amiIDs, nil
}

func (p *AMIProvider) getDefaultAMIFromSSM(ctx context.Context, instanceType cloudprovider.InstanceType, ssmQuery string) (string, error) {
if id, ok := p.ssmCache.Get(ssmQuery); ok {
return id.(string), nil
}
output, err := p.ssm.GetParameterWithContext(ctx, &ssm.GetParameterInput{Name: aws.String(ssmQuery)})
if err != nil {
return "", fmt.Errorf("getting ssm parameter %q, %w", ssmQuery, err)
}
ami := aws.StringValue(output.Parameter.Value)
p.cache.SetDefault(ssmQuery, ami)
p.ssmCache.SetDefault(ssmQuery, ami)
logging.FromContext(ctx).Debugf("Discovered %s for query %q", ami, ssmQuery)
return ami, nil
}

func (p *AMIProvider) getAMIRequirements(ctx context.Context, providerRef *v1alpha5.ProviderRef) (map[string]scheduling.Requirements, error) {
amiRequirements := map[string]scheduling.Requirements{}
if providerRef != nil {
var ant v1alpha1.AWSNodeTemplate
if err := p.kubeClient.Get(ctx, types.NamespacedName{Name: providerRef.Name}, &ant); err != nil {
return amiRequirements, fmt.Errorf("retrieving provider reference, %w", err)
}
if len(ant.Spec.AMISelector) == 0 {
return amiRequirements, nil
}
return p.selectAMIs(ctx, ant.Spec.AMISelector)
}
return amiRequirements, nil
}

func (p *AMIProvider) selectAMIs(ctx context.Context, amiSelector map[string]string) (map[string]scheduling.Requirements, error) {
ec2AMIs, err := p.fetchAMIsFromEC2(ctx, amiSelector)
if err != nil {
return nil, err
}
if len(ec2AMIs) == 0 {
return nil, fmt.Errorf("no amis exist given constraints")
}
var amiIDs = map[string]scheduling.Requirements{}
for _, ec2AMI := range ec2AMIs {
amiIDs[*ec2AMI.ImageId] = p.getRequirementsFromImage(ec2AMI)
}
return amiIDs, nil
}

func (p *AMIProvider) fetchAMIsFromEC2(ctx context.Context, amiSelector map[string]string) ([]*ec2.Image, error) {
filters := getFilters(amiSelector)
hash, err := hashstructure.Hash(filters, hashstructure.FormatV2, &hashstructure.HashOptions{SlicesAsSets: true})
if err != nil {
return nil, err
}
if amis, ok := p.ec2Cache.Get(fmt.Sprint(hash)); ok {
return amis.([]*ec2.Image), nil
}
// This API is not paginated, so a single call suffices.
output, err := p.ec2api.DescribeImagesWithContext(ctx, &ec2.DescribeImagesInput{Filters: filters})
if err != nil {
return nil, fmt.Errorf("describing images %+v, %w", filters, err)
}
p.ec2Cache.SetDefault(fmt.Sprint(hash), output.Images)
amiIDs := lo.Map(output.Images, func(ami *ec2.Image, _ int) string { return *ami.ImageId })
logging.FromContext(ctx).Debugf("Discovered images: %s", amiIDs)
return output.Images, nil
}

func getFilters(amiSelector map[string]string) []*ec2.Filter {
filters := []*ec2.Filter{}
for key, value := range amiSelector {
if key == "aws-ids" {
filterValues := functional.SplitCommaSeparatedString(value)
filters = append(filters, &ec2.Filter{
Name: aws.String("image-id"),
Values: aws.StringSlice(filterValues),
})
} else {
filters = append(filters, &ec2.Filter{
Name: aws.String(fmt.Sprintf("tag:%s", key)),
Values: []*string{aws.String(value)},
})
}
}
return filters
}

func (p *AMIProvider) getRequirementsFromImage(ec2Image *ec2.Image) scheduling.Requirements {
requirements := scheduling.NewRequirements()
for _, tag := range ec2Image.Tags {
if v1alpha5.WellKnownLabels.Has(*tag.Key) {
requirements.Add(scheduling.Requirements{*tag.Key: sets.NewSet(*tag.Value)})
}
}
// Always add the architecture of an image as a requirement, irrespective of what's specified in EC2 tags.
architecture := *ec2Image.Architecture
if value, ok := awsv1alpha1.AWSToKubeArchitectures[architecture]; ok {
architecture = value
}
requirements.Add(scheduling.Requirements{v1.LabelArchStable: sets.NewSet(architecture)})
return requirements
}
20 changes: 10 additions & 10 deletions pkg/cloudprovider/aws/amifamily/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
"github.com/aws/aws-sdk-go/service/ssm/ssmiface"
"github.com/patrickmn/go-cache"
core "k8s.io/api/core/v1"
Expand Down Expand Up @@ -78,11 +79,14 @@ type AMIFamily interface {
}

// New constructs a new launch template Resolver
func New(ctx context.Context, ssm ssmiface.SSMAPI, c *cache.Cache, client client.Client) *Resolver {
func New(ctx context.Context, ssm ssmiface.SSMAPI, ec2api ec2iface.EC2API, ssmCache *cache.Cache, ec2Cache *cache.Cache, client client.Client) *Resolver {
return &Resolver{
amiProvider: &AMIProvider{
ssm: ssm,
cache: c,
ssm: ssm,
ssmCache: ssmCache,
ec2Cache: ec2Cache,
kubeClient: client,
ec2api: ec2api,
},
UserDataProvider: NewUserDataProvider(client),
}
Expand All @@ -96,13 +100,9 @@ func (r Resolver) Resolve(ctx context.Context, provider *v1alpha1.AWS, nodeReque
return nil, err
}
amiFamily := GetAMIFamily(provider.AMIFamily, options)
amiIDs := map[string][]cloudprovider.InstanceType{}
for _, instanceType := range nodeRequest.InstanceTypeOptions {
amiID, err := r.amiProvider.Get(ctx, instanceType, amiFamily.SSMAlias(options.KubernetesVersion, instanceType))
if err != nil {
return nil, err
}
amiIDs[amiID] = append(amiIDs[amiID], instanceType)
amiIDs, err := r.amiProvider.Get(ctx, provider, nodeRequest, options, amiFamily)
if err != nil {
return nil, err
}
var resolvedTemplates []*LaunchTemplate
for amiID, instanceTypes := range amiIDs {
Expand Down
2 changes: 1 addition & 1 deletion pkg/cloudprovider/aws/cloudprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func NewCloudProvider(ctx context.Context, options cloudprovider.Options) *Cloud
ctx,
ec2api,
options.ClientSet,
amifamily.New(ctx, ssm.New(sess), cache.New(CacheTTL, CacheCleanupInterval), options.KubeClient),
amifamily.New(ctx, ssm.New(sess), ec2api, cache.New(CacheTTL, CacheCleanupInterval), cache.New(CacheTTL, CacheCleanupInterval), options.KubeClient),
NewSecurityGroupProvider(ec2api),
getCABundle(ctx),
options.StartAsync,
Expand Down
23 changes: 23 additions & 0 deletions pkg/cloudprovider/aws/fake/ec2api.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type CapacityPool struct {
// pollute each other.
type EC2Behavior struct {
DescribeInstancesOutput AtomicPtr[ec2.DescribeInstancesOutput]
DescribeImagesOutput AtomicPtr[ec2.DescribeImagesOutput]
DescribeLaunchTemplatesOutput AtomicPtr[ec2.DescribeLaunchTemplatesOutput]
DescribeSubnetsOutput AtomicPtr[ec2.DescribeSubnetsOutput]
DescribeSecurityGroupsOutput AtomicPtr[ec2.DescribeSecurityGroupsOutput]
Expand All @@ -52,6 +53,7 @@ type EC2Behavior struct {
DescribeSpotPriceHistoryOutput AtomicPtr[ec2.DescribeSpotPriceHistoryOutput]
CalledWithCreateFleetInput AtomicPtrSlice[ec2.CreateFleetInput]
CalledWithCreateLaunchTemplateInput AtomicPtrSlice[ec2.CreateLaunchTemplateInput]
CalledWithDescribeImagesInput AtomicPtrSlice[ec2.DescribeImagesInput]
Instances sync.Map
LaunchTemplates sync.Map
InsufficientCapacityPools AtomicSlice[CapacityPool]
Expand All @@ -70,6 +72,7 @@ var DefaultSupportedUsageClasses = aws.StringSlice([]string{"on-demand", "spot"}
// each other.
func (e *EC2API) Reset() {
e.DescribeInstancesOutput.Reset()
e.DescribeImagesOutput.Reset()
e.DescribeLaunchTemplatesOutput.Reset()
e.DescribeSubnetsOutput.Reset()
e.DescribeSecurityGroupsOutput.Reset()
Expand All @@ -78,6 +81,7 @@ func (e *EC2API) Reset() {
e.DescribeAvailabilityZonesOutput.Reset()
e.CalledWithCreateFleetInput.Reset()
e.CalledWithCreateLaunchTemplateInput.Reset()
e.CalledWithDescribeImagesInput.Reset()
e.DescribeSpotPriceHistoryOutput.Reset()
e.Instances.Range(func(k, v any) bool {
e.Instances.Delete(k)
Expand Down Expand Up @@ -187,6 +191,25 @@ func (e *EC2API) DescribeInstancesWithContext(_ context.Context, input *ec2.Desc
}, nil
}

func (e *EC2API) DescribeImagesWithContext(_ context.Context, input *ec2.DescribeImagesInput, _ ...request.Option) (*ec2.DescribeImagesOutput, error) {
if !e.NextError.IsNil() {
defer e.NextError.Reset()
return nil, e.NextError.Get()
}
e.CalledWithDescribeImagesInput.Add(input)
if !e.DescribeImagesOutput.IsNil() {
return e.DescribeImagesOutput.Clone(), nil
}
return &ec2.DescribeImagesOutput{
Images: []*ec2.Image{
{
ImageId: aws.String(test.RandomName()),
Architecture: aws.String("x86_64"),
},
},
}, nil
}

func (e *EC2API) DescribeLaunchTemplatesWithContext(_ context.Context, input *ec2.DescribeLaunchTemplatesInput, _ ...request.Option) (*ec2.DescribeLaunchTemplatesOutput, error) {
if !e.NextError.IsNil() {
defer e.NextError.Reset()
Expand Down
Loading

0 comments on commit 3928e7c

Please sign in to comment.