Skip to content

Commit

Permalink
Pass tfjob version to wait_for_job
Browse files Browse the repository at this point in the history
  • Loading branch information
Ankush Agarwal committed Jun 6, 2018
1 parent 960790a commit c25eacf
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 13 deletions.
31 changes: 22 additions & 9 deletions py/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,21 @@ def run_test(args): # pylint: disable=too-many-branches,too-many-statements

logging.info("Created job %s in namespaces %s", name, namespace)
results = tf_job_client.wait_for_job(
api_client, namespace, name, status_callback=tf_job_client.log_status)
api_client, namespace, name, args.tfjob_version, status_callback=tf_job_client.log_status)

if results.get("status", {}).get("state", {}).lower() != "succeeded":
t.failure = "Trial {0} Job {1} in namespace {2} in state {3}".format(
trial, name, namespace, results.get("status", {}).get("state", None))
logging.error(t.failure)
break
if args.tfjob_version == "v1alpha1":
if results.get("status", {}).get("state", {}).lower() != "succeeded":
t.failure = "Trial {0} Job {1} in namespace {2} in state {3}".format(
trial, name, namespace, results.get("status", {}).get("state", None))
logging.error(t.failure)
break
else:
# For v1alpha2 check for non-empty completionTime
if results.get("status", {}).get("conditions", [])[-1].get("type", "").lower() != "succeeded":
t.failure = "Trial {0} Job {1} in namespace {2} in status {3}".format(
trial, name, namespace, results.get("status", {}))
logging.error(t.failure)
break

runtime_id = results.get("spec", {}).get("RuntimeId")
logging.info("Trial %s Job %s in namespace %s runtime ID %s", trial, name,
Expand All @@ -293,8 +301,13 @@ def run_test(args): # pylint: disable=too-many-branches,too-many-statements
created_pods, created_services = parse_events(events)

num_expected = 0
for replica in results.get("spec", {}).get("replicaSpecs", []):
num_expected += replica.get("replicas", 0)
if args.tfjob_version == "v1alpha1":
for replica in results.get("spec", {}).get("replicaSpecs", []):
num_expected += replica.get("replicas", 0)
else:
for replicakey in results.get("spec", {}).get("tfReplicaSpecs", {}):
logging.info("replicakey: %s", results.get("spec", {}).get("tfReplicaSpecs", {}).get(replicakey, {}))
num_expected += results.get("spec", {}).get("tfReplicaSpecs", {}).get(replicakey, {}).get("replicas", 0)

creation_failures = []
if len(created_pods) != num_expected:
Expand All @@ -319,7 +332,7 @@ def run_test(args): # pylint: disable=too-many-branches,too-many-statements

wait_for_pods_to_be_deleted(api_client, namespace, pod_selector)

tf_job_client.delete_tf_job(api_client, namespace, name)
tf_job_client.delete_tf_job(api_client, namespace, name, version=args.tfjob_version)

logging.info("Waiting for job %s in namespaces %s to be deleted.", name,
namespace)
Expand Down
10 changes: 8 additions & 2 deletions py/tf_job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,14 @@ def wait_for_job(client,
status_callback(results)

# If we poll the CRD quick enough status won't have been set yet.
if results.get("status", {}).get("phase", {}) == "Done":
return results
if version == "v1alpha1":
if results.get("status", {}).get("phase", {}) == "Done":
return results
else:
# For v1alpha2 check for non-empty completionTime
if results.get("status", {}).get("completionTime", ""):
return results


if datetime.datetime.now() + polling_interval > end_time:
raise util.TimeoutError(
Expand Down
44 changes: 42 additions & 2 deletions test/workflows/components/simple_tfjob.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ local parts(namespace, name, image) = {
local actualImage = if image != "" then
image
else defaultTestImage,
job:: {
apiVersion: params.apiVersion,
job:: if params.apiVersion == "kubeflow.org/v1alpha1" then {
apiVersion: "kubeflow.org/v1alpha1",
kind: "TFJob",
metadata: {
name: name,
Expand Down Expand Up @@ -63,6 +63,46 @@ local parts(namespace, name, image) = {
},
],
},
} else {
apiVersion: "kubeflow.org/v1alpha2",
kind: "TFJob",
metadata: {
name: name,
namespace: namespace,
},
spec: {
tfReplicaSpecs: {
PS: {
replicas: 2,
restartPolicy: "Never",
template: {
spec: {
containers: [
{
name: "tensorflow",
image: actualImage,
},
],
},
},
},
Worker: {
replicas: 4,
restartPolicy: "Never",
template: {
spec: {
containers: [
{
name: "tensorflow",
image: actualImage,
command: ["python", "/var/tf_dist_mnist/dist_mnist.py", "--train_steps=100"],
},
],
},
},
},
},
},
},
};

Expand Down

0 comments on commit c25eacf

Please sign in to comment.