Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@
async def has_xcom_access(
dag_id: str,
run_id: str,
task_id: str,
xcom_key: Annotated[str, Path(alias="key", min_length=1)],
request: Request,
task_id: str | None = None,
key: str | None = None,
token=CurrentTIToken,
) -> bool:
"""Check if the task has access to the XCom."""
Expand All @@ -53,7 +53,7 @@ async def has_xcom_access(
"Checking %s XCom access for xcom from TaskInstance with key '%s' to XCom '%s'",
"write" if write else "read",
token.id,
xcom_key,
key,
)

# The current version of Airflow does not support true
Expand Down Expand Up @@ -444,3 +444,37 @@ def delete_xcom(
session.execute(query)
session.commit()
return {"message": f"XCom with key: {key} successfully deleted."}


@router.delete(
"/{dag_id}/{run_id}",
description="Bulk delete Xcom values.",
)
def bulk_delete_xcoms(
session: SessionDep,
dag_id: str,
run_id: str,
task_id: Annotated[str | None, Query()] = None,
key: Annotated[str | None, Query()] = None,
map_index: Annotated[int | None, Query()] = None,
):
"""Bulk delete Xcom values."""
query = delete(XComModel).where(
XComModel.dag_id == dag_id,
XComModel.run_id == run_id,
)

if task_id is not None:
query = query.where(XComModel.task_id == task_id)

if key is not None:
query = query.where(XComModel.key == key)

if map_index is not None:
query = query.where(XComModel.map_index == map_index)

result = session.execute(query)
count = getattr(result, "rowcount", 0)
session.commit()

return {"count": count}
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,13 @@ class AddTaskAndAssetStateStoreEndpoints(VersionChange):
endpoint("/store/asset/by-uri/value", ["DELETE"]).didnt_exist,
endpoint("/store/asset/by-uri/clear", ["DELETE"]).didnt_exist,
)


class AddXcomBulkDeleteEndpoint(VersionChange):
"""Add XCom bulk delete endpoint."""

description = __doc__

instructions_to_migrate_to_previous_version = (
endpoint("xcoms/{dag_id}/{run_id}", ["DELETE"]).didnt_exist,
)
5 changes: 5 additions & 0 deletions airflow-core/src/airflow/jobs/triggerer_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from airflow.observability.metrics import stats_utils
from airflow.sdk.api.datamodels._generated import HITLDetailResponse
from airflow.sdk.execution_time.comms import (
BulkDeleteXCom,
CommsDecoder,
ConnectionResult,
DagRunStateResult,
Expand Down Expand Up @@ -90,6 +91,7 @@
_RequestFrame,
)
from airflow.sdk.execution_time.request_handlers import (
handle_bulk_delete_xcom,
handle_delete_variable,
handle_delete_xcom,
handle_get_connection,
Expand Down Expand Up @@ -370,6 +372,7 @@ def from_api_response(cls, response: HITLDetailResponse) -> HITLDetailResponseRe
| GetVariableKeys
| PutVariable
| DeleteXCom
| BulkDeleteXCom
| GetXCom
| SetXCom
| GetTICount
Expand Down Expand Up @@ -592,6 +595,8 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r
resp, dump_opts = handle_put_variable(self.client, msg)
elif isinstance(msg, DeleteXCom):
resp, dump_opts = handle_delete_xcom(self.client, msg)
elif isinstance(msg, BulkDeleteXCom):
resp, dump_opts = handle_bulk_delete_xcom(self.client, msg)
elif isinstance(msg, GetXCom):
Comment thread
justinpakzad marked this conversation as resolved.
resp, dump_opts = handle_get_xcom(self.client, msg)
elif isinstance(msg, SetXCom):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,13 @@ async def _(
request: Request,
dag_id: str = Path(),
run_id: str = Path(),
task_id: str = Path(),
xcom_key: str = Path(alias="key"),
task_id: str | None = None,
key: str | None = None,
token=CurrentTIToken,
):
await has_xcom_access(dag_id, run_id, task_id, xcom_key, request, token)
await has_xcom_access(
dag_id=dag_id, run_id=run_id, request=request, task_id=task_id, key=key, token=token
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
Expand Down Expand Up @@ -609,3 +611,64 @@ def test_xcom_delete_endpoint(self, client, create_task_instance, session):
)
).first()
assert xcom_ti is not None

@pytest.mark.parametrize(
("task_id", "key", "expected_remaining", "expected_deleted"),
[
pytest.param(None, None, 0, 4, id="all_xcoms_for_run"),
pytest.param("t1", None, 2, 2, id="all_keys_for_task"),
pytest.param(None, "xcom_3", 3, 1, id="specific_key_all_tasks"),
],
)
def test_xcom_bulk_delete_endpoint(
self, client, dag_maker, session, task_id, key, expected_remaining, expected_deleted
):
"""Test XCom bulk deletions."""

with dag_maker(dag_id="dag"):
EmptyOperator(task_id="t1")
EmptyOperator(task_id="t2")

dag_run = dag_maker.create_dagrun(run_id="test")

ti = dag_run.get_task_instance("t1")
ti2 = dag_run.get_task_instance("t2")

ti.xcom_push(key="xcom_1", value='"value1"', session=session)
ti.xcom_push(key="xcom_2", value='"value2"', session=session)

ti2.xcom_push(key="xcom_1", value='"value1"', session=session)
ti2.xcom_push(key="xcom_3", value='"value3"', session=session)
session.commit()

params = {}
if task_id is not None:
params["task_id"] = task_id
if key is not None:
params["key"] = key
response = client.delete(f"/execution/xcoms/{ti.dag_id}/{ti.run_id}", params=params)

assert response.status_code == 200
assert response.json() == {"count": expected_deleted}

xcoms = session.scalars(
select(XComModel).where(XComModel.dag_id == ti.dag_id, XComModel.run_id == ti.run_id)
).all()
assert len(xcoms) == expected_remaining

if task_id == "t1" and key is None:
assert not any(xcom.task_id == "t1" for xcom in xcoms)
assert all(xcom.task_id == "t2" for xcom in xcoms)

remaining_keys = {xcom.key for xcom in xcoms}
assert remaining_keys == {"xcom_1", "xcom_3"}

elif task_id is None and key == "xcom_3":
assert not any(xcom.key == "xcom_3" for xcom in xcoms)
assert all(xcom.key != "xcom_3" for xcom in xcoms)

remaining_tasks = {xcom.task_id for xcom in xcoms}
assert remaining_tasks == {"t1", "t2"}

remaining_keys = {xcom.key for xcom in xcoms}
assert remaining_keys == {"xcom_1", "xcom_2"}
2 changes: 2 additions & 0 deletions airflow-core/tests/unit/dag_processing/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1947,6 +1947,7 @@ def get_type_names(union_type):
"AwaitInputTask",
"DeferTask",
"DeleteXCom",
"BulkDeleteXCom",
"GetAssetByName",
"GetAssetByUri",
"GetAssetsByAlias",
Expand Down Expand Up @@ -2009,6 +2010,7 @@ def get_type_names(union_type):
# AIP-103 task/asset store results — worker-only responses to the above messages.
"TaskStateStoreResult",
"AssetStateStoreResult",
"XComDeleteCountResult",
}

supervisor_diff = supervisor_types - manager_types - in_supervisor_but_not_in_manager
Expand Down
1 change: 1 addition & 0 deletions airflow-core/tests/unit/jobs/test_triggerer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2117,6 +2117,7 @@ def get_type_names(union_type):
"CreateHITLDetailPayload",
"PrevSuccessfulDagRunResult",
"XComCountResponse",
"XComDeleteCountResult",
"XComSequenceIndexResult",
"XComSequenceSliceResult",
"PreviousDagRunResult",
Expand Down
22 changes: 22 additions & 0 deletions task-sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
TICount,
UpdateHITLDetail,
XComCountResponse,
XComDeleteCountResult,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -656,6 +657,27 @@ def delete(
# decouple from the server response string
return OKResponse(ok=True)

def delete_all(
self,
dag_id: str,
run_id: str,
task_id: str | None = None,
key: str | None = None,
map_index: int | None = None,
) -> XComDeleteCountResult:
"""Bulk delete XCom values via the API server."""
params: dict[str, str | int] = {}

if map_index is not None and map_index >= 0:
params["map_index"] = map_index
if task_id is not None:
params["task_id"] = task_id
if key is not None:
params["key"] = key

resp = self.client.delete(url=f"xcoms/{dag_id}/{run_id}", params=params)
return XComDeleteCountResult(count=resp.json()["count"])

def get_sequence_item(
self,
dag_id: str,
Expand Down
34 changes: 34 additions & 0 deletions task-sdk/src/airflow/sdk/bases/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import structlog

from airflow.sdk.execution_time.comms import (
BulkDeleteXCom,
DeleteXCom,
GetXCom,
GetXComSequenceSlice,
Expand Down Expand Up @@ -381,3 +382,36 @@ def delete(
map_index=map_index,
),
)

@classmethod
def delete_all(
cls,
dag_id: str,
run_id: str,
task_id: str | None = None,
key: str | None = None,
map_index: int | None = None,
) -> None:
"""
Bulk delete XCom entries, optionally filtered by task_id, key, or map_index.

:param dag_id: Dag ID.
:param run_id: Dag run ID for the task.
:param task_id: Optional task ID filter. If provided, only XComs from this task
will be deleted. Pass *None* (default) to delete across all tasks.
:param key: Optional key filter. If provided, only XComs with this key
will be deleted. Pass *None* (default) to delete all keys.
:param map_index: Optional map index filter. If provided, only XComs with this
map index will be deleted. Pass *None* (default) to delete all map indexes.
"""
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

SUPERVISOR_COMMS.send(
BulkDeleteXCom(
dag_id=dag_id,
run_id=run_id,
task_id=task_id,
key=key,
map_index=map_index,
),
)
16 changes: 16 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,11 @@ class XComCountResponse(BaseModel):
type: Literal["XComCountResponse"] = "XComCountResponse"


class XComDeleteCountResult(BaseModel):
count: int
type: Literal["XComDeleteCountResult"] = "XComDeleteCountResult"


class XComSequenceIndexResult(BaseModel):
root: JsonValue
type: Literal["XComSequenceIndexResult"] = "XComSequenceIndexResult"
Expand Down Expand Up @@ -792,6 +797,7 @@ def from_api_response(cls, dag_response: DagResponse) -> DagResult:
| VariableResult
| VariableKeysResult
| XComCountResponse
| XComDeleteCountResult
| XComResult
| XComSequenceIndexResult
| XComSequenceSliceResult
Expand Down Expand Up @@ -921,6 +927,15 @@ class DeleteXCom(BaseModel):
type: Literal["DeleteXCom"] = "DeleteXCom"


class BulkDeleteXCom(BaseModel):
dag_id: str
run_id: str
task_id: str | None = None
key: str | None = None
map_index: int | None = None
type: Literal["BulkDeleteXCom"] = "BulkDeleteXCom"


class GetTaskStateStore(BaseModel):
ti_id: UUID
key: str
Expand Down Expand Up @@ -1206,6 +1221,7 @@ class GetDag(BaseModel):
| DeleteAssetStateStoreByUri
| DeleteTaskStateStore
| DeleteXCom
| BulkDeleteXCom
| GetAssetByName
| GetAssetByUri
| GetAssetsByAlias
Expand Down
7 changes: 7 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/request_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
XComSequenceSliceResponse,
)
from airflow.sdk.execution_time.comms import (
BulkDeleteXCom,
ConnectionResult,
DagRunStateResult,
DeleteVariable,
Expand Down Expand Up @@ -188,6 +189,12 @@ def handle_delete_xcom(client: Client, msg: DeleteXCom) -> tuple[BaseModel | Non
return None, {}


def handle_bulk_delete_xcom(client: Client, msg: BulkDeleteXCom) -> tuple[BaseModel | None, dict[str, bool]]:
"""Bulk delete XCom values."""
resp = client.xcoms.delete_all(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index)
return resp, {}


def handle_get_dr_count(client: Client, msg: GetDRCount) -> tuple[BaseModel | None, dict[str, bool]]:
"""Fetch dag run counts."""
resp = client.dag_runs.get_count(
Expand Down
Loading
Loading