Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pyflyte-exec-alternative as the entrypoint for SageMaker Custom Training Job #163

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

import flytekit.plugins # noqa: F401

__version__ = "0.12.0b1"
__version__ = "0.12.1b0"
84 changes: 84 additions & 0 deletions flytekit/bin/entrypoint_alt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from __future__ import absolute_import

import importlib as _importlib

import click as _click

import flytekit.common.types.helpers as _type_helpers
from flytekit.common import utils as _utils
from flytekit.common.exceptions import scopes as _scopes
from flytekit.configuration import TemporaryConfiguration as _TemporaryConfiguration
from flytekit.configuration import internal as _internal_config
from flytekit.engines import loader as _engine_loader

SAGEMAKER_CONTAINER_LOCAL_INPUT_PREFIX = "/opt/ml/input/data"

def build_sdk_type_map_from_typed_interface(interface):


@_scopes.system_entry_point
def _execute_task(task_module, task_name, inputs, output_prefix, test, sagemaker_args):
with _TemporaryConfiguration(_internal_config.CONFIGURATION_PATH.get()):
# Load user code
task_module = _importlib.import_module(task_module)
task_def = getattr(task_module, task_name)

if not test:

# Parse the unknown arguments, and create a litealmap out from the task definition
map_of_input_values = {}
# Here we have an assumption that each option key will come with a value right after the key
for i in range(0, len(sagemaker_args), 2):
# Since the sagemaker_args are unprocessed, each of the option keys comes with a leading "--"
# We need to remove them
map_of_input_values[sagemaker_args[i][2:]] = sagemaker_args[i + 1]

map_of_sdk_types = {}
blob_and_schema_local_path_map = {}

for k, v in task_def.interface.inputs.items():
map_of_sdk_types[k] = _type_helpers.get_sdk_type_from_literal_type(v.type)

# We need to do some special handling of the blob-typed inputs, i.e., read them from predefined
# locations in the container
map_of_input_values.update(blob_and_schema_local_path_map)

input_literal_map = _type_helpers.pack_python_string_map_to_literal_map(
map_of_input_values, map_of_sdk_types,
)

# TODO 1. need to handle the case of "collection of blobs" or even "hierarchical collection of blobs"
# TODO 2. replace the blob uris with local
for k, v in task_def.interface.inputs.items():
if v.type.blob is not None or v.type.schema is not None:
blob_and_schema_local_path_map[k] = "{}/{}".format(SAGEMAKER_CONTAINER_LOCAL_INPUT_PREFIX, k)

_engine_loader.get_engine().get_task(task_def).execute(
input_literal_map, context={"output_prefix": output_prefix}
)


@_click.group()
def _pass_through():
pass


# pyflyte-execute-alt is an alternative pyflyte entrypoint specifically designed for SageMaker (currently)
# This entrypoint assumes no --inputs command-line option, and therefore it doesn't accept the input.pb file
# All the inputs will be passed into the entrypoint as unknown arguments
@_pass_through.command("pyflyte-execute-alt", context_settings=dict(ignore_unknown_options=True))
@_click.option("--task-module", required=True)
@_click.option("--task-name", required=True)
@_click.option("--inputs", required=True)
@_click.option("--output-prefix", required=True)
@_click.option("--test", is_flag=True)
@_click.argument("sagemaker_args", nargs=-1, type=_click.UNPROCESSED)
def execute_task_cmd(task_module, task_name, inputs, output_prefix, test, sagemaker_args):
_click.echo(_utils.get_version_message())
_click.echo("sagemaker_args : {}".format(sagemaker_args))
# Note that the unknown arguments are entirely unprocessed, so the leading "--" are still there
_execute_task(task_module, task_name, inputs, output_prefix, test, sagemaker_args)


if __name__ == "__main__":
_pass_through()
10 changes: 10 additions & 0 deletions flytekit/common/types/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,13 @@ def pack_python_std_map_to_literal_map(std_map, type_map):
:raises: flytekit.common.exceptions.user.FlyteTypeException
"""
return _literal_models.LiteralMap(literals={k: v.from_python_std(std_map[k]) for k, v in _six.iteritems(type_map)})


def pack_python_string_map_to_literal_map(str_map, type_map):
"""
:param dict[Text, Text] str_map:
:param dict[Text, flytekit.common.types.base_sdk_types.FlyteSdkType] type_map:
:rtype: flytekit.models.literals.LiteralMap
:raises: flytekit.common.exceptions.user.FlyteTypeException
"""
return _literal_models.LiteralMap(literals={k: v.from_string(str_map[k]) for k, v in _six.iteritems(type_map)})
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
entry_points={
"console_scripts": [
"pyflyte-execute=flytekit.bin.entrypoint:execute_task_cmd",
"pyflyte-execute-alt=flytekit.bin.entrypoint_alt:execute_task_cmd",
"pyflyte=flytekit.clis.sdk_in_container.pyflyte:main",
"flyte-cli=flytekit.clis.flyte_cli.main:_flyte_cli",
]
Expand Down
29 changes: 29 additions & 0 deletions tests/flytekit/common/task_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,32 @@
@python_task
def add_one(wf_params, a, b):
b.set(a + 1)


@inputs(
train=Types.CSV,
validation=Types.MultiPartBlob,
a=Types.Integer,
b=Types.Float,
c=Types.String,
d=Types.Boolean,
e=Types.Datetime,
)
@outputs(
otrain=Types.CSV,
ovalidation=Types.MultiPartBlob,
oa=Types.Integer,
ob=Types.Float,
oc=Types.String,
od=Types.Boolean,
oe=Types.Datetime,
)
@python_task
def dummy_for_entrypoint_alt(wf_params, train, validation, a, b, c, d, e, otrain, ovalidation, oa, ob, oc, od, oe):
otrain.set(train)
ovalidation.set(validation)
oa.set(a)
ob.set(b)
oc.set(c)
od.set(d)
oe.set(e)
139 changes: 139 additions & 0 deletions tests/flytekit/unit/bin/test_python_entrypoint_alt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from __future__ import absolute_import

import os

import six
from click.testing import CliRunner
from dateutil import parser
from flyteidl.core import literals_pb2 as _literals_pb2

from flytekit.bin.entrypoint_alt import SAGEMAKER_CONTAINER_LOCAL_INPUT_PREFIX, _execute_task, execute_task_cmd
from flytekit.common import constants as _constants
from flytekit.common import utils as _utils
from flytekit.common.types import helpers as _type_helpers
from flytekit.configuration import TemporaryConfiguration as _TemporaryConfiguration
from flytekit.models import literals as _literal_models
from tests.flytekit.common import task_definitions as _task_defs


def _type_map_from_variable_map(variable_map):
return {k: _type_helpers.get_sdk_type_from_literal_type(v.type) for k, v in six.iteritems(variable_map)}


def test_single_step_entrypoint_in_proc():
with _TemporaryConfiguration(
os.path.join(os.path.dirname(__file__), "fake.config"),
internal_overrides={"project": "test", "domain": "development"},
):
raw_args = (
"--train",
"/local/host",
"--validation",
"s3://dummy",
"--a",
"1",
"--b",
"0.5",
"--c",
"val",
"--d",
"0",
"--e",
"20180612T09:55:22Z",
)
with _utils.AutoDeletingTempDir("out") as output_dir:
_execute_task(
task_module=_task_defs.dummy_for_entrypoint_alt.task_module,
task_name=_task_defs.dummy_for_entrypoint_alt.task_function_name,
output_prefix=output_dir.name,
test=False,
sagemaker_args=raw_args,
)
p = _utils.load_proto_from_file(
_literals_pb2.LiteralMap, os.path.join(output_dir.name, _constants.OUTPUT_FILE_NAME),
)

raw_args_map = {}
for i in range(0, len(raw_args), 2):
raw_args_map[raw_args[i][2:]] = raw_args[i + 1]

raw_map = _type_helpers.unpack_literal_map_to_sdk_python_std(
_literal_models.LiteralMap.from_flyte_idl(p),
_type_map_from_variable_map(_task_defs.dummy_for_entrypoint_alt.interface.outputs),
)

assert len(raw_map) == 7
assert raw_map["otrain"].uri.rstrip("/") == "{}/{}".format(SAGEMAKER_CONTAINER_LOCAL_INPUT_PREFIX, "train")
assert raw_map["ovalidation"].uri.rstrip("/") == "{}/{}".format(
SAGEMAKER_CONTAINER_LOCAL_INPUT_PREFIX, "validation"
)
assert raw_map["oa"] == 1
assert raw_map["ob"] == 0.5
assert raw_map["oc"] == "val"
assert raw_map["od"] is False
assert raw_map["oe"] == parser.parse("20180612T09:55:22Z")


def test_single_step_entrypoint_out_of_proc():
with _TemporaryConfiguration(
os.path.join(os.path.dirname(__file__), "fake.config"),
internal_overrides={"project": "test", "domain": "development"},
):
with _utils.AutoDeletingTempDir("in") as input_dir:
literal_map = _type_helpers.pack_python_std_map_to_literal_map(
{"a": 9}, _type_map_from_variable_map(_task_defs.add_one.interface.inputs),
)
input_file = os.path.join(input_dir.name, "inputs.pb")
_utils.write_proto_to_file(literal_map.to_flyte_idl(), input_file)

raw_args = (
"--train",
"s3://dummy",
"--validation",
"s3://dummy",
"--a",
"1",
"--b",
"0.5",
"--c",
"val",
"--d",
"0",
"--e",
"20180612T09:55:22Z",
)

with _utils.AutoDeletingTempDir("out") as output_dir:
cmd = []
cmd.extend(["--task-module", _task_defs.dummy_for_entrypoint_alt.task_module])
cmd.extend(["--task-name", _task_defs.dummy_for_entrypoint_alt.task_function_name])
cmd.extend(["--output-prefix", output_dir.name])
cmd.extend(raw_args)
result = CliRunner().invoke(execute_task_cmd, cmd)

assert result.exit_code == 0
p = _utils.load_proto_from_file(
_literals_pb2.LiteralMap, os.path.join(output_dir.name, _constants.OUTPUT_FILE_NAME),
)

raw_args_map = {}
for i in range(0, len(raw_args), 2):
raw_args_map[raw_args[i][2:]] = raw_args[i + 1]

raw_map = _type_helpers.unpack_literal_map_to_sdk_python_std(
_literal_models.LiteralMap.from_flyte_idl(p),
_type_map_from_variable_map(_task_defs.dummy_for_entrypoint_alt.interface.outputs),
)

assert len(raw_map) == 7
assert raw_map["otrain"].uri.rstrip("/") == "{}/{}".format(
SAGEMAKER_CONTAINER_LOCAL_INPUT_PREFIX, "train"
)
assert raw_map["ovalidation"].uri.rstrip("/") == "{}/{}".format(
SAGEMAKER_CONTAINER_LOCAL_INPUT_PREFIX, "validation"
)
assert raw_map["oa"] == 1
assert raw_map["ob"] == 0.5
assert raw_map["oc"] == "val"
assert raw_map["od"] is False
assert raw_map["oe"] == parser.parse("20180612T09:55:22Z")