diff --git a/airflow-core/newsfragments/67941.feature.rst b/airflow-core/newsfragments/67941.feature.rst new file mode 100644 index 0000000000000..37aa553a4c353 --- /dev/null +++ b/airflow-core/newsfragments/67941.feature.rst @@ -0,0 +1 @@ +Add an asset partition sensor for waiting on a specific asset event partition. diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_events.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_events.py index d0ecc3d3adaf2..c8cbc6c779e39 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_events.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_events.py @@ -80,6 +80,7 @@ def get_asset_event_by_asset_name_uri( session: SessionDep, after: Annotated[UtcDateTime | None, Query(description="The start of the time range")] = None, before: Annotated[UtcDateTime | None, Query(description="The end of the time range")] = None, + partition_key: Annotated[str | None, Query(description="The partition key of the Asset Event")] = None, ascending: Annotated[bool, Query(description="Whether to sort results in ascending order")] = True, limit: Annotated[int | None, Query(description="The maximum number of results to return")] = None, ) -> AssetEventsResponse: @@ -102,6 +103,8 @@ def get_asset_event_by_asset_name_uri( where_clause = and_(where_clause, AssetEvent.timestamp >= after) if before: where_clause = and_(where_clause, AssetEvent.timestamp <= before) + if partition_key is not None: + where_clause = and_(where_clause, AssetEvent.partition_key == partition_key) return _get_asset_events_through_sql_clauses( join_clause=AssetEvent.asset, @@ -118,6 +121,7 @@ def get_asset_event_by_asset_alias( session: SessionDep, after: Annotated[UtcDateTime | None, Query(description="The start of the time range")] = None, before: Annotated[UtcDateTime | None, Query(description="The end of the time range")] = None, + partition_key: Annotated[str | None, Query(description="The partition key of the Asset Event")] = None, ascending: Annotated[bool, Query(description="Whether to sort results in ascending order")] = True, limit: Annotated[int | None, Query(description="The maximum number of results to return")] = None, ) -> AssetEventsResponse: @@ -126,6 +130,8 @@ def get_asset_event_by_asset_alias( where_clause = and_(where_clause, AssetEvent.timestamp >= after) if before: where_clause = and_(where_clause, AssetEvent.timestamp <= before) + if partition_key is not None: + where_clause = and_(where_clause, AssetEvent.partition_key == partition_key) return _get_asset_events_through_sql_clauses( join_clause=AssetEvent.source_aliases, diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 3612feb011fc3..56249ac492e01 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -66,6 +66,8 @@ DeleteXCom, DRCount, ErrorResponse, + GetAssetEventByAsset, + GetAssetEventByAssetAlias, GetConnection, GetDagRunState, GetDRCount, @@ -92,6 +94,8 @@ from airflow.sdk.execution_time.request_handlers import ( handle_delete_variable, handle_delete_xcom, + handle_get_asset_event_by_asset, + handle_get_asset_event_by_asset_alias, handle_get_connection, handle_get_dag_run_state, handle_get_dr_count, @@ -600,6 +604,10 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r resp, dump_opts = handle_get_dr_count(self.client, msg) elif isinstance(msg, GetDagRunState): resp, dump_opts = handle_get_dag_run_state(self.client, msg) + elif isinstance(msg, GetAssetEventByAsset): + resp, dump_opts = handle_get_asset_event_by_asset(self.client, msg) + elif isinstance(msg, GetAssetEventByAssetAlias): + resp, dump_opts = handle_get_asset_event_by_asset_alias(self.client, msg) elif isinstance(msg, GetTICount): resp, dump_opts = handle_get_ti_count(self.client, msg) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_asset_events.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_asset_events.py index e3839f19eafe8..83772e67777fd 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_asset_events.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_asset_events.py @@ -54,6 +54,33 @@ def make_timestamp(day): session.commit() +@pytest.fixture +def test_partitioned_asset_events(session, test_asset): + def make_timestamp(day): + return datetime(2021, 1, day, tzinfo=timezone.utc) + + common = { + "asset_id": test_asset.id, + "extra": {"foo": "bar"}, + "source_dag_id": "foo", + "source_task_id": "bar", + "source_run_id": "custom", + "source_map_index": -1, + } + + events = [ + AssetEvent(id=101, timestamp=make_timestamp(1), partition_key="2021-01-01", **common), + AssetEvent(id=102, timestamp=make_timestamp(2), partition_key="2021-01-02", **common), + ] + session.add_all(events) + session.commit() + yield events + + for event in events: + session.delete(event) + session.commit() + + @pytest.fixture def test_asset(session): asset = AssetModel( @@ -90,6 +117,20 @@ def test_asset_alias(session, test_asset_events, test_asset): session.commit() +@pytest.fixture +def test_partitioned_asset_alias(session, test_partitioned_asset_events, test_asset): + alias = AssetAliasModel(name="partitioned_alias") + alias.asset_events = test_partitioned_asset_events + alias.assets.append(test_asset) + session.add(alias) + session.commit() + + yield alias + + session.delete(alias) + session.commit() + + class TestGetAssetEventByAsset: @pytest.mark.parametrize( ("uri", "name"), @@ -457,6 +498,40 @@ def test_get_by_asset_get_last(self, uri, name, client): ] } + @pytest.mark.parametrize( + ("uri", "name"), + [ + (None, "test_get_asset_by_name"), + ("s3://bucket/key", None), + ("s3://bucket/key", "test_get_asset_by_name"), + ], + ) + def test_get_by_asset_with_partition_key_filter(self, uri, name, client, test_partitioned_asset_events): + response = client.get( + "/execution/asset-events/by-asset", + params={"name": name, "uri": uri, "partition_key": "2021-01-02"}, + ) + assert response.status_code == 200 + assert response.json()["asset_events"] == [ + { + "id": test_partitioned_asset_events[1].id, + "extra": {"foo": "bar"}, + "source_task_id": "bar", + "source_dag_id": "foo", + "source_run_id": "custom", + "source_map_index": -1, + "asset": { + "extra": {"foo": "bar"}, + "group": "asset", + "name": "test_get_asset_by_name", + "uri": "s3://bucket/key", + }, + "created_dagruns": [], + "timestamp": "2021-01-02T00:00:00Z", + "partition_key": "2021-01-02", + }, + ] + class TestGetAssetEventByAssetAlias: @pytest.mark.usefixtures("test_asset_alias") @@ -521,3 +596,32 @@ def test_get_by_asset(self, client): }, ] } + + def test_get_by_asset_alias_with_partition_key_filter( + self, client, test_partitioned_asset_alias, test_partitioned_asset_events + ): + response = client.get( + "/execution/asset-events/by-asset-alias", + params={"name": "partitioned_alias", "partition_key": "2021-01-02"}, + ) + + assert response.status_code == 200 + assert response.json()["asset_events"] == [ + { + "id": test_partitioned_asset_events[1].id, + "extra": {"foo": "bar"}, + "source_task_id": "bar", + "source_dag_id": "foo", + "source_run_id": "custom", + "source_map_index": -1, + "asset": { + "extra": {"foo": "bar"}, + "group": "asset", + "name": "test_get_asset_by_name", + "uri": "s3://bucket/key", + }, + "created_dagruns": [], + "timestamp": "2021-01-02T00:00:00Z", + "partition_key": "2021-01-02", + }, + ] diff --git a/generated/provider_dependencies.json.sha256sum b/generated/provider_dependencies.json.sha256sum index 943fd0fc93e4c..b7ae80d985e24 100644 --- a/generated/provider_dependencies.json.sha256sum +++ b/generated/provider_dependencies.json.sha256sum @@ -1 +1 @@ -93831555f2a141e481c81c147142aeb860c34ea860163ca130d045e5ecd0a83b +ad40a2903ed479a80b300a2094554db993e50e349614eb9226a1aa18a5fbf1cb diff --git a/providers/standard/provider.yaml b/providers/standard/provider.yaml index 27e4478f04124..b4a83ccc57e4a 100644 --- a/providers/standard/provider.yaml +++ b/providers/standard/provider.yaml @@ -105,6 +105,7 @@ sensors: - airflow.providers.standard.sensors.python - airflow.providers.standard.sensors.filesystem - airflow.providers.standard.sensors.external_task + - airflow.providers.standard.sensors.asset hooks: - integration-name: Standard python-modules: @@ -119,6 +120,7 @@ triggers: - airflow.providers.standard.triggers.file - airflow.providers.standard.triggers.temporal - airflow.providers.standard.triggers.hitl + - airflow.providers.standard.triggers.asset extra-links: - airflow.providers.standard.operators.trigger_dagrun.TriggerDagRunLink diff --git a/providers/standard/src/airflow/providers/standard/get_provider_info.py b/providers/standard/src/airflow/providers/standard/get_provider_info.py index 1f7b2049454d1..d3a2f063fd890 100644 --- a/providers/standard/src/airflow/providers/standard/get_provider_info.py +++ b/providers/standard/src/airflow/providers/standard/get_provider_info.py @@ -75,6 +75,7 @@ def get_provider_info(): "airflow.providers.standard.sensors.python", "airflow.providers.standard.sensors.filesystem", "airflow.providers.standard.sensors.external_task", + "airflow.providers.standard.sensors.asset", ], } ], @@ -96,6 +97,7 @@ def get_provider_info(): "airflow.providers.standard.triggers.file", "airflow.providers.standard.triggers.temporal", "airflow.providers.standard.triggers.hitl", + "airflow.providers.standard.triggers.asset", ], } ], diff --git a/providers/standard/src/airflow/providers/standard/sensors/asset.py b/providers/standard/src/airflow/providers/standard/sensors/asset.py new file mode 100644 index 0000000000000..f85a3aaf8ec21 --- /dev/null +++ b/providers/standard/src/airflow/providers/standard/sensors/asset.py @@ -0,0 +1,103 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +from airflow.providers.common.compat.sdk import AirflowFailException, BaseSensorOperator, conf +from airflow.providers.standard.triggers.asset import AssetPartitionTrigger + +if TYPE_CHECKING: + from airflow.providers.common.compat.sdk import Asset, Context + + +class AssetPartitionSensor(BaseSensorOperator): + """ + Wait for an asset event with the given partition key. + + :param asset: asset to wait for. + :param partition_key: partition key for the asset event to wait for. + :param deferrable: If waiting for completion, whether to defer the task until done. + """ + + template_fields: Sequence[str] = ("partition_key",) + ui_color = "#e6f1f2" + + def __init__( + self, + *, + asset: Asset, + partition_key: str, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.asset = asset + self.partition_key = partition_key + self.deferrable = deferrable + + def poke(self, context: Context) -> bool: + from airflow.sdk.exceptions import AirflowRuntimeError + from airflow.sdk.execution_time.comms import AssetEventsResult, ErrorResponse, GetAssetEventByAsset + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + self.log.info("Poking for asset event: asset=%s, partition_key=%s", self.asset, self.partition_key) + response = SUPERVISOR_COMMS.send( + GetAssetEventByAsset( + name=self.asset.name, + uri=self.asset.uri, + partition_key=self.partition_key, + ascending=False, + limit=1, + ) + ) + if isinstance(response, ErrorResponse): + raise AirflowRuntimeError(response) + if not isinstance(response, AssetEventsResult): + raise TypeError(f"Unexpected response from supervisor: {type(response).__name__}") + return bool(response.asset_events) + + def execute(self, context: Context) -> None: + if not self.deferrable: + super().execute(context=context) + return + + if not self.poke(context=context): + self.defer( + timeout=datetime.timedelta(seconds=self.timeout), + trigger=AssetPartitionTrigger( + asset_name=self.asset.name, + asset_uri=self.asset.uri, + partition_key=self.partition_key, + poke_interval=self.poke_interval, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + if event and event.get("status") == "success": + self.log.info( + "Asset partition event found: asset=%s, partition_key=%s", + self.asset, + self.partition_key, + ) + return + message = event.get("message") if event else "Trigger completed without an event" + raise AirflowFailException(message) diff --git a/providers/standard/src/airflow/providers/standard/triggers/asset.py b/providers/standard/src/airflow/providers/standard/triggers/asset.py new file mode 100644 index 0000000000000..a11ec811eae99 --- /dev/null +++ b/providers/standard/src/airflow/providers/standard/triggers/asset.py @@ -0,0 +1,102 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator +from typing import Any + +from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS + +if AIRFLOW_V_3_0_PLUS: + from airflow.triggers.base import BaseTrigger, TriggerEvent +else: + from airflow.triggers.base import BaseTrigger, TriggerEvent # type: ignore + + +class AssetPartitionTrigger(BaseTrigger): + """ + Trigger when an asset event exists for the given partition key. + + :param asset_name: name of the asset to wait for. + :param asset_uri: URI of the asset to wait for. + :param partition_key: partition key for the asset event to wait for. + :param poke_interval: polling interval in seconds. + """ + + def __init__( + self, + *, + asset_name: str | None, + asset_uri: str | None, + partition_key: str, + poke_interval: float = 5.0, + ) -> None: + super().__init__() + self.asset_name = asset_name + self.asset_uri = asset_uri + self.partition_key = partition_key + self.poke_interval = poke_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize AssetPartitionTrigger arguments and classpath.""" + return ( + "airflow.providers.standard.triggers.asset.AssetPartitionTrigger", + { + "asset_name": self.asset_name, + "asset_uri": self.asset_uri, + "partition_key": self.partition_key, + "poke_interval": self.poke_interval, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Poll until the requested asset partition event exists.""" + from airflow.sdk.execution_time.comms import AssetEventsResult, ErrorResponse, GetAssetEventByAsset + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + while True: + response = await SUPERVISOR_COMMS.asend( + GetAssetEventByAsset( + name=self.asset_name, + uri=self.asset_uri, + partition_key=self.partition_key, + ascending=False, + limit=1, + ) + ) + if isinstance(response, ErrorResponse): + yield TriggerEvent( + { + "status": "error", + "message": f"{response.error.value}: {response.detail}", + } + ) + return + if not isinstance(response, AssetEventsResult): + yield TriggerEvent( + { + "status": "error", + "message": f"Unexpected response from supervisor: {type(response).__name__}", + } + ) + return + if response.asset_events: + yield TriggerEvent({"status": "success"}) + return + await asyncio.sleep(self.poke_interval) diff --git a/providers/standard/tests/unit/standard/sensors/test_asset.py b/providers/standard/tests/unit/standard/sensors/test_asset.py new file mode 100644 index 0000000000000..9e7ba9eb5a4bf --- /dev/null +++ b/providers/standard/tests/unit/standard/sensors/test_asset.py @@ -0,0 +1,151 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS + +pytestmark = pytest.mark.skipif( + not AIRFLOW_V_3_1_PLUS, + reason="Asset partition sensor tests require Airflow 3.1+ Task SDK supervisor comms", +) + +if AIRFLOW_V_3_1_PLUS: + from airflow.providers.common.compat.sdk import AirflowFailException, Asset, TaskDeferred + from airflow.providers.standard.sensors.asset import AssetPartitionSensor + from airflow.providers.standard.triggers.asset import AssetPartitionTrigger + from airflow.sdk import timezone + from airflow.sdk.api.datamodels._generated import AssetEventResponse, AssetResponse + from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType + from airflow.sdk.execution_time import task_runner + from airflow.sdk.execution_time.comms import ( + AssetEventsResult, + ErrorResponse, + GetAssetEventByAsset, + ) + + +class TestAssetPartitionSensor: + def test_template_fields(self): + assert AssetPartitionSensor.template_fields == ("partition_key",) + + def test_poke_returns_true_when_partition_event_exists(self, monkeypatch): + comms = mock.Mock() + comms.send.return_value = AssetEventsResult( + asset_events=[ + AssetEventResponse( + id=1, + timestamp=timezone.utcnow(), + asset=AssetResponse(name="orders", uri="s3://warehouse/orders", group="asset"), + partition_key="2024-01-01", + created_dagruns=[], + ) + ], + ) + monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", comms, raising=False) + + sensor = AssetPartitionSensor( + task_id="wait_orders", + asset=Asset(name="orders", uri="s3://warehouse/orders"), + partition_key="2024-01-01", + ) + + assert sensor.poke({}) is True + comms.send.assert_called_once_with( + GetAssetEventByAsset( + name="orders", + uri="s3://warehouse/orders", + partition_key="2024-01-01", + ascending=False, + limit=1, + ) + ) + + def test_poke_returns_false_when_partition_event_is_missing(self, monkeypatch): + comms = mock.Mock() + comms.send.return_value = AssetEventsResult(asset_events=[]) + monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", comms, raising=False) + + sensor = AssetPartitionSensor( + task_id="wait_orders", + asset=Asset(name="orders", uri="s3://warehouse/orders"), + partition_key="2024-01-01", + ) + + assert sensor.poke({}) is False + + def test_poke_raises_runtime_error_for_supervisor_error(self, monkeypatch): + comms = mock.Mock() + comms.send.return_value = ErrorResponse(error=ErrorType.ASSET_NOT_FOUND) + monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", comms, raising=False) + + sensor = AssetPartitionSensor( + task_id="wait_orders", + asset=Asset(name="orders", uri="s3://warehouse/orders"), + partition_key="2024-01-01", + ) + + with pytest.raises(AirflowRuntimeError): + sensor.poke({}) + + def test_poke_raises_for_unexpected_supervisor_response(self, monkeypatch): + comms = mock.Mock() + comms.send.return_value = object() + monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", comms, raising=False) + + sensor = AssetPartitionSensor( + task_id="wait_orders", + asset=Asset(name="orders", uri="s3://warehouse/orders"), + partition_key="2024-01-01", + ) + + with pytest.raises(TypeError, match="Unexpected response from supervisor"): + sensor.poke({}) + + def test_execute_defers_when_partition_event_is_missing(self, monkeypatch): + comms = mock.Mock() + comms.send.return_value = AssetEventsResult(asset_events=[]) + monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", comms, raising=False) + + sensor = AssetPartitionSensor( + task_id="wait_orders", + asset=Asset(name="orders", uri="s3://warehouse/orders"), + partition_key="2024-01-01", + deferrable=True, + ) + + with pytest.raises(TaskDeferred) as exc: + sensor.execute({}) + + assert isinstance(exc.value.trigger, AssetPartitionTrigger) + assert exc.value.trigger.asset_name == "orders" + assert exc.value.trigger.asset_uri == "s3://warehouse/orders" + assert exc.value.trigger.partition_key == "2024-01-01" + + def test_execute_complete_raises_for_trigger_error(self): + sensor = AssetPartitionSensor( + task_id="wait_orders", + asset=Asset(name="orders", uri="s3://warehouse/orders"), + partition_key="2024-01-01", + ) + + with pytest.raises(AirflowFailException, match="failed"): + sensor.execute_complete({}, {"status": "error", "message": "failed"}) diff --git a/providers/standard/tests/unit/standard/triggers/test_asset.py b/providers/standard/tests/unit/standard/triggers/test_asset.py new file mode 100644 index 0000000000000..161831cd50053 --- /dev/null +++ b/providers/standard/tests/unit/standard/triggers/test_asset.py @@ -0,0 +1,125 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS + +pytestmark = pytest.mark.skipif( + not AIRFLOW_V_3_1_PLUS, + reason="Asset partition trigger tests require Airflow 3.1+ Task SDK supervisor comms", +) + +if AIRFLOW_V_3_1_PLUS: + from airflow.providers.standard.triggers.asset import AssetPartitionTrigger + from airflow.sdk import timezone + from airflow.sdk.api.datamodels._generated import AssetEventResponse, AssetResponse + from airflow.sdk.exceptions import ErrorType + from airflow.sdk.execution_time import task_runner + from airflow.sdk.execution_time.comms import AssetEventsResult, ErrorResponse, GetAssetEventByAsset + from airflow.triggers.base import TriggerEvent + + +class TestAssetPartitionTrigger: + def test_serialization(self): + trigger = AssetPartitionTrigger( + asset_name="orders", + asset_uri="s3://warehouse/orders", + partition_key="2024-01-01", + poke_interval=10, + ) + + classpath, kwargs = trigger.serialize() + + assert classpath == "airflow.providers.standard.triggers.asset.AssetPartitionTrigger" + assert kwargs == { + "asset_name": "orders", + "asset_uri": "s3://warehouse/orders", + "partition_key": "2024-01-01", + "poke_interval": 10, + } + + @pytest.mark.asyncio + async def test_run_yields_success_when_partition_event_exists(self, monkeypatch): + comms = mock.Mock() + comms.asend = mock.AsyncMock( + return_value=AssetEventsResult( + asset_events=[ + AssetEventResponse( + id=1, + timestamp=timezone.utcnow(), + asset=AssetResponse(name="orders", uri="s3://warehouse/orders", group="asset"), + partition_key="2024-01-01", + created_dagruns=[], + ) + ], + ) + ) + monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", comms, raising=False) + trigger = AssetPartitionTrigger( + asset_name="orders", + asset_uri="s3://warehouse/orders", + partition_key="2024-01-01", + poke_interval=0, + ) + + assert await trigger.run().__anext__() == TriggerEvent({"status": "success"}) + comms.asend.assert_awaited_once_with( + GetAssetEventByAsset( + name="orders", + uri="s3://warehouse/orders", + partition_key="2024-01-01", + ascending=False, + limit=1, + ) + ) + + @pytest.mark.asyncio + async def test_run_yields_error_for_supervisor_error(self, monkeypatch): + comms = mock.Mock() + comms.asend = mock.AsyncMock(return_value=ErrorResponse(error=ErrorType.ASSET_NOT_FOUND)) + monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", comms, raising=False) + trigger = AssetPartitionTrigger( + asset_name="orders", + asset_uri="s3://warehouse/orders", + partition_key="2024-01-01", + poke_interval=0, + ) + + assert await trigger.run().__anext__() == TriggerEvent( + {"status": "error", "message": "ASSET_NOT_FOUND: None"} + ) + + @pytest.mark.asyncio + async def test_run_yields_error_for_unexpected_supervisor_response(self, monkeypatch): + comms = mock.Mock() + comms.asend = mock.AsyncMock(return_value=object()) + monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", comms, raising=False) + trigger = AssetPartitionTrigger( + asset_name="orders", + asset_uri="s3://warehouse/orders", + partition_key="2024-01-01", + poke_interval=0, + ) + + assert await trigger.run().__anext__() == TriggerEvent( + {"status": "error", "message": "Unexpected response from supervisor: object"} + ) diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 66496d5cac911..a49af47868070 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -861,6 +861,7 @@ def get( alias_name: str | None = None, after: datetime | None = None, before: datetime | None = None, + partition_key: str | None = None, ascending: bool = True, limit: int | None = None, ) -> AssetEventsResponse: @@ -870,6 +871,8 @@ def get( common_params["after"] = after.isoformat() if before: common_params["before"] = before.isoformat() + if partition_key is not None: + common_params["partition_key"] = partition_key common_params["ascending"] = ascending if limit: common_params["limit"] = limit diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index a042d7a1e80b8..a4ac7c721e959 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -1099,6 +1099,7 @@ class GetAssetsByAlias(BaseModel): class GetAssetEventByAsset(BaseModel): name: str | None uri: str | None + partition_key: str | None = None after: AwareDatetime | None = None before: AwareDatetime | None = None limit: int | None = None @@ -1108,6 +1109,7 @@ class GetAssetEventByAsset(BaseModel): class GetAssetEventByAssetAlias(BaseModel): alias_name: str + partition_key: str | None = None after: AwareDatetime | None = None before: AwareDatetime | None = None limit: int | None = None diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index 6cbf3dacadb0e..6e2ecaa854073 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -1079,6 +1079,7 @@ class InletEventsAccessor(Sequence["AssetEventResult"]): _before: str | datetime | None _ascending: bool _limit: int | None + _partition_key: str | None _asset_name: str | None _asset_uri: str | None _alias_name: str | None @@ -1093,6 +1094,7 @@ def __init__( self._before = None self._ascending = True self._limit = None + self._partition_key = None def after(self, after: str) -> Self: self._after = after @@ -1114,6 +1116,11 @@ def limit(self, limit: int) -> Self: self._reset_cache() return self + def partition_key(self, partition_key: str | None) -> Self: + self._partition_key = partition_key + self._reset_cache() + return self + @functools.cached_property def _asset_events(self) -> list[AssetEventResult]: from airflow.sdk.execution_time.comms import ( @@ -1129,6 +1136,7 @@ def _asset_events(self) -> list[AssetEventResult]: "before": self._before, "ascending": self._ascending, "limit": self._limit, + "partition_key": self._partition_key, } msg: ToSupervisor diff --git a/task-sdk/src/airflow/sdk/execution_time/request_handlers.py b/task-sdk/src/airflow/sdk/execution_time/request_handlers.py index 959be43fe93cf..e5c9dcb595b6e 100644 --- a/task-sdk/src/airflow/sdk/execution_time/request_handlers.py +++ b/task-sdk/src/airflow/sdk/execution_time/request_handlers.py @@ -40,10 +40,13 @@ XComSequenceSliceResponse, ) from airflow.sdk.execution_time.comms import ( + AssetEventsResult, ConnectionResult, DagRunStateResult, DeleteVariable, DeleteXCom, + GetAssetEventByAsset, + GetAssetEventByAssetAlias, GetConnection, GetDagRunState, GetDRCount, @@ -207,6 +210,37 @@ def handle_get_dag_run_state(client: Client, msg: GetDagRunState) -> tuple[BaseM return dr_resp, {} +def handle_get_asset_event_by_asset( + client: Client, msg: GetAssetEventByAsset +) -> tuple[BaseModel | None, dict[str, bool]]: + """Fetch asset events for an asset.""" + asset_event_resp = client.asset_events.get( + uri=msg.uri, + name=msg.name, + partition_key=msg.partition_key, + after=msg.after, + before=msg.before, + ascending=msg.ascending, + limit=msg.limit, + ) + return AssetEventsResult.from_asset_events_response(asset_event_resp), {"exclude_unset": True} + + +def handle_get_asset_event_by_asset_alias( + client: Client, msg: GetAssetEventByAssetAlias +) -> tuple[BaseModel | None, dict[str, bool]]: + """Fetch asset events for an asset alias.""" + asset_event_resp = client.asset_events.get( + alias_name=msg.alias_name, + partition_key=msg.partition_key, + after=msg.after, + before=msg.before, + ascending=msg.ascending, + limit=msg.limit, + ) + return AssetEventsResult.from_asset_events_response(asset_event_resp), {"exclude_unset": True} + + def handle_get_previous_dag_run( client: Client, msg: GetPreviousDagRun ) -> tuple[BaseModel | None, dict[str, bool]]: diff --git a/task-sdk/src/airflow/sdk/execution_time/schema/schema.json b/task-sdk/src/airflow/sdk/execution_time/schema/schema.json index 4807d9f53b353..e70e67e8832f0 100644 --- a/task-sdk/src/airflow/sdk/execution_time/schema/schema.json +++ b/task-sdk/src/airflow/sdk/execution_time/schema/schema.json @@ -1782,6 +1782,18 @@ ], "title": "Uri" }, + "partition_key": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Partition Key" + }, "after": { "anyOf": [ { @@ -1845,6 +1857,18 @@ "title": "Alias Name", "type": "string" }, + "partition_key": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Partition Key" + }, "after": { "anyOf": [ { diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index d4617efbbae64..05484269aa2dc 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -60,7 +60,6 @@ from airflow.sdk.exceptions import ErrorType from airflow.sdk.execution_time import comms from airflow.sdk.execution_time.comms import ( - AssetEventsResult, AssetResult, AssetStateStoreResult, AwaitInputTask, @@ -135,6 +134,8 @@ from airflow.sdk.execution_time.request_handlers import ( handle_delete_variable, handle_delete_xcom, + handle_get_asset_event_by_asset, + handle_get_asset_event_by_asset_alias, handle_get_connection, handle_get_dag_run_state, handle_get_dr_count, @@ -1751,28 +1752,9 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: elif isinstance(msg, GetAssetsByAlias): resp = self.client.assets.get_by_alias(alias_name=msg.alias_name) elif isinstance(msg, GetAssetEventByAsset): - asset_event_resp = self.client.asset_events.get( - uri=msg.uri, - name=msg.name, - after=msg.after, - before=msg.before, - ascending=msg.ascending, - limit=msg.limit, - ) - asset_event_result = AssetEventsResult.from_asset_events_response(asset_event_resp) - resp = asset_event_result - dump_opts = {"exclude_unset": True} + resp, dump_opts = handle_get_asset_event_by_asset(self.client, msg) elif isinstance(msg, GetAssetEventByAssetAlias): - asset_event_resp = self.client.asset_events.get( - alias_name=msg.alias_name, - after=msg.after, - before=msg.before, - ascending=msg.ascending, - limit=msg.limit, - ) - asset_event_result = AssetEventsResult.from_asset_events_response(asset_event_resp) - resp = asset_event_result - dump_opts = {"exclude_unset": True} + resp, dump_opts = handle_get_asset_event_by_asset_alias(self.client, msg) elif isinstance(msg, GetPrevSuccessfulDagRun): resp, dump_opts = handle_get_prev_successful_dag_run(self.client, self.id) elif isinstance(msg, GetXComCount): diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index 0d6ee44628adf..8a2144d9300e6 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -1177,6 +1177,8 @@ class TestAssetEventOperations: [ ({"name": "this_asset", "uri": "s3://bucket/key"}), ({"alias_name": "this_asset_alias"}), + ({"name": "this_asset", "uri": "s3://bucket/key", "partition_key": "2021-01-02"}), + ({"alias_name": "this_asset_alias", "partition_key": "2021-01-02"}), ], ) def test_by_name_get_success(self, request_params): @@ -1189,6 +1191,8 @@ def handle_request(request: httpx.Request) -> httpx.Response: assert params.get("name") == request_params.get("alias_name") else: return httpx.Response(status_code=400, json={"detail": "Bad Request"}) + if partition_key := request_params.get("partition_key"): + assert params.get("partition_key") == partition_key return httpx.Response( status_code=200, diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py b/task-sdk/tests/task_sdk/execution_time/test_context.py index 6f09489a88d1e..f182bc900e16a 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_context.py +++ b/task-sdk/tests/task_sdk/execution_time/test_context.py @@ -809,12 +809,13 @@ def test__get_item__with_filters(self, sample_inlet_evnets_accessor, mock_superv asset=AssetResponse(name="test_uri", uri="test_uri", group="asset"), ) events_result = AssetEventsResult(asset_events=[asset_event_resp]) - mock_supervisor_comms.send.side_effect = [events_result] * 10 + mock_supervisor_comms.send.side_effect = [events_result] * 11 list(sample_inlet_evnets_accessor[TEST_ASSET]) list(sample_inlet_evnets_accessor[TEST_ASSET].after("2024-01-01T00:00:00Z")) list(sample_inlet_evnets_accessor[TEST_ASSET].before("2024-01-01T00:00:00Z")) list(sample_inlet_evnets_accessor[TEST_ASSET].limit(10)) + list(sample_inlet_evnets_accessor[TEST_ASSET].partition_key("2024-01-01")) list( sample_inlet_evnets_accessor[TEST_ASSET] .after("2024-01-01T00:00:00Z") @@ -823,33 +824,33 @@ def test__get_item__with_filters(self, sample_inlet_evnets_accessor, mock_superv ) list(sample_inlet_evnets_accessor[TEST_ASSET].ascending(False).limit(10)) - assert mock_supervisor_comms.send.call_count == 6 + assert mock_supervisor_comms.send.call_count == 7 # test accessing the accessor without list() or [] sample_inlet_evnets_accessor[TEST_ASSET].ascending(False).limit(10) - assert mock_supervisor_comms.send.call_count == 6 + assert mock_supervisor_comms.send.call_count == 7 # test accessing one of the elements res = sample_inlet_evnets_accessor[TEST_ASSET].ascending(False).limit(10)[0] assert res == asset_event_resp - assert mock_supervisor_comms.send.call_count == 7 + assert mock_supervisor_comms.send.call_count == 8 # test evaluating the accessor multiple times with the same filters res = sample_inlet_evnets_accessor[TEST_ASSET].ascending(False).limit(10) assert res[0] == asset_event_resp assert res[0] == asset_event_resp - assert mock_supervisor_comms.send.call_count == 8 + assert mock_supervisor_comms.send.call_count == 9 # test changing one of the filters assert res.after("2024-01-01T00:00:00Z")[0] == asset_event_resp - assert mock_supervisor_comms.send.call_count == 9 + assert mock_supervisor_comms.send.call_count == 10 # test len() assert len(sample_inlet_evnets_accessor[TEST_ASSET].ascending(True).limit(10)) == 1 - assert mock_supervisor_comms.send.call_count == 10 + assert mock_supervisor_comms.send.call_count == 11 calls = mock_supervisor_comms.send.call_args_list assert calls[0][0][0] == GetAssetEventByAsset( @@ -875,6 +876,15 @@ def test__get_item__with_filters(self, sample_inlet_evnets_accessor, mock_superv name="test_uri", uri="test://test/", after=None, before=None, limit=10, ascending=True ) assert calls[4][0][0] == GetAssetEventByAsset( + name="test_uri", + uri="test://test/", + after=None, + before=None, + limit=None, + ascending=True, + partition_key="2024-01-01", + ) + assert calls[5][0][0] == GetAssetEventByAsset( name="test_uri", uri="test://test/", after="2024-01-01T00:00:00Z", @@ -882,7 +892,7 @@ def test__get_item__with_filters(self, sample_inlet_evnets_accessor, mock_superv limit=10, ascending=True, ) - assert calls[5][0][0] == GetAssetEventByAsset( + assert calls[6][0][0] == GetAssetEventByAsset( name="test_uri", uri="test://test/", after=None, before=None, limit=10, ascending=False ) diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 4893258195446..2b6315aa9661c 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -1983,6 +1983,7 @@ class RequestTestCase: kwargs={ "uri": "s3://bucket/obj", "name": "test", + "partition_key": None, "after": None, "before": None, "limit": None, @@ -2009,6 +2010,7 @@ class RequestTestCase: before=datetime(2024, 10, 15, 12, 0, 0, tzinfo=timezone.utc), limit=5, ascending=False, + partition_key="2024-10-31", ), expected_body={ "asset_events": [ @@ -2026,6 +2028,7 @@ class RequestTestCase: kwargs={ "uri": "s3://bucket/obj", "name": "test", + "partition_key": "2024-10-31", "after": timezone.parse("2024-10-01T12:00:00Z"), "before": timezone.parse("2024-10-15T12:00:00Z"), "limit": 5, @@ -2062,6 +2065,7 @@ class RequestTestCase: kwargs={ "uri": "s3://bucket/obj", "name": None, + "partition_key": None, "after": None, "before": None, "limit": None, @@ -2105,6 +2109,7 @@ class RequestTestCase: kwargs={ "uri": "s3://bucket/obj", "name": None, + "partition_key": None, "after": timezone.parse("2024-10-01T12:00:00Z"), "before": timezone.parse("2024-10-15T12:00:00Z"), "limit": 5, @@ -2141,6 +2146,7 @@ class RequestTestCase: kwargs={ "uri": None, "name": "test", + "partition_key": None, "after": None, "before": None, "limit": None, @@ -2184,6 +2190,7 @@ class RequestTestCase: kwargs={ "uri": None, "name": "test", + "partition_key": None, "after": timezone.parse("2024-10-01T12:00:00Z"), "before": timezone.parse("2024-10-15T12:00:00Z"), "limit": 5, @@ -2219,6 +2226,7 @@ class RequestTestCase: method_path="asset_events.get", kwargs={ "alias_name": "test_alias", + "partition_key": None, "after": None, "before": None, "limit": None, @@ -2244,6 +2252,7 @@ class RequestTestCase: before=datetime(2024, 10, 15, 12, 0, 0, tzinfo=timezone.utc), limit=5, ascending=False, + partition_key="2024-10-31", ), expected_body={ "asset_events": [ @@ -2260,6 +2269,7 @@ class RequestTestCase: method_path="asset_events.get", kwargs={ "alias_name": "test_alias", + "partition_key": "2024-10-31", "after": timezone.parse("2024-10-01T12:00:00Z"), "before": timezone.parse("2024-10-15T12:00:00Z"), "limit": 5,