Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create TFJob and PyTorchJob from Function APIs in the Training SDK #1659

Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/integration-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ jobs:
- name: Run tests
run: |
pip install pytest
python3 -m pip install -r sdk/python/requirements.txt; pytest sdk/python/test --log-cli-level=info
python3 -m pip install -e sdk/python; pytest sdk/python/test --log-cli-level=info
777 changes: 777 additions & 0 deletions sdk/python/examples/create-pytorchjob-from-func.ipynb

Large diffs are not rendered by default.

209 changes: 139 additions & 70 deletions sdk/python/kubeflow/training/api/mpi_job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from .mpi_job_watch import watch as mpijob_watch

logging.basicConfig(format='%(message)s')
logging.basicConfig(format="%(message)s")
logging.getLogger().setLevel(logging.INFO)


Expand All @@ -39,7 +39,8 @@ def wrap_log_stream(q, stream):
return
except Exception as e:
raise RuntimeError(
"Exception when calling CoreV1Api->read_namespaced_pod_log: %s\n" % e)
"Exception when calling CoreV1Api->read_namespaced_pod_log: %s\n" % e
)


def get_log_queue_pool(streams):
Expand All @@ -52,8 +53,13 @@ def get_log_queue_pool(streams):


class MPIJobClient(object):
def __init__(self, config_file=None, context=None, # pylint: disable=too-many-arguments
client_configuration=None, persist_config=True):
def __init__(
self,
config_file=None,
context=None, # pylint: disable=too-many-arguments
client_configuration=None,
persist_config=True,
):
"""
MPIJob client constructor
:param config_file: kubeconfig file, defaults to ~/.kube/config
Expand All @@ -66,7 +72,8 @@ def __init__(self, config_file=None, context=None, # pylint: disable=too-many-a
config_file=config_file,
context=context,
client_configuration=client_configuration,
persist_config=persist_config)
persist_config=persist_config,
)
else:
config.load_incluster_config()

Expand All @@ -86,20 +93,24 @@ def create(self, mpijob, namespace=None):

try:
outputs = self.custom_api.create_namespaced_custom_object(
constants.MPIJOB_GROUP,
constants.KUBEFLOW_GROUP,
constants.MPIJOB_VERSION,
namespace,
constants.MPIJOB_PLURAL,
mpijob)
mpijob,
)
except client.rest.ApiException as e:
raise RuntimeError(
"Exception when calling CustomObjectsApi->create_namespaced_custom_object:\
%s\n" % e)
%s\n"
% e
)

return outputs

def get(self, name=None, namespace=None, watch=False,
timeout_seconds=600): # pylint: disable=inconsistent-return-statements
def get(
self, name=None, namespace=None, watch=False, timeout_seconds=600
): # pylint: disable=inconsistent-return-statements
"""
Get the mpijob
:param name: existing mpijob name, if not defined, the get all mpijobs in the namespace.
Expand All @@ -114,17 +125,17 @@ def get(self, name=None, namespace=None, watch=False,
if name:
if watch:
mpijob_watch(
name=name,
namespace=namespace,
timeout_seconds=timeout_seconds)
name=name, namespace=namespace, timeout_seconds=timeout_seconds
)
else:
thread = self.custom_api.get_namespaced_custom_object(
constants.MPIJOB_GROUP,
constants.KUBEFLOW_GROUP,
constants.MPIJOB_VERSION,
namespace,
constants.MPIJOB_PLURAL,
name,
async_req=True)
async_req=True,
)

mpijob = None
try:
Expand All @@ -134,24 +145,28 @@ def get(self, name=None, namespace=None, watch=False,
except client.rest.ApiException as e:
raise RuntimeError(
"Exception when calling CustomObjectsApi->get_namespaced_custom_object:\
%s\n" % e)
%s\n"
% e
)
except Exception as e:
raise RuntimeError(
"There was a problem to get MPIJob {0} in namespace {1}. Exception: \
{2} ".format(name, namespace, e))
{2} ".format(
name, namespace, e
)
)
return mpijob
else:
if watch:
mpijob_watch(
namespace=namespace,
timeout_seconds=timeout_seconds)
mpijob_watch(namespace=namespace, timeout_seconds=timeout_seconds)
else:
thread = self.custom_api.list_namespaced_custom_object(
constants.MPIJOB_GROUP,
constants.KUBEFLOW_GROUP,
constants.MPIJOB_VERSION,
namespace,
constants.MPIJOB_PLURAL,
async_req=True)
async_req=True,
)

mpijobs = None
try:
Expand All @@ -161,11 +176,16 @@ def get(self, name=None, namespace=None, watch=False,
except client.rest.ApiException as e:
raise RuntimeError(
"Exception when calling CustomObjectsApi->list_namespaced_custom_object:\
%s\n" % e)
%s\n"
% e
)
except Exception as e:
raise RuntimeError(
"There was a problem to list MPIJobs in namespace {0}. \
Exception: {1} ".format(namespace, e))
Exception: {1} ".format(
namespace, e
)
)
return mpijobs

def patch(self, name, mpijob, namespace=None):
Expand All @@ -181,16 +201,19 @@ def patch(self, name, mpijob, namespace=None):

try:
outputs = self.custom_api.patch_namespaced_custom_object(
constants.MPIJOB_GROUP,
constants.KUBEFLOW_GROUP,
constants.MPIJOB_VERSION,
namespace,
constants.MPIJOB_PLURAL,
name,
mpijob)
mpijob,
)
except client.rest.ApiException as e:
raise RuntimeError(
"Exception when calling CustomObjectsApi->patch_namespaced_custom_object:\
%s\n" % e)
%s\n"
% e
)

return outputs

Expand All @@ -206,23 +229,29 @@ def delete(self, name, namespace=None):

try:
return self.custom_api.delete_namespaced_custom_object(
group=constants.MPIJOB_GROUP,
group=constants.KUBEFLOW_GROUP,
version=constants.MPIJOB_VERSION,
namespace=namespace,
plural=constants.MPIJOB_PLURAL,
name=name,
body=client.V1DeleteOptions())
body=client.V1DeleteOptions(),
)
except client.rest.ApiException as e:
raise RuntimeError(
"Exception when calling CustomObjectsApi->delete_namespaced_custom_object:\
%s\n" % e)

def wait_for_job(self, name, # pylint: disable=inconsistent-return-statements
namespace=None,
timeout_seconds=600,
polling_interval=30,
watch=False,
status_callback=None):
%s\n"
% e
)

def wait_for_job(
self,
name, # pylint: disable=inconsistent-return-statements
namespace=None,
timeout_seconds=600,
polling_interval=30,
watch=False,
status_callback=None,
):
"""Wait for the specified job to finish.

:param name: Name of the TfJob.
Expand All @@ -240,24 +269,27 @@ def wait_for_job(self, name, # pylint: disable=inconsistent-return-statements

if watch:
mpijob_watch(
name=name,
namespace=namespace,
timeout_seconds=timeout_seconds)
name=name, namespace=namespace, timeout_seconds=timeout_seconds
)
else:
return self.wait_for_condition(
name,
["Succeeded", "Failed"],
namespace=namespace,
timeout_seconds=timeout_seconds,
polling_interval=polling_interval,
status_callback=status_callback)

def wait_for_condition(self, name,
expected_condition,
namespace=None,
timeout_seconds=600,
polling_interval=30,
status_callback=None):
status_callback=status_callback,
)

def wait_for_condition(
self,
name,
expected_condition,
namespace=None,
timeout_seconds=600,
polling_interval=30,
status_callback=None,
):
"""Waits until any of the specified conditions occur.

:param name: Name of the job.
Expand Down Expand Up @@ -296,7 +328,9 @@ def wait_for_condition(self, name,

raise RuntimeError(
"Timeout waiting for MPIJob {0} in namespace {1} to enter one of the "
"conditions {2}.".format(name, namespace, expected_condition), mpijob)
"conditions {2}.".format(name, namespace, expected_condition),
mpijob,
)

def get_job_status(self, name, namespace=None):
"""Returns MPIJob status, such as Running, Failed or Succeeded.
Expand Down Expand Up @@ -332,8 +366,14 @@ def is_job_succeeded(self, name, namespace=None):
mpijob_status = self.get_job_status(name, namespace=namespace)
return mpijob_status.lower() == "succeeded"

def get_pod_names(self, name, namespace=None, master=False, # pylint: disable=inconsistent-return-statements
replica_type=None, replica_index=None):
def get_pod_names(
self,
name,
namespace=None,
master=False, # pylint: disable=inconsistent-return-statements
replica_type=None,
replica_index=None,
):
"""
Get pod names of MPIJob.
:param name: mpijob name
Expand All @@ -348,29 +388,40 @@ def get_pod_names(self, name, namespace=None, master=False, # pylint: disable=i
if namespace is None:
namespace = utils.get_default_target_namespace()

labels = utils.get_job_labels(name, master=master,
replica_type=replica_type,
replica_index=replica_index)
labels = utils.get_job_labels(
name, master=master, replica_type=replica_type, replica_index=replica_index
)
try:
resp = self.core_api.list_namespaced_pod(
namespace, label_selector=utils.to_selector(labels))
namespace, label_selector=utils.to_selector(labels)
)
except client.rest.ApiException as e:
raise RuntimeError(
"Exception when calling CoreV1Api->read_namespaced_pod_log: %s\n" % e)
"Exception when calling CoreV1Api->read_namespaced_pod_log: %s\n" % e
)

pod_names = []
for pod in resp.items:
if pod.metadata and pod.metadata.name:
pod_names.append(pod.metadata.name)

if not pod_names:
logging.warning("Not found Pods of the MPIJob %s with the labels %s.", name, labels)
logging.warning(
"Not found Pods of the MPIJob %s with the labels %s.", name, labels
)
else:
return set(pod_names)

def get_logs(self, name, namespace=None, master=True,
replica_type=None, replica_index=None,
follow=False, container="mpi"):
def get_logs(
self,
name,
namespace=None,
master=True,
replica_type=None,
replica_index=None,
follow=False,
container="mpi",
):
"""
Get training logs of the MPIJob.
By default only get the logs of Pod that has labels 'job-role: master'.
Expand All @@ -389,16 +440,27 @@ def get_logs(self, name, namespace=None, master=True,
if namespace is None:
namespace = utils.get_default_target_namespace()

pod_names = list(self.get_pod_names(name, namespace=namespace,
master=master,
replica_type=replica_type,
replica_index=replica_index))
pod_names = list(
self.get_pod_names(
name,
namespace=namespace,
master=master,
replica_type=replica_type,
replica_index=replica_index,
)
)
if pod_names:
if follow:
log_streams = []
for pod in pod_names:
log_streams.append(k8s_watch.Watch().stream(self.core_api.read_namespaced_pod_log,
name=pod, namespace=namespace, container=container))
log_streams.append(
k8s_watch.Watch().stream(
self.core_api.read_namespaced_pod_log,
name=pod,
namespace=namespace,
container=container,
)
)
finished = [False for _ in log_streams]

# create thread and queue per stream, for non-blocking iteration
Expand All @@ -424,11 +486,18 @@ def get_logs(self, name, namespace=None, master=True,
else:
for pod in pod_names:
try:
pod_logs = self.core_api.read_namespaced_pod_log(pod, namespace, container=container)
pod_logs = self.core_api.read_namespaced_pod_log(
pod, namespace, container=container
)
logging.info("The logs of Pod %s:\n %s", pod, pod_logs)
except client.rest.ApiException as e:
raise RuntimeError(
"Exception when calling CoreV1Api->read_namespaced_pod_log: %s\n" % e)
"Exception when calling CoreV1Api->read_namespaced_pod_log: %s\n"
% e
)
else:
raise RuntimeError("Not found Pods of the MPIJob {} "
"in namespace {}".format(name, namespace))
raise RuntimeError(
"Not found Pods of the MPIJob {} "
"in namespace {}".format(name, namespace)
)

Loading