diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 133c37b413b3d..6a9616c7c0dcd 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1370,6 +1370,7 @@ rebase Rebasing Recency recurse +redeclare redelivery Redhat redis diff --git a/providers/apache/spark/docs/operators.rst b/providers/apache/spark/docs/operators.rst index 4d3a9a526af5b..d20c1da5cf7fe 100644 --- a/providers/apache/spark/docs/operators.rst +++ b/providers/apache/spark/docs/operators.rst @@ -236,7 +236,7 @@ Python Kubernetes client rather than holding ``spark-submit`` open for the full conn_id="spark_k8s", deploy_mode="cluster", track_driver_via_k8s_api=True, - reconnect_on_retry=True, + durable=True, ) **Requirements** @@ -246,9 +246,9 @@ Python Kubernetes client rather than holding ``spark-submit`` open for the full conflicts with the flag and a ``ValueError`` will be raised at task start. * The Airflow worker must be able to reach the Kubernetes API server and have permission to read and delete pods in the driver's namespace; otherwise pod tracking and cleanup will fail. -* Set ``reconnect_on_retry=True`` (the default) to enable crash recovery: the driver pod name is +* Set ``durable=True`` (the default) to enable crash recovery: the driver pod name is persisted to task state before polling begins, so a worker crash and retry reconnects to the - existing pod instead of submitting a fresh one. Set ``reconnect_on_retry=False`` to always + existing pod instead of submitting a fresh one. Set ``durable=False`` to always submit a fresh driver on retry. * Pod completion is detected from ``pod.status.phase``. If your driver pods have sidecar containers (e.g. Istio injection enabled for the driver namespace), the pod phase may not diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py index 7ceb95b387a5d..8a2129e967300 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py @@ -17,12 +17,14 @@ # under the License. from __future__ import annotations +import warnings from collections.abc import Sequence from typing import TYPE_CHECKING, Any, cast import requests from tenacity import retry, stop_after_attempt, wait_fixed +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.apache.spark.hooks.spark_submit import _K8S_WAIT_APP_COMPLETION_CONF, SparkSubmitHook from airflow.providers.common.compat.openlineage.utils.spark import ( inject_parent_job_information_into_spark_properties, @@ -46,6 +48,11 @@ class ResumableJobMixin: # type: ignore[no-redef] external_id_key: str = "remote_job_id" + def __init__(self, *, durable: bool = True, **kwargs: Any) -> None: + # Accept durable so the kwarg doesn't leak to BaseOperator; crash recovery is a no-op here. + super().__init__(**kwargs) + self.durable = durable + def execute_resumable(self, context): external_id = self.submit_job(context) self.poll_until_complete(external_id, context) @@ -139,6 +146,9 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator): omitted, Kerberos-enabled Spark connections with both ``keytab`` and ``principal`` configured use ``requests-kerberos`` automatically. Defaults to ``None`` (no auth for non-Kerberos connections). + :param durable: When ``True`` (the default), the external job ID is persisted to task state + store before polling begins so that a worker crash and retry reconnects to the existing job + instead of submitting a fresh one. Set to ``False`` to always submit a new job on retry. """ # Generic key used across all Spark deployment modes (standalone driver ID, @@ -203,7 +213,6 @@ def __init__( deploy_mode: str | None = None, use_krb5ccache: bool = False, post_submit_commands: list[str] | None = None, - reconnect_on_retry: bool = True, track_driver_via_k8s_api: bool = False, yarn_track_via_rm_api: bool = False, yarn_rm_auth: AuthBase | None = None, @@ -213,8 +222,16 @@ def __init__( openlineage_inject_transport_info: bool = conf.getboolean( "openlineage", "spark_inject_transport_info", fallback=False ), + reconnect_on_retry: bool | None = None, **kwargs: Any, ) -> None: + if reconnect_on_retry is not None: + warnings.warn( + "reconnect_on_retry is renamed to durable.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + kwargs.setdefault("durable", reconnect_on_retry) super().__init__(**kwargs) self.application = application self.conf = conf @@ -252,7 +269,6 @@ def __init__( self._yarn_track_via_rm_api = yarn_track_via_rm_api self._yarn_rm_auth = yarn_rm_auth - self.reconnect_on_retry = reconnect_on_retry self._track_driver_via_k8s_api = track_driver_via_k8s_api self._openlineage_inject_parent_job_info = openlineage_inject_parent_job_info self._openlineage_inject_transport_info = openlineage_inject_transport_info @@ -272,33 +288,18 @@ def execute(self, context: Context) -> None: if self._track_driver_via_k8s_api: hook._validate_track_driver_via_k8s_api_config() if hook._should_track_driver_status: - if self.reconnect_on_retry: - return self.execute_resumable(context) - # reconnect_on_retry=False: still submit-and-poll, just skip task_state_store persistence. - driver_id = self.submit_job(context) - self.poll_until_complete(driver_id, context) - return self.get_job_result(driver_id, context) + return self.execute_resumable(context) if hook._should_track_driver_via_k8s_api(): - if self.reconnect_on_retry: - return self.execute_resumable(context) - # reconnect_on_retry=False: still submit-and-poll, just skip task_state persistence. - driver_id = self.submit_job(context) - self.poll_until_complete(driver_id, context) - return self.get_job_result(driver_id, context) + return self.execute_resumable(context) if hook._is_yarn_cluster_mode: - if self.reconnect_on_retry and not hook._yarn_track_via_rm_api: + if self.durable and not hook._yarn_track_via_rm_api: raise ValueError( - "YARN cluster mode with reconnect_on_retry=True requires yarn_track_via_rm_api=True. " + "YARN cluster mode with durable=True requires yarn_track_via_rm_api=True. " "The RM REST API is needed to check application status on retry." ) if hook._yarn_track_via_rm_api: hook._validate_yarn_track_via_rm_api_config() - if self.reconnect_on_retry: - return self.execute_resumable(context) - # reconnect_on_retry=False: still submit-and-poll, just skip task_state_store persistence. - driver_id = self.submit_job(context) - self.poll_until_complete(driver_id, context) - return self.get_job_result(driver_id, context) + return self.execute_resumable(context) hook.submit(self.application) def submit_job(self, context: Context) -> str | None: @@ -319,7 +320,7 @@ def submit_job(self, context: Context) -> str | None: raise ValueError( "spark.yarn.submit.waitAppCompletion=true cannot be set for cluster mode as it conflicts" "with the need to exit spark-submit immediately to persist the application ID for tracking. " - "Either remove the explicit conf or set reconnect_on_retry=False." + "Either remove the explicit conf or set durable=False." ) self._hook._conf["spark.yarn.submit.waitAppCompletion"] = "false" self._hook.submit(self.application) @@ -445,7 +446,7 @@ def poll_until_complete(self, external_id: JsonValue, context: Context) -> None: # Cache only when the pod actually reached Succeeded, the 404/vanished path # returns None for cases like: pod deleted by on_kill or garbage collected after failure) # and must not be cached, otherwise a retry would see "Succeeded" and skip resubmission. - if terminal_phase == "Succeeded" and self.reconnect_on_retry: + if terminal_phase == "Succeeded" and self.durable: if (task_state_store := context.get("task_state_store")) is not None: task_state_store.set(self._K8S_DRIVER_STATUS_KEY, "Succeeded") return diff --git a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py index 9708aab1a02ac..daa0fa119ca3a 100644 --- a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py +++ b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py @@ -18,12 +18,14 @@ from __future__ import annotations import logging +import warnings from datetime import timedelta from unittest import mock from unittest.mock import MagicMock import pytest +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.models import DagRun, TaskInstance from airflow.models.dag import DAG from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator @@ -590,8 +592,17 @@ def test_submits_fresh_when_task_state_store_unavailable(self): operator._hook.submit.assert_called_once_with("test.jar") assert polled == ["driver-001"] - def test_reconnect_on_retry_false_submits_fresh_and_polls(self): - operator = self._make_operator(reconnect_on_retry=False) + def test_reconnect_on_retry_deprecated_alias(self): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + operator = self._make_operator(reconnect_on_retry=False) + assert len(w) == 1 + assert issubclass(w[0].category, AirflowProviderDeprecationWarning) + assert "reconnect_on_retry" in str(w[0].message) + assert operator.durable is False + + def test_durable_false_submits_fresh_and_polls(self): + operator = self._make_operator(durable=False) operator._hook = self._make_hook(should_track=True) operator._hook.submit.return_value = "driver-new" task_store = FakeTaskStateStore({"spark_job_id": "driver-old"}) @@ -599,7 +610,7 @@ def test_reconnect_on_retry_false_submits_fresh_and_polls(self): operator.poll_until_complete = lambda external_id, context: polled.append(external_id) operator.execute(context={"task_state_store": task_store}) - # reconnect_on_retry=False: ignores prior driver ID, submits fresh, but still polls + # durable=False: ignores prior driver ID, submits fresh, but still polls operator._hook.submit.assert_called_once_with("test.jar") assert polled == ["driver-new"] @@ -863,8 +874,8 @@ def test_on_kill_sends_authenticated_kill_to_yarn_rm(self): hook._kill_yarn_application.assert_called_once_with("application_1234_0001") def test_yarn_cluster_reconnect_without_rm_api_raises(self): - """reconnect_on_retry=True + yarn_track_via_rm_api=False must raise - RM API is required for resume.""" - operator = self._make_operator(reconnect_on_retry=True) + """durable=True + yarn_track_via_rm_api=False must raise - RM API is required for resume.""" + operator = self._make_operator(durable=True) hook = self._make_hook(is_yarn_cluster=True) hook._yarn_track_via_rm_api = False operator._hook = hook @@ -873,8 +884,8 @@ def test_yarn_cluster_reconnect_without_rm_api_raises(self): operator.execute(context={}) def test_yarn_cluster_without_rm_api_reconnect_false_falls_through_to_hook_submit(self): - """reconnect_on_retry=False + yarn_track_via_rm_api=False falls through to hook.submit() - no RM polling.""" - operator = self._make_operator(reconnect_on_retry=False) + """durable=False + yarn_track_via_rm_api=False falls through to hook.submit() - no RM polling.""" + operator = self._make_operator(durable=False) hook = self._make_hook(is_yarn_cluster=True) hook._yarn_track_via_rm_api = False operator._hook = hook @@ -1028,6 +1039,7 @@ def test_k8s_poll_until_complete_sets_pod_name_and_calls_poll_api(self): assert hook._kubernetes_driver_pod == "spark-abc-driver" hook._poll_k8s_driver_via_api.assert_called_once() + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="task_state_store requires Airflow 3.3+") def test_k8s_poll_until_complete_writes_succeeded_to_task_store(self): operator = self._make_operator(track_driver_via_k8s_api=True) hook = self._make_k8s_hook() @@ -1039,8 +1051,9 @@ def test_k8s_poll_until_complete_writes_succeeded_to_task_store(self): assert task_store.get("k8s_driver_status") == "Succeeded" + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="task_state_store requires Airflow 3.3+") def test_k8s_polling_does_not_write_task_store_when_reconnect_disabled(self): - operator = self._make_operator(track_driver_via_k8s_api=True, reconnect_on_retry=False) + operator = self._make_operator(track_driver_via_k8s_api=True, durable=False) hook = self._make_k8s_hook() hook._poll_k8s_driver_via_api.return_value = "Succeeded" operator._hook = hook @@ -1072,9 +1085,9 @@ def test_k8s_poll_until_complete_tolerates_absent_task_store(self): not AIRFLOW_V_3_3_PLUS, reason="ResumableJobMixin reconnect requires task_state, available in Airflow 3.3+", ) - def test_k8s_execute_persists_pod_id_when_reconnect_on_retry(self): - """execute() with reconnect_on_retry=True stores the pod ID in task_store before polling.""" - operator = self._make_operator(track_driver_via_k8s_api=True, reconnect_on_retry=True) + def test_k8s_execute_persists_pod_id_when_durable(self): + """execute() with durable=True stores the pod ID in task_store before polling.""" + operator = self._make_operator(track_driver_via_k8s_api=True, durable=True) hook = self._make_k8s_hook() hook._kubernetes_driver_pod = "spark-abc-driver" hook._connection = {"namespace": "mynamespace"} @@ -1095,9 +1108,9 @@ def track_poll(external_id, context): not AIRFLOW_V_3_3_PLUS, reason="ResumableJobMixin reconnect requires task_state, available in Airflow 3.3+", ) - def test_k8s_execute_reconnect_on_retry_false_does_not_persist_pod_id(self): - """execute() with reconnect_on_retry=False does not write spark_job_id to task_store.""" - operator = self._make_operator(track_driver_via_k8s_api=True, reconnect_on_retry=False) + def test_k8s_execute_durable_false_does_not_persist_pod_id(self): + """execute() with durable=False does not write spark_job_id to task_store.""" + operator = self._make_operator(track_driver_via_k8s_api=True, durable=False) hook = self._make_k8s_hook() hook._kubernetes_driver_pod = "spark-abc-driver" hook._connection = {"namespace": "mynamespace"} diff --git a/task-sdk/docs/resumable-job-mixin.rst b/task-sdk/docs/resumable-job-mixin.rst index e78a2e5ce1aea..345fc4c4494fb 100644 --- a/task-sdk/docs/resumable-job-mixin.rst +++ b/task-sdk/docs/resumable-job-mixin.rst @@ -120,7 +120,7 @@ Example from pydantic import JsonValue - class MyBatchOperator(BaseOperator, ResumableJobMixin): + class MyBatchOperator(ResumableJobMixin, BaseOperator): external_id_key = "batch_job_id" @@ -145,6 +145,27 @@ Example def get_job_result(self, external_id: JsonValue, context): return None +.. _sdk-resumable-job-mixin-resume-on-retry: + +Disabling crash recovery per task +---------------------------------- + +Set ``durable=False`` on a task to opt out of crash recovery for that specific instance. +The operator will always submit a fresh job on retry, with no ``task_state_store`` interaction: + +.. code-block:: python + + run_spark = MyBatchOperator( + task_id="run_spark", + durable=False, + ) + +This is useful when the external job is not idempotent and you want Airflow to always submit a +clean run rather than reconnect to a prior submission. + +The default is ``True``. ``durable`` is owned by the mixin — operators do not need to +redeclare it. ``default_args`` injection and ``.partial()`` work automatically. + .. _sdk-resumable-job-mixin-external-id-key: External ID key diff --git a/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py b/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py index 7066d10cced4e..27533dbe8409c 100644 --- a/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py +++ b/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py @@ -21,6 +21,7 @@ from opentelemetry import trace from airflow.sdk._shared.observability.metrics import stats +from airflow.sdk.bases.operator import BaseOperatorMeta if TYPE_CHECKING: from pydantic import JsonValue @@ -90,6 +91,14 @@ def get_job_result(self, external_id: JsonValue, context: Context) -> Any: # Renaming this on a deployed operator breaks in-flight retries — the old key is already stored. external_id_key: str = "remote_job_id" + # The mixin is not a BaseOperator subclass, but _apply_defaults is only ever called on concrete + # operators that are BaseOperator subclasses. That is a runtime MRO guarantee not visible in the static + # type signature here and hence we need the type ignore. + @BaseOperatorMeta._apply_defaults # type: ignore[type-var] + def __init__(self, *, durable: bool = True, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.durable = durable + def execute_resumable(self, context: Context) -> Any: """ Core of the resumable execution logic. Call this from execute() when reconnection is supported. @@ -107,6 +116,11 @@ def execute_resumable(self, context: Context) -> Any: Closing this window would require atomic "submit + persist", which is not possible across an external system boundary. """ + if not self.durable: + external_id = self.submit_job(context) + self.poll_until_complete(external_id, context) + return self.get_job_result(external_id, context) + stats_tags = {"operator": type(self).__name__} # The task is team-scoped in multi-team deployments; surface team_name on the # resumable_job metrics via the running task instance's stats tags (omitted when @@ -114,6 +128,7 @@ def execute_resumable(self, context: Context) -> Any: ti = context.get("ti") if ti is not None and (team_name := ti.stats_tags.get("team_name")): stats_tags["team_name"] = team_name + reconnect_to: Any = None already_succeeded_id: Any = None diff --git a/task-sdk/tests/task_sdk/bases/test_resumablejobmixin.py b/task-sdk/tests/task_sdk/bases/test_resumablejobmixin.py index e186cc19a367c..8e962583e992d 100644 --- a/task-sdk/tests/task_sdk/bases/test_resumablejobmixin.py +++ b/task-sdk/tests/task_sdk/bases/test_resumablejobmixin.py @@ -206,6 +206,35 @@ def get_job_result(self, external_id, context) -> str: assert task_state._store == {} +class TestResumeOnRetryDisabled: + def test_submits_and_polls_without_task_store_interaction(self): + op = ConcreteResumableOperator(task_id="test_task", durable=False) + task_store = FakeTaskState() + op.execute_resumable(make_context(task_store)) + + assert op.submitted_ids == ["job-001"] + assert op.polled_ids == ["job-001"] + assert task_store._store == {}, "task_store must not be written when durable=False" + + def test_does_not_reconnect_when_prior_id_exists(self): + op = ConcreteResumableOperator(task_id="test_task", durable=False) + op._status_map["job-001"] = "RUNNING" + task_store = FakeTaskState({"test_job_id": "job-001"}) + + op.execute_resumable(make_context(task_store)) + + assert op.submitted_ids == ["job-001"], "should submit fresh even with a prior ID stored" + + def test_returns_result(self): + op = ConcreteResumableOperator(task_id="test_task", durable=False) + result = op.execute_resumable(make_context(FakeTaskState())) + assert result == "result-of-job-001" + + def test_default_is_true(self): + op = ConcreteResumableOperator(task_id="test_task") + assert op.durable is True + + class TestExternalIdKey: def test_custom_key_used_for_storage_and_retrieval(self): class CustomKeyOp(ConcreteResumableOperator):