diff --git a/cmd/infra/aws/iam.go b/cmd/infra/aws/iam.go index 2100dfa0f4bb..359156f35cc6 100644 --- a/cmd/infra/aws/iam.go +++ b/cmd/infra/aws/iam.go @@ -676,7 +676,8 @@ func controlPlaneOperatorPolicy(hostedZone string, sharedVPC bool) policyBinding "ec2:RevokeSecurityGroupIngress", "ec2:RevokeSecurityGroupEgress", "ec2:DescribeSecurityGroups", - "ec2:DescribeVpcs" + "ec2:DescribeVpcs", + "ec2:DescribeSubnets" ], "Resource": "*" } @@ -702,7 +703,8 @@ func controlPlaneOperatorPolicy(hostedZone string, sharedVPC bool) policyBinding "ec2:RevokeSecurityGroupIngress", "ec2:RevokeSecurityGroupEgress", "ec2:DescribeSecurityGroups", - "ec2:DescribeVpcs" + "ec2:DescribeVpcs", + "ec2:DescribeSubnets" ], "Resource": "*" }, @@ -776,7 +778,8 @@ func sharedVPCEndpointRole(controlPlaneRoleARN string) sharedVPCPolicyBinding { "ec2:RevokeSecurityGroupIngress", "ec2:RevokeSecurityGroupEgress", "ec2:DescribeSecurityGroups", - "ec2:DescribeVpcs" + "ec2:DescribeVpcs", + "ec2:DescribeSubnets" ], "Resource": "*" } diff --git a/control-plane-operator/controllers/awsprivatelink/awsprivatelink_controller.go b/control-plane-operator/controllers/awsprivatelink/awsprivatelink_controller.go index ed9f5f373170..41996839aa3f 100644 --- a/control-plane-operator/controllers/awsprivatelink/awsprivatelink_controller.go +++ b/control-plane-operator/controllers/awsprivatelink/awsprivatelink_controller.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "sort" "strings" "sync" "time" @@ -202,6 +203,9 @@ type AWSEndpointServiceReconciler struct { client.Client upsert.CreateOrUpdateProvider awsClientBuilder awsClientProvider + + subnetAZMu sync.RWMutex + subnetAZCache map[string]string } // awsClientProvider abstracts AWS client creation for testability. @@ -562,6 +566,93 @@ func diffIDs(desired []string, existing []string) (added, removed []string) { return } +// deduplicateSubnetsByAZ ensures at most one subnet per AZ is passed to +// CreateVpcEndpoint/ModifyVpcEndpoint, since AWS rejects requests with +// multiple subnets in the same AZ. +func (r *AWSEndpointServiceReconciler) deduplicateSubnetsByAZ(ctx context.Context, ec2Client awsapi.EC2API, subnetIDs []string) ([]string, error) { + if len(subnetIDs) <= 1 { + return subnetIDs, nil + } + + // Read-only path: all subnets already cached, no AWS call needed. + r.subnetAZMu.RLock() + allCached := r.subnetAZCache != nil + if allCached { + for _, id := range subnetIDs { + if _, ok := r.subnetAZCache[id]; !ok { + allCached = false + break + } + } + } + if allCached { + azForSubnet := make(map[string]string, len(subnetIDs)) + for _, id := range subnetIDs { + azForSubnet[id] = r.subnetAZCache[id] + } + r.subnetAZMu.RUnlock() + return pickOneSubnetPerAZ(subnetIDs, azForSubnet), nil + } + r.subnetAZMu.RUnlock() + + // Write path: new subnets found, call DescribeSubnets to populate cache. + r.subnetAZMu.Lock() + if r.subnetAZCache == nil { + r.subnetAZCache = make(map[string]string) + } + + // Re-check which subnets are uncached — between the RUnlock and Lock above, + // another reconcile may have already fetched them. + var uncached []string + for _, id := range subnetIDs { + if _, ok := r.subnetAZCache[id]; !ok { + uncached = append(uncached, id) + } + } + + if len(uncached) > 0 { + output, err := ec2Client.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{ + SubnetIds: uncached, + }) + if err != nil { + r.subnetAZMu.Unlock() + return nil, fmt.Errorf("failed to describe subnets for AZ deduplication: %w", err) + } + for _, subnet := range output.Subnets { + r.subnetAZCache[aws.ToString(subnet.SubnetId)] = aws.ToString(subnet.AvailabilityZone) + } + } + + azForSubnet := make(map[string]string, len(subnetIDs)) + for _, id := range subnetIDs { + azForSubnet[id] = r.subnetAZCache[id] + } + r.subnetAZMu.Unlock() + + return pickOneSubnetPerAZ(subnetIDs, azForSubnet), nil +} + +func pickOneSubnetPerAZ(subnetIDs []string, azForSubnet map[string]string) []string { + sorted := make([]string, len(subnetIDs)) + copy(sorted, subnetIDs) + sort.Strings(sorted) + + azToSubnet := make(map[string]string) + for _, id := range sorted { + az := azForSubnet[id] + if _, exists := azToSubnet[az]; !exists { + azToSubnet[az] = id + } + } + + deduped := make([]string, 0, len(azToSubnet)) + for _, id := range azToSubnet { + deduped = append(deduped, id) + } + sort.Strings(deduped) + return deduped +} + func (r *AWSEndpointServiceReconciler) reconcileAWSEndpointService(ctx context.Context, awsEndpointService *hyperv1.AWSEndpointService, hcp *hyperv1.HostedControlPlane, ec2Client awsapi.EC2API, route53Client awsapi.ROUTE53API) error { log, err := logr.FromContext(ctx) if err != nil { @@ -577,6 +668,13 @@ func (r *AWSEndpointServiceReconciler) reconcileAWSEndpointService(ctx context.C return err } + deduped, err := r.deduplicateSubnetsByAZ(ctx, ec2Client, awsEndpointService.Spec.SubnetIDs) + if err != nil { + log.Error(err, "failed to deduplicate subnets by AZ, proceeding with original list") + } else { + awsEndpointService.Spec.SubnetIDs = deduped + } + endpointID := awsEndpointService.Status.EndpointID var endpointDNSEntries []ec2types.DnsEntry if endpointID != "" { diff --git a/control-plane-operator/controllers/awsprivatelink/awsprivatelink_controller_test.go b/control-plane-operator/controllers/awsprivatelink/awsprivatelink_controller_test.go index 971655e63b20..56ddff942fed 100644 --- a/control-plane-operator/controllers/awsprivatelink/awsprivatelink_controller_test.go +++ b/control-plane-operator/controllers/awsprivatelink/awsprivatelink_controller_test.go @@ -107,6 +107,120 @@ func Test_diffIDs(t *testing.T) { } } +func Test_deduplicateSubnetsByAZ(t *testing.T) { + tests := []struct { + name string + subnetIDs []string + cachedAZs map[string]string + describeOutput *ec2v2.DescribeSubnetsOutput + describeErr error + expectDescribeCall bool + expectedSubnets []string + expectErr bool + }{ + { + name: "When empty subnet list it should return empty", + subnetIDs: []string{}, + expectedSubnets: []string{}, + }, + { + name: "When single subnet it should pass through without DescribeSubnets call", + subnetIDs: []string{"subnet-aaa"}, + expectedSubnets: []string{"subnet-aaa"}, + }, + { + name: "When subnets in different AZs it should keep all", + subnetIDs: []string{"subnet-aaa", "subnet-bbb"}, + describeOutput: &ec2v2.DescribeSubnetsOutput{ + Subnets: []ec2types.Subnet{ + {SubnetId: aws.String("subnet-aaa"), AvailabilityZone: aws.String("us-east-1a")}, + {SubnetId: aws.String("subnet-bbb"), AvailabilityZone: aws.String("us-east-1b")}, + }, + }, + expectDescribeCall: true, + expectedSubnets: []string{"subnet-aaa", "subnet-bbb"}, + }, + { + name: "When multiple subnets in same AZ it should keep lexicographically first", + subnetIDs: []string{"subnet-bbb", "subnet-aaa"}, + describeOutput: &ec2v2.DescribeSubnetsOutput{ + Subnets: []ec2types.Subnet{ + {SubnetId: aws.String("subnet-bbb"), AvailabilityZone: aws.String("us-east-1a")}, + {SubnetId: aws.String("subnet-aaa"), AvailabilityZone: aws.String("us-east-1a")}, + }, + }, + expectDescribeCall: true, + expectedSubnets: []string{"subnet-aaa"}, + }, + { + name: "When mixed same-AZ and different-AZ subnets it should dedup correctly", + subnetIDs: []string{"subnet-aaa", "subnet-bbb", "subnet-ccc"}, + describeOutput: &ec2v2.DescribeSubnetsOutput{ + Subnets: []ec2types.Subnet{ + {SubnetId: aws.String("subnet-aaa"), AvailabilityZone: aws.String("us-east-1a")}, + {SubnetId: aws.String("subnet-bbb"), AvailabilityZone: aws.String("us-east-1a")}, + {SubnetId: aws.String("subnet-ccc"), AvailabilityZone: aws.String("us-east-1b")}, + }, + }, + expectDescribeCall: true, + expectedSubnets: []string{"subnet-aaa", "subnet-ccc"}, + }, + { + name: "When cache covers all subnets it should not call DescribeSubnets", + subnetIDs: []string{"subnet-aaa", "subnet-bbb"}, + cachedAZs: map[string]string{"subnet-aaa": "us-east-1a", "subnet-bbb": "us-east-1b"}, + expectDescribeCall: false, + expectedSubnets: []string{"subnet-aaa", "subnet-bbb"}, + }, + { + name: "When cache covers some subnets it should call DescribeSubnets only for new ones", + subnetIDs: []string{"subnet-aaa", "subnet-bbb", "subnet-ccc"}, + cachedAZs: map[string]string{"subnet-aaa": "us-east-1a"}, + describeOutput: &ec2v2.DescribeSubnetsOutput{ + Subnets: []ec2types.Subnet{ + {SubnetId: aws.String("subnet-bbb"), AvailabilityZone: aws.String("us-east-1b")}, + {SubnetId: aws.String("subnet-ccc"), AvailabilityZone: aws.String("us-east-1c")}, + }, + }, + expectDescribeCall: true, + expectedSubnets: []string{"subnet-aaa", "subnet-bbb", "subnet-ccc"}, + }, + { + name: "When DescribeSubnets fails it should return error", + subnetIDs: []string{"subnet-aaa", "subnet-bbb"}, + describeErr: fmt.Errorf("access denied"), + expectDescribeCall: true, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + mockCtrl := gomock.NewController(t) + mockEC2 := awsapi.NewMockEC2API(mockCtrl) + + if tt.expectDescribeCall { + mockEC2.EXPECT().DescribeSubnets(gomock.Any(), gomock.Any()).Return(tt.describeOutput, tt.describeErr) + } + + r := &AWSEndpointServiceReconciler{ + subnetAZCache: tt.cachedAZs, + } + + gotSubnets, err := r.deduplicateSubnetsByAZ(context.Background(), mockEC2, tt.subnetIDs) + + if tt.expectErr { + g.Expect(err).To(HaveOccurred()) + return + } + + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(gotSubnets).To(Equal(tt.expectedSubnets)) + }) + } +} + func TestRecordForService(t *testing.T) { testCases := []struct { name string