Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sagemaker script back in #1112

Merged
merged 2 commits into from
Aug 7, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Unit test added
Signed-off-by: Ketan Umare <ketan.umare@gmail.com>
kumare3 committed Jul 28, 2022
commit 7628b50c74f4243e146e3765325eb9eb21dbe795
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
import sys
from unittest import mock

from scripts.flytekit_sagemaker_runner import run as _flyte_sagemaker_run

cmd = []
cmd.extend(["--__FLYTE_ENV_VAR_env1__", "val1"])
cmd.extend(["--__FLYTE_ENV_VAR_env2__", "val2"])
cmd.extend(["--__FLYTE_CMD_0_service_venv__", "__FLYTE_CMD_DUMMY_VALUE__"])
cmd.extend(["--__FLYTE_CMD_1_pyflyte-execute__", "__FLYTE_CMD_DUMMY_VALUE__"])
cmd.extend(["--__FLYTE_CMD_2_--task-module__", "__FLYTE_CMD_DUMMY_VALUE__"])
cmd.extend(["--__FLYTE_CMD_3_blah__", "__FLYTE_CMD_DUMMY_VALUE__"])
cmd.extend(["--__FLYTE_CMD_4_--task-name__", "__FLYTE_CMD_DUMMY_VALUE__"])
cmd.extend(["--__FLYTE_CMD_5_bloh__", "__FLYTE_CMD_DUMMY_VALUE__"])
cmd.extend(["--__FLYTE_CMD_6_--output-prefix__", "__FLYTE_CMD_DUMMY_VALUE__"])
cmd.extend(["--__FLYTE_CMD_7_s3://fake-bucket__", "__FLYTE_CMD_DUMMY_VALUE__"])
cmd.extend(["--__FLYTE_CMD_8_--inputs__", "__FLYTE_CMD_DUMMY_VALUE__"])
cmd.extend(["--__FLYTE_CMD_9_s3://fake-bucket__", "__FLYTE_CMD_DUMMY_VALUE__"])


@mock.patch.dict("os.environ")
@mock.patch("subprocess.run")
def test(mock_subprocess_run):
_flyte_sagemaker_run(cmd)
assert "env1" in os.environ
assert "env2" in os.environ
assert os.environ["env1"] == "val1"
assert os.environ["env2"] == "val2"
mock_subprocess_run.assert_called_with(
"service_venv pyflyte-execute --task-module blah --task-name bloh "
"--output-prefix s3://fake-bucket --inputs s3://fake-bucket".split(),
stdout=sys.stdout,
stderr=sys.stderr,
encoding="utf-8",
check=True,
)