diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..e8603ac --- /dev/null +++ b/.flake8 @@ -0,0 +1,8 @@ +[flake8] +max-line-length = 120 +extend-ignore = E203, E266, E501, W503, E741 +exclude = .svn,CVS,.bzr,.hg,.git,__pycache__,venv/*,src/*,tests/unit/common/protos/*,build +max-complexity=32 +per-file-ignores = + *:F821 + */__init__.py: F401 diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..0485337 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,40 @@ +name: Build + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +jobs: + build: + runs-on: ubuntu-latest + + strategy: + fail-fast: false + matrix: + python-version: ["3.7", "3.8", "3.9", "3.10"] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Cache pip + uses: actions/cache@v2 + with: + # This path is specific to Ubuntu + path: ~/.cache/pip + # Look to see if there is a cache hit for the corresponding requirements files + key: ${{ format('{0}-pip-{1}', runner.os, hashFiles('requirements.txt', 'requirements-dev.txt', 'requirements-docs.txt')) }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + - name: Lint + run: | + pre-commit run --all-files + - name: Unit tests with pytest + run: | + python -m unittest diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..843b1df --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,36 @@ +name: Publish Airflow Flyte Provider + +on: + release: + types: [published] + +jobs: + deploy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: "0" + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: "3.x" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install setuptools wheel twine + - name: Autobump version setup.py + id: bump-setup-py + run: | + # from refs/tags/v1.2.3 get 1.2.3 + VERSION=$(echo $GITHUB_REF | sed 's#.*/v##') + PLACEHOLDER="__version__\ =\ \"0.0.0+dev0\"" + grep "$PLACEHOLDER" "setup.py" + sed -i "s#$PLACEHOLDER#__version__ = \"$VERSION\"#g" "setup.py" + - name: Build and publish + env: + TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} + TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} + run: | + python setup.py sdist bdist_wheel + twine upload dist/* diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..96db32d --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +.DS_Store +__pycache__/ +.coverage +*.egg-info/ +build/ +dist/ +.python-version +*_astro diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..4865d11 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,24 @@ +repos: + - repo: https://github.com/PyCQA/flake8 + rev: 4.0.1 + hooks: + - id: flake8 + - repo: https://github.com/psf/black + rev: 22.3.0 + hooks: + - id: black + - repo: https://github.com/PyCQA/isort + rev: 5.10.1 + hooks: + - id: isort + args: ["--profile", "black"] + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.2.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + - repo: https://github.com/shellcheck-py/shellcheck-py + rev: v0.8.0.4 + hooks: + - id: shellcheck diff --git a/README.md b/README.md index dd571fd..a35efef 100644 --- a/README.md +++ b/README.md @@ -1 +1,83 @@ -# airflow-provider-flyte +# Flyte Provider for Apache Airflow + +This package provides an operator, a sensor, and a hook that integrates [Flyte](flyte.org/) into Apache Airflow. +`FlyteOperator` is helpful to trigger a task/workflow in Flyte and `FlyteSensor` enables monitoring a Flyte execution status +for completion. + +## Installation + +Prerequisites: An environment running `apache-airflow`. + +``` +pip install airflow-provider-flyte +``` + +## Configuration + +In the Airflow UI, configure a _Connection_ for Flyte. + +- Host(optional): The FlyteAdmin host. Defaults to localhost. +- Port (optional): The FlyteAdmin port. Defaults to 30081. +- Login (optional): `client_id` +- Password (optional): `client_credentials_secret` +- Extra (optional): Specify the `extra` parameter as JSON dictionary to provide additional parameters. + - `project`: The default project to connect to. + - `domain`: The default domain to connect to. + - `insecure`: Whether to use SSL or not. + - `command`: The command to execute to return a token using an external process. + - `scopes`: List of scopes to request. + - `auth_mode`: The OAuth mode to use. Defaults to pkce flow. + - `env_prefix`: Prefix that will be used to lookup for injected secrets at runtime. + - `default_dir`: Default directory that will be used to find secrets as individual files. + - `file_prefix`: Prefix for the file in the `default_dir`. + - `statsd_host`: The statsd host. + - `statsd_port`: The statsd port. + - `statsd_disabled`: Whether to send statsd or not. + - `statsd_disabled_tags`: Turn on to reduce cardinality. + - `local_sandbox_path` + - S3 Config: + - `s3_enable_debug` + - `s3_endpoint` + - `s3_retries` + - `s3_backoff` + - `s3_access_key_id` + - `s3_secret_access_key` + - GCS Config: + - `gsutil_parallelism` + +## Modules + +### [Flyte Operator](https://github.com/flyteorg/airflow-provider-flyte/blob/main/flyte_provider/operators/flyte.py) + +The `FlyteOperator` requires a `flyte_conn_id` to fetch all the connection-related +parameters that are useful to instantiate `FlyteRemote`. Also, you must give a +`launchplan_name` to trigger a workflow, or `task_name` to trigger a task; you can give a +handful of other values that are optional, such as `project`, `domain`, `max_parallelism`, +`raw_data_prefix`, `kubernetes_service_account`, `labels`, `annotations`, +`secrets`, `notifications`, `disable_notifications`, `oauth2_client`, `version`, and `inputs`. + +Import into your DAG via: + +``` +from flyte_provider.operators.flyte import FlyteOperator +``` + +### [Flyte Sensor](https://github.com/flyteorg/airflow-provider-flyte/blob/main/flyte_provider/sensors/flyte.py) + +If you need to wait for an execution to complete, use `FlyteSensor`. +Monitoring with `FlyteSensor` allows you to trigger downstream processes only when the Flyte executions are complete. + +Import into your DAG via: + +``` +from flyte_provider.sensors.flyte import FlyteSensor +``` + +## Examples + +See the [examples](https://github.com/flyte/airflow-provider-flyte/tree/main/flyte_provider/example_dags) directory for an example DAG. + +## Issues + +Please file issues and open pull requests [here](https://github.com/flyteorg/airflow-provider-flyte). +If you hit any roadblock, hit us up on [Slack](https://slack.flyte.org/). diff --git a/flyte_provider/__init__.py b/flyte_provider/__init__.py new file mode 100644 index 0000000..7f9c234 --- /dev/null +++ b/flyte_provider/__init__.py @@ -0,0 +1,12 @@ +from typing import Any, Dict + + +def get_provider_info() -> Dict[str, Any]: + return { + "package-name": "airflow-provider-flyte", + "name": "Flyte Airflow Provider", + "description": "A Flyte provider for Apache Airflow.", + "hook-class-names": ["flyte_provider.hooks.flyte.FlyteHook"], + "extra-links": ["flyte_provider.operators.flyte.RegistryLink"], + "versions": ["0.0.1"], + } diff --git a/flyte_provider/example_dags/__init__.py b/flyte_provider/example_dags/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/flyte_provider/example_dags/example_flyte.py b/flyte_provider/example_dags/example_flyte.py new file mode 100644 index 0000000..0412664 --- /dev/null +++ b/flyte_provider/example_dags/example_flyte.py @@ -0,0 +1,35 @@ +from datetime import datetime, timedelta + +from airflow import DAG +from flyte_providers.flyte.operators.flyte import FlyteOperator +from flytekit.models.core import execution as _execution_model + +_workflow_execution_succeeded = _execution_model.WorkflowExecutionPhase.SUCCEEDED + +with DAG( + dag_id="example_flyte_operator", + schedule_interval=None, + start_date=datetime(2021, 1, 1), + dagrun_timeout=timedelta(minutes=60), + tags=["example"], + catchup=False, +) as dag: + # do not wait for the execution to complete + flyte_execution = FlyteOperator( + task_id="flyte_task", + flyte_conn_id="flyte_conn_example", + project="flytesnacks", + domain="development", + launchplan_name="core.basic.lp.my_wf", + kubernetes_service_account="demo", + version="v1", + inputs={"val": 19}, + notifications=[ + { + "phases": [_workflow_execution_succeeded], + "email": {"recipients_email": ["abc@flyte.org"]}, + } + ], + oauth2_client={"client_id": "123", "client_secret": "456"}, + secrets=[{"group": "secrets", "key": "123"}], + ) diff --git a/flyte_provider/example_dags/example_flyte_wait.py b/flyte_provider/example_dags/example_flyte_wait.py new file mode 100644 index 0000000..f0a1832 --- /dev/null +++ b/flyte_provider/example_dags/example_flyte_wait.py @@ -0,0 +1,52 @@ +from datetime import datetime, timedelta + +from airflow import DAG +from flyte_providers.flyte.operators.flyte import FlyteOperator +from flyte_providers.flyte.sensors.flyte import FlyteSensor + +with DAG( + dag_id="example_flyte_operator", + schedule_interval=None, + start_date=datetime(2021, 1, 1), + dagrun_timeout=timedelta(minutes=60), + tags=["example"], + catchup=False, +) as dag: + # wait for the execution to complete + flyte_execution_start = FlyteOperator( + task_id="flyte_task", + flyte_conn_id="flyte_conn_example", + project="flytesnacks", + domain="development", + launchplan_name="core.basic.lp.my_wf", + max_parallelism=2, + raw_data_prefix="s3://flyte-demo/raw_data", + kubernetes_service_account="demo", + version="v1", + inputs={"val": 19}, + ) + + flyte_execution_wait = FlyteSensor( + task_id="flyte_sensor_one", + execution_name=flyte_execution_start.output, + project="flytesnacks", + domain="development", + flyte_conn_id="flyte_conn_example", + ) # poke every 60 seconds (default) + + flyte_execution_start >> flyte_execution_wait + + # wait for a long-running execution to complete + flyte_execution_wait_long = FlyteSensor( + task_id="flyte_sensor_two", + execution_name=flyte_execution_start.output, + project="flytesnacks", + domain="development", + flyte_conn_id="flyte_conn_example", + mode="reschedule", + poke_interval=5 * 60, # check every 5 minutes + timeout="86400", # wait for a day + soft_fail=True, # task is skipped if the condition is not met by timeout + ) + + flyte_execution_start >> flyte_execution_wait_long diff --git a/flyte_provider/example_dags/example_flyte_xcom.py b/flyte_provider/example_dags/example_flyte_xcom.py new file mode 100644 index 0000000..c6a40a6 --- /dev/null +++ b/flyte_provider/example_dags/example_flyte_xcom.py @@ -0,0 +1,49 @@ +from datetime import datetime, timedelta + +from airflow import DAG +from flyte_providers.flyte.operators.flyte import FlyteOperator +from flyte_providers.flyte.sensors.flyte import FlyteSensor + +with DAG( + dag_id="example_flyte_operator", + schedule_interval=None, + start_date=datetime(2021, 1, 1), + dagrun_timeout=timedelta(minutes=60), + tags=["example"], + catchup=False, +) as dag: + + flyte_execution_start = FlyteOperator( + task_id="flyte_task_one", + flyte_conn_id="flyte_conn_example", + project="flytesnacks", + domain="development", + launchplan_name="core.basic.lp.my_wf", + max_parallelism=2, + raw_data_prefix="s3://flyte-demo/raw_data", + kubernetes_service_account="demo", + version="v1", + inputs={"val": 19}, + ) + + # wait for the execution to complete + flyte_execution_wait = FlyteSensor( + task_id="flyte_sensor", + execution_name=flyte_execution_start.output, + project="flytesnacks", + domain="development", + flyte_conn_id="flyte_conn_example", + ) # poke every 60 seconds (default) + + flyte_execution = FlyteOperator( + task_id="flyte_task_two", + flyte_conn_id="flyte_conn_example", + project="flytesnacks", + domain="development", + launchplan_name="core.basic.lp.my_wf", + kubernetes_service_account="demo", + version="v1", + inputs={"val": 19}, + ) + + flyte_execution_start >> flyte_execution_wait >> flyte_execution diff --git a/flyte_provider/hooks/__init__.py b/flyte_provider/hooks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/flyte_provider/hooks/flyte.py b/flyte_provider/hooks/flyte.py new file mode 100644 index 0000000..5b3a15f --- /dev/null +++ b/flyte_provider/hooks/flyte.py @@ -0,0 +1,302 @@ +from typing import Any, Dict, List, Optional + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from flytekit.configuration import ( + Config, + DataConfig, + GCSConfig, + PlatformConfig, + S3Config, + SecretsConfig, + StatsConfig, +) +from flytekit.exceptions.user import FlyteEntityNotExistException +from flytekit.models.common import ( + Annotations, + EmailNotification, + Labels, + Notification, + PagerDutyNotification, + SlackNotification, +) +from flytekit.models.core import execution as core_execution_models +from flytekit.models.core.identifier import WorkflowExecutionIdentifier +from flytekit.models.security import Identity, OAuth2Client, Secret, SecurityContext +from flytekit.remote.remote import FlyteRemote, Options + + +class FlyteHook(BaseHook): + """ + Interact with the FlyteRemote API. + + :param flyte_conn_id: Required. The name of the Flyte connection to get + the connection information for Flyte. + :param project: Optional. The project under consideration. + :param domain: Optional. The domain under consideration. + """ + + SUCCEEDED = core_execution_models.WorkflowExecutionPhase.SUCCEEDED + FAILED = core_execution_models.WorkflowExecutionPhase.FAILED + TIMED_OUT = core_execution_models.WorkflowExecutionPhase.TIMED_OUT + ABORTED = core_execution_models.WorkflowExecutionPhase.ABORTED + + flyte_conn_id = "flyte_default" + conn_type = "flyte" + hook_name = "Flyte" + + def __init__( + self, + flyte_conn_id: str = flyte_conn_id, + project: Optional[str] = None, + domain: Optional[str] = None, + ) -> None: + super().__init__() + self.flyte_conn_id = flyte_conn_id + self.flyte_conn = self.get_connection(self.flyte_conn_id) + self.project = project or self.flyte_conn.extra_dejson.get("project") + self.domain = domain or self.flyte_conn.extra_dejson.get("domain") + + if not (self.project and self.domain): + raise AirflowException("Please provide a project and domain.") + + def execution_id(self, execution_name: str) -> WorkflowExecutionIdentifier: + """Get the execution id.""" + return WorkflowExecutionIdentifier(self.project, self.domain, execution_name) + + def create_flyte_remote(self) -> FlyteRemote: + """Create a FlyteRemote object.""" + remote = FlyteRemote( + config=Config( + platform=PlatformConfig( + endpoint=":".join([self.flyte_conn.host, self.flyte_conn.port]) + if (self.flyte_conn.host and self.flyte_conn.port) + else (self.flyte_conn.host or PlatformConfig.endpoint), + insecure=self.flyte_conn.extra_dejson.get( + "insecure", PlatformConfig.insecure + ), + insecure_skip_verify=self.flyte_conn.extra_dejson.get( + "insecure_skip_verify", + PlatformConfig.insecure_skip_verify, + ), + client_id=self.flyte_conn.login or PlatformConfig.client_id, + client_credentials_secret=self.flyte_conn.password + or PlatformConfig.client_credentials_secret, + command=self.flyte_conn.extra_dejson.get( + "command", PlatformConfig.command + ), + scopes=self.flyte_conn.extra_dejson.get( + "scopes", getattr(PlatformConfig, "scopes", []) + ), + auth_mode=self.flyte_conn.extra_dejson.get( + "auth_mode", PlatformConfig.auth_mode + ), + ), + secrets=SecretsConfig( + env_prefix=self.flyte_conn.extra_dejson.get( + "env_prefix", SecretsConfig.env_prefix + ), + default_dir=self.flyte_conn.extra_dejson.get( + "default_dir", SecretsConfig.default_dir + ), + file_prefix=self.flyte_conn.extra_dejson.get( + "file_prefix", SecretsConfig.file_prefix + ), + ), + stats=StatsConfig( + host=self.flyte_conn.extra_dejson.get( + "statsd_host", StatsConfig.host + ), + port=self.flyte_conn.extra_dejson.get( + "statsd_port", StatsConfig.port + ), + disabled=self.flyte_conn.extra_dejson.get( + "statsd_disabled", StatsConfig.disabled + ), + disabled_tags=self.flyte_conn.extra_dejson.get( + "statsd_disabled_tags", + StatsConfig.disabled_tags, + ), + ), + data_config=DataConfig( + s3=S3Config( + enable_debug=self.flyte_conn.extra_dejson.get( + "s3_enable_debug", S3Config.enable_debug + ), + endpoint=self.flyte_conn.extra_dejson.get( + "s3_endpoint", S3Config.endpoint + ), + retries=self.flyte_conn.extra_dejson.get( + "s3_retries", S3Config.retries + ), + backoff=self.flyte_conn.extra_dejson.get( + "s3_backoff", S3Config.backoff + ), + access_key_id=self.flyte_conn.extra_dejson.get( + "s3_access_key_id", S3Config.access_key_id + ), + secret_access_key=self.flyte_conn.extra_dejson.get( + "s3_secret_access_key", + S3Config.secret_access_key, + ), + ), + gcs=GCSConfig( + gsutil_parallelism=self.flyte_conn.extra_dejson.get( + "gsutil_parallelism", + GCSConfig.gsutil_parallelism, + ) + ), + ), + local_sandbox_path=self.flyte_conn.extra_dejson.get( + "local_sandbox_path", Config.local_sandbox_path + ), + ), + ) + return remote + + def trigger_execution( + self, + execution_name: str, + launchplan_name: Optional[str] = None, + task_name: Optional[str] = None, + max_parallelism: Optional[int] = None, + raw_output_data_config: Optional[str] = None, + kubernetes_service_account: Optional[str] = None, + oauth2_client: Optional[Dict[str, str]] = None, + labels: Optional[Dict[str, str]] = None, + annotations: Optional[Dict[str, str]] = None, + secrets: Optional[List[Dict[str, str]]] = None, + notifications: Optional[List[Dict[str, Any]]] = None, + disable_notifications: Optional[bool] = None, + version: Optional[str] = None, + inputs: Dict[str, Any] = {}, + ) -> None: + """ + Trigger an execution. + + :param execution_name: Required. The name of the execution to trigger. + :param launchplan_name: Optional. The name of the launchplan to trigger. + :param task_name: Optional. The name of the task to trigger. + :param max_parallelism: Optional. The maximum number of parallel executions to allow. + :param raw_output_data_config: Optional. Location of offloaded data for things like S3, etc. + :param kubernetes_service_account: Optional. The kubernetes service account to use. + :param oauth2_client: Optional. The OAuth2 client to use. + :param labels: Optional. The labels to use. + :param annotations: Optional. The annotations to use. + :param secrets: Optional. Custom secrets to be applied to the execution resource. + :param notifications: Optional. List of notifications to be applied to the execution resource. + :param disable_notifications: Optional. Whether to disable notifications. + :param version: Optional. The version of the launchplan to trigger. + :param inputs: Optional. The inputs to the launchplan. + """ + remote = self.create_flyte_remote() + try: + if launchplan_name: + flyte_entity = remote.fetch_launch_plan( + name=launchplan_name, + project=self.project, + domain=self.domain, + version=version, + ) + elif task_name: + flyte_entity = remote.fetch_task( + name=task_name, + project=self.project, + domain=self.domain, + version=version, + ) + except FlyteEntityNotExistException as e: + raise AirflowException(f"Failed to fetch entity: {e}") + + try: + remote.execute( + flyte_entity, + inputs=inputs, + project=self.project, + domain=self.domain, + execution_name=execution_name, + options=Options( + raw_output_data_config=raw_output_data_config, + max_parallelism=max_parallelism, + security_context=SecurityContext( + run_as=Identity( + k8s_service_account=kubernetes_service_account, + oauth2_client=OAuth2Client( + client_id=oauth2_client.get("client_id"), + client_secret=oauth2_client.get("client_secret"), + ) + if oauth2_client + else None, + ), + secrets=[ + Secret( + group=secret.get("group"), + key=secret.get("key"), + group_version=secret.get("group_version"), + ) + for secret in secrets + ] + if secrets + else None, + ), + labels=Labels(labels), + annotations=Annotations(annotations), + notifications=[ + Notification( + phases=notification.get("phases"), + email=EmailNotification( + recipients_email=notification.get("email", {}).get( + "recipients_email" + ) + ), + slack=SlackNotification( + recipients_email=notification.get("slack", {}).get( + "recipients_email" + ) + ), + pager_duty=PagerDutyNotification( + recipients_email=notification.get("pager_duty", {}).get( + "recipients_email" + ) + ), + ) + for notification in notifications + ] + if notifications + else None, + disable_notifications=disable_notifications, + ), + ) + except Exception as e: + raise AirflowException(f"Failed to trigger execution: {e}") + + def execution_status(self, execution_name: str, remote: FlyteRemote): + phase = remote.client.get_execution( + self.execution_id(execution_name) + ).closure.phase + + if phase == self.SUCCEEDED: + return True + elif phase == self.FAILED: + raise AirflowException(f"Execution {execution_name} failed") + elif phase == self.TIMED_OUT: + raise AirflowException(f"Execution {execution_name} timedout") + elif phase == self.ABORTED: + raise AirflowException(f"Execution {execution_name} aborted") + else: + return False + + def terminate( + self, + execution_name: str, + cause: str, + ) -> None: + """ + Terminate an execution. + + :param execution: Required. The execution to terminate. + :param cause: Required. The cause of the termination. + """ + remote = self.create_flyte_remote() + execution_id = self.execution_id(execution_name) + remote.client.terminate_execution(id=execution_id, cause=cause) diff --git a/flyte_provider/operators/__init__.py b/flyte_provider/operators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/flyte_provider/operators/flyte.py b/flyte_provider/operators/flyte.py new file mode 100644 index 0000000..2f3a310 --- /dev/null +++ b/flyte_provider/operators/flyte.py @@ -0,0 +1,258 @@ +import inspect +import re +from dataclasses import fields +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator, BaseOperatorLink +from airflow.models.abstractoperator import AbstractOperator +from airflow.models.taskinstance import TaskInstanceKey +from flytekit.models.common import ( + EmailNotification, + Notification, + PagerDutyNotification, + SlackNotification, +) +from flytekit.models.security import OAuth2Client, Secret + +from flyte_provider.hooks.flyte import FlyteHook + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class RegistryLink(BaseOperatorLink): + """Link to Registry""" + + name = "Astronomer Registry" + + def get_link(self, operator: AbstractOperator, ti_key: TaskInstanceKey) -> str: + """Get link to registry page.""" + + registry_link = ( + "https://registry.astronomer.io/providers/{provider}/modules/{operator}" + ) + return registry_link.format(provider="flyte", operator="flyteoperator") + + +class FlyteOperator(BaseOperator): + """ + Launch Flyte executions from within Airflow. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AirflowFlyteOperator` + + :param flyte_conn_id: Required. The connection to Flyte setup, containing metadata. + :param project: Optional. The project to connect to. + :param domain: Optional. The domain to connect to. + :param launchplan_name: Optional. The name of the launchplan to trigger. + :param task_name: Optional. The name of the task to trigger. + :param max_parallelism: Optional. The maximum number of parallel executions to allow. + :param raw_output_data_config: Optional. Location of offloaded data for things like S3, etc. + :param kubernetes_service_account: Optional. The Kubernetes service account to use. + :param oauth2_client: Optional. The OAuth2 client to use. + :param labels: Optional. Custom labels to be applied to the execution resource. + :param annotations: Optional. Custom annotations to be applied to the execution resource. + :param secrets: Optional. Custom secrets to be applied to the execution resource. + :param notifications: Optional. List of notifications to be applied to the execution resource. + :param disable_notifications: Optional. Whether to disable notifications. + :param version: Optional. The version of the launchplan/task to trigger. + :param inputs: Optional. The inputs to the launchplan/task. + """ + + template_fields: Sequence[str] = ("flyte_conn_id",) # mypy fix + + def __init__( + self, + flyte_conn_id: str, + project: Optional[str] = None, + domain: Optional[str] = None, + launchplan_name: Optional[str] = None, + task_name: Optional[str] = None, + max_parallelism: Optional[int] = None, + raw_output_data_config: Optional[str] = None, + kubernetes_service_account: Optional[str] = None, + oauth2_client: Optional[Dict[str, str]] = None, + labels: Dict[str, str] = {}, + annotations: Dict[str, str] = {}, + secrets: Optional[List[Dict[str, str]]] = None, + notifications: Optional[ + List[Dict[str, Union[List[str], Dict[str, List[str]]]]] + ] = None, + disable_notifications: Optional[bool] = None, + version: Optional[str] = None, + inputs: Dict[str, Any] = {}, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.flyte_conn_id = flyte_conn_id + self.project = project + self.domain = domain + self.launchplan_name = launchplan_name + self.task_name = task_name + self.max_parallelism = max_parallelism + self.raw_output_data_config = raw_output_data_config + self.kubernetes_service_account = kubernetes_service_account + self.oauth2_client = oauth2_client + self.labels = labels + self.annotations = annotations + self.secrets = secrets + self.notifications = notifications + self.disable_notifications = disable_notifications + self.version = version + self.inputs = inputs + self.execution_name: str = "" + + if (not (self.task_name or self.launchplan_name)) or ( + self.task_name and self.launchplan_name + ): + raise AirflowException("Either task_name or launchplan_name is required.") + + if oauth2_client: + if not isinstance(oauth2_client, dict): + raise AirflowException( + f"oauth2_client isn't a dict, instead it is of type {type(oauth2_client)}" + ) + if not ( + set( + field.name + for field in fields(OAuth2Client) + if not hasattr(OAuth2Client, field.name) + ) + <= set(oauth2_client.keys()) + <= set(map(lambda x: x.name, fields(OAuth2Client))) + ): + raise AirflowException( + "oauth2_client doesn't have all the required keys or the key names do not match." + ) + + if secrets: + if not isinstance(secrets, list): + raise AirflowException( + f"secrets isn't a list, instead it is of type {type(oauth2_client)}" + ) + for secret in secrets: + if not isinstance(secret, dict): + raise AirflowException( + f"secret isn't a dict, instead it is of type {type(oauth2_client)}" + ) + if secret and not ( + set( + field.name + for field in fields(Secret) + if not hasattr(Secret, field.name) + ) + <= set(secret.keys()) + <= set(map(lambda x: x.name, fields(Secret))) + ): + raise AirflowException( + "secret doesn't have all the required keys or the key names do not match." + ) + + if notifications: + map_key_class = { + "email": EmailNotification, + "slack": SlackNotification, + "pagerduty": PagerDutyNotification, + } + + if not isinstance(notifications, list): + raise AirflowException( + f"notifications isn't a dict, instead it is of type {type(oauth2_client)}" + ) + for notification in notifications: + if not isinstance(notification, dict): + raise AirflowException( + f"notification isn't a dict, instead it is of type {type(oauth2_client)}" + ) + if notification and not set( + arg_name + for arg_name, v in inspect.signature( + Notification.__init__ + ).parameters.items() + if v.default is inspect._empty and arg_name != "self" + ) <= set(notification.keys()) <= set( + map( + lambda x: x, + inspect.signature(Notification.__init__).parameters.keys(), + ) + ): + raise AirflowException( + "notification doesn't have all the required keys or the key names do not match." + ) + + for key in notification.keys(): + if key in {"email", "slack", "pager_duty"}: + if (not isinstance(notification[key], dict)) or not ( + set( + arg_name + for arg_name, v in inspect.signature( + map_key_class[key].__init__ + ).parameters.items() + if v.default is inspect._empty and arg_name != "self" + ) + <= set(notification[key].keys()) + <= set( + map( + lambda x: x, + inspect.signature( + map_key_class[key].__init__ + ).parameters.keys(), + ) + ) + ): + raise AirflowException( + f"notification[{key}] isn't a dict/doesn't have all the required keys/the key names do not match." + ) + + def execute(self, context: "Context") -> str: + """Trigger an execution.""" + + # create a deterministic execution name + task_id = re.sub(r"[\W_]+", "", context["task"].task_id)[:5] + self.execution_name = ( + task_id + + re.sub( + r"[\W_]+", + "", + context["dag_run"].run_id.split("__")[-1].lower(), + )[: (20 - len(task_id))] + ) + + hook = FlyteHook( + flyte_conn_id=self.flyte_conn_id, project=self.project, domain=self.domain + ) + hook.trigger_execution( + launchplan_name=self.launchplan_name, + task_name=self.task_name, + max_parallelism=self.max_parallelism, + raw_output_data_config=self.raw_output_data_config, + kubernetes_service_account=self.kubernetes_service_account, + oauth2_client=self.oauth2_client, + labels=self.labels, + annotations=self.annotations, + secrets=self.secrets, + notifications=self.notifications, + disable_notifications=self.disable_notifications, + version=self.version, + inputs=self.inputs, + execution_name=self.execution_name, + ) + self.log.info("Execution %s submitted", self.execution_name) + + return self.execution_name + + def on_kill(self) -> None: + """Kill the execution.""" + if self.execution_name: + print(f"Killing execution {self.execution_name}") + hook = FlyteHook( + flyte_conn_id=self.flyte_conn_id, + project=self.project, + domain=self.domain, + ) + hook.terminate( + execution_name=self.execution_name, + cause="Killed by Airflow", + ) diff --git a/flyte_provider/sensors/__init__.py b/flyte_provider/sensors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/flyte_provider/sensors/flyte.py b/flyte_provider/sensors/flyte.py new file mode 100644 index 0000000..b4b00ec --- /dev/null +++ b/flyte_provider/sensors/flyte.py @@ -0,0 +1,49 @@ +from typing import TYPE_CHECKING, Optional, Sequence + +from airflow.sensors.base import BaseSensorOperator + +from flyte_provider.hooks.flyte import FlyteHook + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class FlyteSensor(BaseSensorOperator): + """ + Check for the status of a Flyte execution. + + :param execution_name: Required. The name of the execution to check. + :param project: Optional. The project to connect to. + :param domain: Optional. The domain to connect to. + :param flyte_conn_id: Required. The name of the Flyte connection to + get the connection information for Flyte. + """ + + template_fields: Sequence[str] = ("execution_name",) # mypy fix + + def __init__( + self, + execution_name: str, + project: Optional[str] = None, + domain: Optional[str] = None, + flyte_conn_id: str = "flyte_default", + **kwargs, + ): + super().__init__(**kwargs) + self.execution_name = execution_name + self.project = project + self.domain = domain + self.flyte_conn_id = flyte_conn_id + + def poke(self, context: "Context") -> bool: + """Check for the status of a Flyte execution.""" + hook = FlyteHook( + flyte_conn_id=self.flyte_conn_id, project=self.project, domain=self.domain + ) + remote = hook.create_flyte_remote() + + if hook.execution_status(self.execution_name, remote): + return True + + self.log.info("Waiting for execution %s to complete", self.execution_name) + return False diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..66d85d1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +flytekit +apache-airflow +pre-commit +pytest diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..d666ea6 --- /dev/null +++ b/setup.py @@ -0,0 +1,37 @@ +from setuptools import setup + +__version__ = "0.0.0+develop" + +with open("README.md", "r") as fh: + long_description = fh.read() + +"""Perform the package airflow-provider-flyte setup.""" +setup( + name="airflow-provider-flyte", + version=__version__, + description="Flyte Airflow Provider", + long_description=long_description, + long_description_content_type="text/markdown", + entry_points={ + "apache_airflow_provider": [ + "provider_info=flyte_provider.__init__:get_provider_info" + ] + }, + license="Apache License 2.0", + packages=[ + "flyte_provider", + "flyte_provider.hooks", + "flyte_provider.sensors", + "flyte_provider.operators", + ], + install_requires=["apache-airflow>=2.0", "flytekit>=1.0.0"], + setup_requires=["setuptools", "wheel"], + author="Samhita Alla", + author_email="samhita@union.ai", + url="https://flyte.org/", + classifiers=[ + "Framework :: Apache Airflow", + "Framework :: Apache Airflow :: Provider", + ], + python_requires="~=3.7", +) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/hooks/__init__.py b/tests/hooks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/hooks/test_flyte.py b/tests/hooks/test_flyte.py new file mode 100644 index 0000000..2e12d78 --- /dev/null +++ b/tests/hooks/test_flyte.py @@ -0,0 +1,169 @@ +import unittest +from datetime import timedelta +from unittest import mock + +import pytest +from airflow.exceptions import AirflowException +from airflow.models import Connection +from flytekit.configuration import Config, PlatformConfig +from flytekit.exceptions.user import FlyteEntityNotExistException, FlyteValueException +from flytekit.remote import FlyteRemote + +from flyte_provider.hooks.flyte import FlyteHook + + +class TestFlyteHook(unittest.TestCase): + + flyte_conn_id = "flyte_default" + execution_name = "flyte20220330t133856" + conn_type = "flyte" + host = "localhost" + port = "30081" + extra = {"project": "flytesnacks", "domain": "development"} + launchplan_name = "core.basic.hello_world.my_wf" + task_name = "core.basic.hello_world.say_hello" + raw_output_data_config = "s3://flyte-demo/raw_data" + kubernetes_service_account = "default" + version = "v1" + inputs = {"name": "hello world"} + timeout = timedelta(seconds=3600) + oauth2_client = {"client_id": "123", "client_secret": "456"} + secrets = [{"group": "secrets", "key": "123"}] + notifications = [{"phases": [1], "email": {"recipients_email": ["abc@flyte.org"]}}] + + @classmethod + def get_mock_connection(cls): + return Connection( + conn_id=cls.flyte_conn_id, + conn_type=cls.conn_type, + host=cls.host, + port=cls.port, + extra=cls.extra, + ) + + @classmethod + def create_remote(cls): + return FlyteRemote( + config=Config( + platform=PlatformConfig( + endpoint=":".join([cls.host, cls.port]), insecure=True + ), + ) + ) + + @mock.patch("flyte_provider.hooks.flyte.FlyteHook.get_connection") + @mock.patch("flyte_provider.hooks.flyte.FlyteHook.create_flyte_remote") + def test_trigger_execution_success( + self, mock_create_flyte_remote, mock_get_connection + ): + mock_connection = self.get_mock_connection() + mock_get_connection.return_value = mock_connection + + test_hook = FlyteHook(self.flyte_conn_id) + + mock_remote = self.create_remote() + mock_create_flyte_remote.return_value = mock_remote + mock_remote.fetch_launch_plan = mock.MagicMock() + + mock_remote.execute = mock.MagicMock() + + test_hook.trigger_execution( + launchplan_name=self.launchplan_name, + raw_output_data_config=self.raw_output_data_config, + kubernetes_service_account=self.kubernetes_service_account, + version=self.version, + inputs=self.inputs, + execution_name=self.execution_name, + oauth2_client=self.oauth2_client, + secrets=self.secrets, + notifications=self.notifications, + ) + mock_create_flyte_remote.assert_called_once() + + @mock.patch("flyte_provider.hooks.flyte.FlyteHook.get_connection") + @mock.patch("flyte_provider.hooks.flyte.FlyteHook.create_flyte_remote") + def test_trigger_task_execution_success( + self, mock_create_flyte_remote, mock_get_connection + ): + mock_connection = self.get_mock_connection() + mock_get_connection.return_value = mock_connection + + test_hook = FlyteHook(self.flyte_conn_id) + + mock_remote = self.create_remote() + mock_create_flyte_remote.return_value = mock_remote + mock_remote.fetch_task = mock.MagicMock() + + mock_remote.execute = mock.MagicMock() + + test_hook.trigger_execution( + task_name=self.task_name, + raw_output_data_config=self.raw_output_data_config, + kubernetes_service_account=self.kubernetes_service_account, + version=self.version, + inputs=self.inputs, + execution_name=self.execution_name, + oauth2_client=self.oauth2_client, + secrets=self.secrets, + notifications=self.notifications, + ) + mock_create_flyte_remote.assert_called_once() + + @mock.patch("flyte_provider.hooks.flyte.FlyteHook.get_connection") + @mock.patch("flyte_provider.hooks.flyte.FlyteHook.create_flyte_remote") + def test_trigger_execution_failed_to_fetch( + self, mock_create_flyte_remote, mock_get_connection + ): + mock_connection = self.get_mock_connection() + mock_get_connection.return_value = mock_connection + + test_hook = FlyteHook(self.flyte_conn_id) + + mock_remote = self.create_remote() + mock_create_flyte_remote.return_value = mock_remote + mock_remote.fetch_launch_plan = mock.MagicMock( + side_effect=FlyteEntityNotExistException + ) + + with pytest.raises(AirflowException): + test_hook.trigger_execution( + launchplan_name=self.launchplan_name, + raw_output_data_config=self.raw_output_data_config, + kubernetes_service_account=self.kubernetes_service_account, + version=self.version, + inputs=self.inputs, + execution_name=self.execution_name, + oauth2_client=self.oauth2_client, + secrets=self.secrets, + notifications=self.notifications, + ) + mock_create_flyte_remote.assert_called_once() + + @mock.patch("flyte_provider.hooks.flyte.FlyteHook.get_connection") + @mock.patch("flyte_provider.hooks.flyte.FlyteHook.create_flyte_remote") + def test_trigger_execution_failed_to_trigger( + self, mock_create_flyte_remote, mock_get_connection + ): + mock_connection = self.get_mock_connection() + mock_get_connection.return_value = mock_connection + + test_hook = FlyteHook(self.flyte_conn_id) + + mock_remote = self.create_remote() + mock_create_flyte_remote.return_value = mock_remote + mock_remote.fetch_launch_plan = mock.MagicMock() + mock_remote.execute = mock.MagicMock(side_effect=FlyteValueException) + + with pytest.raises(AirflowException): + test_hook.trigger_execution( + launchplan_name=self.launchplan_name, + raw_output_data_config=self.raw_output_data_config, + kubernetes_service_account=self.kubernetes_service_account, + version=self.version, + inputs=self.inputs, + execution_name=self.execution_name, + oauth2_client=self.oauth2_client, + secrets=self.secrets, + notifications=self.notifications, + ) + mock_create_flyte_remote.assert_called_once() diff --git a/tests/operators/__init__.py b/tests/operators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/operators/test_flyte.py b/tests/operators/test_flyte.py new file mode 100644 index 0000000..40c0413 --- /dev/null +++ b/tests/operators/test_flyte.py @@ -0,0 +1,152 @@ +import unittest +from unittest import mock + +import pytest +from airflow import AirflowException +from airflow.models import Connection +from airflow.models.dagrun import DagRun + +from flyte_provider.operators.flyte import FlyteOperator + + +class TestFlyteOperator(unittest.TestCase): + + task_id = "test_flyte_operator" + flyte_conn_id = "flyte_default" + run_id = "manual__2022-03-30T13:55:08.715694+00:00" + conn_type = "flyte" + host = "localhost" + port = "30081" + project = "flytesnacks" + domain = "development" + launchplan_name = "core.basic.hello_world.my_wf" + raw_output_data_config = "s3://flyte-demo/raw_data" + kubernetes_service_account = "default" + labels = {"key1": "value1"} + version = "v1" + inputs = {"name": "hello world"} + execution_name = "testf20220330t135508" + oauth2_client = {"client_id": "123", "client_secret": "456"} + secrets = [{"group": "secrets", "key": "123"}] + notifications = [{"phases": [1], "email": {"recipients_email": ["abc@flyte.org"]}}] + wrong_notifications = [ + {"phases": [1], "email": {"recipient_email": ["abc@flyte.org"]}} + ] + + @classmethod + def get_connection(cls): + return Connection( + conn_id=cls.flyte_conn_id, + conn_type=cls.conn_type, + host=cls.host, + port=cls.port, + extra={"project": cls.project, "domain": cls.domain}, + ) + + @mock.patch("flyte_provider.hooks.flyte.FlyteHook.trigger_execution") + @mock.patch("flyte_provider.hooks.flyte.FlyteHook.get_connection") + def test_execute(self, mock_get_connection, mock_trigger_execution): + mock_get_connection.return_value = self.get_connection() + + operator = FlyteOperator( + task_id=self.task_id, + flyte_conn_id=self.flyte_conn_id, + project=self.project, + domain=self.domain, + launchplan_name=self.launchplan_name, + raw_output_data_config=self.raw_output_data_config, + kubernetes_service_account=self.kubernetes_service_account, + labels=self.labels, + version=self.version, + inputs=self.inputs, + oauth2_client=self.oauth2_client, + secrets=self.secrets, + notifications=self.notifications, + ) + result = operator.execute( + {"dag_run": DagRun(run_id=self.run_id), "task": operator} + ) + + assert result == self.execution_name + mock_get_connection.assert_called_once_with(self.flyte_conn_id) + mock_trigger_execution.assert_called_once_with( + launchplan_name=self.launchplan_name, + task_name=None, + max_parallelism=None, + raw_output_data_config=self.raw_output_data_config, + kubernetes_service_account=self.kubernetes_service_account, + oauth2_client=self.oauth2_client, + labels=self.labels, + annotations={}, + secrets=self.secrets, + notifications=self.notifications, + version=self.version, + inputs=self.inputs, + execution_name=self.execution_name, + disable_notifications=None, + ) + + @mock.patch( + "flyte_provider.hooks.flyte.FlyteHook.trigger_execution", return_value=None + ) + @mock.patch("flyte_provider.hooks.flyte.FlyteHook.terminate") + @mock.patch("flyte_provider.hooks.flyte.FlyteHook.get_connection") + def test_on_kill_success( + self, mock_get_connection, mock_terminate, mock_trigger_execution + ): + mock_get_connection.return_value = self.get_connection() + + operator = FlyteOperator( + task_id=self.task_id, + flyte_conn_id=self.flyte_conn_id, + project=self.project, + domain=self.domain, + launchplan_name=self.launchplan_name, + inputs=self.inputs, + oauth2_client=self.oauth2_client, + secrets=self.secrets, + notifications=self.notifications, + ) + operator.execute({"dag_run": DagRun(run_id=self.run_id), "task": operator}) + operator.on_kill() + + mock_get_connection.has_calls([mock.call(self.flyte_conn_id)] * 2) + mock_trigger_execution.assert_called() + mock_terminate.assert_called_once_with( + execution_name=self.execution_name, cause="Killed by Airflow" + ) + + @mock.patch("flyte_provider.hooks.flyte.FlyteHook.terminate") + @mock.patch("flyte_provider.hooks.flyte.FlyteHook.get_connection") + def test_on_kill_noop(self, mock_get_connection, mock_terminate): + mock_get_connection.return_value = self.get_connection() + + operator = FlyteOperator( + task_id=self.task_id, + flyte_conn_id=self.flyte_conn_id, + project=self.project, + domain=self.domain, + launchplan_name=self.launchplan_name, + inputs=self.inputs, + oauth2_client=self.oauth2_client, + secrets=self.secrets, + notifications=self.notifications, + ) + operator.on_kill() + + mock_get_connection.assert_not_called() + mock_terminate.assert_not_called() + + def test_execute_failure(self): + with pytest.raises(AirflowException): + FlyteOperator( + task_id=self.task_id, + flyte_conn_id=self.flyte_conn_id, + project=self.project, + domain=self.domain, + launchplan_name=self.launchplan_name, + inputs=self.inputs, + oauth2_client=self.oauth2_client, + secrets=self.secrets, + notifications=self.wrong_notifications, + ) diff --git a/tests/sensors/__init__.py b/tests/sensors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/sensors/test_flyte.py b/tests/sensors/test_flyte.py new file mode 100644 index 0000000..ffa5f3b --- /dev/null +++ b/tests/sensors/test_flyte.py @@ -0,0 +1,146 @@ +import unittest +from unittest import mock + +import pytest +from airflow import AirflowException +from airflow.models import Connection +from flytekit.configuration import Config, PlatformConfig +from flytekit.models.core import execution as core_execution_models +from flytekit.remote import FlyteRemote + +from flyte_provider.hooks.flyte import FlyteHook +from flyte_provider.sensors.flyte import FlyteSensor + + +class TestFlyteSensor(unittest.TestCase): + + task_id = "test_flyte_sensor" + flyte_conn_id = "flyte_default" + conn_type = "flyte" + host = "localhost" + port = "30081" + project = "flytesnacks" + domain = "development" + execution_name = "testf20220330t135508" + + @classmethod + def get_connection(cls): + return Connection( + conn_id=cls.flyte_conn_id, + conn_type=cls.conn_type, + host=cls.host, + port=cls.port, + extra={"project": cls.project, "domain": cls.domain}, + ) + + @classmethod + def create_remote(cls): + return FlyteRemote( + config=Config( + platform=PlatformConfig( + endpoint=":".join([cls.host, cls.port]), insecure=True + ), + ) + ) + + @mock.patch("flyte_provider.hooks.flyte.FlyteHook.get_connection") + @mock.patch("flyte_provider.hooks.flyte.FlyteHook.create_flyte_remote") + @mock.patch("flyte_provider.hooks.flyte.FlyteHook.execution_id") + def test_poke_done( + self, mock_execution_id, mock_create_flyte_remote, mock_get_connection + ): + mock_get_connection.return_value = self.get_connection() + + mock_remote = self.create_remote() + mock_create_flyte_remote.return_value = mock_remote + + execution_id = mock.MagicMock() + mock_execution_id.return_value = execution_id + + mock_get_execution = mock.MagicMock() + mock_remote.client.get_execution = mock_get_execution + mock_phase = mock.PropertyMock(return_value=FlyteHook.SUCCEEDED) + type(mock_get_execution().closure).phase = mock_phase + + sensor = FlyteSensor( + task_id=self.task_id, + execution_name=self.execution_name, + project=self.project, + domain=self.domain, + flyte_conn_id=self.flyte_conn_id, + ) + + return_value = sensor.poke({}) + + assert return_value + mock_create_flyte_remote.assert_called_once() + mock_execution_id.assert_called_with(self.execution_name) + + @mock.patch("flyte_provider.hooks.flyte.FlyteHook.get_connection") + @mock.patch("flyte_provider.hooks.flyte.FlyteHook.create_flyte_remote") + @mock.patch("flyte_provider.hooks.flyte.FlyteHook.execution_id") + def test_poke_failed( + self, mock_execution_id, mock_create_flyte_remote, mock_get_connection + ): + mock_get_connection.return_value = self.get_connection() + + mock_remote = self.create_remote() + mock_create_flyte_remote.return_value = mock_remote + + sensor = FlyteSensor( + task_id=self.task_id, + execution_name=self.execution_name, + project=self.project, + domain=self.domain, + flyte_conn_id=self.flyte_conn_id, + ) + + execution_id = mock.MagicMock() + mock_execution_id.return_value = execution_id + + for phase in [FlyteHook.ABORTED, FlyteHook.FAILED, FlyteHook.TIMED_OUT]: + mock_get_execution = mock.MagicMock() + mock_remote.client.get_execution = mock_get_execution + mock_phase = mock.PropertyMock(return_value=phase) + type(mock_get_execution().closure).phase = mock_phase + + with pytest.raises(AirflowException): + sensor.poke({}) + + mock_create_flyte_remote.has_calls([mock.call()] * 3) + mock_execution_id.has_calls([mock.call(self.execution_name)] * 3) + + @mock.patch("flyte_provider.hooks.flyte.FlyteHook.get_connection") + @mock.patch("flyte_provider.hooks.flyte.FlyteHook.create_flyte_remote") + @mock.patch("flyte_provider.hooks.flyte.FlyteHook.execution_id") + def test_poke_running( + self, mock_execution_id, mock_create_flyte_remote, mock_get_connection + ): + mock_get_connection.return_value = self.get_connection() + + mock_remote = self.create_remote() + mock_create_flyte_remote.return_value = mock_remote + + execution_id = mock.MagicMock() + mock_execution_id.return_value = execution_id + + mock_get_execution = mock.MagicMock() + mock_remote.client.get_execution = mock_get_execution + mock_phase = mock.PropertyMock( + return_value=core_execution_models.WorkflowExecutionPhase.RUNNING + ) + type(mock_get_execution().closure).phase = mock_phase + + sensor = FlyteSensor( + task_id=self.task_id, + execution_name=self.execution_name, + project=self.project, + domain=self.domain, + flyte_conn_id=self.flyte_conn_id, + ) + + return_value = sensor.poke({}) + assert not return_value + + mock_create_flyte_remote.assert_called_once() + mock_execution_id.assert_called_with(self.execution_name)