Skip to content

Commit

Permalink
Merge WasbBlobAsyncSensor to WasbBlobSensor (#30488)
Browse files Browse the repository at this point in the history
  • Loading branch information
phanikumv authored Apr 17, 2023
1 parent 6a6455a commit 6b5db07
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 47 deletions.
90 changes: 47 additions & 43 deletions airflow/providers/microsoft/azure/sensors/wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import warnings
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Sequence

Expand All @@ -38,6 +39,8 @@ class WasbBlobSensor(BaseSensorOperator):
:param wasb_conn_id: Reference to the :ref:`wasb connection <howto/connection:wasb>`.
:param check_options: Optional keyword arguments that
`WasbHook.check_for_blob()` takes.
:param deferrable: Run sensor in the deferrable mode.
:param public_read: whether an anonymous public read access should be used. Default is False
"""

template_fields: Sequence[str] = ("container_name", "blob_name")
Expand All @@ -49,6 +52,8 @@ def __init__(
blob_name: str,
wasb_conn_id: str = "wasb_default",
check_options: dict | None = None,
public_read: bool = False,
deferrable: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -58,57 +63,32 @@ def __init__(
self.container_name = container_name
self.blob_name = blob_name
self.check_options = check_options
self.public_read = public_read
self.deferrable = deferrable

def poke(self, context: Context):
self.log.info("Poking for blob: %s\n in wasb://%s", self.blob_name, self.container_name)
hook = WasbHook(wasb_conn_id=self.wasb_conn_id)
return hook.check_for_blob(self.container_name, self.blob_name, **self.check_options)


class WasbBlobAsyncSensor(WasbBlobSensor):
"""
Polls asynchronously for the existence of a blob in a WASB container.
:param container_name: name of the container in which the blob should be searched for
:param blob_name: name of the blob to check existence for
:param wasb_conn_id: the connection identifier for connecting to Azure WASB
:param poke_interval: polling period in seconds to check for the status
:param public_read: whether an anonymous public read access should be used. Default is False
:param timeout: Time, in seconds before the task times out and fails.
"""

def __init__(
self,
*,
container_name: str,
blob_name: str,
wasb_conn_id: str = "wasb_default",
public_read: bool = False,
poke_interval: float = 5.0,
**kwargs: Any,
):
self.container_name = container_name
self.blob_name = blob_name
self.poke_interval = poke_interval
super().__init__(container_name=container_name, blob_name=blob_name, **kwargs)
self.wasb_conn_id = wasb_conn_id
self.public_read = public_read

def execute(self, context: Context) -> None:
"""Defers trigger class to poll for state of the job run until it reaches
a failure state or success state
"""Defers trigger class to poll for state of the job run until
it reaches a failure state or success state
"""
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=WasbBlobSensorTrigger(
container_name=self.container_name,
blob_name=self.blob_name,
wasb_conn_id=self.wasb_conn_id,
public_read=self.public_read,
poke_interval=self.poke_interval,
),
method_name="execute_complete",
)
if not self.deferrable:
super().execute(context=context)
else:
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=WasbBlobSensorTrigger(
container_name=self.container_name,
blob_name=self.blob_name,
wasb_conn_id=self.wasb_conn_id,
public_read=self.public_read,
poke_interval=self.poke_interval,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, str]) -> None:
"""
Expand All @@ -124,6 +104,30 @@ def execute_complete(self, context: Context, event: dict[str, str]) -> None:
raise AirflowException("Did not receive valid event from the triggerer")


class WasbBlobAsyncSensor(WasbBlobSensor):
"""
Polls asynchronously for the existence of a blob in a WASB container.
:param container_name: name of the container in which the blob should be searched for
:param blob_name: name of the blob to check existence for
:param wasb_conn_id: the connection identifier for connecting to Azure WASB
:param poke_interval: polling period in seconds to check for the status
:param public_read: whether an anonymous public read access should be used. Default is False
:param timeout: Time, in seconds before the task times out and fails.
"""

def __init__(self, **kwargs: Any) -> None:
warnings.warn(
"Class `WasbBlobAsyncSensor` is deprecated and "
"will be removed in a future release. "
"Please use `WasbBlobSensor` and "
"set `deferrable` attribute to `True` instead",
DeprecationWarning,
stacklevel=2,
)
super().__init__(**kwargs, deferrable=True)


class WasbPrefixSensor(BaseSensorOperator):
"""
Waits for blobs matching a prefix to arrive on Azure Blob Storage.
Expand Down
4 changes: 2 additions & 2 deletions tests/providers/microsoft/azure/sensors/test_wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.providers.microsoft.azure.sensors.wasb import (
WasbBlobAsyncSensor,
WasbBlobSensor,
WasbPrefixSensor,
)
Expand Down Expand Up @@ -120,11 +119,12 @@ def create_context(self, task, dag=None):
"logical_date": execution_date,
}

SENSOR = WasbBlobAsyncSensor(
SENSOR = WasbBlobSensor(
task_id="wasb_blob_async_sensor",
container_name=TEST_DATA_STORAGE_CONTAINER_NAME,
blob_name=TEST_DATA_STORAGE_BLOB_NAME,
timeout=5,
deferrable=True,
)

def test_wasb_blob_sensor_async(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from datetime import datetime

from airflow import DAG
from airflow.providers.microsoft.azure.sensors.wasb import WasbBlobAsyncSensor, WasbBlobSensor
from airflow.providers.microsoft.azure.sensors.wasb import WasbBlobSensor
from airflow.providers.microsoft.azure.transfers.azure_blob_to_gcs import AzureBlobStorageToGCSOperator

# Ignore missing args provided by default_args
Expand All @@ -46,7 +46,7 @@

wait_for_blob = WasbBlobSensor(task_id="wait_for_blob")

wait_for_blob_async = WasbBlobAsyncSensor(task_id="wait_for_blob_async")
wait_for_blob_async = WasbBlobSensor(task_id="wait_for_blob_async", deferrable=True)

transfer_files_to_gcs = AzureBlobStorageToGCSOperator(
task_id="transfer_files_to_gcs",
Expand Down

0 comments on commit 6b5db07

Please sign in to comment.