Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1369,6 +1369,7 @@ rebase
Rebasing
Recency
recurse
redeclare
redelivery
Redhat
redis
Expand Down
6 changes: 3 additions & 3 deletions providers/apache/spark/docs/operators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ Python Kubernetes client rather than holding ``spark-submit`` open for the full
conn_id="spark_k8s",
deploy_mode="cluster",
track_driver_via_k8s_api=True,
reconnect_on_retry=True,
durable=True,
)

**Requirements**
Expand All @@ -246,9 +246,9 @@ Python Kubernetes client rather than holding ``spark-submit`` open for the full
conflicts with the flag and a ``ValueError`` will be raised at task start.
* The Airflow worker must be able to reach the Kubernetes API server and have permission to
read and delete pods in the driver's namespace; otherwise pod tracking and cleanup will fail.
* Set ``reconnect_on_retry=True`` (the default) to enable crash recovery: the driver pod name is
* Set ``durable=True`` (the default) to enable crash recovery: the driver pod name is
persisted to task state before polling begins, so a worker crash and retry reconnects to the
existing pod instead of submitting a fresh one. Set ``reconnect_on_retry=False`` to always
existing pod instead of submitting a fresh one. Set ``durable=False`` to always
submit a fresh driver on retry.
* Pod completion is detected from ``pod.status.phase``. If your driver pods have sidecar
containers (e.g. Istio injection enabled for the driver namespace), the pod phase may not
Expand Down

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import is from airflow.sdk import ResumableJobMixin but the mixin itself lives in task-sdk/src/airflow/sdk/bases/resumablejobmixin.py -- is that import right (as in using the canonical location)?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from airflow.sdk import ResumableJobMixin is the right one to use, its exported and documented

Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
# under the License.
from __future__ import annotations

import warnings
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, cast

import requests
from tenacity import retry, stop_after_attempt, wait_fixed

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.apache.spark.hooks.spark_submit import _K8S_WAIT_APP_COMPLETION_CONF, SparkSubmitHook
from airflow.providers.common.compat.openlineage.utils.spark import (
inject_parent_job_information_into_spark_properties,
Expand All @@ -46,6 +48,11 @@ class ResumableJobMixin: # type: ignore[no-redef]

external_id_key: str = "remote_job_id"

def __init__(self, *, durable: bool = True, **kwargs: Any) -> None:
# Accept durable so the kwarg doesn't leak to BaseOperator; crash recovery is a no-op here.
super().__init__(**kwargs)
self.durable = durable

def execute_resumable(self, context):
external_id = self.submit_job(context)
self.poll_until_complete(external_id, context)
Expand Down Expand Up @@ -139,6 +146,9 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
omitted, Kerberos-enabled Spark connections with both ``keytab`` and
``principal`` configured use ``requests-kerberos`` automatically.
Defaults to ``None`` (no auth for non-Kerberos connections).
:param durable: When ``True`` (the default), the external job ID is persisted to task state
store before polling begins so that a worker crash and retry reconnects to the existing job
instead of submitting a fresh one. Set to ``False`` to always submit a new job on retry.
"""

# Generic key used across all Spark deployment modes (standalone driver ID,
Expand Down Expand Up @@ -203,7 +213,6 @@ def __init__(
deploy_mode: str | None = None,
use_krb5ccache: bool = False,
post_submit_commands: list[str] | None = None,
reconnect_on_retry: bool = True,
track_driver_via_k8s_api: bool = False,
yarn_track_via_rm_api: bool = False,
yarn_rm_auth: AuthBase | None = None,
Expand All @@ -213,8 +222,16 @@ def __init__(
openlineage_inject_transport_info: bool = conf.getboolean(
"openlineage", "spark_inject_transport_info", fallback=False
),
reconnect_on_retry: bool | None = None,
**kwargs: Any,
) -> None:
if reconnect_on_retry is not None:
warnings.warn(
"reconnect_on_retry is renamed to durable.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
kwargs.setdefault("durable", reconnect_on_retry)
super().__init__(**kwargs)
self.application = application
self.conf = conf
Expand Down Expand Up @@ -252,7 +269,6 @@ def __init__(
self._yarn_track_via_rm_api = yarn_track_via_rm_api
self._yarn_rm_auth = yarn_rm_auth

self.reconnect_on_retry = reconnect_on_retry
self._track_driver_via_k8s_api = track_driver_via_k8s_api
self._openlineage_inject_parent_job_info = openlineage_inject_parent_job_info
self._openlineage_inject_transport_info = openlineage_inject_transport_info
Expand All @@ -272,33 +288,18 @@ def execute(self, context: Context) -> None:
if self._track_driver_via_k8s_api:
hook._validate_track_driver_via_k8s_api_config()
if hook._should_track_driver_status:
if self.reconnect_on_retry:
return self.execute_resumable(context)
# reconnect_on_retry=False: still submit-and-poll, just skip task_state_store persistence.
driver_id = self.submit_job(context)
self.poll_until_complete(driver_id, context)
return self.get_job_result(driver_id, context)
return self.execute_resumable(context)
if hook._should_track_driver_via_k8s_api():
if self.reconnect_on_retry:
return self.execute_resumable(context)
# reconnect_on_retry=False: still submit-and-poll, just skip task_state persistence.
driver_id = self.submit_job(context)
self.poll_until_complete(driver_id, context)
return self.get_job_result(driver_id, context)
return self.execute_resumable(context)
if hook._is_yarn_cluster_mode:
if self.reconnect_on_retry and not hook._yarn_track_via_rm_api:
if self.durable and not hook._yarn_track_via_rm_api:
raise ValueError(
"YARN cluster mode with reconnect_on_retry=True requires yarn_track_via_rm_api=True. "
"YARN cluster mode with durable=True requires yarn_track_via_rm_api=True. "
"The RM REST API is needed to check application status on retry."
)
if hook._yarn_track_via_rm_api:
hook._validate_yarn_track_via_rm_api_config()
if self.reconnect_on_retry:
return self.execute_resumable(context)
# reconnect_on_retry=False: still submit-and-poll, just skip task_state_store persistence.
driver_id = self.submit_job(context)
self.poll_until_complete(driver_id, context)
return self.get_job_result(driver_id, context)
return self.execute_resumable(context)
hook.submit(self.application)

def submit_job(self, context: Context) -> str | None:
Expand All @@ -319,7 +320,7 @@ def submit_job(self, context: Context) -> str | None:
raise ValueError(
"spark.yarn.submit.waitAppCompletion=true cannot be set for cluster mode as it conflicts"
"with the need to exit spark-submit immediately to persist the application ID for tracking. "
"Either remove the explicit conf or set reconnect_on_retry=False."
"Either remove the explicit conf or set durable=False."
)
self._hook._conf["spark.yarn.submit.waitAppCompletion"] = "false"
self._hook.submit(self.application)
Expand Down Expand Up @@ -445,7 +446,7 @@ def poll_until_complete(self, external_id: JsonValue, context: Context) -> None:
# Cache only when the pod actually reached Succeeded, the 404/vanished path
# returns None for cases like: pod deleted by on_kill or garbage collected after failure)
# and must not be cached, otherwise a retry would see "Succeeded" and skip resubmission.
if terminal_phase == "Succeeded" and self.reconnect_on_retry:
if terminal_phase == "Succeeded" and self.durable:
if (task_state_store := context.get("task_state_store")) is not None:
task_state_store.set(self._K8S_DRIVER_STATUS_KEY, "Succeeded")
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
from __future__ import annotations

import logging
import warnings
from datetime import timedelta
from unittest import mock
from unittest.mock import MagicMock

import pytest

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.models import DagRun, TaskInstance
from airflow.models.dag import DAG
from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator
Expand Down Expand Up @@ -590,16 +592,25 @@ def test_submits_fresh_when_task_state_store_unavailable(self):
operator._hook.submit.assert_called_once_with("test.jar")
assert polled == ["driver-001"]

def test_reconnect_on_retry_false_submits_fresh_and_polls(self):
operator = self._make_operator(reconnect_on_retry=False)
def test_reconnect_on_retry_deprecated_alias(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
operator = self._make_operator(reconnect_on_retry=False)
assert len(w) == 1
assert issubclass(w[0].category, AirflowProviderDeprecationWarning)
assert "reconnect_on_retry" in str(w[0].message)
assert operator.durable is False

def test_durable_false_submits_fresh_and_polls(self):
operator = self._make_operator(durable=False)
operator._hook = self._make_hook(should_track=True)
operator._hook.submit.return_value = "driver-new"
task_store = FakeTaskStateStore({"spark_job_id": "driver-old"})
polled = []
operator.poll_until_complete = lambda external_id, context: polled.append(external_id)

operator.execute(context={"task_state_store": task_store})
# reconnect_on_retry=False: ignores prior driver ID, submits fresh, but still polls
# durable=False: ignores prior driver ID, submits fresh, but still polls
operator._hook.submit.assert_called_once_with("test.jar")
assert polled == ["driver-new"]

Expand Down Expand Up @@ -863,8 +874,8 @@ def test_on_kill_sends_authenticated_kill_to_yarn_rm(self):
hook._kill_yarn_application.assert_called_once_with("application_1234_0001")

def test_yarn_cluster_reconnect_without_rm_api_raises(self):
"""reconnect_on_retry=True + yarn_track_via_rm_api=False must raise - RM API is required for resume."""
operator = self._make_operator(reconnect_on_retry=True)
"""durable=True + yarn_track_via_rm_api=False must raise - RM API is required for resume."""
operator = self._make_operator(durable=True)
hook = self._make_hook(is_yarn_cluster=True)
hook._yarn_track_via_rm_api = False
operator._hook = hook
Expand All @@ -873,8 +884,8 @@ def test_yarn_cluster_reconnect_without_rm_api_raises(self):
operator.execute(context={})

def test_yarn_cluster_without_rm_api_reconnect_false_falls_through_to_hook_submit(self):
"""reconnect_on_retry=False + yarn_track_via_rm_api=False falls through to hook.submit() - no RM polling."""
operator = self._make_operator(reconnect_on_retry=False)
"""durable=False + yarn_track_via_rm_api=False falls through to hook.submit() - no RM polling."""
operator = self._make_operator(durable=False)
hook = self._make_hook(is_yarn_cluster=True)
hook._yarn_track_via_rm_api = False
operator._hook = hook
Expand Down Expand Up @@ -1028,6 +1039,7 @@ def test_k8s_poll_until_complete_sets_pod_name_and_calls_poll_api(self):
assert hook._kubernetes_driver_pod == "spark-abc-driver"
hook._poll_k8s_driver_via_api.assert_called_once()

@pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="task_state_store requires Airflow 3.3+")
def test_k8s_poll_until_complete_writes_succeeded_to_task_store(self):
operator = self._make_operator(track_driver_via_k8s_api=True)
hook = self._make_k8s_hook()
Expand All @@ -1039,8 +1051,9 @@ def test_k8s_poll_until_complete_writes_succeeded_to_task_store(self):

assert task_store.get("k8s_driver_status") == "Succeeded"

@pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="task_state_store requires Airflow 3.3+")
def test_k8s_polling_does_not_write_task_store_when_reconnect_disabled(self):
operator = self._make_operator(track_driver_via_k8s_api=True, reconnect_on_retry=False)
operator = self._make_operator(track_driver_via_k8s_api=True, durable=False)
hook = self._make_k8s_hook()
hook._poll_k8s_driver_via_api.return_value = "Succeeded"
operator._hook = hook
Expand Down Expand Up @@ -1072,9 +1085,9 @@ def test_k8s_poll_until_complete_tolerates_absent_task_store(self):
not AIRFLOW_V_3_3_PLUS,
reason="ResumableJobMixin reconnect requires task_state, available in Airflow 3.3+",
)
def test_k8s_execute_persists_pod_id_when_reconnect_on_retry(self):
"""execute() with reconnect_on_retry=True stores the pod ID in task_store before polling."""
operator = self._make_operator(track_driver_via_k8s_api=True, reconnect_on_retry=True)
def test_k8s_execute_persists_pod_id_when_durable(self):
"""execute() with durable=True stores the pod ID in task_store before polling."""
operator = self._make_operator(track_driver_via_k8s_api=True, durable=True)
hook = self._make_k8s_hook()
hook._kubernetes_driver_pod = "spark-abc-driver"
hook._connection = {"namespace": "mynamespace"}
Expand All @@ -1095,9 +1108,9 @@ def track_poll(external_id, context):
not AIRFLOW_V_3_3_PLUS,
reason="ResumableJobMixin reconnect requires task_state, available in Airflow 3.3+",
)
def test_k8s_execute_reconnect_on_retry_false_does_not_persist_pod_id(self):
"""execute() with reconnect_on_retry=False does not write spark_job_id to task_store."""
operator = self._make_operator(track_driver_via_k8s_api=True, reconnect_on_retry=False)
def test_k8s_execute_durable_false_does_not_persist_pod_id(self):
"""execute() with durable=False does not write spark_job_id to task_store."""
operator = self._make_operator(track_driver_via_k8s_api=True, durable=False)
hook = self._make_k8s_hook()
hook._kubernetes_driver_pod = "spark-abc-driver"
hook._connection = {"namespace": "mynamespace"}
Expand Down
23 changes: 22 additions & 1 deletion task-sdk/docs/resumable-job-mixin.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ Example
from pydantic import JsonValue


class MyBatchOperator(BaseOperator, ResumableJobMixin):
class MyBatchOperator(ResumableJobMixin, BaseOperator):

external_id_key = "batch_job_id"

Expand All @@ -145,6 +145,27 @@ Example
def get_job_result(self, external_id: JsonValue, context):
return None

.. _sdk-resumable-job-mixin-resume-on-retry:

Disabling crash recovery per task
----------------------------------

Set ``durable=False`` on a task to opt out of crash recovery for that specific instance.
The operator will always submit a fresh job on retry, with no ``task_state_store`` interaction:

.. code-block:: python

run_spark = MyBatchOperator(
task_id="run_spark",
durable=False,
)

This is useful when the external job is not idempotent and you want Airflow to always submit a
clean run rather than reconnect to a prior submission.

The default is ``True``. ``durable`` is owned by the mixin — operators do not need to
redeclare it. ``default_args`` injection and ``.partial()`` work automatically.

.. _sdk-resumable-job-mixin-external-id-key:

External ID key
Expand Down
15 changes: 15 additions & 0 deletions task-sdk/src/airflow/sdk/bases/resumablejobmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from opentelemetry import trace

from airflow.sdk._shared.observability.metrics import stats
from airflow.sdk.bases.operator import BaseOperatorMeta

if TYPE_CHECKING:
from pydantic import JsonValue
Expand Down Expand Up @@ -90,6 +91,14 @@ def get_job_result(self, external_id: JsonValue, context: Context) -> Any:
# Renaming this on a deployed operator breaks in-flight retries — the old key is already stored.
external_id_key: str = "remote_job_id"

# The mixin is not a BaseOperator subclass, but _apply_defaults is only ever called on concrete
# operators that are BaseOperator subclasses. That is a runtime MRO guarantee not visible in the static
# type signature here and hence we need the type ignore.
@BaseOperatorMeta._apply_defaults # type: ignore[type-var]
def __init__(self, *, durable: bool = True, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.durable = durable

def execute_resumable(self, context: Context) -> Any:
"""
Core of the resumable execution logic. Call this from execute() when reconnection is supported.
Expand All @@ -107,13 +116,19 @@ def execute_resumable(self, context: Context) -> Any:
Closing this window would require atomic "submit + persist", which is not possible across
an external system boundary.
"""
if not self.durable:
external_id = self.submit_job(context)
self.poll_until_complete(external_id, context)
return self.get_job_result(external_id, context)

stats_tags = {"operator": type(self).__name__}
# The task is team-scoped in multi-team deployments; surface team_name on the
# resumable_job metrics via the running task instance's stats tags (omitted when
# not multi-team or the task has no team).
ti = context.get("ti")
if ti is not None and (team_name := ti.stats_tags.get("team_name")):
stats_tags["team_name"] = team_name

reconnect_to: Any = None
already_succeeded_id: Any = None

Expand Down
29 changes: 29 additions & 0 deletions task-sdk/tests/task_sdk/bases/test_resumablejobmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,35 @@ def get_job_result(self, external_id, context) -> str:
assert task_state._store == {}


class TestResumeOnRetryDisabled:
def test_submits_and_polls_without_task_store_interaction(self):
op = ConcreteResumableOperator(task_id="test_task", durable=False)
task_store = FakeTaskState()
op.execute_resumable(make_context(task_store))

assert op.submitted_ids == ["job-001"]
assert op.polled_ids == ["job-001"]
assert task_store._store == {}, "task_store must not be written when durable=False"

def test_does_not_reconnect_when_prior_id_exists(self):
op = ConcreteResumableOperator(task_id="test_task", durable=False)
op._status_map["job-001"] = "RUNNING"
task_store = FakeTaskState({"test_job_id": "job-001"})

op.execute_resumable(make_context(task_store))

assert op.submitted_ids == ["job-001"], "should submit fresh even with a prior ID stored"

def test_returns_result(self):
op = ConcreteResumableOperator(task_id="test_task", durable=False)
result = op.execute_resumable(make_context(FakeTaskState()))
assert result == "result-of-job-001"

def test_default_is_true(self):
op = ConcreteResumableOperator(task_id="test_task")
assert op.durable is True


class TestExternalIdKey:
def test_custom_key_used_for_storage_and_retrieval(self):
class CustomKeyOp(ConcreteResumableOperator):
Expand Down
Loading