Skip to content

perf: Add Get() on InstanceType provider #8118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions pkg/apis/v1/ec2nodeclass_status.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package v1

import (
"github.com/awslabs/operatorpkg/status"
"github.com/samber/lo"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
Expand Down Expand Up @@ -144,3 +145,15 @@ func (in *EC2NodeClass) GetConditions() []status.Condition {
func (in *EC2NodeClass) SetConditions(conditions []status.Condition) {
in.Status.Conditions = conditions
}

func (in *EC2NodeClass) ZoneIDMap() map[string]string {
return lo.SliceToMap(in.Status.Subnets, func(s Subnet) (string, string) {
return s.Zone, s.ZoneID
})
}

func (in *EC2NodeClass) Zones() []string {
return lo.Map(in.Status.Subnets, func(s Subnet, _ int) string {
return s.Zone
})
}
37 changes: 24 additions & 13 deletions pkg/cloudprovider/cloudprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ func (c *CloudProvider) Create(ctx context.Context, nodeClaim *karpv1.NodeClaim)
if errors.IsNotFound(err) {
// We treat a failure to resolve the NodeClass as an ICE since this means there is no capacity possibilities for this NodeClaim
c.recorder.Publish(cloudproviderevents.NodeClaimFailedToResolveNodeClass(nodeClaim))
return nil, cloudprovider.NewInsufficientCapacityError(fmt.Errorf("resolving node class, %w", err))
return nil, cloudprovider.NewInsufficientCapacityError(fmt.Errorf("resolving nodeclass, %w", err))
}
// Transient error when resolving the NodeClass
return nil, fmt.Errorf("resolving node class, %w", err)
return nil, fmt.Errorf("resolving nodeclass, %w", err)
}

nodeClassReady := nodeClass.StatusConditions().Get(status.ConditionReady)
Expand Down Expand Up @@ -142,16 +142,16 @@ func (c *CloudProvider) List(ctx context.Context) ([]*karpv1.NodeClaim, error) {
return nil, fmt.Errorf("listing instances, %w", err)
}
var nodeClaims []*karpv1.NodeClaim
for _, instance := range instances {
instanceType, err := c.resolveInstanceTypeFromInstance(ctx, instance)
for _, it := range instances {
instanceType, err := c.resolveInstanceTypeFromInstance(ctx, it)
if err != nil {
return nil, fmt.Errorf("resolving instance type, %w", err)
}
nc, err := c.resolveNodeClassFromInstance(ctx, instance)
nc, err := c.resolveNodeClassFromInstance(ctx, it)
if client.IgnoreNotFound(err) != nil {
return nil, fmt.Errorf("resolving nodeclass, %w", err)
}
nodeClaims = append(nodeClaims, c.instanceToNodeClaim(instance, instanceType, nc))
nodeClaims = append(nodeClaims, c.instanceToNodeClaim(it, instanceType, nc))
}
return nodeClaims, nil
}
Expand Down Expand Up @@ -186,7 +186,7 @@ func (c *CloudProvider) GetInstanceTypes(ctx context.Context, nodePool *karpv1.N
c.recorder.Publish(cloudproviderevents.NodePoolFailedToResolveNodeClass(nodePool))
return nil, nil
}
return nil, fmt.Errorf("resolving node class, %w", err)
return nil, fmt.Errorf("resolving nodeclass, %w", err)
}
// TODO, break this coupling
instanceTypes, err := c.instanceTypeProvider.List(ctx, nodeClass)
Expand All @@ -196,6 +196,20 @@ func (c *CloudProvider) GetInstanceTypes(ctx context.Context, nodePool *karpv1.N
return instanceTypes, nil
}

// getInstanceType returns a specific instance type to avoid re-constructing all InstanceTypes
func (c *CloudProvider) getInstanceType(ctx context.Context, nodePool *karpv1.NodePool, name ec2types.InstanceType) (*cloudprovider.InstanceType, error) {
nodeClass, err := c.resolveNodeClassFromNodePool(ctx, nodePool)
if err != nil {
if errors.IsNotFound(err) {
// If we can't resolve the NodeClass, then it's impossible for us to resolve the instance types
c.recorder.Publish(cloudproviderevents.NodePoolFailedToResolveNodeClass(nodePool))
return nil, nil
}
return nil, fmt.Errorf("resolving nodeclass, %w", err)
}
return c.instanceTypeProvider.Get(ctx, nodeClass, name)
}

func (c *CloudProvider) Delete(ctx context.Context, nodeClaim *karpv1.NodeClaim) error {
id, err := utils.ParseInstanceID(nodeClaim.Status.ProviderID)
if err != nil {
Expand Down Expand Up @@ -233,7 +247,7 @@ func (c *CloudProvider) IsDrifted(ctx context.Context, nodeClaim *karpv1.NodeCla
c.recorder.Publish(cloudproviderevents.NodePoolFailedToResolveNodeClass(nodePool))
return "", nil
}
return "", fmt.Errorf("resolving node class, %w", err)
return "", fmt.Errorf("resolving nodeclass, %w", err)
}
driftReason, err := c.isNodeClassDrifted(ctx, nodeClaim, nodePool, nodeClass)
if err != nil {
Expand Down Expand Up @@ -350,14 +364,11 @@ func (c *CloudProvider) resolveInstanceTypeFromInstance(ctx context.Context, ins
// If we can't resolve the NodePool, we fall back to not getting instance type info
return nil, client.IgnoreNotFound(fmt.Errorf("resolving nodepool, %w", err))
}
instanceTypes, err := c.GetInstanceTypes(ctx, nodePool)
instanceType, err := c.getInstanceType(ctx, nodePool, instance.Type)
if err != nil {
// If we can't resolve the NodePool, we fall back to not getting instance type info
return nil, client.IgnoreNotFound(fmt.Errorf("resolving nodeclass, %w", err))
return nil, client.IgnoreNotFound(fmt.Errorf("resolving instance type, %w", err))
}
instanceType, _ := lo.Find(instanceTypes, func(i *cloudprovider.InstanceType) bool {
return i.Name == string(instance.Type)
})
return instanceType, nil
}

Expand Down
6 changes: 4 additions & 2 deletions pkg/controllers/providers/instancetype/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,10 @@ var _ = Describe("InstanceType", func() {
},
})
Expect(err).To(BeNil())
for i := range instanceTypes {
Expect(instanceTypes[i].Name).To(Equal(string(ec2InstanceTypes[i].InstanceType)))

Expect(instanceTypes).To(HaveLen(len(ec2InstanceTypes)))
for _, it := range instanceTypes {
Expect(lo.ContainsBy(ec2InstanceTypes, func(i ec2types.InstanceTypeInfo) bool { return string(i.InstanceType) == it.Name })).To(BeTrue())
}
})
It("should update instance type offering date with response from the DescribeInstanceTypesOfferings API", func() {
Expand Down
135 changes: 86 additions & 49 deletions pkg/providers/instancetype/instancetype.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ import (
)

type Provider interface {
Get(context.Context, *v1.EC2NodeClass, ec2types.InstanceType) (*cloudprovider.InstanceType, error)
List(context.Context, *v1.EC2NodeClass) ([]*cloudprovider.InstanceType, error)
}

Expand All @@ -65,10 +66,10 @@ type DefaultProvider struct {

muInstanceTypesInfo sync.RWMutex
// TODO @engedaam: Look into only storing the needed EC2InstanceTypeInfo
instanceTypesInfo []ec2types.InstanceTypeInfo
instanceTypesInfo map[ec2types.InstanceType]ec2types.InstanceTypeInfo

muInstanceTypesOfferings sync.RWMutex
instanceTypesOfferings map[string]sets.Set[string]
instanceTypesOfferings map[ec2types.InstanceType]sets.Set[string]
allZones sets.Set[string]

instanceTypesCache *cache.Cache
Expand Down Expand Up @@ -96,8 +97,8 @@ func NewDefaultProvider(
return &DefaultProvider{
ec2api: ec2api,
subnetProvider: subnetProvider,
instanceTypesInfo: []ec2types.InstanceTypeInfo{},
instanceTypesOfferings: map[string]sets.Set[string]{},
instanceTypesInfo: map[ec2types.InstanceType]ec2types.InstanceTypeInfo{},
instanceTypesOfferings: map[ec2types.InstanceType]sets.Set[string]{},
instanceTypesResolver: instanceTypesResolver,
instanceTypesCache: instanceTypesCache,
discoveredCapacityCache: discoveredCapacityCache,
Expand Down Expand Up @@ -129,28 +130,21 @@ func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1.EC2NodeClass)
return nil, fmt.Errorf("no subnets found")
}

subnetZones := sets.New(lo.Map(nodeClass.Status.Subnets, func(s v1.Subnet, _ int) string {
return lo.FromPtr(&s.Zone)
})...)

// Compute fully initialized instance types hash key
subnetZonesHash, _ := hashstructure.Hash(subnetZones, hashstructure.FormatV2, &hashstructure.HashOptions{SlicesAsSets: true})
// Compute hash key against node class AMIs (used to force cache rebuild when AMIs change)
amiHash, _ := hashstructure.Hash(nodeClass.Status.AMIs, hashstructure.FormatV2, &hashstructure.HashOptions{SlicesAsSets: true})
key := fmt.Sprintf("%d-%d-%016x-%016x-%016x",
p.instanceTypesSeqNum,
p.instanceTypesOfferingsSeqNum,
amiHash,
subnetZonesHash,
p.instanceTypesResolver.CacheKey(nodeClass),
)
key := p.cacheKey(nodeClass)
var instanceTypes []*cloudprovider.InstanceType
if item, ok := p.instanceTypesCache.Get(key); ok {
// Ensure what's returned from this function is a shallow-copy of the slice (not a deep-copy of the data itself)
// so that modifications to the ordering of the data don't affect the original
instanceTypes = item.([]*cloudprovider.InstanceType)
} else {
instanceTypes = p.resolveInstanceTypes(ctx, nodeClass, amiHash)
zonesToZoneIDs := nodeClass.ZoneIDMap()
instanceTypes = lo.FilterMapToSlice(p.instanceTypesInfo, func(name ec2types.InstanceType, info ec2types.InstanceTypeInfo) (*cloudprovider.InstanceType, bool) {
it, err := p.get(ctx, nodeClass, zonesToZoneIDs, name)
if err != nil {
return nil, false
}
return it, true
})
p.instanceTypesCache.SetDefault(key, instanceTypes)
}
// Offerings aren't cached along with the rest of the instance type info because reserved offerings need to have up to
Expand All @@ -165,30 +159,71 @@ func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1.EC2NodeClass)
), nil
}

func (p *DefaultProvider) resolveInstanceTypes(
ctx context.Context,
nodeClass *v1.EC2NodeClass,
amiHash uint64,
) []*cloudprovider.InstanceType {
zonesToZoneIDs := lo.SliceToMap(nodeClass.Status.Subnets, func(s v1.Subnet) (string, string) {
return s.Zone, s.ZoneID
})
return lo.FilterMap(p.instanceTypesInfo, func(info ec2types.InstanceTypeInfo, _ int) (*cloudprovider.InstanceType, bool) {
it := p.instanceTypesResolver.Resolve(ctx, info, p.instanceTypesOfferings[string(info.InstanceType)].UnsortedList(), zonesToZoneIDs, nodeClass)
if it == nil {
return nil, false
}
if cached, ok := p.discoveredCapacityCache.Get(fmt.Sprintf("%s-%016x", it.Name, amiHash)); ok {
it.Capacity[corev1.ResourceMemory] = cached.(resource.Quantity)
}
InstanceTypeVCPU.Set(float64(lo.FromPtr(info.VCpuInfo.DefaultVCpus)), map[string]string{
instanceTypeLabel: string(info.InstanceType),
})
InstanceTypeMemory.Set(float64(lo.FromPtr(info.MemoryInfo.SizeInMiB)*1024*1024), map[string]string{
instanceTypeLabel: string(info.InstanceType),
func (p *DefaultProvider) Get(ctx context.Context, nodeClass *v1.EC2NodeClass, name ec2types.InstanceType) (*cloudprovider.InstanceType, error) {
p.muInstanceTypesInfo.RLock()
p.muInstanceTypesOfferings.RLock()
defer p.muInstanceTypesInfo.RUnlock()
defer p.muInstanceTypesOfferings.RUnlock()

if len(p.instanceTypesInfo) == 0 {
return nil, fmt.Errorf("no instance types found")
}
if len(p.instanceTypesOfferings) == 0 {
return nil, fmt.Errorf("no instance types offerings found")
}
if len(nodeClass.Status.Subnets) == 0 {
return nil, fmt.Errorf("no subnets found")
}
var instanceType *cloudprovider.InstanceType
if item, ok := p.instanceTypesCache.Get(p.cacheKey(nodeClass)); ok {
instanceType, _ = lo.Find(item.([]*cloudprovider.InstanceType), func(i *cloudprovider.InstanceType) bool {
return ec2types.InstanceType(i.Name) == name
})
return it, true
}
if instanceType == nil {
var err error
instanceType, err = p.get(ctx, nodeClass, nodeClass.ZoneIDMap(), name)
if err != nil {
return nil, err
}
}
return p.offeringProvider.InjectOfferings(ctx, []*cloudprovider.InstanceType{instanceType}, nodeClass, p.allZones)[0], nil
}

func (p *DefaultProvider) get(ctx context.Context, nodeClass *v1.EC2NodeClass, zoneIDMap map[string]string, name ec2types.InstanceType) (*cloudprovider.InstanceType, error) {
info, ok := p.instanceTypesInfo[name]
if !ok {
return nil, fmt.Errorf("instance type %s not found in cache", name)
}
it := p.instanceTypesResolver.Resolve(ctx, info, p.instanceTypesOfferings[info.InstanceType].UnsortedList(), zoneIDMap, nodeClass)
if it == nil {
return nil, fmt.Errorf("failed to generate instance type %s", name)
}
amiHash, _ := hashstructure.Hash(nodeClass.Status.AMIs, hashstructure.FormatV2, &hashstructure.HashOptions{SlicesAsSets: true})
if cached, ok := p.discoveredCapacityCache.Get(fmt.Sprintf("%s-%016x", it.Name, amiHash)); ok {
it.Capacity[corev1.ResourceMemory] = cached.(resource.Quantity)
}
InstanceTypeVCPU.Set(float64(lo.FromPtr(info.VCpuInfo.DefaultVCpus)), map[string]string{
instanceTypeLabel: string(info.InstanceType),
})
InstanceTypeMemory.Set(float64(lo.FromPtr(info.MemoryInfo.SizeInMiB)*1024*1024), map[string]string{
instanceTypeLabel: string(info.InstanceType),
})
return it, nil
}

func (p *DefaultProvider) cacheKey(nodeClass *v1.EC2NodeClass) string {
// Compute fully initialized instance types hash key
subnetZonesHash, _ := hashstructure.Hash(nodeClass.Zones(), hashstructure.FormatV2, &hashstructure.HashOptions{SlicesAsSets: true})
// Compute hash key against node class AMIs (used to force cache rebuild when AMIs change)
amiHash, _ := hashstructure.Hash(nodeClass.Status.AMIs, hashstructure.FormatV2, &hashstructure.HashOptions{SlicesAsSets: true})
return fmt.Sprintf("%d-%d-%016x-%016x-%016x",
p.instanceTypesSeqNum,
p.instanceTypesOfferingsSeqNum,
amiHash,
subnetZonesHash,
p.instanceTypesResolver.CacheKey(nodeClass),
)
}

func (p *DefaultProvider) UpdateInstanceTypes(ctx context.Context) error {
Expand Down Expand Up @@ -225,7 +260,9 @@ func (p *DefaultProvider) UpdateInstanceTypes(ctx context.Context) error {
atomic.AddUint64(&p.instanceTypesSeqNum, 1)
log.FromContext(ctx).WithValues("count", len(instanceTypes)).V(1).Info("discovered instance types")
}
p.instanceTypesInfo = instanceTypes
p.instanceTypesInfo = lo.SliceToMap(instanceTypes, func(i ec2types.InstanceTypeInfo) (ec2types.InstanceType, ec2types.InstanceTypeInfo) {
return i.InstanceType, i
})
return nil
}

Expand All @@ -239,7 +276,7 @@ func (p *DefaultProvider) UpdateInstanceTypeOfferings(ctx context.Context) error
defer p.muInstanceTypesOfferings.Unlock()

// Get offerings from EC2
instanceTypeOfferings := map[string]sets.Set[string]{}
instanceTypeOfferings := map[ec2types.InstanceType]sets.Set[string]{}

paginator := ec2.NewDescribeInstanceTypeOfferingsPaginator(p.ec2api, &ec2.DescribeInstanceTypeOfferingsInput{
LocationType: ec2types.LocationTypeAvailabilityZone,
Expand All @@ -252,10 +289,10 @@ func (p *DefaultProvider) UpdateInstanceTypeOfferings(ctx context.Context) error
}

for _, offering := range page.InstanceTypeOfferings {
if _, ok := instanceTypeOfferings[string(offering.InstanceType)]; !ok {
instanceTypeOfferings[string(offering.InstanceType)] = sets.New[string]()
if _, ok := instanceTypeOfferings[offering.InstanceType]; !ok {
instanceTypeOfferings[offering.InstanceType] = sets.New[string]()
}
instanceTypeOfferings[string(offering.InstanceType)].Insert(lo.FromPtr(offering.Location))
instanceTypeOfferings[offering.InstanceType].Insert(lo.FromPtr(offering.Location))
}
}

Expand Down Expand Up @@ -307,8 +344,8 @@ func (p *DefaultProvider) UpdateInstanceTypeCapacityFromNode(ctx context.Context
}

func (p *DefaultProvider) Reset() {
p.instanceTypesInfo = []ec2types.InstanceTypeInfo{}
p.instanceTypesOfferings = map[string]sets.Set[string]{}
p.instanceTypesInfo = map[ec2types.InstanceType]ec2types.InstanceTypeInfo{}
p.instanceTypesOfferings = map[ec2types.InstanceType]sets.Set[string]{}
p.instanceTypesCache.Flush()
p.discoveredCapacityCache.Flush()
}
Loading