Skip to content

Commit

Permalink
Updates to enable e2e test for v1alpha2 (#629)
Browse files Browse the repository at this point in the history
* Pass tfjob version to wait_for_job

* Handle case where replicas is not set

* Update
  • Loading branch information
Ankush Agarwal authored and k8s-ci-robot committed Jun 7, 2018
1 parent bc12a01 commit 76661c0
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 14 deletions.
35 changes: 25 additions & 10 deletions py/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def run_test(args): # pylint: disable=too-many-branches,too-many-statements

start = time.time()

try:
try: # pylint: disable=too-many-nested-blocks
# We repeat the test multiple times.
# This ensures that if we delete the job we can create a new job with the
# same name.
Expand All @@ -277,13 +277,22 @@ def run_test(args): # pylint: disable=too-many-branches,too-many-statements

logging.info("Created job %s in namespaces %s", name, namespace)
results = tf_job_client.wait_for_job(
api_client, namespace, name, status_callback=tf_job_client.log_status)
api_client, namespace, name, args.tfjob_version, status_callback=tf_job_client.log_status)

if results.get("status", {}).get("state", {}).lower() != "succeeded":
t.failure = "Trial {0} Job {1} in namespace {2} in state {3}".format(
trial, name, namespace, results.get("status", {}).get("state", None))
logging.error(t.failure)
break
if args.tfjob_version == "v1alpha1":
if results.get("status", {}).get("state", {}).lower() != "succeeded":
t.failure = "Trial {0} Job {1} in namespace {2} in state {3}".format(
trial, name, namespace, results.get("status", {}).get("state", None))
logging.error(t.failure)
break
else:
# For v1alpha2 check for non-empty completionTime
last_condition = results.get("status", {}).get("conditions", [])[-1]
if last_condition.get("type", "").lower() != "succeeded":
t.failure = "Trial {0} Job {1} in namespace {2} in status {3}".format(
trial, name, namespace, results.get("status", {}))
logging.error(t.failure)
break

runtime_id = results.get("spec", {}).get("RuntimeId")
logging.info("Trial %s Job %s in namespace %s runtime ID %s", trial, name,
Expand All @@ -294,8 +303,14 @@ def run_test(args): # pylint: disable=too-many-branches,too-many-statements
created_pods, created_services = parse_events(events)

num_expected = 0
for replica in results.get("spec", {}).get("replicaSpecs", []):
num_expected += replica.get("replicas", 0)
if args.tfjob_version == "v1alpha1":
for replica in results.get("spec", {}).get("replicaSpecs", []):
num_expected += replica.get("replicas", 0)
else:
for replicakey in results.get("spec", {}).get("tfReplicaSpecs", {}):
replica_spec = results.get("spec", {}).get("tfReplicaSpecs", {}).get(replicakey, {})
if replica_spec:
num_expected += replica_spec.get("replicas", 1)

creation_failures = []
if len(created_pods) != num_expected:
Expand All @@ -320,7 +335,7 @@ def run_test(args): # pylint: disable=too-many-branches,too-many-statements

wait_for_pods_to_be_deleted(api_client, namespace, pod_selector)

tf_job_client.delete_tf_job(api_client, namespace, name)
tf_job_client.delete_tf_job(api_client, namespace, name, version=args.tfjob_version)

logging.info("Waiting for job %s in namespaces %s to be deleted.", name,
namespace)
Expand Down
10 changes: 8 additions & 2 deletions py/tf_job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,14 @@ def wait_for_job(client,
status_callback(results)

# If we poll the CRD quick enough status won't have been set yet.
if results.get("status", {}).get("phase", {}) == "Done":
return results
if version == "v1alpha1":
if results.get("status", {}).get("phase", {}) == "Done":
return results
else:
# For v1alpha2 check for non-empty completionTime
if results.get("status", {}).get("completionTime", ""):
return results


if datetime.datetime.now() + polling_interval > end_time:
raise util.TimeoutError(
Expand Down
44 changes: 42 additions & 2 deletions test/workflows/components/simple_tfjob.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ local parts(namespace, name, image) = {
local actualImage = if image != "" then
image
else defaultTestImage,
job:: {
apiVersion: params.apiVersion,
job:: if params.apiVersion == "kubeflow.org/v1alpha1" then {
apiVersion: "kubeflow.org/v1alpha1",
kind: "TFJob",
metadata: {
name: name,
Expand Down Expand Up @@ -63,6 +63,46 @@ local parts(namespace, name, image) = {
},
],
},
} else {
apiVersion: "kubeflow.org/v1alpha2",
kind: "TFJob",
metadata: {
name: name,
namespace: namespace,
},
spec: {
tfReplicaSpecs: {
PS: {
replicas: 2,
restartPolicy: "Never",
template: {
spec: {
containers: [
{
name: "tensorflow",
image: actualImage,
},
],
},
},
},
Worker: {
replicas: 4,
restartPolicy: "Never",
template: {
spec: {
containers: [
{
name: "tensorflow",
image: actualImage,
command: ["python", "/var/tf_dist_mnist/dist_mnist.py", "--train_steps=100"],
},
],
},
},
},
},
},
},
};

Expand Down

0 comments on commit 76661c0

Please sign in to comment.