diff --git a/airflow-core/newsfragments/64294.feature.rst b/airflow-core/newsfragments/64294.feature.rst new file mode 100644 index 0000000000000..7a6f857dacb82 --- /dev/null +++ b/airflow-core/newsfragments/64294.feature.rst @@ -0,0 +1 @@ +Add a new ``max_new_dagruns_per_loop_to_schedule`` configuration to control how many new dagruns are scheduled each scheduler iteration diff --git a/airflow-core/newsfragments/64294.significant.rst b/airflow-core/newsfragments/64294.significant.rst new file mode 100644 index 0000000000000..8a304373eb82d --- /dev/null +++ b/airflow-core/newsfragments/64294.significant.rst @@ -0,0 +1 @@ +``get_running_dag_runs_to_examine`` now returns a ``Sequence[DagRun]`` type instead of ``ScalarResult[Dagrun]`` diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index 03fd2794b5adc..d67c346438357 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -2688,6 +2688,19 @@ scheduler: type: integer default: "20" see_also: ":ref:`scheduler:ha:tunables`" + max_new_dagruns_per_loop_to_schedule: + description: | + When set > 0, the scheduler runs a second query per loop that fetches up to this + many dagruns that have never been examined (``last_scheduling_decision IS NULL``), + in addition to the ``max_dagruns_per_loop_to_schedule`` already-examined ones. + + This prevents starvation of older dagruns when large batches of new dagruns are + created at once (for example, via ``TriggerDagRunOperator``). Note that the total + number of dagruns locked and scheduled per loop becomes the sum of both limits. + example: ~ + version_added: 3.3.0 + type: integer + default: "0" use_job_schedule: description: | Turn off scheduler use of cron intervals by setting this to ``False``. diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 8f5784267f54f..13f245cbeba23 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -40,6 +40,7 @@ Index, Integer, PrimaryKeyConstraint, + SQLColumnExpression, String, Text, UniqueConstraint, @@ -57,7 +58,15 @@ from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.mutable import MutableDict -from sqlalchemy.orm import Mapped, declared_attr, joinedload, mapped_column, relationship, synonym, validates +from sqlalchemy.orm import ( + Mapped, + declared_attr, + joinedload, + mapped_column, + relationship, + synonym, + validates, +) from sqlalchemy.orm.exc import StaleDataError from sqlalchemy.sql.expression import false, select from sqlalchemy.sql.functions import coalesce @@ -319,6 +328,11 @@ class DagRun(Base, LoggingMixin): "max_dagruns_per_loop_to_schedule", fallback=20, ) + DEFAULT_NEW_DAGRUNS_TO_EXAMINE = airflow_conf.getint( + "scheduler", + "max_new_dagruns_per_loop_to_schedule", + fallback=0, + ) _ti_dag_versions = association_proxy("task_instances", "dag_version") _tih_dag_versions = association_proxy("task_instances_histories", "dag_version") @@ -623,40 +637,77 @@ def active_runs_of_dags( @classmethod @retry_db_transaction - def get_running_dag_runs_to_examine(cls, session: Session) -> ScalarResult[DagRun]: + def get_running_dag_runs_to_examine(cls, session: Session) -> Sequence[DagRun]: """ - Return the next DagRuns that the scheduler should attempt to schedule. + Return the next DagRuns (as a list) that the scheduler should attempt to schedule. - This will return zero or more DagRun rows that are row-level-locked with a "SELECT ... FOR UPDATE" + This will return zero or more DagRuns that are row-level-locked with a "SELECT ... FOR UPDATE" query, you should ensure that any scheduling decisions are made in a single transaction -- as soon as the transaction is committed it will be unlocked. + With max_new_dagruns_per_loop_to_schedule > 0, this runs 2 queries, one for new dagruns (where + last_scheduling_decision is None) and for old (already examined) dagruns, cleared / requeued dagruns + will appear in the new dagruns query. :meta private: """ from airflow.models.backfill import BackfillDagRun from airflow.models.dag import DagModel - query = ( - select(cls) - .with_hint(cls, "USE INDEX (idx_dag_run_running_dags)", dialect_name="mysql") - .where(cls.state == DagRunState.RUNNING) - .join(DagModel, DagModel.dag_id == cls.dag_id) - .join(BackfillDagRun, BackfillDagRun.dag_run_id == DagRun.id, isouter=True) - .where( - DagModel.is_paused == false(), - DagModel.is_stale == false(), - ) - .order_by( - nulls_first(cast("ColumnElement[Any]", BackfillDagRun.sort_ordinal), session=session), - nulls_first(cast("ColumnElement[Any]", cls.last_scheduling_decision), session=session), - cls.run_after, + def _get_dagrun_query( + filters: list[ColumnElement[bool]], order_by: list[SQLColumnExpression[Any]], limit: int + ): + return ( + select(DagRun) + .with_hint(DagRun, "USE INDEX (idx_dag_run_running_dags)", dialect_name="mysql") + .where(DagRun.state == DagRunState.RUNNING) + .join(DagModel, DagModel.dag_id == cls.dag_id) + .join(BackfillDagRun, BackfillDagRun.dag_run_id == DagRun.id, isouter=True) + .where(*filters) + .order_by(*order_by) + .limit(limit) ) - .limit(cls.DEFAULT_DAGRUNS_TO_EXAMINE) + + filters = [ + DagRun.run_after <= func.now(), + DagModel.is_paused == false(), + DagModel.is_stale == false(), + ] + + order = [ + nulls_first(cast("ColumnElement[Any]", BackfillDagRun.sort_ordinal), session=session), + nulls_first(cast("ColumnElement[Any]", DagRun.last_scheduling_decision), session=session), + DagRun.run_after, + ] + + new_dagruns_to_examine = max(cls.DEFAULT_NEW_DAGRUNS_TO_EXAMINE, 0) + dagruns_to_examine = cls.DEFAULT_DAGRUNS_TO_EXAMINE + + old_filters = ( + [*filters, DagRun.last_scheduling_decision.is_not(None)] + if new_dagruns_to_examine > 0 + else filters ) + query = _get_dagrun_query(filters=old_filters, order_by=order, limit=dagruns_to_examine) - query = query.where(DagRun.run_after <= func.now()) + result: list[DagRun] = cast( + "list[DagRun]", + session.scalars(with_row_locks(query, of=cls, session=session, skip_locked=True)).unique().all(), + ) + + if new_dagruns_to_examine > 0: + new_dagruns_query = _get_dagrun_query( + filters=[*filters, DagRun.last_scheduling_decision.is_(None)], + order_by=order, + limit=new_dagruns_to_examine, + ) + new_dagruns: Sequence[DagRun] = ( + session.scalars(with_row_locks(new_dagruns_query, of=cls, session=session, skip_locked=True)) + .unique() + .all() + ) + + result.extend(new_dagruns) - result = session.scalars(with_row_locks(query, of=cls, session=session, skip_locked=True)).unique() return result @classmethod diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index 1154a19617e3e..ebe866d5a53f0 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -61,9 +61,7 @@ from airflow.providers.standard.operators.python import PythonOperator, ShortCircuitOperator from airflow.sdk import DAG, BaseOperator, get_current_context, setup, task, task_group, teardown from airflow.sdk.definitions.callback import AsyncCallback -from airflow.sdk.definitions.deadline import DeadlineAlert, DeadlineReference, VariableInterval -from airflow.sdk.definitions.variable import Variable -from airflow.sdk.exceptions import AirflowRuntimeError +from airflow.sdk.definitions.deadline import DeadlineAlert, DeadlineReference from airflow.serialization.definitions.deadline import SerializedReferenceModels from airflow.serialization.serialized_objects import LazyDeserializedDAG from airflow.settings import get_policy_plugin_manager @@ -103,6 +101,35 @@ def dagbag(): return DagBag(include_examples=True) +@pytest.fixture +def create_dagruns(): + def _create_dagruns( + dag_maker, + session, + last_scheduling_decision: datetime.datetime | None = None, + count: int = 20, + ): + dagrun = dag_maker.create_dagrun( + run_type=DagRunType.SCHEDULED, + state=State.RUNNING, + run_after=datetime.datetime(2024, 1, 1), + ) + dagrun.last_scheduling_decision = last_scheduling_decision + session.merge(dagrun) + for _ in range(count - 1): + dagrun = dag_maker.create_dagrun_after( + dagrun, + run_type=DagRunType.SCHEDULED, + state=State.RUNNING, + run_after=datetime.datetime(2024, 1, 1), + ) + + dagrun.last_scheduling_decision = last_scheduling_decision + session.merge(dagrun) + + return _create_dagruns + + @pytest.fixture def deadline_test_dag(session): """Fixture that creates and syncs a basic DAG with two tasks.""" @@ -996,6 +1023,93 @@ def test_wait_for_downstream(self, dag_maker, session, prev_ti_state, is_ti_sche schedulable_tis = [ti.task_id for ti in decision.schedulable_tis] assert (upstream.task_id in schedulable_tis) == is_ti_schedulable + @pytest.mark.parametrize( + "new_dagruns_to_examine", + [ + 0, + -1, + ], + ) + def test_get_running_dag_runs_ignores_new_dagruns_to_examine_when_smaller_than_0( + self, + session, + dag_maker, + create_dagruns, + monkeypatch, + new_dagruns_to_examine, + ): + monkeypatch.setattr( + DagRun, + "DEFAULT_NEW_DAGRUNS_TO_EXAMINE", + new_dagruns_to_examine, + ) + + with dag_maker( + dag_id="dummy_dag", + schedule=datetime.timedelta(days=1), + start_date=datetime.datetime(2024, 1, 1), + session=session, + ): + EmptyOperator(task_id="dummy_task") + + create_dagruns(dag_maker, session, None, 10) + + with dag_maker( + dag_id="dummy_dag2", + schedule=datetime.timedelta(days=1), + start_date=datetime.datetime(2024, 1, 1), + session=session, + ): + EmptyOperator(task_id="dummy_task2") + + create_dagruns(dag_maker, session, timezone.utcnow(), 20) + + session.flush() + + dagruns = list(DagRun.get_running_dag_runs_to_examine(session=session)) + + assert len([dagrun for dagrun in dagruns if dagrun.last_scheduling_decision is None]) == 10 + + assert len([dagrun for dagrun in dagruns if dagrun.last_scheduling_decision is not None]) == 10 + + def test_get_running_dag_runs_with_max_new_dagruns_to_examine( + self, session, dag_maker, create_dagruns, monkeypatch + ): + monkeypatch.setattr(DagRun, "DEFAULT_NEW_DAGRUNS_TO_EXAMINE", 10) + + with dag_maker( + dag_id="dummy_dag", + schedule=datetime.timedelta(days=1), + start_date=datetime.datetime(2024, 1, 1), + session=session, + ): + EmptyOperator(task_id="dummy_task") + + create_dagruns(dag_maker, session, None) + + with dag_maker( + dag_id="dummy_dag2", + schedule=datetime.timedelta(days=1), + start_date=datetime.datetime(2024, 1, 1), + session=session, + ): + EmptyOperator(task_id="dummy_task2") + + create_dagruns(dag_maker, session, timezone.utcnow()) + + session.flush() + + dagruns = list(DagRun.get_running_dag_runs_to_examine(session=session)) + + assert ( + len([dagrun for dagrun in dagruns if dagrun.last_scheduling_decision is None]) + == DagRun.DEFAULT_NEW_DAGRUNS_TO_EXAMINE + ) + assert ( + len([dagrun for dagrun in dagruns if dagrun.last_scheduling_decision is not None]) + == DagRun.DEFAULT_DAGRUNS_TO_EXAMINE + ) + @pytest.mark.parametrize("state", [DagRunState.QUEUED, DagRunState.RUNNING]) def test_next_dagruns_to_examine_only_unpaused(self, session, state, testing_dag_bundle): """ @@ -1034,9 +1148,10 @@ def test_next_dagruns_to_examine_only_unpaused(self, session, state, testing_dag if state == DagRunState.RUNNING: func = DagRun.get_running_dag_runs_to_examine + runs = func(session) else: func = DagRun.get_queued_dag_runs_to_set_running - runs = func(session).all() + runs = func(session).all() assert runs == [dr] @@ -1044,7 +1159,11 @@ def test_next_dagruns_to_examine_only_unpaused(self, session, state, testing_dag session.merge(orm_dag) session.commit() - runs = func(session).all() + if state == DagRunState.RUNNING: + runs = func(session) + else: + runs = func(session).all() + assert runs == [] @mock.patch("airflow._shared.observability.metrics.stats.timing") @@ -1126,7 +1245,7 @@ def test_emit_scheduling_delay(self, session, schedule, expected, testing_dag_bu session.flush() with mock.patch("airflow._shared.observability.metrics.stats.timing") as stats_mock: - dag_run.update_state(session=session) + dag_run.update_state(session) metric_name = f"dagrun.{dag.dag_id}.first_task_scheduling_delay" @@ -1215,7 +1334,7 @@ def test_emit_first_task_start_delay(self, session, queued_at_offset, expected, session.flush() with mock.patch("airflow._shared.observability.metrics.stats.timing") as stats_mock: - dag_run.update_state(session=session) + dag_run.update_state(session) start_delay_call = call("dagrun.first_task_start_delay", mock.ANY, tags=expected_stat_tags) if expected: @@ -1328,28 +1447,17 @@ def test_dag_run_dag_versions_with_null_created_dag_version(self, dag_maker, ses assert isinstance(dag_run.dag_versions, list) assert len(dag_run.dag_versions) == 0 - @pytest.mark.parametrize( - "interval", - [ - datetime.timedelta(hours=1), - VariableInterval("my_key"), - ], - ) - @mock.patch.object(Variable, "get") @mock.patch.object(Deadline, "prune_deadlines") - def test_dagrun_success_deadline(self, _, mock_get, interval, session, deadline_test_dag): + def test_dagrun_success_deadline(self, _, session, deadline_test_dag): def on_success_callable(context): assert context["dag_run"].dag_id == "test_dag" future_date = datetime.datetime.now() + datetime.timedelta(days=365) - # First value used during resolution - mock_get.return_value = "5" - scheduler_dag = deadline_test_dag( deadline=DeadlineAlert( reference=DeadlineReference.FIXED_DATETIME(future_date), - interval=interval, + interval=datetime.timedelta(hours=1), callback=AsyncCallback(empty_callback_for_deadline), ), on_success_callback=on_success_callable, @@ -1454,73 +1562,6 @@ def test_dagrun_success_handles_empty_deadline_list(self, mock_prune, dag_maker, mock_prune.assert_not_called() assert dag_run.state == DagRunState.SUCCESS - @mock.patch.object(Variable, "get") - @mock.patch.object(Deadline, "prune_deadlines") - def test_dagrun_deadline_variable_interval_stable(self, _, mock_get, session, deadline_test_dag): - future_date = datetime.datetime.now() + datetime.timedelta(days=365) - - # First value used during resolution. - mock_get.return_value = "60" - - scheduler_dag = deadline_test_dag( - deadline=DeadlineAlert( - reference=DeadlineReference.FIXED_DATETIME(future_date), - interval=VariableInterval("my_key"), - callback=AsyncCallback(empty_callback_for_deadline), - ), - ) - - dag_run = self.create_dag_run( - dag=scheduler_dag, - task_states={"task_1": TaskInstanceState.SUCCESS, "task_2": TaskInstanceState.SUCCESS}, - session=session, - ) - dag_run.dag = scheduler_dag - - # First update resolve interval to "5". - dag_run.update_state(session=session) - - deadline = session.execute(select(Deadline)).scalars().one_or_none() - first_deadline_time = deadline.deadline_time - - # Change Variable value after resolution. - mock_get.return_value = "120" - - # Run again (This should not change existing deadline). - dag_run.update_state(session=session) - - deadline = session.execute(select(Deadline)).scalars().one_or_none() - assert deadline.deadline_time == first_deadline_time - - @mock.patch.object(Deadline, "prune_deadlines") - def test_dagrun_deadline_variable_interval_missing_variable_fails(self, _, session, deadline_test_dag): - - mock_err = mock.Mock() - mock_err.error.value = "MISSING_DEADLINE" - mock_err.detail = "missing deadline" - - with mock.patch.object( - Variable, - "get", - side_effect=AirflowRuntimeError(mock_err), - ): - future_date = datetime.datetime.now() + datetime.timedelta(days=365) - - scheduler_dag = deadline_test_dag( - deadline=DeadlineAlert( - reference=DeadlineReference.FIXED_DATETIME(future_date), - interval=VariableInterval("missing_key"), - callback=AsyncCallback(empty_callback_for_deadline), - ), - ) - - with pytest.raises(ValueError, match="not found"): - self.create_dag_run( - dag=scheduler_dag, - task_states={"task_1": TaskInstanceState.SUCCESS}, - session=session, - ) - @pytest.mark.parametrize( ("run_type", "expected_tis"), @@ -3831,23 +3872,3 @@ def test_emit_dagrun_span_with_none_or_empty_carrier(self, dag_maker, session, c assert spans[0].name == f"dag_run.{dr.dag_id}" else: assert len(spans) == 0 - - @pytest.mark.db_test - def test_context_carrier_includes_detail_level_from_conf(self, dag_maker): - """DagRun created with TASK_SPAN_DETAIL_LEVEL_KEY in conf should encode the level in trace state.""" - from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator - - from airflow._shared.observability.traces import ( - TASK_SPAN_DETAIL_LEVEL_KEY, - get_task_span_detail_level, - ) - - with dag_maker("test_tracing_detail_level"): - EmptyOperator(task_id="t1") - dr = dag_maker.create_dagrun(conf={TASK_SPAN_DETAIL_LEVEL_KEY: 2}) - - ctx = TraceContextTextMapPropagator().extract(dr.context_carrier) - from opentelemetry import trace - - span = trace.get_current_span(ctx) - assert get_task_span_detail_level(span) == 2