Skip to content

Commit

Permalink
Merge pull request #2 from flyteorg/add-try-number
Browse files Browse the repository at this point in the history
add try number to execution name
  • Loading branch information
samhita-alla authored Jul 15, 2022
2 parents 79cd08f + fb6fb1b commit cd068e2
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 14 deletions.
8 changes: 4 additions & 4 deletions demo/dags/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@
from flyte_provider.sensors.flyte import FlyteSensor

with DAG(
dag_id="example_sensor",
dag_id="example_flyte",
schedule_interval=None,
start_date=datetime(2021, 1, 1),
dagrun_timeout=timedelta(minutes=60),
catchup=False,
) as dag:
task = FlyteOperator(
task_id="task",
task_id="diabetes_predictions",
flyte_conn_id="flyte_conn",
project="flytesnacks",
domain="development",
launchplan_name="core.flyte_basics.torch_example.wf",
kubernetes_service_account="default",
launchplan_name="ml_training.pima_diabetes.diabetes.diabetes_xgboost_model",
inputs={"test_split_ratio": 0.66, "seed": 5},
)

sensor = FlyteSensor(
Expand Down
6 changes: 4 additions & 2 deletions flyte_provider/operators/flyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,13 @@ def execute(self, context: "Context") -> str:
"""Trigger an execution."""

# create a deterministic execution name
task_id = re.sub(r"[\W_]+", "", context["task"].task_id)[:5]
task_id = re.sub(r"[\W_]+", "", context["task"].task_id)[:4] + str(
context["task_instance"].try_number
)
self.execution_name = (
task_id
+ re.sub(
r"[\W_]+",
r"[\W_t]+",
"",
context["dag_run"].run_id.split("__")[-1].lower(),
)[: (20 - len(task_id))]
Expand Down
4 changes: 1 addition & 3 deletions tests/hooks/test_flyte.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import unittest
from datetime import timedelta
from unittest import mock

import pytest
Expand All @@ -15,7 +14,7 @@
class TestFlyteHook(unittest.TestCase):

flyte_conn_id = "flyte_default"
execution_name = "flyte20220330t133856"
execution_name = "flyt1202203301338565"
conn_type = "flyte"
host = "localhost"
port = "30081"
Expand All @@ -26,7 +25,6 @@ class TestFlyteHook(unittest.TestCase):
kubernetes_service_account = "default"
version = "v1"
inputs = {"name": "hello world"}
timeout = timedelta(seconds=3600)
oauth2_client = {"client_id": "123", "client_secret": "456"}
secrets = [{"group": "secrets", "key": "123"}]
notifications = [{"phases": [1], "email": {"recipients_email": ["[email protected]"]}}]
Expand Down
18 changes: 14 additions & 4 deletions tests/operators/test_flyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest
from airflow import AirflowException
from airflow.models import Connection
from airflow.models import Connection, TaskInstance
from airflow.models.dagrun import DagRun

from flyte_provider.operators.flyte import FlyteOperator
Expand All @@ -25,7 +25,7 @@ class TestFlyteOperator(unittest.TestCase):
labels = {"key1": "value1"}
version = "v1"
inputs = {"name": "hello world"}
execution_name = "testf20220330t135508"
execution_name = "test1202203301355087"
oauth2_client = {"client_id": "123", "client_secret": "456"}
secrets = [{"group": "secrets", "key": "123"}]
notifications = [{"phases": [1], "email": {"recipients_email": ["[email protected]"]}}]
Expand Down Expand Up @@ -64,7 +64,11 @@ def test_execute(self, mock_get_connection, mock_trigger_execution):
notifications=self.notifications,
)
result = operator.execute(
{"dag_run": DagRun(run_id=self.run_id), "task": operator}
{
"dag_run": DagRun(run_id=self.run_id),
"task": operator,
"task_instance": TaskInstance(task=operator),
}
)

assert result == self.execution_name
Expand Down Expand Up @@ -107,7 +111,13 @@ def test_on_kill_success(
secrets=self.secrets,
notifications=self.notifications,
)
operator.execute({"dag_run": DagRun(run_id=self.run_id), "task": operator})
operator.execute(
{
"dag_run": DagRun(run_id=self.run_id),
"task": operator,
"task_instance": TaskInstance(task=operator),
}
)
operator.on_kill()

mock_get_connection.has_calls([mock.call(self.flyte_conn_id)] * 2)
Expand Down
2 changes: 1 addition & 1 deletion tests/sensors/test_flyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class TestFlyteSensor(unittest.TestCase):
port = "30081"
project = "flytesnacks"
domain = "development"
execution_name = "testf20220330t135508"
execution_name = "test1202203301355081"

@classmethod
def get_connection(cls):
Expand Down

0 comments on commit cd068e2

Please sign in to comment.