Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions airflow-core/newsfragments/67941.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add an asset partition sensor for waiting on a specific asset event partition.
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def get_asset_event_by_asset_name_uri(
session: SessionDep,
after: Annotated[UtcDateTime | None, Query(description="The start of the time range")] = None,
before: Annotated[UtcDateTime | None, Query(description="The end of the time range")] = None,
partition_key: Annotated[str | None, Query(description="The partition key of the Asset Event")] = None,
ascending: Annotated[bool, Query(description="Whether to sort results in ascending order")] = True,
limit: Annotated[int | None, Query(description="The maximum number of results to return")] = None,
) -> AssetEventsResponse:
Expand All @@ -102,6 +103,8 @@ def get_asset_event_by_asset_name_uri(
where_clause = and_(where_clause, AssetEvent.timestamp >= after)
if before:
where_clause = and_(where_clause, AssetEvent.timestamp <= before)
if partition_key is not None:
where_clause = and_(where_clause, AssetEvent.partition_key == partition_key)

return _get_asset_events_through_sql_clauses(
join_clause=AssetEvent.asset,
Expand All @@ -118,6 +121,7 @@ def get_asset_event_by_asset_alias(
session: SessionDep,
after: Annotated[UtcDateTime | None, Query(description="The start of the time range")] = None,
before: Annotated[UtcDateTime | None, Query(description="The end of the time range")] = None,
partition_key: Annotated[str | None, Query(description="The partition key of the Asset Event")] = None,
ascending: Annotated[bool, Query(description="Whether to sort results in ascending order")] = True,
limit: Annotated[int | None, Query(description="The maximum number of results to return")] = None,
) -> AssetEventsResponse:
Expand All @@ -126,6 +130,8 @@ def get_asset_event_by_asset_alias(
where_clause = and_(where_clause, AssetEvent.timestamp >= after)
if before:
where_clause = and_(where_clause, AssetEvent.timestamp <= before)
if partition_key is not None:
where_clause = and_(where_clause, AssetEvent.partition_key == partition_key)

return _get_asset_events_through_sql_clauses(
join_clause=AssetEvent.source_aliases,
Expand Down
8 changes: 8 additions & 0 deletions airflow-core/src/airflow/jobs/triggerer_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
DeleteXCom,
DRCount,
ErrorResponse,
GetAssetEventByAsset,
GetAssetEventByAssetAlias,
GetConnection,
GetDagRunState,
GetDRCount,
Expand All @@ -92,6 +94,8 @@
from airflow.sdk.execution_time.request_handlers import (
handle_delete_variable,
handle_delete_xcom,
handle_get_asset_event_by_asset,
handle_get_asset_event_by_asset_alias,
handle_get_connection,
handle_get_dag_run_state,
handle_get_dr_count,
Expand Down Expand Up @@ -600,6 +604,10 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r
resp, dump_opts = handle_get_dr_count(self.client, msg)
elif isinstance(msg, GetDagRunState):
resp, dump_opts = handle_get_dag_run_state(self.client, msg)
elif isinstance(msg, GetAssetEventByAsset):
resp, dump_opts = handle_get_asset_event_by_asset(self.client, msg)
elif isinstance(msg, GetAssetEventByAssetAlias):
resp, dump_opts = handle_get_asset_event_by_asset_alias(self.client, msg)

elif isinstance(msg, GetTICount):
resp, dump_opts = handle_get_ti_count(self.client, msg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,33 @@ def make_timestamp(day):
session.commit()


@pytest.fixture
def test_partitioned_asset_events(session, test_asset):
def make_timestamp(day):
return datetime(2021, 1, day, tzinfo=timezone.utc)

common = {
"asset_id": test_asset.id,
"extra": {"foo": "bar"},
"source_dag_id": "foo",
"source_task_id": "bar",
"source_run_id": "custom",
"source_map_index": -1,
}

events = [
AssetEvent(id=101, timestamp=make_timestamp(1), partition_key="2021-01-01", **common),
AssetEvent(id=102, timestamp=make_timestamp(2), partition_key="2021-01-02", **common),
]
session.add_all(events)
session.commit()
yield events

for event in events:
session.delete(event)
session.commit()


@pytest.fixture
def test_asset(session):
asset = AssetModel(
Expand Down Expand Up @@ -90,6 +117,20 @@ def test_asset_alias(session, test_asset_events, test_asset):
session.commit()


@pytest.fixture
def test_partitioned_asset_alias(session, test_partitioned_asset_events, test_asset):
alias = AssetAliasModel(name="partitioned_alias")
alias.asset_events = test_partitioned_asset_events
alias.assets.append(test_asset)
session.add(alias)
session.commit()

yield alias

session.delete(alias)
session.commit()


class TestGetAssetEventByAsset:
@pytest.mark.parametrize(
("uri", "name"),
Expand Down Expand Up @@ -457,6 +498,40 @@ def test_get_by_asset_get_last(self, uri, name, client):
]
}

@pytest.mark.parametrize(
("uri", "name"),
[
(None, "test_get_asset_by_name"),
("s3://bucket/key", None),
("s3://bucket/key", "test_get_asset_by_name"),
],
)
def test_get_by_asset_with_partition_key_filter(self, uri, name, client, test_partitioned_asset_events):
response = client.get(
"/execution/asset-events/by-asset",
params={"name": name, "uri": uri, "partition_key": "2021-01-02"},
)
assert response.status_code == 200
assert response.json()["asset_events"] == [
{
"id": test_partitioned_asset_events[1].id,
"extra": {"foo": "bar"},
"source_task_id": "bar",
"source_dag_id": "foo",
"source_run_id": "custom",
"source_map_index": -1,
"asset": {
"extra": {"foo": "bar"},
"group": "asset",
"name": "test_get_asset_by_name",
"uri": "s3://bucket/key",
},
"created_dagruns": [],
"timestamp": "2021-01-02T00:00:00Z",
"partition_key": "2021-01-02",
},
]


class TestGetAssetEventByAssetAlias:
@pytest.mark.usefixtures("test_asset_alias")
Expand Down Expand Up @@ -521,3 +596,32 @@ def test_get_by_asset(self, client):
},
]
}

def test_get_by_asset_alias_with_partition_key_filter(
self, client, test_partitioned_asset_alias, test_partitioned_asset_events
):
response = client.get(
"/execution/asset-events/by-asset-alias",
params={"name": "partitioned_alias", "partition_key": "2021-01-02"},
)

assert response.status_code == 200
assert response.json()["asset_events"] == [
{
"id": test_partitioned_asset_events[1].id,
"extra": {"foo": "bar"},
"source_task_id": "bar",
"source_dag_id": "foo",
"source_run_id": "custom",
"source_map_index": -1,
"asset": {
"extra": {"foo": "bar"},
"group": "asset",
"name": "test_get_asset_by_name",
"uri": "s3://bucket/key",
},
"created_dagruns": [],
"timestamp": "2021-01-02T00:00:00Z",
"partition_key": "2021-01-02",
},
]
2 changes: 1 addition & 1 deletion generated/provider_dependencies.json.sha256sum
Original file line number Diff line number Diff line change
@@ -1 +1 @@
93831555f2a141e481c81c147142aeb860c34ea860163ca130d045e5ecd0a83b
ad40a2903ed479a80b300a2094554db993e50e349614eb9226a1aa18a5fbf1cb
2 changes: 2 additions & 0 deletions providers/standard/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ sensors:
- airflow.providers.standard.sensors.python
- airflow.providers.standard.sensors.filesystem
- airflow.providers.standard.sensors.external_task
- airflow.providers.standard.sensors.asset
hooks:
- integration-name: Standard
python-modules:
Expand All @@ -119,6 +120,7 @@ triggers:
- airflow.providers.standard.triggers.file
- airflow.providers.standard.triggers.temporal
- airflow.providers.standard.triggers.hitl
- airflow.providers.standard.triggers.asset

extra-links:
- airflow.providers.standard.operators.trigger_dagrun.TriggerDagRunLink
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def get_provider_info():
"airflow.providers.standard.sensors.python",
"airflow.providers.standard.sensors.filesystem",
"airflow.providers.standard.sensors.external_task",
"airflow.providers.standard.sensors.asset",
],
}
],
Expand All @@ -96,6 +97,7 @@ def get_provider_info():
"airflow.providers.standard.triggers.file",
"airflow.providers.standard.triggers.temporal",
"airflow.providers.standard.triggers.hitl",
"airflow.providers.standard.triggers.asset",
],
}
],
Expand Down
103 changes: 103 additions & 0 deletions providers/standard/src/airflow/providers/standard/sensors/asset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import datetime
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any

from airflow.providers.common.compat.sdk import AirflowFailException, BaseSensorOperator, conf
from airflow.providers.standard.triggers.asset import AssetPartitionTrigger

if TYPE_CHECKING:
from airflow.providers.common.compat.sdk import Asset, Context


class AssetPartitionSensor(BaseSensorOperator):
"""
Wait for an asset event with the given partition key.

:param asset: asset to wait for.
:param partition_key: partition key for the asset event to wait for.
:param deferrable: If waiting for completion, whether to defer the task until done.
"""

template_fields: Sequence[str] = ("partition_key",)
ui_color = "#e6f1f2"

def __init__(
self,
*,
asset: Asset,
partition_key: str,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
super().__init__(**kwargs)
self.asset = asset
self.partition_key = partition_key
self.deferrable = deferrable

def poke(self, context: Context) -> bool:
from airflow.sdk.exceptions import AirflowRuntimeError
from airflow.sdk.execution_time.comms import AssetEventsResult, ErrorResponse, GetAssetEventByAsset
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

self.log.info("Poking for asset event: asset=%s, partition_key=%s", self.asset, self.partition_key)
response = SUPERVISOR_COMMS.send(
GetAssetEventByAsset(
name=self.asset.name,
uri=self.asset.uri,
partition_key=self.partition_key,
ascending=False,
limit=1,
)
)
if isinstance(response, ErrorResponse):
raise AirflowRuntimeError(response)
if not isinstance(response, AssetEventsResult):
raise TypeError(f"Unexpected response from supervisor: {type(response).__name__}")
return bool(response.asset_events)

def execute(self, context: Context) -> None:
if not self.deferrable:
super().execute(context=context)
return

if not self.poke(context=context):
self.defer(
timeout=datetime.timedelta(seconds=self.timeout),
trigger=AssetPartitionTrigger(
asset_name=self.asset.name,
asset_uri=self.asset.uri,
partition_key=self.partition_key,
poke_interval=self.poke_interval,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
if event and event.get("status") == "success":
self.log.info(
"Asset partition event found: asset=%s, partition_key=%s",
self.asset,
self.partition_key,
)
return
message = event.get("message") if event else "Trigger completed without an event"
raise AirflowFailException(message)
Loading
Loading