Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ctx Context can be used within shell tasks - to access context vars and secrets #832

Merged
merged 3 commits into from
Jan 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,33 @@ class SecretsManager(object):
All configuration values can always be overridden by injecting an environment variable
"""

class _GroupSecrets(object):
"""
This is a dummy class whose sole purpose is to support "attribute" style lookup for secrets
"""

def __init__(self, group: str, sm: typing.Any):
self._group = group
self._sm = sm

def __getattr__(self, item: str) -> str:
"""
Returns the secret that matches "group"."key"
the key, here is the item
"""
return self._sm.get(self._group, item)

def __init__(self):
self._base_dir = str(secrets.SECRETS_DEFAULT_DIR.get()).strip()
self._file_prefix = str(secrets.SECRETS_FILE_PREFIX.get()).strip()
self._env_prefix = str(secrets.SECRETS_ENV_PREFIX.get()).strip()

def __getattr__(self, item: str) -> _GroupSecrets:
"""
returns a new _GroupSecrets objects, that allows all keys within this group to be looked up like attributes
"""
return self._GroupSecrets(item, self)

def get(self, group: str, key: str) -> str:
"""
Retrieves a secret using the resolution order -> Env followed by file. If not found raises a ValueError
Expand Down
26 changes: 20 additions & 6 deletions flytekit/extras/tasks/shell.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import collections
import datetime
import logging
import os
Expand All @@ -7,6 +6,7 @@
import typing
from dataclasses import dataclass

import flytekit
from flytekit.core.context_manager import ExecutionParameters
from flytekit.core.interface import Interface
from flytekit.core.python_function_task import PythonInstanceTask
Expand Down Expand Up @@ -38,7 +38,15 @@ def _dummy_task_func():
return None


T = typing.TypeVar("T")
class AttrDict(dict):
"""
Convert a dictionary to an attribute style lookup. Do not use this in regular places, this is used for
namespacing inputs and outputs
"""

def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self


class _PythonFStringInterpolizer:
Expand Down Expand Up @@ -73,16 +81,22 @@ def interpolate(
"""
inputs = inputs or {}
outputs = outputs or {}
reused_vars = inputs.keys() & outputs.keys()
if reused_vars:
raise ValueError(f"Variables {reused_vars} in Query cannot be shared between inputs and outputs.")
consolidated_args = collections.ChainMap(inputs, outputs)
inputs = AttrDict(inputs)
outputs = AttrDict(outputs)
consolidated_args = {
"inputs": inputs,
"outputs": outputs,
"ctx": flytekit.current_context(),
}
try:
return self._Formatter().format(tmpl, **consolidated_args)
except KeyError as e:
raise ValueError(f"Variable {e} in Query not found in inputs {consolidated_args.keys()}")


T = typing.TypeVar("T")


class ShellTask(PythonInstanceTask[T]):
""" """

Expand Down
87 changes: 86 additions & 1 deletion tests/flytekit/unit/core/test_context_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager, look_up_image_info
import os

import py
import pytest

from flytekit.configuration import secrets
from flytekit.core.context_manager import (
ExecutionState,
FlyteContext,
FlyteContextManager,
SecretsManager,
look_up_image_info,
)


class SampleTestClass(object):
Expand Down Expand Up @@ -65,3 +77,76 @@ def test_additional_context():
)
) as exec_ctx_inner:
assert exec_ctx_inner.execution_state.additional_context == {1: "inner", 2: "foo", 3: "baz"}


def test_secrets_manager_default():
with pytest.raises(ValueError):
sec = SecretsManager()
sec.get("group", "key")

with pytest.raises(ValueError):
_ = sec.group.key


def test_secrets_manager_get_envvar():
sec = SecretsManager()
with pytest.raises(ValueError):
sec.get_secrets_env_var("test", "")
with pytest.raises(ValueError):
sec.get_secrets_env_var("", "x")
assert sec.get_secrets_env_var("group", "test") == f"{secrets.SECRETS_ENV_PREFIX.get()}GROUP_TEST"


def test_secrets_manager_get_file():
sec = SecretsManager()
with pytest.raises(ValueError):
sec.get_secrets_file("test", "")
with pytest.raises(ValueError):
sec.get_secrets_file("", "x")
assert sec.get_secrets_file("group", "test") == os.path.join(
secrets.SECRETS_DEFAULT_DIR.get(),
"group",
f"{secrets.SECRETS_FILE_PREFIX.get()}test",
)


def test_secrets_manager_file(tmpdir: py.path.local):
tmp = tmpdir.mkdir("file_test").dirname
os.environ["FLYTE_SECRETS_DEFAULT_DIR"] = tmp
sec = SecretsManager()
f = os.path.join(tmp, "test")
with open(f, "w+") as w:
w.write("my-password")

with pytest.raises(ValueError):
sec.get("test", "")
with pytest.raises(ValueError):
sec.get("", "x")
# Group dir not exists
with pytest.raises(ValueError):
sec.get("group", "test")

g = os.path.join(tmp, "group")
os.makedirs(g)
f = os.path.join(g, "test")
with open(f, "w+") as w:
w.write("my-password")
assert sec.get("group", "test") == "my-password"
assert sec.group.test == "my-password"
del os.environ["FLYTE_SECRETS_DEFAULT_DIR"]


def test_secrets_manager_bad_env():
with pytest.raises(ValueError):
os.environ["TEST"] = "value"
sec = SecretsManager()
sec.get("group", "test")


def test_secrets_manager_env():
sec = SecretsManager()
os.environ[sec.get_secrets_env_var("group", "test")] = "value"
assert sec.get("group", "test") == "value"

os.environ[sec.get_secrets_env_var(group="group", key="key")] = "value"
assert sec.get(group="group", key="key") == "value"
72 changes: 47 additions & 25 deletions tests/flytekit/unit/extras/tasks/test_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
from dataclasses_json import dataclass_json

import flytekit
from flytekit import kwtypes
from flytekit.extras.tasks.shell import OutputLocation, ShellTask
from flytekit.types.directory import FlyteDirectory
Expand Down Expand Up @@ -46,8 +47,8 @@ def test_input_substitution_primitive():
name="test",
script="""
set -ex
cat {f}
echo "Hello World {y} on {j}"
cat {inputs.f}
echo "Hello World {inputs.y} on {inputs.j}"
""",
inputs=kwtypes(f=str, y=int, j=datetime.datetime),
)
Expand All @@ -62,24 +63,46 @@ def test_input_substitution_files():
t = ShellTask(
name="test",
script="""
cat {f}
echo "Hello World {y} on {j}"
cat {inputs.f}
echo "Hello World {inputs.y} on {inputs.j}"
""",
inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime),
)

assert t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0)) is None


def test_input_substitution_files_ctx():
sec = flytekit.current_context().secrets
envvar = sec.get_secrets_env_var("group", "key")
os.environ[envvar] = "value"
assert sec.get("group", "key") == "value"

t = ShellTask(
name="test",
script="""
export EXEC={ctx.execution_id}
export SECRET={ctx.secrets.group.key}
cat {inputs.f}
echo "Hello World {inputs.y} on {inputs.j}"
""",
inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime),
debug=True,
)

assert t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0)) is None
del os.environ[envvar]


def test_input_output_substitution_files():
script = "cat {f} > {y}"
script = "cat {inputs.f} > {outputs.y}"
t = ShellTask(
name="test",
debug=True,
script=script,
inputs=kwtypes(f=CSVFile),
output_locs=[
OutputLocation(var="y", var_type=FlyteFile, location="{f}.mod"),
OutputLocation(var="y", var_type=FlyteFile, location="{inputs.f}.mod"),
],
)

Expand All @@ -101,15 +124,15 @@ def test_input_output_substitution_files():

def test_input_single_output_substitution_files():
script = """
cat {f} >> {z}
echo "Hello World {y} on {j}"
cat {inputs.f} >> {outputs.z}
echo "Hello World {inputs.y} on {inputs.j}"
"""
t = ShellTask(
name="test",
debug=True,
script=script,
inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime),
output_locs=[OutputLocation(var="z", var_type=FlyteFile, location="{f}.pyc")],
output_locs=[OutputLocation(var="z", var_type=FlyteFile, location="{inputs.f}.pyc")],
)

assert t.script == script
Expand All @@ -122,14 +145,14 @@ def test_input_single_output_substitution_files():
[
(
"""
cat {missing} >> {z}
echo "Hello World {y} on {j} - output {x}"
cat {missing} >> {outputs.z}
echo "Hello World {inputs.y} on {inputs.j} - output {outputs.x}"
"""
),
(
"""
cat {f} {missing} >> {z}
echo "Hello World {y} on {j} - output {x}"
cat {inputs.f} {missing} >> {outputs.z}
echo "Hello World {inputs.y} on {inputs.j} - output {outputs.x}"
"""
),
],
Expand All @@ -141,31 +164,30 @@ def test_input_output_extra_and_missing_variables(script):
script=script,
inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime),
output_locs=[
OutputLocation(var="x", var_type=FlyteDirectory, location="{y}"),
OutputLocation(var="z", var_type=FlyteFile, location="{f}.pyc"),
OutputLocation(var="x", var_type=FlyteDirectory, location="{inputs.y}"),
OutputLocation(var="z", var_type=FlyteFile, location="{inputs.f}.pyc"),
],
)

with pytest.raises(ValueError, match="missing"):
t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0))


def test_cannot_reuse_variables_for_both_inputs_and_outputs():
def test_reuse_variables_for_both_inputs_and_outputs():
t = ShellTask(
name="test",
debug=True,
script="""
cat {f} >> {y}
echo "Hello World {y} on {j}"
cat {inputs.f} >> {outputs.y}
echo "Hello World {inputs.y} on {inputs.j}"
""",
inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime),
output_locs=[
OutputLocation(var="y", var_type=FlyteFile, location="{f}.pyc"),
OutputLocation(var="y", var_type=FlyteFile, location="{inputs.f}.pyc"),
],
)

with pytest.raises(ValueError, match="Variables {'y'} in Query"):
t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0))
t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0))


def test_can_use_complex_types_for_inputs_to_f_string_template():
Expand All @@ -177,10 +199,10 @@ class InputArgs:
t = ShellTask(
name="test",
debug=True,
script="""cat {input_args.in_file} >> {input_args.in_file}.tmp""",
script="""cat {inputs.input_args.in_file} >> {inputs.input_args.in_file}.tmp""",
inputs=kwtypes(input_args=InputArgs),
output_locs=[
OutputLocation(var="x", var_type=FlyteFile, location="{input_args.in_file}.tmp"),
OutputLocation(var="x", var_type=FlyteFile, location="{inputs.input_args.in_file}.tmp"),
],
)

Expand All @@ -196,8 +218,8 @@ def test_shell_script():
script_file=script_sh,
inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime),
output_locs=[
OutputLocation(var="x", var_type=FlyteDirectory, location="{y}"),
OutputLocation(var="z", var_type=FlyteFile, location="{f}.pyc"),
OutputLocation(var="x", var_type=FlyteDirectory, location="{inputs.y}"),
OutputLocation(var="z", var_type=FlyteFile, location="{inputs.f}.pyc"),
],
)

Expand Down
4 changes: 2 additions & 2 deletions tests/flytekit/unit/extras/tasks/testdata/script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@

set -ex

cat "{f}" >> "{z}"
echo "Hello World {y} on {j} - output {x}"
cat "{inputs.f}" >> "{outputs.z}"
echo "Hello World {inputs.y} on {inputs.j} - output {outputs.x}"