Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to enable e2e test for v1alpha2 #629

Merged
merged 3 commits into from
Jun 7, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions py/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def run_test(args): # pylint: disable=too-many-branches,too-many-statements

start = time.time()

try:
try: # pylint: disable=too-many-nested-blocks
# We repeat the test multiple times.
# This ensures that if we delete the job we can create a new job with the
# same name.
Expand All @@ -277,13 +277,22 @@ def run_test(args): # pylint: disable=too-many-branches,too-many-statements

logging.info("Created job %s in namespaces %s", name, namespace)
results = tf_job_client.wait_for_job(
api_client, namespace, name, status_callback=tf_job_client.log_status)
api_client, namespace, name, args.tfjob_version, status_callback=tf_job_client.log_status)

if results.get("status", {}).get("state", {}).lower() != "succeeded":
t.failure = "Trial {0} Job {1} in namespace {2} in state {3}".format(
trial, name, namespace, results.get("status", {}).get("state", None))
logging.error(t.failure)
break
if args.tfjob_version == "v1alpha1":
if results.get("status", {}).get("state", {}).lower() != "succeeded":
t.failure = "Trial {0} Job {1} in namespace {2} in state {3}".format(
trial, name, namespace, results.get("status", {}).get("state", None))
logging.error(t.failure)
break
else:
# For v1alpha2 check for non-empty completionTime
last_condition = results.get("status", {}).get("conditions", [])[-1]
if last_condition.get("type", "").lower() != "succeeded":
t.failure = "Trial {0} Job {1} in namespace {2} in status {3}".format(
trial, name, namespace, results.get("status", {}))
logging.error(t.failure)
break

runtime_id = results.get("spec", {}).get("RuntimeId")
logging.info("Trial %s Job %s in namespace %s runtime ID %s", trial, name,
Expand All @@ -294,8 +303,14 @@ def run_test(args): # pylint: disable=too-many-branches,too-many-statements
created_pods, created_services = parse_events(events)

num_expected = 0
for replica in results.get("spec", {}).get("replicaSpecs", []):
num_expected += replica.get("replicas", 0)
if args.tfjob_version == "v1alpha1":
for replica in results.get("spec", {}).get("replicaSpecs", []):
num_expected += replica.get("replicas", 0)
else:
for replicakey in results.get("spec", {}).get("tfReplicaSpecs", {}):
replica_spec = results.get("spec", {}).get("tfReplicaSpecs", {}).get(replicakey, {})
if replica_spec:
num_expected += replica_spec.get("replicas", 1)

creation_failures = []
if len(created_pods) != num_expected:
Expand All @@ -320,7 +335,7 @@ def run_test(args): # pylint: disable=too-many-branches,too-many-statements

wait_for_pods_to_be_deleted(api_client, namespace, pod_selector)

tf_job_client.delete_tf_job(api_client, namespace, name)
tf_job_client.delete_tf_job(api_client, namespace, name, version=args.tfjob_version)

logging.info("Waiting for job %s in namespaces %s to be deleted.", name,
namespace)
Expand Down
10 changes: 8 additions & 2 deletions py/tf_job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,14 @@ def wait_for_job(client,
status_callback(results)

# If we poll the CRD quick enough status won't have been set yet.
if results.get("status", {}).get("phase", {}) == "Done":
return results
if version == "v1alpha1":
if results.get("status", {}).get("phase", {}) == "Done":
return results
else:
# For v1alpha2 check for non-empty completionTime
if results.get("status", {}).get("completionTime", ""):
return results


if datetime.datetime.now() + polling_interval > end_time:
raise util.TimeoutError(
Expand Down
44 changes: 42 additions & 2 deletions test/workflows/components/simple_tfjob.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ local parts(namespace, name, image) = {
local actualImage = if image != "" then
image
else defaultTestImage,
job:: {
apiVersion: params.apiVersion,
job:: if params.apiVersion == "kubeflow.org/v1alpha1" then {
apiVersion: "kubeflow.org/v1alpha1",
kind: "TFJob",
metadata: {
name: name,
Expand Down Expand Up @@ -63,6 +63,46 @@ local parts(namespace, name, image) = {
},
],
},
} else {
apiVersion: "kubeflow.org/v1alpha2",
kind: "TFJob",
metadata: {
name: name,
namespace: namespace,
},
spec: {
tfReplicaSpecs: {
PS: {
replicas: 2,
restartPolicy: "Never",
template: {
spec: {
containers: [
{
name: "tensorflow",
image: actualImage,
},
],
},
},
},
Worker: {
replicas: 4,
restartPolicy: "Never",
template: {
spec: {
containers: [
{
name: "tensorflow",
image: actualImage,
command: ["python", "/var/tf_dist_mnist/dist_mnist.py", "--train_steps=100"],
},
],
},
},
},
},
},
},
};

Expand Down