diff --git a/go.mod b/go.mod index a3c6c2acb..37c69a3ef 100644 --- a/go.mod +++ b/go.mod @@ -8,8 +8,8 @@ require ( github.com/Azure/karpenter-provider-azure v1.5.1 github.com/crossplane/crossplane-runtime/v2 v2.1.0 github.com/evanphx/json-patch/v5 v5.9.11 - github.com/gofrs/uuid v4.4.0+incompatible github.com/go-logr/logr v1.4.3 + github.com/gofrs/uuid v4.4.0+incompatible github.com/google/go-cmp v0.7.0 github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 github.com/onsi/ginkgo/v2 v2.23.4 diff --git a/pkg/clients/azure/compute/vmsizerecommenderclient.go b/pkg/clients/azure/compute/vmsizerecommenderclient.go index 0c9db0904..5c6fb2822 100644 --- a/pkg/clients/azure/compute/vmsizerecommenderclient.go +++ b/pkg/clients/azure/compute/vmsizerecommenderclient.go @@ -16,7 +16,6 @@ import ( "os" "time" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/gofrs/uuid" "google.golang.org/protobuf/encoding/protojson" "k8s.io/klog/v2" @@ -24,6 +23,7 @@ import ( computev1 "go.goms.io/fleet/apis/protos/azure/compute/v1" "go.goms.io/fleet/pkg/clients/httputil" "go.goms.io/fleet/pkg/utils/controller" + fleetErrors "go.goms.io/fleet/pkg/utils/errors" ) const ( @@ -137,9 +137,14 @@ func (c *AttributeBasedVMSizeRecommenderClient) GenerateAttributeBasedRecommenda return nil, fmt.Errorf("failed to read response body: %w", err) } - // Check status code + // Check status code - categorize based on transient vs non-transient errors if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return nil, fmt.Errorf("request failed with status %d: %w", resp.StatusCode, runtime.NewResponseError(resp)) + desc := fmt.Sprintf("request failed with status %d: %s %s", resp.StatusCode, httpReq.Method, url) + if httputil.IsTransientStatusCode(resp.StatusCode) { + return nil, fleetErrors.NewTransientError(nil, desc) + } + // Non-transient errors (4xx client errors) should not be retried. + return nil, fleetErrors.NewUnexpectedError(nil, desc) } // Unmarshal response using protojson for proper proto3 support @@ -147,7 +152,7 @@ func (c *AttributeBasedVMSizeRecommenderClient) GenerateAttributeBasedRecommenda unmarshaler := protojson.UnmarshalOptions{ DiscardUnknown: true, } - if err := unmarshaler.Unmarshal(respBody, response); err != nil { + if err = unmarshaler.Unmarshal(respBody, response); err != nil { return nil, fmt.Errorf("failed to unmarshal response: %w", err) } diff --git a/pkg/clients/azure/compute/vmsizerecommenderclient_test.go b/pkg/clients/azure/compute/vmsizerecommenderclient_test.go index bfb36c60a..00933938f 100644 --- a/pkg/clients/azure/compute/vmsizerecommenderclient_test.go +++ b/pkg/clients/azure/compute/vmsizerecommenderclient_test.go @@ -15,6 +15,7 @@ import ( "google.golang.org/protobuf/proto" computev1 "go.goms.io/fleet/apis/protos/azure/compute/v1" + fleetErrors "go.goms.io/fleet/pkg/utils/errors" "go.goms.io/fleet/test/utils/azure/compute" ) @@ -81,13 +82,14 @@ func TestNewAttributeBasedVMSizeRecommenderClient(t *testing.T) { func TestClient_GenerateAttributeBasedRecommendations(t *testing.T) { tests := []struct { - name string - request *computev1.GenerateAttributeBasedRecommendationsRequest - mockStatusCode int - mockResponse string - wantResponse *computev1.GenerateAttributeBasedRecommendationsResponse - wantErr bool - wantErrMsg string + name string + request *computev1.GenerateAttributeBasedRecommendationsRequest + mockStatusCode int + mockResponse string + wantResponse *computev1.GenerateAttributeBasedRecommendationsResponse + wantErr bool + wantErrMsg string + wantIsTransient bool // true if error should be a transient HTTPError }{ { name: "successful request with regular priority profile", @@ -169,7 +171,7 @@ func TestClient_GenerateAttributeBasedRecommendations(t *testing.T) { wantErrMsg: "either regular priority profile or spot priority profile must be provided", }, { - name: "HTTP 400 error", + name: "HTTP 400 error is not transient", request: &computev1.GenerateAttributeBasedRecommendationsRequest{ SubscriptionId: "sub-123", Location: "eastus", @@ -177,13 +179,14 @@ func TestClient_GenerateAttributeBasedRecommendations(t *testing.T) { TargetCapacity: 5, }, }, - mockStatusCode: http.StatusBadRequest, - mockResponse: `{"error":"invalid request"}`, - wantErr: true, - wantErrMsg: "request failed with status 400", + mockStatusCode: http.StatusBadRequest, + mockResponse: `{"error":"invalid request"}`, + wantErr: true, + wantErrMsg: "request failed with status 400: POST", + wantIsTransient: false, // 400 is NOT transient }, { - name: "HTTP 500 error", + name: "HTTP 500 error is transient", request: &computev1.GenerateAttributeBasedRecommendationsRequest{ SubscriptionId: "sub-123", Location: "eastus", @@ -191,10 +194,41 @@ func TestClient_GenerateAttributeBasedRecommendations(t *testing.T) { TargetCapacity: 5, }, }, - mockStatusCode: http.StatusInternalServerError, - mockResponse: `{"error":"internal server error"}`, - wantErr: true, - wantErrMsg: "request failed with status 500", + mockStatusCode: http.StatusInternalServerError, + mockResponse: `{"error":"internal server error"}`, + wantErr: true, + wantErrMsg: "request failed with status 500: POST", + wantIsTransient: true, // 500 IS transient + }, + { + name: "HTTP 503 error is transient", + request: &computev1.GenerateAttributeBasedRecommendationsRequest{ + SubscriptionId: "sub-123", + Location: "eastus", + RegularPriorityProfile: &computev1.RegularPriorityProfile{ + TargetCapacity: 5, + }, + }, + mockStatusCode: http.StatusServiceUnavailable, + mockResponse: `{"error":"service unavailable"}`, + wantErr: true, + wantErrMsg: "request failed with status 503: POST", + wantIsTransient: true, // 503 IS transient + }, + { + name: "HTTP 429 error is transient", + request: &computev1.GenerateAttributeBasedRecommendationsRequest{ + SubscriptionId: "sub-123", + Location: "eastus", + RegularPriorityProfile: &computev1.RegularPriorityProfile{ + TargetCapacity: 5, + }, + }, + mockStatusCode: http.StatusTooManyRequests, + mockResponse: `{"error":"too many requests"}`, + wantErr: true, + wantErrMsg: "request failed with status 429: POST", + wantIsTransient: true, // 429 IS transient }, { name: "invalid JSON response", @@ -240,6 +274,20 @@ func TestClient_GenerateAttributeBasedRecommendations(t *testing.T) { return } + // Check if error has retry policy using fleetErrors.IsRetryable. + if tt.wantErr && tt.wantIsTransient { + isRetryable, hasRetryPolicy := fleetErrors.IsRetryable(err) + if !hasRetryPolicy || !isRetryable { + t.Errorf("GenerateAttributeBasedRecommendations() error = %v, want retryable error", err) + } + } + if tt.wantErr && !tt.wantIsTransient && tt.mockStatusCode >= 400 && tt.mockStatusCode < 500 { + isRetryable, hasRetryPolicy := fleetErrors.IsRetryable(err) + if hasRetryPolicy && isRetryable { + t.Errorf("GenerateAttributeBasedRecommendations() error = %v, should NOT be retryable for 4xx errors", err) + } + } + // Compare response. if !proto.Equal(tt.wantResponse, got) { t.Errorf("GenerateAttributeBasedRecommendations() = %+v, want %+v", got, tt.wantResponse) diff --git a/pkg/clients/httputil/httputil.go b/pkg/clients/httputil/httputil.go index 7ccbe84f6..3af8faed4 100644 --- a/pkg/clients/httputil/httputil.go +++ b/pkg/clients/httputil/httputil.go @@ -41,3 +41,19 @@ var ( // DefaultClientForAzure is the default HTTP client to access Azure services. DefaultClientForAzure = &http.Client{Timeout: HTTPTimeoutAzure} ) + +// transientHTTPStatusCodes defines HTTP status codes that indicate transient errors +// which may succeed on retry. +var transientHTTPStatusCodes = map[int]bool{ + http.StatusTooManyRequests: true, // 429 - Rate limiting + http.StatusInternalServerError: true, // 500 - Server error + http.StatusBadGateway: true, // 502 - Bad gateway + http.StatusServiceUnavailable: true, // 503 - Service unavailable + http.StatusGatewayTimeout: true, // 504 - Gateway timeout +} + +// IsTransientStatusCode returns true if the HTTP status code indicates a transient +// error that may succeed on retry (429, 5xx). +func IsTransientStatusCode(statusCode int) bool { + return transientHTTPStatusCodes[statusCode] +} diff --git a/pkg/clients/httputil/httputil_test.go b/pkg/clients/httputil/httputil_test.go new file mode 100644 index 000000000..aaf6affcf --- /dev/null +++ b/pkg/clients/httputil/httputil_test.go @@ -0,0 +1,37 @@ +/* +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. +*/ + +package httputil + +import ( + "net/http" + "testing" +) + +func TestIsTransientStatusCode(t *testing.T) { + tests := []struct { + name string + statusCode int + want bool + }{ + {"400 Bad Request", http.StatusBadRequest, false}, + {"401 Unauthorized", http.StatusUnauthorized, false}, + {"403 Forbidden", http.StatusForbidden, false}, + {"404 Not Found", http.StatusNotFound, false}, + {"429 Too Many Requests", http.StatusTooManyRequests, true}, + {"500 Internal Server Error", http.StatusInternalServerError, true}, + {"502 Bad Gateway", http.StatusBadGateway, true}, + {"503 Service Unavailable", http.StatusServiceUnavailable, true}, + {"504 Gateway Timeout", http.StatusGatewayTimeout, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsTransientStatusCode(tt.statusCode); got != tt.want { + t.Errorf("IsTransientStatusCode(%d) = %v, want %v", tt.statusCode, got, tt.want) + } + }) + } +} diff --git a/pkg/propertychecker/azure/checker.go b/pkg/propertychecker/azure/checker.go index c031beee6..3aad55c3b 100644 --- a/pkg/propertychecker/azure/checker.go +++ b/pkg/propertychecker/azure/checker.go @@ -18,6 +18,7 @@ import ( placementv1beta1 "go.goms.io/fleet/apis/placement/v1beta1" computev1 "go.goms.io/fleet/apis/protos/azure/compute/v1" "go.goms.io/fleet/pkg/clients/azure/compute" + fleetErrors "go.goms.io/fleet/pkg/utils/errors" "go.goms.io/fleet/pkg/utils/labels" ) @@ -58,6 +59,10 @@ func NewPropertyChecker(vmSizeRecommenderClient compute.AttributeBasedVMSizeReco // // The cluster must have both Azure location and subscription ID labels configured. // Returns true if the SKU capacity requirement can be met, false otherwise. +// +// Errors returned implement the RetryableError interface to indicate whether the operation +// can be retried. Configuration errors (missing labels, invalid capacity) are non-retryable, +// while Azure API errors preserve the retryability of the underlying HTTP error. func (s *PropertyChecker) CheckIfMeetSKUCapacityRequirement( cluster *clusterv1beta1.MemberCluster, req placementv1beta1.PropertySelectorRequirement, @@ -65,18 +70,24 @@ func (s *PropertyChecker) CheckIfMeetSKUCapacityRequirement( ) (bool, error) { location, err := labels.ExtractLabelFromMemberCluster(cluster, labels.AzureLocationLabel) if err != nil { - return false, fmt.Errorf("failed to extract Azure location label from cluster %s: %w", cluster.Name, err) + // Missing label is a configuration error; not retryable. + return false, fleetErrors.NewUserError(err, + fmt.Sprintf("failed to extract Azure location label from cluster %s", cluster.Name)) } subID, err := labels.ExtractLabelFromMemberCluster(cluster, labels.AzureSubscriptionIDLabel) if err != nil { - return false, fmt.Errorf("failed to extract Azure subscription ID label from cluster %s: %w", cluster.Name, err) + // Missing label is a configuration error; not retryable. + return false, fleetErrors.NewUserError(err, + fmt.Sprintf("failed to extract Azure subscription ID label from cluster %s", cluster.Name)) } // Extract capacity requirements from the property selector requirement. capacity, err := extractCapacityRequirements(req) if err != nil { - return false, fmt.Errorf("failed to extract capacity requirements from property selector requirement: %w", err) + // Invalid capacity specification is a user error; not retryable. + return false, fleetErrors.NewUserError(err, + "failed to extract capacity requirements from property selector requirement") } // Request VM size recommendations to validate SKU availability and capacity. @@ -101,7 +112,11 @@ func (s *PropertyChecker) CheckIfMeetSKUCapacityRequirement( respObj, err := s.vmSizeRecommenderClient.GenerateAttributeBasedRecommendations(context.Background(), request) if err != nil { - return false, fmt.Errorf("failed to generate VM size recommendations from Azure: %w", err) + // Wrap the error with context. The underlying error already has the appropriate + // category set (transient for 429/5xx, API server error for other failures), + // so fleetErrors.IsRetryable() will detect it in the error chain. + return false, fleetErrors.Wraps(err, + fmt.Sprintf("failed to generate VM size recommendations from Azure for SKU %s in cluster %s", sku, cluster.Name)) } // This check is a defense mechanism; vmSizeRecommenderClient should return a VM size recommendation diff --git a/pkg/scheduler/framework/framework.go b/pkg/scheduler/framework/framework.go index d0018214e..58484b44e 100644 --- a/pkg/scheduler/framework/framework.go +++ b/pkg/scheduler/framework/framework.go @@ -27,6 +27,7 @@ import ( "time" "golang.org/x/sync/errgroup" + corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/meta" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -45,6 +46,7 @@ import ( "go.goms.io/fleet/pkg/utils/annotations" "go.goms.io/fleet/pkg/utils/condition" "go.goms.io/fleet/pkg/utils/controller" + fleetErrors "go.goms.io/fleet/pkg/utils/errors" "go.goms.io/fleet/pkg/utils/parallelizer" ) @@ -57,6 +59,10 @@ const ( // NotFullyScheduledReason is the reason string of placement condition when the placement policy cannot be fully satisfied. NotFullyScheduledReason = "SchedulingPolicyUnfulfilled" + // Event reasons for scheduling errors. + // SchedulingErrorReason is used when the scheduler encounters an error during scheduling. + SchedulingErrorReason = "SchedulingError" + fullyScheduledMessage = "found all cluster needed as specified by the scheduling policy, found %d cluster(s)" notFullyScheduledMessage = "could not find all clusters needed as specified by the scheduling policy, found %d cluster(s) instead" @@ -539,6 +545,16 @@ func (f *framework) runAllPluginsForPickAllPlacementType( passed, filtered, err := f.runFilterPlugins(ctx, state, policy, clusters) if err != nil { klog.ErrorS(err, "Failed to run filter plugins", "policySnapshot", policyRef) + // Emit an event to inform the user about the scheduling error. + f.eventRecorder.Event(policy, corev1.EventTypeWarning, SchedulingErrorReason, + fmt.Sprintf("Failed to run filter plugins: %v", err)) + // Check if the error has a retry policy configured. + // If the error (or any error in its chain) implements ErrorWithRetryPolicy and indicates + // it's retryable, return it as-is so the scheduler can requeue. + // Otherwise, wrap it as unexpected behavior. + if isRetryable, hasRetryPolicy := fleetErrors.IsRetryable(err); hasRetryPolicy && isRetryable { + return nil, nil, err + } return nil, nil, controller.NewUnexpectedBehaviorError(err) } @@ -1159,6 +1175,16 @@ func (f *framework) runAllPluginsForPickNPlacementType( passed, filtered, err := f.runFilterPlugins(ctx, state, policy, clusters) if err != nil { klog.ErrorS(err, "Failed to run filter plugins", "policySnapshot", policyRef) + // Emit an event to inform the user about the scheduling error. + f.eventRecorder.Event(policy, corev1.EventTypeWarning, SchedulingErrorReason, + fmt.Sprintf("Failed to run filter plugins: %v", err)) + // Check if the error has a retry policy configured. + // If the error (or any error in its chain) implements ErrorWithRetryPolicy and indicates + // it's retryable, return it as-is so the scheduler can requeue. + // Otherwise, wrap it as unexpected behavior. + if isRetryable, hasRetryPolicy := fleetErrors.IsRetryable(err); hasRetryPolicy && isRetryable { + return nil, nil, err + } return nil, nil, controller.NewUnexpectedBehaviorError(err) } diff --git a/pkg/scheduler/framework/framework_test.go b/pkg/scheduler/framework/framework_test.go index c7124916c..a8cfd6fd0 100644 --- a/pkg/scheduler/framework/framework_test.go +++ b/pkg/scheduler/framework/framework_test.go @@ -36,6 +36,7 @@ import ( "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/kubernetes/scheme" + "k8s.io/client-go/tools/record" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" @@ -1323,8 +1324,9 @@ func TestRunAllPluginsForPickAllPlacementType(t *testing.T) { profile.WithFilterPlugin(p) } f := &framework{ - profile: profile, - parallelizer: parallelizer.NewParallelizer(parallelizer.DefaultNumOfWorkers), + profile: profile, + parallelizer: parallelizer.NewParallelizer(parallelizer.DefaultNumOfWorkers), + eventRecorder: record.NewFakeRecorder(10), } ctx := context.Background() @@ -6253,8 +6255,9 @@ func TestRunAllPluginsForPickNPlacementType(t *testing.T) { } f := &framework{ - profile: profile, - parallelizer: parallelizer.NewParallelizer(parallelizer.DefaultNumOfWorkers), + profile: profile, + parallelizer: parallelizer.NewParallelizer(parallelizer.DefaultNumOfWorkers), + eventRecorder: record.NewFakeRecorder(10), } ctx := context.Background() diff --git a/pkg/scheduler/framework/status.go b/pkg/scheduler/framework/status.go index 2b0a67a1f..d5013830d 100644 --- a/pkg/scheduler/framework/status.go +++ b/pkg/scheduler/framework/status.go @@ -140,13 +140,13 @@ func (s *Status) String() string { return strings.Join(desc, ", ") } -// AsError returns a status as an error; it returns nil if the status is of the internalError code. +// AsError returns a status as an error; it returns nil if the status is not of the internalError code. func (s *Status) AsError() error { if !s.IsInteralError() { return nil } - return fmt.Errorf("plugin %s returned an error %s", s.sourcePlugin, s.String()) + return fmt.Errorf("plugin %s returned an error %w", s.sourcePlugin, s.err) } // NewNonErrorStatus returns a Status with a non-error status code. diff --git a/pkg/scheduler/framework/status_test.go b/pkg/scheduler/framework/status_test.go index 896bcba3d..de52abe57 100644 --- a/pkg/scheduler/framework/status_test.go +++ b/pkg/scheduler/framework/status_test.go @@ -17,6 +17,7 @@ limitations under the License. package framework import ( + "errors" "fmt" "strings" "testing" @@ -157,3 +158,79 @@ func TestNilStatusMethods(t *testing.T) { t.Fatalf("String() = %s, want %s", status.String(), wantDesc) } } + +// customError is a custom error type for testing error chain preservation. +type customError struct { + code int +} + +func (e *customError) Error() string { + return fmt.Sprintf("custom error with code %d", e.code) +} + +func TestAsError(t *testing.T) { + testCases := []struct { + name string + status *Status + wantNil bool + wantCode int + }{ + { + name: "nil status returns nil", + status: nil, + wantNil: true, + }, + { + name: "success status returns nil", + status: NewNonErrorStatus(Success, dummyPlugin, "reason1"), + wantNil: true, + }, + { + name: "unschedulable status returns nil", + status: NewNonErrorStatus(ClusterUnschedulable, dummyPlugin, "reason1"), + wantNil: true, + }, + { + name: "internal error preserves error chain", + status: FromError(&customError{code: 503}, dummyPlugin, "reason1", "reason2"), + wantNil: false, + wantCode: 503, + }, + { + name: "internal error with wrapped error preserves full chain", + status: FromError(fmt.Errorf("outer error: %w", &customError{code: 429}), dummyPlugin, "rate limited"), + wantNil: false, + wantCode: 429, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.status.AsError() + + if tc.wantNil { + if err != nil { + t.Fatalf("AsError() = %v, want nil", err) + } + return + } + + if err == nil { + t.Fatalf("AsError() = nil, want non-nil error") + } + + // Verify error message contains plugin name + if !strings.Contains(err.Error(), dummyPlugin) { + t.Errorf("AsError().Error() = %q, want it to contain plugin name %q", err.Error(), dummyPlugin) + } + + // Verify the error chain is preserved using errors.As + var customErr *customError + if !errors.As(err, &customErr) { + t.Errorf("AsError() error chain broken: errors.As() could not find *customError in chain") + } else if customErr.code != tc.wantCode { + t.Errorf("AsError() error chain: customError.code = %d, want %d", customErr.code, tc.wantCode) + } + }) + } +} diff --git a/pkg/utils/errors/errors.go b/pkg/utils/errors/errors.go index 3f04059e5..16236df86 100644 --- a/pkg/utils/errors/errors.go +++ b/pkg/utils/errors/errors.go @@ -54,7 +54,38 @@ const ( ErrCategoryUncategorized ErrCategory = "uncategorized" ) +// ErrorWithRetryPolicy is an interface that errors can implement to indicate whether the +// operation that caused the error can be retried. This allows the control loop to +// make retry decisions based on error semantics rather than inspecting error formats. +type ErrorWithRetryPolicy interface { + error + // IsRetryable returns true if the error is transient and the operation may succeed + // on retry. Returns false if the error is permanent and retrying would not help. + IsRetryable() bool +} + +// IsRetryable checks if the given error (or any error in its chain) has a retry policy +// configured. It traverses the error chain using errors.As to find any error that +// implements ErrorWithRetryPolicy interface. +// +// Returns: +// - (true, true) if an ErrorWithRetryPolicy is found and IsRetryable() returns true +// - (false, true) if an ErrorWithRetryPolicy is found and IsRetryable() returns false +// - (false, false) if no ErrorWithRetryPolicy is found in the chain (no retry policy configured) +func IsRetryable(err error) (isRetryable bool, hasRetryPolicy bool) { + if err == nil { + return false, false + } + + var errWithPolicy ErrorWithRetryPolicy + if errors.As(err, &errWithPolicy) { + return errWithPolicy.IsRetryable(), true + } + return false, false +} + var _ error = &Error{} +var _ ErrorWithRetryPolicy = &Error{} type Error struct { // category is the category of the error. @@ -77,6 +108,29 @@ func (e *Error) categoryWithDefault() ErrCategory { return e.category } +// IsRetryable implements the ErrorWithRetryPolicy interface. +// It determines retryability based on the error category: +// - ErrCategoryTransient: retryable (will self-resolve) +// - ErrCategoryAPIServer: retryable (API server issues are often transient) +// - ErrCategoryUnexpected: not retryable (unknown state, cannot recover) +// - ErrCategoryUser: not retryable (requires user action to fix) +// - ErrCategoryUncategorized: retryable (default to retry when unknown) +func (e *Error) IsRetryable() bool { + switch e.category { + case ErrCategoryTransient: + return true + case ErrCategoryAPIServer: + return true + case ErrCategoryUnexpected: + return false + case ErrCategoryUser: + return false + default: + // ErrCategoryUncategorized or unknown: default to retryable. + return true + } +} + // Error implements the error interface. // // Note the output intentionally does not include the additional attributes, so as to keep the cardinality diff --git a/pkg/utils/errors/errors_test.go b/pkg/utils/errors/errors_test.go index d715640fa..277cbc2e3 100644 --- a/pkg/utils/errors/errors_test.go +++ b/pkg/utils/errors/errors_test.go @@ -578,3 +578,154 @@ func readFromBuffer(t *testing.T, buf *bytes.Buffer) string { outputStr := string(output) return outputStr } + +// mockErrorWithRetryPolicy is a test helper that implements the ErrorWithRetryPolicy interface. +type mockErrorWithRetryPolicy struct { + retryable bool +} + +func (e *mockErrorWithRetryPolicy) Error() string { + return "mock error with retry policy" +} + +func (e *mockErrorWithRetryPolicy) IsRetryable() bool { + return e.retryable +} + +func TestIsRetryableFunction(t *testing.T) { + testCases := []struct { + name string + err error + wantRetryable bool + wantFound bool + }{ + { + name: "nil error", + err: nil, + wantRetryable: false, + wantFound: false, + }, + { + name: "plain error (no retry policy in chain)", + err: fmt.Errorf("plain error"), + wantRetryable: false, + wantFound: false, + }, + { + name: "error with retry policy (retryable=true)", + err: &mockErrorWithRetryPolicy{retryable: true}, + wantRetryable: true, + wantFound: true, + }, + { + name: "error with retry policy (retryable=false)", + err: &mockErrorWithRetryPolicy{retryable: false}, + wantRetryable: false, + wantFound: true, + }, + { + name: "wrapped error with retry policy", + err: fmt.Errorf("wrapped: %w", &mockErrorWithRetryPolicy{retryable: true}), + wantRetryable: true, + wantFound: true, + }, + { + name: "Error with ErrCategoryTransient (retryable)", + err: &Error{category: ErrCategoryTransient}, + wantRetryable: true, + wantFound: true, + }, + { + name: "Error with ErrCategoryUser (not retryable)", + err: &Error{category: ErrCategoryUser}, + wantRetryable: false, + wantFound: true, + }, + { + name: "Error with ErrCategoryUnexpected (not retryable)", + err: &Error{category: ErrCategoryUnexpected}, + wantRetryable: false, + wantFound: true, + }, + { + name: "Error with ErrCategoryAPIServer (retryable)", + err: &Error{category: ErrCategoryAPIServer}, + wantRetryable: true, + wantFound: true, + }, + { + name: "Error with ErrCategoryUncategorized (retryable by default)", + err: &Error{category: ErrCategoryUncategorized}, + wantRetryable: true, + wantFound: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + gotRetryable, gotFound := IsRetryable(tc.err) + if gotRetryable != tc.wantRetryable { + t.Errorf("IsRetryable() retryable = %v, want %v", gotRetryable, tc.wantRetryable) + } + if gotFound != tc.wantFound { + t.Errorf("IsRetryable() found = %v, want %v", gotFound, tc.wantFound) + } + }) + } +} + +func TestErrorIsRetryable(t *testing.T) { + testCases := []struct { + name string + err *Error + wantRetryable bool + }{ + { + name: "ErrCategoryTransient is retryable", + err: &Error{category: ErrCategoryTransient}, + wantRetryable: true, + }, + { + name: "ErrCategoryAPIServer is retryable", + err: &Error{category: ErrCategoryAPIServer}, + wantRetryable: true, + }, + { + name: "ErrCategoryUnexpected is not retryable", + err: &Error{category: ErrCategoryUnexpected}, + wantRetryable: false, + }, + { + name: "ErrCategoryUser is not retryable", + err: &Error{category: ErrCategoryUser}, + wantRetryable: false, + }, + { + name: "ErrCategoryUncategorized is retryable by default", + err: &Error{category: ErrCategoryUncategorized}, + wantRetryable: true, + }, + { + name: "empty category defaults to retryable", + err: &Error{}, + wantRetryable: true, + }, + { + name: "Error with wrapped error uses its own category", + err: &Error{ + category: ErrCategoryUser, + wrapped: fmt.Errorf("plain error"), + }, + wantRetryable: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + gotRetryable := tc.err.IsRetryable() + if gotRetryable != tc.wantRetryable { + t.Errorf("IsRetryable() = %v, want %v", gotRetryable, tc.wantRetryable) + } + }) + } +}