Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions cmd/infra/aws/iam.go
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,8 @@ func controlPlaneOperatorPolicy(hostedZone string, sharedVPC bool) policyBinding
"ec2:RevokeSecurityGroupIngress",
"ec2:RevokeSecurityGroupEgress",
"ec2:DescribeSecurityGroups",
"ec2:DescribeVpcs"
"ec2:DescribeVpcs",
"ec2:DescribeSubnets"
],
"Resource": "*"
}
Expand All @@ -702,7 +703,8 @@ func controlPlaneOperatorPolicy(hostedZone string, sharedVPC bool) policyBinding
"ec2:RevokeSecurityGroupIngress",
"ec2:RevokeSecurityGroupEgress",
"ec2:DescribeSecurityGroups",
"ec2:DescribeVpcs"
"ec2:DescribeVpcs",
"ec2:DescribeSubnets"
],
"Resource": "*"
},
Expand Down Expand Up @@ -776,7 +778,8 @@ func sharedVPCEndpointRole(controlPlaneRoleARN string) sharedVPCPolicyBinding {
"ec2:RevokeSecurityGroupIngress",
"ec2:RevokeSecurityGroupEgress",
"ec2:DescribeSecurityGroups",
"ec2:DescribeVpcs"
"ec2:DescribeVpcs",
"ec2:DescribeSubnets"
],
"Resource": "*"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"sort"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -202,6 +203,9 @@ type AWSEndpointServiceReconciler struct {
client.Client
upsert.CreateOrUpdateProvider
awsClientBuilder awsClientProvider

subnetAZMu sync.RWMutex
subnetAZCache map[string]string
}

// awsClientProvider abstracts AWS client creation for testability.
Expand Down Expand Up @@ -562,6 +566,93 @@ func diffIDs(desired []string, existing []string) (added, removed []string) {
return
}

// deduplicateSubnetsByAZ ensures at most one subnet per AZ is passed to
// CreateVpcEndpoint/ModifyVpcEndpoint, since AWS rejects requests with
// multiple subnets in the same AZ.
func (r *AWSEndpointServiceReconciler) deduplicateSubnetsByAZ(ctx context.Context, ec2Client awsapi.EC2API, subnetIDs []string) ([]string, error) {
if len(subnetIDs) <= 1 {
return subnetIDs, nil
}

// Read-only path: all subnets already cached, no AWS call needed.
r.subnetAZMu.RLock()
allCached := r.subnetAZCache != nil
if allCached {
for _, id := range subnetIDs {
if _, ok := r.subnetAZCache[id]; !ok {
allCached = false
break
}
}
}
if allCached {
azForSubnet := make(map[string]string, len(subnetIDs))
for _, id := range subnetIDs {
azForSubnet[id] = r.subnetAZCache[id]
}
r.subnetAZMu.RUnlock()
return pickOneSubnetPerAZ(subnetIDs, azForSubnet), nil
}
r.subnetAZMu.RUnlock()

// Write path: new subnets found, call DescribeSubnets to populate cache.
r.subnetAZMu.Lock()
if r.subnetAZCache == nil {
r.subnetAZCache = make(map[string]string)
}

// Re-check which subnets are uncached — between the RUnlock and Lock above,
// another reconcile may have already fetched them.
var uncached []string
for _, id := range subnetIDs {
if _, ok := r.subnetAZCache[id]; !ok {
uncached = append(uncached, id)
}
}

if len(uncached) > 0 {
output, err := ec2Client.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{
SubnetIds: uncached,
})
if err != nil {
r.subnetAZMu.Unlock()
return nil, fmt.Errorf("failed to describe subnets for AZ deduplication: %w", err)
}
for _, subnet := range output.Subnets {
r.subnetAZCache[aws.ToString(subnet.SubnetId)] = aws.ToString(subnet.AvailabilityZone)
}
}

azForSubnet := make(map[string]string, len(subnetIDs))
for _, id := range subnetIDs {
azForSubnet[id] = r.subnetAZCache[id]
}
r.subnetAZMu.Unlock()

return pickOneSubnetPerAZ(subnetIDs, azForSubnet), nil
}

func pickOneSubnetPerAZ(subnetIDs []string, azForSubnet map[string]string) []string {
sorted := make([]string, len(subnetIDs))
copy(sorted, subnetIDs)
sort.Strings(sorted)

azToSubnet := make(map[string]string)
for _, id := range sorted {
az := azForSubnet[id]
if _, exists := azToSubnet[az]; !exists {
azToSubnet[az] = id
}
}

deduped := make([]string, 0, len(azToSubnet))
for _, id := range azToSubnet {
deduped = append(deduped, id)
}
sort.Strings(deduped)
return deduped
}

func (r *AWSEndpointServiceReconciler) reconcileAWSEndpointService(ctx context.Context, awsEndpointService *hyperv1.AWSEndpointService, hcp *hyperv1.HostedControlPlane, ec2Client awsapi.EC2API, route53Client awsapi.ROUTE53API) error {
log, err := logr.FromContext(ctx)
if err != nil {
Expand All @@ -577,6 +668,13 @@ func (r *AWSEndpointServiceReconciler) reconcileAWSEndpointService(ctx context.C
return err
}

deduped, err := r.deduplicateSubnetsByAZ(ctx, ec2Client, awsEndpointService.Spec.SubnetIDs)
if err != nil {
log.Error(err, "failed to deduplicate subnets by AZ, proceeding with original list")
} else {
awsEndpointService.Spec.SubnetIDs = deduped
}

endpointID := awsEndpointService.Status.EndpointID
var endpointDNSEntries []ec2types.DnsEntry
if endpointID != "" {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,120 @@ func Test_diffIDs(t *testing.T) {
}
}

func Test_deduplicateSubnetsByAZ(t *testing.T) {
tests := []struct {
name string
subnetIDs []string
cachedAZs map[string]string
describeOutput *ec2v2.DescribeSubnetsOutput
describeErr error
expectDescribeCall bool
expectedSubnets []string
expectErr bool
}{
{
name: "When empty subnet list it should return empty",
subnetIDs: []string{},
expectedSubnets: []string{},
},
{
name: "When single subnet it should pass through without DescribeSubnets call",
subnetIDs: []string{"subnet-aaa"},
expectedSubnets: []string{"subnet-aaa"},
},
{
name: "When subnets in different AZs it should keep all",
subnetIDs: []string{"subnet-aaa", "subnet-bbb"},
describeOutput: &ec2v2.DescribeSubnetsOutput{
Subnets: []ec2types.Subnet{
{SubnetId: aws.String("subnet-aaa"), AvailabilityZone: aws.String("us-east-1a")},
{SubnetId: aws.String("subnet-bbb"), AvailabilityZone: aws.String("us-east-1b")},
},
},
expectDescribeCall: true,
expectedSubnets: []string{"subnet-aaa", "subnet-bbb"},
},
{
name: "When multiple subnets in same AZ it should keep lexicographically first",
subnetIDs: []string{"subnet-bbb", "subnet-aaa"},
describeOutput: &ec2v2.DescribeSubnetsOutput{
Subnets: []ec2types.Subnet{
{SubnetId: aws.String("subnet-bbb"), AvailabilityZone: aws.String("us-east-1a")},
{SubnetId: aws.String("subnet-aaa"), AvailabilityZone: aws.String("us-east-1a")},
},
},
expectDescribeCall: true,
expectedSubnets: []string{"subnet-aaa"},
},
{
name: "When mixed same-AZ and different-AZ subnets it should dedup correctly",
subnetIDs: []string{"subnet-aaa", "subnet-bbb", "subnet-ccc"},
describeOutput: &ec2v2.DescribeSubnetsOutput{
Subnets: []ec2types.Subnet{
{SubnetId: aws.String("subnet-aaa"), AvailabilityZone: aws.String("us-east-1a")},
{SubnetId: aws.String("subnet-bbb"), AvailabilityZone: aws.String("us-east-1a")},
{SubnetId: aws.String("subnet-ccc"), AvailabilityZone: aws.String("us-east-1b")},
},
},
expectDescribeCall: true,
expectedSubnets: []string{"subnet-aaa", "subnet-ccc"},
},
{
name: "When cache covers all subnets it should not call DescribeSubnets",
subnetIDs: []string{"subnet-aaa", "subnet-bbb"},
cachedAZs: map[string]string{"subnet-aaa": "us-east-1a", "subnet-bbb": "us-east-1b"},
expectDescribeCall: false,
expectedSubnets: []string{"subnet-aaa", "subnet-bbb"},
},
{
name: "When cache covers some subnets it should call DescribeSubnets only for new ones",
subnetIDs: []string{"subnet-aaa", "subnet-bbb", "subnet-ccc"},
cachedAZs: map[string]string{"subnet-aaa": "us-east-1a"},
describeOutput: &ec2v2.DescribeSubnetsOutput{
Subnets: []ec2types.Subnet{
{SubnetId: aws.String("subnet-bbb"), AvailabilityZone: aws.String("us-east-1b")},
{SubnetId: aws.String("subnet-ccc"), AvailabilityZone: aws.String("us-east-1c")},
},
},
expectDescribeCall: true,
expectedSubnets: []string{"subnet-aaa", "subnet-bbb", "subnet-ccc"},
},
{
name: "When DescribeSubnets fails it should return error",
subnetIDs: []string{"subnet-aaa", "subnet-bbb"},
describeErr: fmt.Errorf("access denied"),
expectDescribeCall: true,
expectErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := NewGomegaWithT(t)
mockCtrl := gomock.NewController(t)
mockEC2 := awsapi.NewMockEC2API(mockCtrl)

if tt.expectDescribeCall {
mockEC2.EXPECT().DescribeSubnets(gomock.Any(), gomock.Any()).Return(tt.describeOutput, tt.describeErr)
}

r := &AWSEndpointServiceReconciler{
subnetAZCache: tt.cachedAZs,
}

gotSubnets, err := r.deduplicateSubnetsByAZ(context.Background(), mockEC2, tt.subnetIDs)

if tt.expectErr {
g.Expect(err).To(HaveOccurred())
return
}

g.Expect(err).ToNot(HaveOccurred())
g.Expect(gotSubnets).To(Equal(tt.expectedSubnets))
})
}
}

func TestRecordForService(t *testing.T) {
testCases := []struct {
name string
Expand Down
Loading