From 466df553d0a6df05cab64e5279cadb1032b171f6 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Tue, 16 Jun 2026 16:34:29 +0530 Subject: [PATCH 1/6] Add a standard toggle for resumability to ResumableJobMixin --- .../provider_dependencies.json.sha256sum | 2 +- providers/apache/spark/docs/operators.rst | 6 ++-- .../apache/spark/operators/spark_submit.py | 33 +++++-------------- .../spark/operators/test_spark_submit.py | 28 ++++++++-------- task-sdk/docs/resumable-job-mixin.rst | 25 ++++++++++++++ .../airflow/sdk/bases/resumablejobmixin.py | 13 ++++++++ .../task_sdk/bases/test_resumablejobmixin.py | 32 +++++++++++++++++- 7 files changed, 96 insertions(+), 43 deletions(-) diff --git a/generated/provider_dependencies.json.sha256sum b/generated/provider_dependencies.json.sha256sum index d173961afb128..413fc17032478 100644 --- a/generated/provider_dependencies.json.sha256sum +++ b/generated/provider_dependencies.json.sha256sum @@ -1 +1 @@ -b17f09d421b67d9d3925516c27c0fc4b4fb9f4fa4e4c495ebf3c643b3d12e59c +8609061b1d7c65722ca143c6e54bf569c2b3bb2bfeac9ecc85c97a114a5d83ac diff --git a/providers/apache/spark/docs/operators.rst b/providers/apache/spark/docs/operators.rst index 4d3a9a526af5b..3a3b9b8a3f051 100644 --- a/providers/apache/spark/docs/operators.rst +++ b/providers/apache/spark/docs/operators.rst @@ -236,7 +236,7 @@ Python Kubernetes client rather than holding ``spark-submit`` open for the full conn_id="spark_k8s", deploy_mode="cluster", track_driver_via_k8s_api=True, - reconnect_on_retry=True, + resume_on_retry=True, ) **Requirements** @@ -246,9 +246,9 @@ Python Kubernetes client rather than holding ``spark-submit`` open for the full conflicts with the flag and a ``ValueError`` will be raised at task start. * The Airflow worker must be able to reach the Kubernetes API server and have permission to read and delete pods in the driver's namespace; otherwise pod tracking and cleanup will fail. -* Set ``reconnect_on_retry=True`` (the default) to enable crash recovery: the driver pod name is +* Set ``resume_on_retry=True`` (the default) to enable crash recovery: the driver pod name is persisted to task state before polling begins, so a worker crash and retry reconnects to the - existing pod instead of submitting a fresh one. Set ``reconnect_on_retry=False`` to always + existing pod instead of submitting a fresh one. Set ``resume_on_retry=False`` to always submit a fresh driver on retry. * Pod completion is detected from ``pod.status.phase``. If your driver pods have sidecar containers (e.g. Istio injection enabled for the driver namespace), the pod phase may not diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py index 7ceb95b387a5d..42437a48eb39e 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py @@ -203,7 +203,7 @@ def __init__( deploy_mode: str | None = None, use_krb5ccache: bool = False, post_submit_commands: list[str] | None = None, - reconnect_on_retry: bool = True, + resume_on_retry: bool = True, track_driver_via_k8s_api: bool = False, yarn_track_via_rm_api: bool = False, yarn_rm_auth: AuthBase | None = None, @@ -252,7 +252,7 @@ def __init__( self._yarn_track_via_rm_api = yarn_track_via_rm_api self._yarn_rm_auth = yarn_rm_auth - self.reconnect_on_retry = reconnect_on_retry + self.resume_on_retry = resume_on_retry self._track_driver_via_k8s_api = track_driver_via_k8s_api self._openlineage_inject_parent_job_info = openlineage_inject_parent_job_info self._openlineage_inject_transport_info = openlineage_inject_transport_info @@ -272,33 +272,18 @@ def execute(self, context: Context) -> None: if self._track_driver_via_k8s_api: hook._validate_track_driver_via_k8s_api_config() if hook._should_track_driver_status: - if self.reconnect_on_retry: - return self.execute_resumable(context) - # reconnect_on_retry=False: still submit-and-poll, just skip task_state_store persistence. - driver_id = self.submit_job(context) - self.poll_until_complete(driver_id, context) - return self.get_job_result(driver_id, context) + return self.execute_resumable(context) if hook._should_track_driver_via_k8s_api(): - if self.reconnect_on_retry: - return self.execute_resumable(context) - # reconnect_on_retry=False: still submit-and-poll, just skip task_state persistence. - driver_id = self.submit_job(context) - self.poll_until_complete(driver_id, context) - return self.get_job_result(driver_id, context) + return self.execute_resumable(context) if hook._is_yarn_cluster_mode: - if self.reconnect_on_retry and not hook._yarn_track_via_rm_api: + if self.resume_on_retry and not hook._yarn_track_via_rm_api: raise ValueError( - "YARN cluster mode with reconnect_on_retry=True requires yarn_track_via_rm_api=True. " + "YARN cluster mode with resume_on_retry=True requires yarn_track_via_rm_api=True. " "The RM REST API is needed to check application status on retry." ) if hook._yarn_track_via_rm_api: hook._validate_yarn_track_via_rm_api_config() - if self.reconnect_on_retry: - return self.execute_resumable(context) - # reconnect_on_retry=False: still submit-and-poll, just skip task_state_store persistence. - driver_id = self.submit_job(context) - self.poll_until_complete(driver_id, context) - return self.get_job_result(driver_id, context) + return self.execute_resumable(context) hook.submit(self.application) def submit_job(self, context: Context) -> str | None: @@ -319,7 +304,7 @@ def submit_job(self, context: Context) -> str | None: raise ValueError( "spark.yarn.submit.waitAppCompletion=true cannot be set for cluster mode as it conflicts" "with the need to exit spark-submit immediately to persist the application ID for tracking. " - "Either remove the explicit conf or set reconnect_on_retry=False." + "Either remove the explicit conf or set resume_on_retry=False." ) self._hook._conf["spark.yarn.submit.waitAppCompletion"] = "false" self._hook.submit(self.application) @@ -445,7 +430,7 @@ def poll_until_complete(self, external_id: JsonValue, context: Context) -> None: # Cache only when the pod actually reached Succeeded, the 404/vanished path # returns None for cases like: pod deleted by on_kill or garbage collected after failure) # and must not be cached, otherwise a retry would see "Succeeded" and skip resubmission. - if terminal_phase == "Succeeded" and self.reconnect_on_retry: + if terminal_phase == "Succeeded" and self.resume_on_retry: if (task_state_store := context.get("task_state_store")) is not None: task_state_store.set(self._K8S_DRIVER_STATUS_KEY, "Succeeded") return diff --git a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py index 9708aab1a02ac..def6cc91aec49 100644 --- a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py +++ b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py @@ -590,8 +590,8 @@ def test_submits_fresh_when_task_state_store_unavailable(self): operator._hook.submit.assert_called_once_with("test.jar") assert polled == ["driver-001"] - def test_reconnect_on_retry_false_submits_fresh_and_polls(self): - operator = self._make_operator(reconnect_on_retry=False) + def test_resume_on_retry_false_submits_fresh_and_polls(self): + operator = self._make_operator(resume_on_retry=False) operator._hook = self._make_hook(should_track=True) operator._hook.submit.return_value = "driver-new" task_store = FakeTaskStateStore({"spark_job_id": "driver-old"}) @@ -599,7 +599,7 @@ def test_reconnect_on_retry_false_submits_fresh_and_polls(self): operator.poll_until_complete = lambda external_id, context: polled.append(external_id) operator.execute(context={"task_state_store": task_store}) - # reconnect_on_retry=False: ignores prior driver ID, submits fresh, but still polls + # resume_on_retry=False: ignores prior driver ID, submits fresh, but still polls operator._hook.submit.assert_called_once_with("test.jar") assert polled == ["driver-new"] @@ -863,8 +863,8 @@ def test_on_kill_sends_authenticated_kill_to_yarn_rm(self): hook._kill_yarn_application.assert_called_once_with("application_1234_0001") def test_yarn_cluster_reconnect_without_rm_api_raises(self): - """reconnect_on_retry=True + yarn_track_via_rm_api=False must raise - RM API is required for resume.""" - operator = self._make_operator(reconnect_on_retry=True) + """resume_on_retry=True + yarn_track_via_rm_api=False must raise - RM API is required for resume.""" + operator = self._make_operator(resume_on_retry=True) hook = self._make_hook(is_yarn_cluster=True) hook._yarn_track_via_rm_api = False operator._hook = hook @@ -873,8 +873,8 @@ def test_yarn_cluster_reconnect_without_rm_api_raises(self): operator.execute(context={}) def test_yarn_cluster_without_rm_api_reconnect_false_falls_through_to_hook_submit(self): - """reconnect_on_retry=False + yarn_track_via_rm_api=False falls through to hook.submit() - no RM polling.""" - operator = self._make_operator(reconnect_on_retry=False) + """resume_on_retry=False + yarn_track_via_rm_api=False falls through to hook.submit() - no RM polling.""" + operator = self._make_operator(resume_on_retry=False) hook = self._make_hook(is_yarn_cluster=True) hook._yarn_track_via_rm_api = False operator._hook = hook @@ -1040,7 +1040,7 @@ def test_k8s_poll_until_complete_writes_succeeded_to_task_store(self): assert task_store.get("k8s_driver_status") == "Succeeded" def test_k8s_polling_does_not_write_task_store_when_reconnect_disabled(self): - operator = self._make_operator(track_driver_via_k8s_api=True, reconnect_on_retry=False) + operator = self._make_operator(track_driver_via_k8s_api=True, resume_on_retry=False) hook = self._make_k8s_hook() hook._poll_k8s_driver_via_api.return_value = "Succeeded" operator._hook = hook @@ -1072,9 +1072,9 @@ def test_k8s_poll_until_complete_tolerates_absent_task_store(self): not AIRFLOW_V_3_3_PLUS, reason="ResumableJobMixin reconnect requires task_state, available in Airflow 3.3+", ) - def test_k8s_execute_persists_pod_id_when_reconnect_on_retry(self): - """execute() with reconnect_on_retry=True stores the pod ID in task_store before polling.""" - operator = self._make_operator(track_driver_via_k8s_api=True, reconnect_on_retry=True) + def test_k8s_execute_persists_pod_id_when_resume_on_retry(self): + """execute() with resume_on_retry=True stores the pod ID in task_store before polling.""" + operator = self._make_operator(track_driver_via_k8s_api=True, resume_on_retry=True) hook = self._make_k8s_hook() hook._kubernetes_driver_pod = "spark-abc-driver" hook._connection = {"namespace": "mynamespace"} @@ -1095,9 +1095,9 @@ def track_poll(external_id, context): not AIRFLOW_V_3_3_PLUS, reason="ResumableJobMixin reconnect requires task_state, available in Airflow 3.3+", ) - def test_k8s_execute_reconnect_on_retry_false_does_not_persist_pod_id(self): - """execute() with reconnect_on_retry=False does not write spark_job_id to task_store.""" - operator = self._make_operator(track_driver_via_k8s_api=True, reconnect_on_retry=False) + def test_k8s_execute_resume_on_retry_false_does_not_persist_pod_id(self): + """execute() with resume_on_retry=False does not write spark_job_id to task_store.""" + operator = self._make_operator(track_driver_via_k8s_api=True, resume_on_retry=False) hook = self._make_k8s_hook() hook._kubernetes_driver_pod = "spark-abc-driver" hook._connection = {"namespace": "mynamespace"} diff --git a/task-sdk/docs/resumable-job-mixin.rst b/task-sdk/docs/resumable-job-mixin.rst index e78a2e5ce1aea..160b9e3147fe7 100644 --- a/task-sdk/docs/resumable-job-mixin.rst +++ b/task-sdk/docs/resumable-job-mixin.rst @@ -124,6 +124,10 @@ Example external_id_key = "batch_job_id" + def __init__(self, *, resume_on_retry: bool = True, **kwargs): + super().__init__(**kwargs) + self.resume_on_retry = resume_on_retry + def execute(self, context): return self.execute_resumable(context) @@ -145,6 +149,27 @@ Example def get_job_result(self, external_id: JsonValue, context): return None +.. _sdk-resumable-job-mixin-resume-on-retry: + +Disabling crash recovery per task +---------------------------------- + +Set ``resume_on_retry=False`` on a task to opt out of crash recovery for that specific instance. +The operator will always submit a fresh job on retry, with no ``task_state_store`` interaction: + +.. code-block:: python + + run_spark = MyBatchOperator( + task_id="run_spark", + resume_on_retry=False, + ) + +This is useful when the external job is not idempotent and you want Airflow to always submit a +clean run rather than reconnect to a prior submission. + +The default is ``True``. Operators that use this mixin must declare ``resume_on_retry`` as an +``__init__`` parameter so that ``default_args`` injection and ``.partial()`` work correctly. + .. _sdk-resumable-job-mixin-external-id-key: External ID key diff --git a/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py b/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py index 8bfda28a66442..10529e35726a2 100644 --- a/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py +++ b/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py @@ -59,6 +59,10 @@ class ResumableJobMixin: class MyOperator(ResumableJobMixin, BaseOperator): external_id_key = "my_job_id" + def __init__(self, *, resume_on_retry: bool = True, **kwargs): + super().__init__(**kwargs) + self.resume_on_retry = resume_on_retry + def execute(self, context): return self.execute_resumable(context) @@ -90,6 +94,10 @@ def get_job_result(self, external_id: JsonValue, context: Context) -> Any: # Renaming this on a deployed operator breaks in-flight retries — the old key is already stored. external_id_key: str = "remote_job_id" + # Per-task toggle switch for resumability. When False, execute_resumable() skips all task_state_store interaction + # and submits fresh every time. This class attribute is the fallback default. + resume_on_retry: bool = True + def execute_resumable(self, context: Context) -> Any: """ Core of the resumable execution logic. Call this from execute() when reconnection is supported. @@ -107,6 +115,11 @@ def execute_resumable(self, context: Context) -> Any: Closing this window would require atomic "submit + persist", which is not possible across an external system boundary. """ + if not self.resume_on_retry: + external_id = self.submit_job(context) + self.poll_until_complete(external_id, context) + return self.get_job_result(external_id, context) + operator_tag = {"operator": type(self).__name__} reconnect_to: Any = None already_succeeded_id: Any = None diff --git a/task-sdk/tests/task_sdk/bases/test_resumablejobmixin.py b/task-sdk/tests/task_sdk/bases/test_resumablejobmixin.py index 11f97616cdf78..2ce0a31f669a0 100644 --- a/task-sdk/tests/task_sdk/bases/test_resumablejobmixin.py +++ b/task-sdk/tests/task_sdk/bases/test_resumablejobmixin.py @@ -38,8 +38,9 @@ class ConcreteResumableOperator(ResumableJobMixin, BaseOperator): external_id_key = "test_job_id" - def __init__(self, **kwargs): + def __init__(self, *, resume_on_retry: bool = True, **kwargs): super().__init__(**kwargs) + self.resume_on_retry = resume_on_retry self.submitted_ids: list[str] = [] self.polled_ids: list[str] = [] self._next_id = "job-001" @@ -192,6 +193,35 @@ def get_job_result(self, external_id, context) -> str: assert task_state._store == {} +class TestResumeOnRetryDisabled: + def test_submits_and_polls_without_task_store_interaction(self): + op = ConcreteResumableOperator(task_id="test_task", resume_on_retry=False) + task_store = FakeTaskState() + op.execute_resumable(make_context(task_store)) + + assert op.submitted_ids == ["job-001"] + assert op.polled_ids == ["job-001"] + assert task_store._store == {}, "task_store must not be written when resume_on_retry=False" + + def test_does_not_reconnect_when_prior_id_exists(self): + op = ConcreteResumableOperator(task_id="test_task", resume_on_retry=False) + op._status_map["job-001"] = "RUNNING" + task_store = FakeTaskState({"test_job_id": "job-001"}) + + op.execute_resumable(make_context(task_store)) + + assert op.submitted_ids == ["job-001"], "should submit fresh even with a prior ID stored" + + def test_returns_result(self): + op = ConcreteResumableOperator(task_id="test_task", resume_on_retry=False) + result = op.execute_resumable(make_context(FakeTaskState())) + assert result == "result-of-job-001" + + def test_default_is_true(self): + op = ConcreteResumableOperator(task_id="test_task") + assert op.resume_on_retry is True + + class TestExternalIdKey: def test_custom_key_used_for_storage_and_retrieval(self): class CustomKeyOp(ConcreteResumableOperator): From 32111eeb6c66d65762dc93432806b9817a6369ae Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 17 Jun 2026 14:37:05 +0530 Subject: [PATCH 2/6] Add a standard toggle for resumability to ResumableJobMixin --- .../apache/spark/operators/spark_submit.py | 12 ++++++++++-- .../apache/spark/operators/test_spark_submit.py | 13 +++++++++++++ task-sdk/docs/resumable-job-mixin.rst | 10 +++------- .../src/airflow/sdk/bases/resumablejobmixin.py | 15 ++++++++------- .../task_sdk/bases/test_resumablejobmixin.py | 3 +-- 5 files changed, 35 insertions(+), 18 deletions(-) diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py index 42437a48eb39e..b9c393a635c3b 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py @@ -17,12 +17,14 @@ # under the License. from __future__ import annotations +import warnings from collections.abc import Sequence from typing import TYPE_CHECKING, Any, cast import requests from tenacity import retry, stop_after_attempt, wait_fixed +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.apache.spark.hooks.spark_submit import _K8S_WAIT_APP_COMPLETION_CONF, SparkSubmitHook from airflow.providers.common.compat.openlineage.utils.spark import ( inject_parent_job_information_into_spark_properties, @@ -203,7 +205,6 @@ def __init__( deploy_mode: str | None = None, use_krb5ccache: bool = False, post_submit_commands: list[str] | None = None, - resume_on_retry: bool = True, track_driver_via_k8s_api: bool = False, yarn_track_via_rm_api: bool = False, yarn_rm_auth: AuthBase | None = None, @@ -213,8 +214,16 @@ def __init__( openlineage_inject_transport_info: bool = conf.getboolean( "openlineage", "spark_inject_transport_info", fallback=False ), + reconnect_on_retry: bool | None = None, **kwargs: Any, ) -> None: + if reconnect_on_retry is not None: + warnings.warn( + "reconnect_on_retry is renamed to resume_on_retry.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + kwargs.setdefault("resume_on_retry", reconnect_on_retry) super().__init__(**kwargs) self.application = application self.conf = conf @@ -252,7 +261,6 @@ def __init__( self._yarn_track_via_rm_api = yarn_track_via_rm_api self._yarn_rm_auth = yarn_rm_auth - self.resume_on_retry = resume_on_retry self._track_driver_via_k8s_api = track_driver_via_k8s_api self._openlineage_inject_parent_job_info = openlineage_inject_parent_job_info self._openlineage_inject_transport_info = openlineage_inject_transport_info diff --git a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py index def6cc91aec49..aaca4ddf4c416 100644 --- a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py +++ b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py @@ -590,6 +590,19 @@ def test_submits_fresh_when_task_state_store_unavailable(self): operator._hook.submit.assert_called_once_with("test.jar") assert polled == ["driver-001"] + def test_reconnect_on_retry_deprecated_alias(self): + import warnings + + from airflow.exceptions import AirflowProviderDeprecationWarning + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + operator = self._make_operator(reconnect_on_retry=False) + assert len(w) == 1 + assert issubclass(w[0].category, AirflowProviderDeprecationWarning) + assert "reconnect_on_retry" in str(w[0].message) + assert operator.resume_on_retry is False + def test_resume_on_retry_false_submits_fresh_and_polls(self): operator = self._make_operator(resume_on_retry=False) operator._hook = self._make_hook(should_track=True) diff --git a/task-sdk/docs/resumable-job-mixin.rst b/task-sdk/docs/resumable-job-mixin.rst index 160b9e3147fe7..f6066736baa83 100644 --- a/task-sdk/docs/resumable-job-mixin.rst +++ b/task-sdk/docs/resumable-job-mixin.rst @@ -120,14 +120,10 @@ Example from pydantic import JsonValue - class MyBatchOperator(BaseOperator, ResumableJobMixin): + class MyBatchOperator(ResumableJobMixin, BaseOperator): external_id_key = "batch_job_id" - def __init__(self, *, resume_on_retry: bool = True, **kwargs): - super().__init__(**kwargs) - self.resume_on_retry = resume_on_retry - def execute(self, context): return self.execute_resumable(context) @@ -167,8 +163,8 @@ The operator will always submit a fresh job on retry, with no ``task_state_store This is useful when the external job is not idempotent and you want Airflow to always submit a clean run rather than reconnect to a prior submission. -The default is ``True``. Operators that use this mixin must declare ``resume_on_retry`` as an -``__init__`` parameter so that ``default_args`` injection and ``.partial()`` work correctly. +The default is ``True``. ``resume_on_retry`` is owned by the mixin — operators do not need to +redeclare it. ``default_args`` injection and ``.partial()`` work automatically. .. _sdk-resumable-job-mixin-external-id-key: diff --git a/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py b/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py index 10529e35726a2..eabfc10d0bee5 100644 --- a/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py +++ b/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py @@ -21,6 +21,7 @@ from opentelemetry import trace from airflow.sdk._shared.observability.metrics import stats +from airflow.sdk.bases.operator import BaseOperatorMeta if TYPE_CHECKING: from pydantic import JsonValue @@ -59,10 +60,6 @@ class ResumableJobMixin: class MyOperator(ResumableJobMixin, BaseOperator): external_id_key = "my_job_id" - def __init__(self, *, resume_on_retry: bool = True, **kwargs): - super().__init__(**kwargs) - self.resume_on_retry = resume_on_retry - def execute(self, context): return self.execute_resumable(context) @@ -94,9 +91,13 @@ def get_job_result(self, external_id: JsonValue, context: Context) -> Any: # Renaming this on a deployed operator breaks in-flight retries — the old key is already stored. external_id_key: str = "remote_job_id" - # Per-task toggle switch for resumability. When False, execute_resumable() skips all task_state_store interaction - # and submits fresh every time. This class attribute is the fallback default. - resume_on_retry: bool = True + # The mixin is not a BaseOperator subclass, but _apply_defaults is only ever called on concrete + # operators that are BaseOperator subclasses. That is a runtime MRO guarantee not visible in the static + # type signature here and hence we need the type ignore. + @BaseOperatorMeta._apply_defaults # type: ignore[type-var] + def __init__(self, *, resume_on_retry: bool = True, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.resume_on_retry = resume_on_retry def execute_resumable(self, context: Context) -> Any: """ diff --git a/task-sdk/tests/task_sdk/bases/test_resumablejobmixin.py b/task-sdk/tests/task_sdk/bases/test_resumablejobmixin.py index 2ce0a31f669a0..9132b923bc144 100644 --- a/task-sdk/tests/task_sdk/bases/test_resumablejobmixin.py +++ b/task-sdk/tests/task_sdk/bases/test_resumablejobmixin.py @@ -38,9 +38,8 @@ class ConcreteResumableOperator(ResumableJobMixin, BaseOperator): external_id_key = "test_job_id" - def __init__(self, *, resume_on_retry: bool = True, **kwargs): + def __init__(self, **kwargs): super().__init__(**kwargs) - self.resume_on_retry = resume_on_retry self.submitted_ids: list[str] = [] self.polled_ids: list[str] = [] self._next_id = "job-001" From fba7d7717039cd05e12ee4b60441bcb71e302d1b Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 17 Jun 2026 14:42:50 +0530 Subject: [PATCH 3/6] top level import --- .../tests/unit/apache/spark/operators/test_spark_submit.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py index aaca4ddf4c416..2603df9ca64ae 100644 --- a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py +++ b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py @@ -18,12 +18,14 @@ from __future__ import annotations import logging +import warnings from datetime import timedelta from unittest import mock from unittest.mock import MagicMock import pytest +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.models import DagRun, TaskInstance from airflow.models.dag import DAG from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator @@ -591,10 +593,6 @@ def test_submits_fresh_when_task_state_store_unavailable(self): assert polled == ["driver-001"] def test_reconnect_on_retry_deprecated_alias(self): - import warnings - - from airflow.exceptions import AirflowProviderDeprecationWarning - with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") operator = self._make_operator(reconnect_on_retry=False) From d967ab5354cc78448d0616cb7bdc13d60e808d87 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Thu, 18 Jun 2026 13:25:41 +0530 Subject: [PATCH 4/6] changing to durable --- .../apache/spark/operators/spark_submit.py | 17 ++++++---- .../spark/operators/test_spark_submit.py | 32 ++++++++++--------- task-sdk/docs/resumable-job-mixin.rst | 6 ++-- .../airflow/sdk/bases/resumablejobmixin.py | 6 ++-- .../task_sdk/bases/test_resumablejobmixin.py | 10 +++--- 5 files changed, 39 insertions(+), 32 deletions(-) diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py index b9c393a635c3b..2da9246d02073 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py @@ -48,6 +48,11 @@ class ResumableJobMixin: # type: ignore[no-redef] external_id_key: str = "remote_job_id" + def __init__(self, *, durable: bool = True, **kwargs: Any) -> None: + # Accept durable so the kwarg doesn't leak to BaseOperator; crash recovery is a no-op here. + super().__init__(**kwargs) + self.durable = durable + def execute_resumable(self, context): external_id = self.submit_job(context) self.poll_until_complete(external_id, context) @@ -219,11 +224,11 @@ def __init__( ) -> None: if reconnect_on_retry is not None: warnings.warn( - "reconnect_on_retry is renamed to resume_on_retry.", + "reconnect_on_retry is renamed to durable.", AirflowProviderDeprecationWarning, stacklevel=2, ) - kwargs.setdefault("resume_on_retry", reconnect_on_retry) + kwargs.setdefault("durable", reconnect_on_retry) super().__init__(**kwargs) self.application = application self.conf = conf @@ -284,9 +289,9 @@ def execute(self, context: Context) -> None: if hook._should_track_driver_via_k8s_api(): return self.execute_resumable(context) if hook._is_yarn_cluster_mode: - if self.resume_on_retry and not hook._yarn_track_via_rm_api: + if self.durable and not hook._yarn_track_via_rm_api: raise ValueError( - "YARN cluster mode with resume_on_retry=True requires yarn_track_via_rm_api=True. " + "YARN cluster mode with durable=True requires yarn_track_via_rm_api=True. " "The RM REST API is needed to check application status on retry." ) if hook._yarn_track_via_rm_api: @@ -312,7 +317,7 @@ def submit_job(self, context: Context) -> str | None: raise ValueError( "spark.yarn.submit.waitAppCompletion=true cannot be set for cluster mode as it conflicts" "with the need to exit spark-submit immediately to persist the application ID for tracking. " - "Either remove the explicit conf or set resume_on_retry=False." + "Either remove the explicit conf or set durable=False." ) self._hook._conf["spark.yarn.submit.waitAppCompletion"] = "false" self._hook.submit(self.application) @@ -438,7 +443,7 @@ def poll_until_complete(self, external_id: JsonValue, context: Context) -> None: # Cache only when the pod actually reached Succeeded, the 404/vanished path # returns None for cases like: pod deleted by on_kill or garbage collected after failure) # and must not be cached, otherwise a retry would see "Succeeded" and skip resubmission. - if terminal_phase == "Succeeded" and self.resume_on_retry: + if terminal_phase == "Succeeded" and self.durable: if (task_state_store := context.get("task_state_store")) is not None: task_state_store.set(self._K8S_DRIVER_STATUS_KEY, "Succeeded") return diff --git a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py index 2603df9ca64ae..daa0fa119ca3a 100644 --- a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py +++ b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py @@ -599,10 +599,10 @@ def test_reconnect_on_retry_deprecated_alias(self): assert len(w) == 1 assert issubclass(w[0].category, AirflowProviderDeprecationWarning) assert "reconnect_on_retry" in str(w[0].message) - assert operator.resume_on_retry is False + assert operator.durable is False - def test_resume_on_retry_false_submits_fresh_and_polls(self): - operator = self._make_operator(resume_on_retry=False) + def test_durable_false_submits_fresh_and_polls(self): + operator = self._make_operator(durable=False) operator._hook = self._make_hook(should_track=True) operator._hook.submit.return_value = "driver-new" task_store = FakeTaskStateStore({"spark_job_id": "driver-old"}) @@ -610,7 +610,7 @@ def test_resume_on_retry_false_submits_fresh_and_polls(self): operator.poll_until_complete = lambda external_id, context: polled.append(external_id) operator.execute(context={"task_state_store": task_store}) - # resume_on_retry=False: ignores prior driver ID, submits fresh, but still polls + # durable=False: ignores prior driver ID, submits fresh, but still polls operator._hook.submit.assert_called_once_with("test.jar") assert polled == ["driver-new"] @@ -874,8 +874,8 @@ def test_on_kill_sends_authenticated_kill_to_yarn_rm(self): hook._kill_yarn_application.assert_called_once_with("application_1234_0001") def test_yarn_cluster_reconnect_without_rm_api_raises(self): - """resume_on_retry=True + yarn_track_via_rm_api=False must raise - RM API is required for resume.""" - operator = self._make_operator(resume_on_retry=True) + """durable=True + yarn_track_via_rm_api=False must raise - RM API is required for resume.""" + operator = self._make_operator(durable=True) hook = self._make_hook(is_yarn_cluster=True) hook._yarn_track_via_rm_api = False operator._hook = hook @@ -884,8 +884,8 @@ def test_yarn_cluster_reconnect_without_rm_api_raises(self): operator.execute(context={}) def test_yarn_cluster_without_rm_api_reconnect_false_falls_through_to_hook_submit(self): - """resume_on_retry=False + yarn_track_via_rm_api=False falls through to hook.submit() - no RM polling.""" - operator = self._make_operator(resume_on_retry=False) + """durable=False + yarn_track_via_rm_api=False falls through to hook.submit() - no RM polling.""" + operator = self._make_operator(durable=False) hook = self._make_hook(is_yarn_cluster=True) hook._yarn_track_via_rm_api = False operator._hook = hook @@ -1039,6 +1039,7 @@ def test_k8s_poll_until_complete_sets_pod_name_and_calls_poll_api(self): assert hook._kubernetes_driver_pod == "spark-abc-driver" hook._poll_k8s_driver_via_api.assert_called_once() + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="task_state_store requires Airflow 3.3+") def test_k8s_poll_until_complete_writes_succeeded_to_task_store(self): operator = self._make_operator(track_driver_via_k8s_api=True) hook = self._make_k8s_hook() @@ -1050,8 +1051,9 @@ def test_k8s_poll_until_complete_writes_succeeded_to_task_store(self): assert task_store.get("k8s_driver_status") == "Succeeded" + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="task_state_store requires Airflow 3.3+") def test_k8s_polling_does_not_write_task_store_when_reconnect_disabled(self): - operator = self._make_operator(track_driver_via_k8s_api=True, resume_on_retry=False) + operator = self._make_operator(track_driver_via_k8s_api=True, durable=False) hook = self._make_k8s_hook() hook._poll_k8s_driver_via_api.return_value = "Succeeded" operator._hook = hook @@ -1083,9 +1085,9 @@ def test_k8s_poll_until_complete_tolerates_absent_task_store(self): not AIRFLOW_V_3_3_PLUS, reason="ResumableJobMixin reconnect requires task_state, available in Airflow 3.3+", ) - def test_k8s_execute_persists_pod_id_when_resume_on_retry(self): - """execute() with resume_on_retry=True stores the pod ID in task_store before polling.""" - operator = self._make_operator(track_driver_via_k8s_api=True, resume_on_retry=True) + def test_k8s_execute_persists_pod_id_when_durable(self): + """execute() with durable=True stores the pod ID in task_store before polling.""" + operator = self._make_operator(track_driver_via_k8s_api=True, durable=True) hook = self._make_k8s_hook() hook._kubernetes_driver_pod = "spark-abc-driver" hook._connection = {"namespace": "mynamespace"} @@ -1106,9 +1108,9 @@ def track_poll(external_id, context): not AIRFLOW_V_3_3_PLUS, reason="ResumableJobMixin reconnect requires task_state, available in Airflow 3.3+", ) - def test_k8s_execute_resume_on_retry_false_does_not_persist_pod_id(self): - """execute() with resume_on_retry=False does not write spark_job_id to task_store.""" - operator = self._make_operator(track_driver_via_k8s_api=True, resume_on_retry=False) + def test_k8s_execute_durable_false_does_not_persist_pod_id(self): + """execute() with durable=False does not write spark_job_id to task_store.""" + operator = self._make_operator(track_driver_via_k8s_api=True, durable=False) hook = self._make_k8s_hook() hook._kubernetes_driver_pod = "spark-abc-driver" hook._connection = {"namespace": "mynamespace"} diff --git a/task-sdk/docs/resumable-job-mixin.rst b/task-sdk/docs/resumable-job-mixin.rst index f6066736baa83..345fc4c4494fb 100644 --- a/task-sdk/docs/resumable-job-mixin.rst +++ b/task-sdk/docs/resumable-job-mixin.rst @@ -150,20 +150,20 @@ Example Disabling crash recovery per task ---------------------------------- -Set ``resume_on_retry=False`` on a task to opt out of crash recovery for that specific instance. +Set ``durable=False`` on a task to opt out of crash recovery for that specific instance. The operator will always submit a fresh job on retry, with no ``task_state_store`` interaction: .. code-block:: python run_spark = MyBatchOperator( task_id="run_spark", - resume_on_retry=False, + durable=False, ) This is useful when the external job is not idempotent and you want Airflow to always submit a clean run rather than reconnect to a prior submission. -The default is ``True``. ``resume_on_retry`` is owned by the mixin — operators do not need to +The default is ``True``. ``durable`` is owned by the mixin — operators do not need to redeclare it. ``default_args`` injection and ``.partial()`` work automatically. .. _sdk-resumable-job-mixin-external-id-key: diff --git a/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py b/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py index eabfc10d0bee5..a9c9bcc913e6a 100644 --- a/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py +++ b/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py @@ -95,9 +95,9 @@ def get_job_result(self, external_id: JsonValue, context: Context) -> Any: # operators that are BaseOperator subclasses. That is a runtime MRO guarantee not visible in the static # type signature here and hence we need the type ignore. @BaseOperatorMeta._apply_defaults # type: ignore[type-var] - def __init__(self, *, resume_on_retry: bool = True, **kwargs: Any) -> None: + def __init__(self, *, durable: bool = True, **kwargs: Any) -> None: super().__init__(**kwargs) - self.resume_on_retry = resume_on_retry + self.durable = durable def execute_resumable(self, context: Context) -> Any: """ @@ -116,7 +116,7 @@ def execute_resumable(self, context: Context) -> Any: Closing this window would require atomic "submit + persist", which is not possible across an external system boundary. """ - if not self.resume_on_retry: + if not self.durable: external_id = self.submit_job(context) self.poll_until_complete(external_id, context) return self.get_job_result(external_id, context) diff --git a/task-sdk/tests/task_sdk/bases/test_resumablejobmixin.py b/task-sdk/tests/task_sdk/bases/test_resumablejobmixin.py index 9132b923bc144..153557845025f 100644 --- a/task-sdk/tests/task_sdk/bases/test_resumablejobmixin.py +++ b/task-sdk/tests/task_sdk/bases/test_resumablejobmixin.py @@ -194,16 +194,16 @@ def get_job_result(self, external_id, context) -> str: class TestResumeOnRetryDisabled: def test_submits_and_polls_without_task_store_interaction(self): - op = ConcreteResumableOperator(task_id="test_task", resume_on_retry=False) + op = ConcreteResumableOperator(task_id="test_task", durable=False) task_store = FakeTaskState() op.execute_resumable(make_context(task_store)) assert op.submitted_ids == ["job-001"] assert op.polled_ids == ["job-001"] - assert task_store._store == {}, "task_store must not be written when resume_on_retry=False" + assert task_store._store == {}, "task_store must not be written when durable=False" def test_does_not_reconnect_when_prior_id_exists(self): - op = ConcreteResumableOperator(task_id="test_task", resume_on_retry=False) + op = ConcreteResumableOperator(task_id="test_task", durable=False) op._status_map["job-001"] = "RUNNING" task_store = FakeTaskState({"test_job_id": "job-001"}) @@ -212,13 +212,13 @@ def test_does_not_reconnect_when_prior_id_exists(self): assert op.submitted_ids == ["job-001"], "should submit fresh even with a prior ID stored" def test_returns_result(self): - op = ConcreteResumableOperator(task_id="test_task", resume_on_retry=False) + op = ConcreteResumableOperator(task_id="test_task", durable=False) result = op.execute_resumable(make_context(FakeTaskState())) assert result == "result-of-job-001" def test_default_is_true(self): op = ConcreteResumableOperator(task_id="test_task") - assert op.resume_on_retry is True + assert op.durable is True class TestExternalIdKey: From 57c514f521b3c4983c72eee700241c8d2338a6dd Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Thu, 18 Jun 2026 14:48:50 +0530 Subject: [PATCH 5/6] fixing docs --- docs/spelling_wordlist.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 145c091e60054..7e6f2d9b66579 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1369,6 +1369,7 @@ rebase Rebasing Recency recurse +redeclare redelivery Redhat redis From 9a9ba6edb51929b870ca0628c84a5ca1bb53448c Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Thu, 18 Jun 2026 17:15:03 +0530 Subject: [PATCH 6/6] handling comments from ash (cherry picked from commit 546469d70ec3efc373bfa1d73c2f8d8d79b5cd03) --- providers/apache/spark/docs/operators.rst | 6 +++--- .../providers/apache/spark/operators/spark_submit.py | 3 +++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/providers/apache/spark/docs/operators.rst b/providers/apache/spark/docs/operators.rst index 3a3b9b8a3f051..d20c1da5cf7fe 100644 --- a/providers/apache/spark/docs/operators.rst +++ b/providers/apache/spark/docs/operators.rst @@ -236,7 +236,7 @@ Python Kubernetes client rather than holding ``spark-submit`` open for the full conn_id="spark_k8s", deploy_mode="cluster", track_driver_via_k8s_api=True, - resume_on_retry=True, + durable=True, ) **Requirements** @@ -246,9 +246,9 @@ Python Kubernetes client rather than holding ``spark-submit`` open for the full conflicts with the flag and a ``ValueError`` will be raised at task start. * The Airflow worker must be able to reach the Kubernetes API server and have permission to read and delete pods in the driver's namespace; otherwise pod tracking and cleanup will fail. -* Set ``resume_on_retry=True`` (the default) to enable crash recovery: the driver pod name is +* Set ``durable=True`` (the default) to enable crash recovery: the driver pod name is persisted to task state before polling begins, so a worker crash and retry reconnects to the - existing pod instead of submitting a fresh one. Set ``resume_on_retry=False`` to always + existing pod instead of submitting a fresh one. Set ``durable=False`` to always submit a fresh driver on retry. * Pod completion is detected from ``pod.status.phase``. If your driver pods have sidecar containers (e.g. Istio injection enabled for the driver namespace), the pod phase may not diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py index 2da9246d02073..8a2129e967300 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py @@ -146,6 +146,9 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator): omitted, Kerberos-enabled Spark connections with both ``keytab`` and ``principal`` configured use ``requests-kerberos`` automatically. Defaults to ``None`` (no auth for non-Kerberos connections). + :param durable: When ``True`` (the default), the external job ID is persisted to task state + store before polling begins so that a worker crash and retry reconnects to the existing job + instead of submitting a fresh one. Set to ``False`` to always submit a new job on retry. """ # Generic key used across all Spark deployment modes (standalone driver ID,