Skip to content

Commit

Permalink
Merge branch 'master' into error_on_retry
Browse files Browse the repository at this point in the history
  • Loading branch information
pingsutw authored Jan 24, 2023
2 parents cc70280 + b6605bc commit cafad0c
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 12 deletions.
4 changes: 4 additions & 0 deletions docs/source/clients.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
.. automodule:: flytekit.clients
:no-members:
:no-inherited-members:
:no-special-members:
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ Expected output:
flytekit
configuration
remote
clients
testing
extend
deck
Expand Down
19 changes: 19 additions & 0 deletions flytekit/clients/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
=====================
Clients
=====================
.. currentmodule:: flytekit.clients
This module provides lower level access to a Flyte backend.
.. _clients_module:
.. autosummary::
:template: custom.rst
:toctree: generated/
:nosignatures:
~friendly.SynchronousFlyteClient
~raw.RawSynchronousFlyteClient
"""
7 changes: 7 additions & 0 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from flytekit.core.resources import Resources, convert_resources_to_resource_model
from flytekit.core.utils import _dnsify
from flytekit.loggers import logger
from flytekit.models import literals as _literal_models
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.task import Resources as _resources_model
Expand Down Expand Up @@ -119,6 +120,12 @@ def with_overrides(self, *args, **kwargs):
self._metadata._interruptible = kwargs["interruptible"]
if "name" in kwargs:
self._metadata._name = kwargs["name"]
if "task_config" in kwargs:
logger.warning("This override is beta. We may want to revisit this in the future.")
new_task_config = kwargs["task_config"]
if not isinstance(new_task_config, type(self.flyte_entity._task_config)):
raise ValueError("can't change the type of the task config")
self.flyte_entity._task_config = new_task_config
return self


Expand Down
3 changes: 2 additions & 1 deletion flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,8 @@ def create_and_link_node_from_remote(
extra_inputs = used_inputs ^ set(kwargs.keys())
if len(extra_inputs) > 0:
raise _user_exceptions.FlyteAssertion(
"Too many inputs were specified for the interface. Extra inputs were: {}".format(extra_inputs)
f"Too many inputs for [{entity.name}] Expected inputs: {typed_interface.inputs.keys()} "
f"- extra inputs: {extra_inputs}"
)

# Detect upstream nodes
Expand Down
23 changes: 14 additions & 9 deletions flytekit/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@

# By default, the root flytekit logger to debug so everything is logged, but enable fine-tuning
logger = logging.getLogger("flytekit")
# Root logger control
flytekit_root_env_var = f"{LOGGING_ENV_VAR}_ROOT"
if os.getenv(flytekit_root_env_var) is not None:
logger.setLevel(int(os.getenv(flytekit_root_env_var)))
else:
logger.setLevel(logging.DEBUG)

# Stop propagation so that configuration is isolated to this file (so that it doesn't matter what the
# global Python root logger is set to).
Expand All @@ -40,22 +34,33 @@

# create console handler
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)

# Root logger control
# Don't want to import the configuration library since that will cause all sorts of circular imports, let's
# just use the environment variable if it's defined. Decide in the future when we implement better controls
# if we should control with the channel or with the logger level.
# The handler log level controls whether log statements will actually print to the screen
flytekit_root_env_var = f"{LOGGING_ENV_VAR}_ROOT"
level_from_env = os.getenv(LOGGING_ENV_VAR)
if level_from_env is not None:
ch.setLevel(int(level_from_env))
root_level_from_env = os.getenv(flytekit_root_env_var)
if root_level_from_env is not None:
logger.setLevel(int(root_level_from_env))
elif level_from_env is not None:
logger.setLevel(int(level_from_env))
else:
ch.setLevel(logging.WARNING)
logger.setLevel(logging.WARNING)

for log_name, child_logger in child_loggers.items():
env_var = f"{LOGGING_ENV_VAR}_{log_name.upper()}"
level_from_env = os.getenv(env_var)
if level_from_env is not None:
child_logger.setLevel(int(level_from_env))
else:
if child_logger is user_space_logger:
child_logger.setLevel(logging.INFO)
else:
child_logger.setLevel(logging.WARNING)

# create formatter
formatter = jsonlogger.JsonFormatter(fmt="%(asctime)s %(name)s %(levelname)s %(message)s")
Expand Down
7 changes: 6 additions & 1 deletion flytekit/remote/lazy_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@ def entity(self) -> T:
"""
with self._mutex:
if self._entity is None:
self._entity = self._getter()
try:
self._entity = self._getter()
except AttributeError as e:
raise RuntimeError(
f"Error downloading the entity {self._name}, (check original exception...)"
) from e
return self._entity

def __getattr__(self, item: str) -> typing.Any:
Expand Down
6 changes: 5 additions & 1 deletion flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from flyteidl.admin.signal_pb2 import Signal, SignalListRequest, SignalSetRequest
from flyteidl.core import literals_pb2 as literals_pb2
from git import Repo

from flytekit import Literal
from flytekit.clients.friendly import SynchronousFlyteClient
Expand Down Expand Up @@ -127,9 +126,14 @@ def _get_git_repo_url(source_path):
Get git repo URL from remote.origin.url
"""
try:
from git import Repo

return "github.com/" + Repo(source_path).remotes.origin.url.split(".git")[0].split(":")[-1]
except ImportError:
remote_logger.warning("Could not import git. is the git executable installed?")
except Exception:
# If the file isn't in the git repo, we can't get the url from git config
remote_logger.debug(f"{source_path} is not a git repo.")
return ""


Expand Down
23 changes: 23 additions & 0 deletions tests/flytekit/unit/core/test_node_creation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import typing
from collections import OrderedDict
from dataclasses import dataclass

import pytest

Expand Down Expand Up @@ -424,3 +425,25 @@ def my_wf(a: str) -> str:
wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
assert len(wf_spec.template.nodes) == 1
assert wf_spec.template.nodes[0].metadata.name == "foo"


def test_config_override():
@dataclass
class DummyConfig:
name: str

@task(task_config=DummyConfig(name="hello"))
def t1(a: str) -> str:
return f"*~*~*~{a}*~*~*~"

@workflow
def my_wf(a: str) -> str:
return t1(a=a).with_overrides(task_config=DummyConfig("flyte"))

assert my_wf.nodes[0].flyte_entity.task_config.name == "flyte"

with pytest.raises(ValueError):

@workflow
def my_wf(a: str) -> str:
return t1(a=a).with_overrides(task_config=None)
13 changes: 13 additions & 0 deletions tests/flytekit/unit/remote/test_lazy_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,16 @@ def _getter():
e.compile(ctx)
assert e._entity is not None
assert e.entity == dummy_task


def test_lazy_loading_exception():
def _getter():
raise AttributeError("Error")

e = LazyEntity("x", _getter)
assert e.name == "x"
assert e._entity is None
with pytest.raises(RuntimeError) as exc:
assert e.blah

assert isinstance(exc.value.__cause__, AttributeError)

0 comments on commit cafad0c

Please sign in to comment.