From 988b548185a0abff9ed7fe145dd91b4359cdfadc Mon Sep 17 00:00:00 2001 From: justinpakzad <114518232+justinpakzad@users.noreply.github.com> Date: Tue, 16 Jun 2026 20:33:14 -0400 Subject: [PATCH] Add bulk XCom deletion to Execution API Add a DELETE /xcoms/{dag_id}/{run_id} endpoint that deletes all XCom entries for a Dag run, with optional task_id, key, and map_index filters. Returns the count of deleted rows. --- .../api_fastapi/execution_api/routes/xcoms.py | 40 +++++++++- .../execution_api/versions/v2026_06_30.py | 10 +++ .../src/airflow/jobs/triggerer_job_runner.py | 5 ++ .../execution_api/versions/head/test_xcoms.py | 69 +++++++++++++++- .../unit/dag_processing/test_processor.py | 2 + .../tests/unit/jobs/test_triggerer_job.py | 1 + task-sdk/src/airflow/sdk/api/client.py | 22 ++++++ task-sdk/src/airflow/sdk/bases/xcom.py | 34 ++++++++ .../src/airflow/sdk/execution_time/comms.py | 16 ++++ .../sdk/execution_time/request_handlers.py | 7 ++ .../sdk/execution_time/schema/schema.json | 79 +++++++++++++++++++ .../airflow/sdk/execution_time/supervisor.py | 4 + .../execution_time/test_supervisor.py | 35 ++++++++ 13 files changed, 318 insertions(+), 6 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py index a7592e4f0f2ec..68aa9ca1e4d5e 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py @@ -41,9 +41,9 @@ async def has_xcom_access( dag_id: str, run_id: str, - task_id: str, - xcom_key: Annotated[str, Path(alias="key", min_length=1)], request: Request, + task_id: str | None = None, + key: str | None = None, token=CurrentTIToken, ) -> bool: """Check if the task has access to the XCom.""" @@ -53,7 +53,7 @@ async def has_xcom_access( "Checking %s XCom access for xcom from TaskInstance with key '%s' to XCom '%s'", "write" if write else "read", token.id, - xcom_key, + key, ) # The current version of Airflow does not support true @@ -444,3 +444,37 @@ def delete_xcom( session.execute(query) session.commit() return {"message": f"XCom with key: {key} successfully deleted."} + + +@router.delete( + "/{dag_id}/{run_id}", + description="Bulk delete Xcom values.", +) +def bulk_delete_xcoms( + session: SessionDep, + dag_id: str, + run_id: str, + task_id: Annotated[str | None, Query()] = None, + key: Annotated[str | None, Query()] = None, + map_index: Annotated[int | None, Query()] = None, +): + """Bulk delete Xcom values.""" + query = delete(XComModel).where( + XComModel.dag_id == dag_id, + XComModel.run_id == run_id, + ) + + if task_id is not None: + query = query.where(XComModel.task_id == task_id) + + if key is not None: + query = query.where(XComModel.key == key) + + if map_index is not None: + query = query.where(XComModel.map_index == map_index) + + result = session.execute(query) + count = getattr(result, "rowcount", 0) + session.commit() + + return {"count": count} diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py index cbd801c0a9b0b..9c3a522db3878 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py @@ -127,3 +127,13 @@ class AddTaskAndAssetStateStoreEndpoints(VersionChange): endpoint("/store/asset/by-uri/value", ["DELETE"]).didnt_exist, endpoint("/store/asset/by-uri/clear", ["DELETE"]).didnt_exist, ) + + +class AddXcomBulkDeleteEndpoint(VersionChange): + """Add XCom bulk delete endpoint.""" + + description = __doc__ + + instructions_to_migrate_to_previous_version = ( + endpoint("xcoms/{dag_id}/{run_id}", ["DELETE"]).didnt_exist, + ) diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 97fd669204456..8e583953d0a8a 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -59,6 +59,7 @@ from airflow.observability.metrics import stats_utils from airflow.sdk.api.datamodels._generated import HITLDetailResponse from airflow.sdk.execution_time.comms import ( + BulkDeleteXCom, CommsDecoder, ConnectionResult, DagRunStateResult, @@ -90,6 +91,7 @@ _RequestFrame, ) from airflow.sdk.execution_time.request_handlers import ( + handle_bulk_delete_xcom, handle_delete_variable, handle_delete_xcom, handle_get_connection, @@ -370,6 +372,7 @@ def from_api_response(cls, response: HITLDetailResponse) -> HITLDetailResponseRe | GetVariableKeys | PutVariable | DeleteXCom + | BulkDeleteXCom | GetXCom | SetXCom | GetTICount @@ -592,6 +595,8 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r resp, dump_opts = handle_put_variable(self.client, msg) elif isinstance(msg, DeleteXCom): resp, dump_opts = handle_delete_xcom(self.client, msg) + elif isinstance(msg, BulkDeleteXCom): + resp, dump_opts = handle_bulk_delete_xcom(self.client, msg) elif isinstance(msg, GetXCom): resp, dump_opts = handle_get_xcom(self.client, msg) elif isinstance(msg, SetXCom): diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py index 899435f717a9f..54c2a267626fa 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py @@ -57,11 +57,13 @@ async def _( request: Request, dag_id: str = Path(), run_id: str = Path(), - task_id: str = Path(), - xcom_key: str = Path(alias="key"), + task_id: str | None = None, + key: str | None = None, token=CurrentTIToken, ): - await has_xcom_access(dag_id, run_id, task_id, xcom_key, request, token) + await has_xcom_access( + dag_id=dag_id, run_id=run_id, request=request, task_id=task_id, key=key, token=token + ) raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail={ @@ -609,3 +611,64 @@ def test_xcom_delete_endpoint(self, client, create_task_instance, session): ) ).first() assert xcom_ti is not None + + @pytest.mark.parametrize( + ("task_id", "key", "expected_remaining", "expected_deleted"), + [ + pytest.param(None, None, 0, 4, id="all_xcoms_for_run"), + pytest.param("t1", None, 2, 2, id="all_keys_for_task"), + pytest.param(None, "xcom_3", 3, 1, id="specific_key_all_tasks"), + ], + ) + def test_xcom_bulk_delete_endpoint( + self, client, dag_maker, session, task_id, key, expected_remaining, expected_deleted + ): + """Test XCom bulk deletions.""" + + with dag_maker(dag_id="dag"): + EmptyOperator(task_id="t1") + EmptyOperator(task_id="t2") + + dag_run = dag_maker.create_dagrun(run_id="test") + + ti = dag_run.get_task_instance("t1") + ti2 = dag_run.get_task_instance("t2") + + ti.xcom_push(key="xcom_1", value='"value1"', session=session) + ti.xcom_push(key="xcom_2", value='"value2"', session=session) + + ti2.xcom_push(key="xcom_1", value='"value1"', session=session) + ti2.xcom_push(key="xcom_3", value='"value3"', session=session) + session.commit() + + params = {} + if task_id is not None: + params["task_id"] = task_id + if key is not None: + params["key"] = key + response = client.delete(f"/execution/xcoms/{ti.dag_id}/{ti.run_id}", params=params) + + assert response.status_code == 200 + assert response.json() == {"count": expected_deleted} + + xcoms = session.scalars( + select(XComModel).where(XComModel.dag_id == ti.dag_id, XComModel.run_id == ti.run_id) + ).all() + assert len(xcoms) == expected_remaining + + if task_id == "t1" and key is None: + assert not any(xcom.task_id == "t1" for xcom in xcoms) + assert all(xcom.task_id == "t2" for xcom in xcoms) + + remaining_keys = {xcom.key for xcom in xcoms} + assert remaining_keys == {"xcom_1", "xcom_3"} + + elif task_id is None and key == "xcom_3": + assert not any(xcom.key == "xcom_3" for xcom in xcoms) + assert all(xcom.key != "xcom_3" for xcom in xcoms) + + remaining_tasks = {xcom.task_id for xcom in xcoms} + assert remaining_tasks == {"t1", "t2"} + + remaining_keys = {xcom.key for xcom in xcoms} + assert remaining_keys == {"xcom_1", "xcom_2"} diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index 2e3d6940cc7eb..c357ff61224a3 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -1947,6 +1947,7 @@ def get_type_names(union_type): "AwaitInputTask", "DeferTask", "DeleteXCom", + "BulkDeleteXCom", "GetAssetByName", "GetAssetByUri", "GetAssetsByAlias", @@ -2009,6 +2010,7 @@ def get_type_names(union_type): # AIP-103 task/asset store results — worker-only responses to the above messages. "TaskStateStoreResult", "AssetStateStoreResult", + "XComDeleteCountResult", } supervisor_diff = supervisor_types - manager_types - in_supervisor_but_not_in_manager diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index 923779df0034c..970338c771ea0 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -2117,6 +2117,7 @@ def get_type_names(union_type): "CreateHITLDetailPayload", "PrevSuccessfulDagRunResult", "XComCountResponse", + "XComDeleteCountResult", "XComSequenceIndexResult", "XComSequenceSliceResult", "PreviousDagRunResult", diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 66496d5cac911..0146f065396cf 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -105,6 +105,7 @@ TICount, UpdateHITLDetail, XComCountResponse, + XComDeleteCountResult, ) if TYPE_CHECKING: @@ -656,6 +657,27 @@ def delete( # decouple from the server response string return OKResponse(ok=True) + def delete_all( + self, + dag_id: str, + run_id: str, + task_id: str | None = None, + key: str | None = None, + map_index: int | None = None, + ) -> XComDeleteCountResult: + """Bulk delete XCom values via the API server.""" + params: dict[str, str | int] = {} + + if map_index is not None and map_index >= 0: + params["map_index"] = map_index + if task_id is not None: + params["task_id"] = task_id + if key is not None: + params["key"] = key + + resp = self.client.delete(url=f"xcoms/{dag_id}/{run_id}", params=params) + return XComDeleteCountResult(count=resp.json()["count"]) + def get_sequence_item( self, dag_id: str, diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py b/task-sdk/src/airflow/sdk/bases/xcom.py index 9667a608a1d0b..4b00b51433724 100644 --- a/task-sdk/src/airflow/sdk/bases/xcom.py +++ b/task-sdk/src/airflow/sdk/bases/xcom.py @@ -23,6 +23,7 @@ import structlog from airflow.sdk.execution_time.comms import ( + BulkDeleteXCom, DeleteXCom, GetXCom, GetXComSequenceSlice, @@ -381,3 +382,36 @@ def delete( map_index=map_index, ), ) + + @classmethod + def delete_all( + cls, + dag_id: str, + run_id: str, + task_id: str | None = None, + key: str | None = None, + map_index: int | None = None, + ) -> None: + """ + Bulk delete XCom entries, optionally filtered by task_id, key, or map_index. + + :param dag_id: Dag ID. + :param run_id: Dag run ID for the task. + :param task_id: Optional task ID filter. If provided, only XComs from this task + will be deleted. Pass *None* (default) to delete across all tasks. + :param key: Optional key filter. If provided, only XComs with this key + will be deleted. Pass *None* (default) to delete all keys. + :param map_index: Optional map index filter. If provided, only XComs with this + map index will be deleted. Pass *None* (default) to delete all map indexes. + """ + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + SUPERVISOR_COMMS.send( + BulkDeleteXCom( + dag_id=dag_id, + run_id=run_id, + task_id=task_id, + key=key, + map_index=map_index, + ), + ) diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index a042d7a1e80b8..2c018a9c40303 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -513,6 +513,11 @@ class XComCountResponse(BaseModel): type: Literal["XComCountResponse"] = "XComCountResponse" +class XComDeleteCountResult(BaseModel): + count: int + type: Literal["XComDeleteCountResult"] = "XComDeleteCountResult" + + class XComSequenceIndexResult(BaseModel): root: JsonValue type: Literal["XComSequenceIndexResult"] = "XComSequenceIndexResult" @@ -792,6 +797,7 @@ def from_api_response(cls, dag_response: DagResponse) -> DagResult: | VariableResult | VariableKeysResult | XComCountResponse + | XComDeleteCountResult | XComResult | XComSequenceIndexResult | XComSequenceSliceResult @@ -921,6 +927,15 @@ class DeleteXCom(BaseModel): type: Literal["DeleteXCom"] = "DeleteXCom" +class BulkDeleteXCom(BaseModel): + dag_id: str + run_id: str + task_id: str | None = None + key: str | None = None + map_index: int | None = None + type: Literal["BulkDeleteXCom"] = "BulkDeleteXCom" + + class GetTaskStateStore(BaseModel): ti_id: UUID key: str @@ -1206,6 +1221,7 @@ class GetDag(BaseModel): | DeleteAssetStateStoreByUri | DeleteTaskStateStore | DeleteXCom + | BulkDeleteXCom | GetAssetByName | GetAssetByUri | GetAssetsByAlias diff --git a/task-sdk/src/airflow/sdk/execution_time/request_handlers.py b/task-sdk/src/airflow/sdk/execution_time/request_handlers.py index 959be43fe93cf..e5168d00ea416 100644 --- a/task-sdk/src/airflow/sdk/execution_time/request_handlers.py +++ b/task-sdk/src/airflow/sdk/execution_time/request_handlers.py @@ -40,6 +40,7 @@ XComSequenceSliceResponse, ) from airflow.sdk.execution_time.comms import ( + BulkDeleteXCom, ConnectionResult, DagRunStateResult, DeleteVariable, @@ -188,6 +189,12 @@ def handle_delete_xcom(client: Client, msg: DeleteXCom) -> tuple[BaseModel | Non return None, {} +def handle_bulk_delete_xcom(client: Client, msg: BulkDeleteXCom) -> tuple[BaseModel | None, dict[str, bool]]: + """Bulk delete XCom values.""" + resp = client.xcoms.delete_all(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index) + return resp, {} + + def handle_get_dr_count(client: Client, msg: GetDRCount) -> tuple[BaseModel | None, dict[str, bool]]: """Fetch dag run counts.""" resp = client.dag_runs.get_count( diff --git a/task-sdk/src/airflow/sdk/execution_time/schema/schema.json b/task-sdk/src/airflow/sdk/execution_time/schema/schema.json index 4807d9f53b353..ff1c959d16da0 100644 --- a/task-sdk/src/airflow/sdk/execution_time/schema/schema.json +++ b/task-sdk/src/airflow/sdk/execution_time/schema/schema.json @@ -416,6 +416,66 @@ "title": "AwaitInputTask", "type": "object" }, + "BulkDeleteXCom": { + "properties": { + "dag_id": { + "title": "Dag Id", + "type": "string" + }, + "run_id": { + "title": "Run Id", + "type": "string" + }, + "task_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Task Id" + }, + "key": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Key" + }, + "map_index": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Map Index" + }, + "type": { + "const": "BulkDeleteXCom", + "default": "BulkDeleteXCom", + "title": "Type", + "type": "string" + } + }, + "required": [ + "dag_id", + "run_id" + ], + "title": "BulkDeleteXCom", + "type": "object" + }, "BundleInfo": { "description": "Schema for telling task which bundle to run with.", "properties": { @@ -4270,6 +4330,25 @@ "title": "XComCountResponse", "type": "object" }, + "XComDeleteCountResult": { + "properties": { + "count": { + "title": "Count", + "type": "integer" + }, + "type": { + "const": "XComDeleteCountResult", + "default": "XComDeleteCountResult", + "title": "Type", + "type": "string" + } + }, + "required": [ + "count" + ], + "title": "XComDeleteCountResult", + "type": "object" + }, "XComResult": { "description": "Response to ReadXCom request.", "properties": { diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index d4617efbbae64..4ab4b4878cdd3 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -64,6 +64,7 @@ AssetResult, AssetStateStoreResult, AwaitInputTask, + BulkDeleteXCom, ClearAssetStateStoreByName, ClearAssetStateStoreByUri, ClearTaskStateStore, @@ -133,6 +134,7 @@ ) from airflow.sdk.execution_time.coordinator import get_coordinator_manager from airflow.sdk.execution_time.request_handlers import ( + handle_bulk_delete_xcom, handle_delete_variable, handle_delete_xcom, handle_get_connection, @@ -1726,6 +1728,8 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: resp, dump_opts = handle_set_xcom(self.client, msg) elif isinstance(msg, DeleteXCom): resp, dump_opts = handle_delete_xcom(self.client, msg) + elif isinstance(msg, BulkDeleteXCom): + resp, dump_opts = handle_bulk_delete_xcom(self.client, msg) elif isinstance(msg, PutVariable): resp, dump_opts = handle_put_variable(self.client, msg) elif isinstance(msg, SetRenderedFields): diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 4893258195446..9461d57c3b986 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -77,6 +77,7 @@ AssetsByAliasResult, AssetStateStoreResult, AwaitInputTask, + BulkDeleteXCom, ClearAssetStateStoreByName, ClearAssetStateStoreByUri, ClearTaskStateStore, @@ -154,6 +155,7 @@ VariableKeysResult, VariableResult, XComCountResponse, + XComDeleteCountResult, XComResult, XComSequenceIndexResult, XComSequenceSliceResult, @@ -1871,6 +1873,39 @@ class RequestTestCase: ), test_id="delete_xcom", ), + RequestTestCase( + message=BulkDeleteXCom( + dag_id="test_dag", + run_id="test_run", + ), + client_mock=ClientMock( + method_path="xcoms.delete_all", + args=("test_dag", "test_run", None, None, None), + response=XComDeleteCountResult(count=5), + ), + expected_body={"count": 5, "type": "XComDeleteCountResult"}, + test_id="bulk_delete_xcoms", + ), + RequestTestCase( + message=BulkDeleteXCom(dag_id="test_dag", run_id="test_run", task_id="t1"), + client_mock=ClientMock( + method_path="xcoms.delete_all", + args=("test_dag", "test_run", "t1", None, None), + response=XComDeleteCountResult(count=3), + ), + expected_body={"count": 3, "type": "XComDeleteCountResult"}, + test_id="bulk_delete_xcoms_with_task_id", + ), + RequestTestCase( + message=BulkDeleteXCom(dag_id="test_dag", run_id="test_run", task_id="t1", map_index=0), + client_mock=ClientMock( + method_path="xcoms.delete_all", + args=("test_dag", "test_run", "t1", None, 0), + response=XComDeleteCountResult(count=1), + ), + expected_body={"count": 1, "type": "XComDeleteCountResult"}, + test_id="bulk_delete_xcoms_with_map_index", + ), RequestTestCase( message=RetryTask( end_date=timezone.parse("2024-10-31T12:00:00Z"), rendered_map_index="test retry task"