diff --git a/py/deploy.py b/py/deploy.py new file mode 100755 index 0000000000..2c8eb273ad --- /dev/null +++ b/py/deploy.py @@ -0,0 +1,213 @@ +#!/usr/bin/python +"""Deploy/manage K8s clusters and the operator. + +This binary is primarily intended for use in managing resources for our tests. +""" + +import argparse +import logging +import os +import subprocess +import tempfile +import time + +from kubernetes import client as k8s_client + +from googleapiclient import discovery +from google.cloud import storage # pylint: disable=no-name-in-module + +from py import test_util +from py import util + +def setup(args): + """Setup a GKE cluster for TensorFlow jobs. + + Args: + args: Command line arguments that control the setup process. + """ + gke = discovery.build("container", "v1") + + project = args.project + cluster_name = args.cluster + zone = args.zone + chart = args.chart + machine_type = "n1-standard-8" + + # TODO(jlewi): Should make these command line arguments. + use_gpu = False + if use_gpu: + accelerator = "nvidia-tesla-k80" + accelerator_count = 1 + else: + accelerator = None + accelerator_count = 0 + + cluster_request = { + "cluster": { + "name": cluster_name, + "description": "A GKE cluster for TF.", + "initialNodeCount": 1, + "nodeConfig": { + "machineType": machine_type, + "oauthScopes": [ + "https://www.googleapis.com/auth/cloud-platform", + ], + }, + # TODO(jlewi): Stop pinning GKE version once 1.8 becomes the default. + "initialClusterVersion": "1.8.1-gke.1", + } + } + + if bool(accelerator) != (accelerator_count > 0): + raise ValueError("If accelerator is set accelerator_count must be > 0") + + if accelerator: + # TODO(jlewi): Stop enabling Alpha once GPUs make it out of Alpha + cluster_request["cluster"]["enableKubernetesAlpha"] = True + + cluster_request["cluster"]["nodeConfig"]["accelerators"] = [ + { + "acceleratorCount": accelerator_count, + "acceleratorType": accelerator, + }, + ] + + util.create_cluster(gke, project, zone, cluster_request) + + util.configure_kubectl(project, zone, cluster_name) + + util.load_kube_config() + # Create an API client object to talk to the K8s master. + api_client = k8s_client.ApiClient() + + util.setup_cluster(api_client) + + if chart.startswith("gs://"): + remote = chart + chart = os.path.join(tempfile.gettempdir(), os.path.basename(chart)) + gcs_client = storage.Client(project=project) + bucket_name, path = util.split_gcs_uri(remote) + + bucket = gcs_client.get_bucket(bucket_name) + blob = bucket.blob(path) + logging.info("Downloading %s to %s", remote, chart) + blob.download_to_filename(chart) + + t = test_util.TestCase() + try: + start = time.time() + util.run(["helm", "install", chart, "-n", "tf-job", "--wait", "--replace", + "--set", "rbac.install=true,cloud=gke"]) + except subprocess.CalledProcessError as e: + t.failure = "helm install failed;\n" + e.output + finally: + t.time = time.time() - start + t.name = "helm-tfjob-install" + t.class_name = "GKE" + test_util.create_junit_xml_file([t], args.junit_path, gcs_client) + +def test(args): + """Run the tests.""" + gcs_client = storage.Client(project=args.project) + project = args.project + cluster_name = args.cluster + zone = args.zone + util.configure_kubectl(project, zone, cluster_name) + + t = test_util.TestCase() + try: + start = time.time() + util.run(["helm", "test", "tf-job"]) + except subprocess.CalledProcessError as e: + t.failure = "helm test failed;\n" + e.output + finally: + t.time = time.time() - start + t.name = "e2e-test" + t.class_name = "GKE" + test_util.create_junit_xml_file([t], args.junit_path, gcs_client) + +def teardown(args): + """Teardown the resources.""" + gke = discovery.build("container", "v1") + + project = args.project + cluster_name = args.cluster + zone = args.zone + util.delete_cluster(gke, cluster_name, project, zone) + +def add_common_args(parser): + """Add common command line arguments to a parser. + + Args: + parser: The parser to add command line arguments to. + """ + parser.add_argument( + "--project", + default=None, + type=str, + help=("The project to use.")) + parser.add_argument( + "--cluster", + default=None, + type=str, + help=("The name of the cluster.")) + parser.add_argument( + "--zone", + default="us-east1-d", + type=str, + help=("The zone for the cluster.")) + + parser.add_argument( + "--junit_path", + default="", + type=str, + help="Where to write the junit xml file with the results.") + +def main(): # pylint: disable=too-many-locals + logging.getLogger().setLevel(logging.INFO) # pylint: disable=too-many-locals + # create the top-level parser + parser = argparse.ArgumentParser( + description="Setup clusters for testing.") + subparsers = parser.add_subparsers() + + ############################################################################# + # setup + # + parser_setup = subparsers.add_parser( + "setup", + help="Setup a cluster for testing.") + + parser_setup.set_defaults(func=setup) + add_common_args(parser_setup) + + parser_setup.add_argument( + "--chart", + type=str, + required=True, + help="The path for the helm chart.") + + ############################################################################# + # test + # + parser_test = subparsers.add_parser( + "test", + help="Run the tests.") + + parser_test.set_defaults(func=test) + add_common_args(parser_test) + + ############################################################################# + # teardown + # + parser_teardown = subparsers.add_parser( + "teardown", + help="Teardown the cluster.") + parser_teardown.set_defaults(func=teardown) + add_common_args(parser_teardown) + + # parse the args and call whatever function was selected + args = parser.parse_args() + args.func(args) + +if __name__ == "__main__": + main() diff --git a/py/release.py b/py/release.py index bce77cb3ee..ac0eafd4c0 100755 --- a/py/release.py +++ b/py/release.py @@ -1,7 +1,7 @@ #!/usr/bin/python -"""Release a new Docker image and helm package. +"""Build a new Docker image and helm package. -This script should be run from the root directory of the repo. +This module assumes py is a top level python package. """ import argparse @@ -29,6 +29,7 @@ def get_latest_green_presubmit(gcs_client): + """Find the commit corresponding to the latest passing postsubmit.""" bucket = gcs_client.get_bucket(RESULTS_BUCKET) blob = bucket.blob(os.path.join(JOB_NAME, "latest_green.json")) contents = blob.download_as_string() @@ -110,13 +111,11 @@ def create_latest(bucket, sha, target): blob.upload_from_string(json.dumps(data)) -def build_operator_image(root_dir, registry, output_path=None, project=None, - should_push=True): +def build_operator_image(root_dir, registry, project=None, should_push=True): """Build the main docker image for the TfJob CRD. Args: root_dir: Root directory of the repository. registry: The registry to use. - output_path: Path to write build information for. project: If set it will be built using GCB. Returns: build_info: Dictionary containing information about the build. @@ -183,18 +182,14 @@ def build_operator_image(root_dir, registry, output_path=None, project=None, util.run(["gcloud", "docker", "--", "push", latest_image]) logging.info("Pushed image: %s", latest_image) - output = {"image": image, - "commit": commit, - } - if output_path: - logging.info("Writing build information to %s", output_path) - with open(output_path, mode='w') as hf: - yaml.dump(output, hf) - + output = { + "image": image, + "commit": commit, + } return output def build_and_push_artifacts(go_dir, src_dir, registry, publish_path=None, - gcb_project=None): + gcb_project=None, build_info_path=None): """Build and push the artifacts. Args: @@ -205,6 +200,8 @@ def build_and_push_artifacts(go_dir, src_dir, registry, publish_path=None, Set to none to only build locally. gcb_project: The project to use with GCB to build docker images. If set to none uses docker to build. + build_info_path: (Optional): GCS location to write YAML file containing + information about the build. """ # Update the GOPATH to the temporary directory. env = os.environ.copy() @@ -215,13 +212,7 @@ def build_and_push_artifacts(go_dir, src_dir, registry, publish_path=None, if not os.path.exists(bin_dir): os.makedirs(bin_dir) - build_info_file = os.path.join(bin_dir, "build_info.yaml") - - build_info = build_operator_image(src_dir, registry, project=gcb_project, - output_path=build_info_file) - - with open(build_info_file) as hf: - build_info = yaml.load(hf) + build_info = build_operator_image(src_dir, registry, project=gcb_project) # Copy the chart to a temporary directory because we will modify some # of its YAML files. @@ -243,7 +234,7 @@ def build_and_push_artifacts(go_dir, src_dir, registry, publish_path=None, logging.info("Delete previous build: %s", m) os.unlink(m) - util.run(["helm", "package", "--destination=" + bin_dir, + util.run(["helm", "package", "--save=false", "--destination=" + bin_dir, "./tf-job-operator-chart"], cwd=chart_build_dir) matches = glob.glob(os.path.join(bin_dir, "tf-job-operator-chart*.tgz")) @@ -262,12 +253,14 @@ def build_and_push_artifacts(go_dir, src_dir, registry, publish_path=None, ] if publish_path: - gcs_client = storage.Client() + gcs_client = storage.Client(project=gcb_project) bucket_name, base_path = util.split_gcs_uri(publish_path) bucket = gcs_client.get_bucket(bucket_name) for t in targets: blob = bucket.blob(os.path.join(base_path, t)) - gcs_path = util.to_gcs_uri(bucket_name, t) + gcs_path = util.to_gcs_uri(bucket_name, blob.name) + if not t.startswith("latest"): + build_info["helm_chart"] = gcs_path if blob.exists() and not t.startswith("latest"): logging.warn("%s already exists", gcs_path) continue @@ -277,11 +270,59 @@ def build_and_push_artifacts(go_dir, src_dir, registry, publish_path=None, create_latest(bucket, build_info["commit"], util.to_gcs_uri(bucket_name, targets[0])) + # Always write to the bin dir. + paths = [os.path.join(bin_dir, "build_info.yaml")] + + if build_info_path: + paths.append(build_info_path) + + write_build_info(build_info, paths, project=gcb_project) + +def write_build_info(build_info, paths, project=None): + """Write the build info files. + """ + gcs_client = None + + contents = yaml.dump(build_info) + + for p in paths: + logging.info("Writing build information to %s", p) + if p.startswith("gs://"): + if not gcs_client: + gcs_client = storage.Client(project=project) + bucket_name, path = util.split_gcs_uri(p) + bucket = gcs_client.get_bucket(bucket_name) + blob = bucket.blob(path) + blob.upload_from_string(contents) + + else: + with open(p, mode='w') as hf: + hf.write(contents) + +def build_and_push(go_dir, src_dir, args): + if args.dryrun: + logging.info("dryrun...") + # In dryrun mode we want to produce the build info file because this + # is needed to test xcoms with Airflow. + if args.build_info_path: + paths = [args.build_info_path] + build_info = { + "image": "gcr.io/dryrun/dryrun:latest", + "commit": "1234abcd", + "helm_package": "gs://dryrun/dryrun.latest.", + } + write_build_info(build_info, paths, project=args.project) + return + build_and_push_artifacts(go_dir, src_dir, registry=args.registry, + publish_path=args.releases_path, + gcb_project=args.project, + build_info_path=args.build_info_path) + def build_local(args): """Build the artifacts from the local copy of the code.""" go_dir = None src_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) - build_and_push_artifacts(go_dir, src_dir, args.registry) + build_and_push(go_dir, src_dir, args) def build_postsubmit(args): """Build the artifacts from a postsubmit.""" @@ -294,7 +335,7 @@ def build_postsubmit(args): util.clone_repo(src_dir, util.MASTER_REPO_OWNER, util.MASTER_REPO_NAME, args.commit) - build_and_push_artifacts(go_dir, src_dir, args.registry) + build_and_push(go_dir, src_dir, args) def build_pr(args): """Build the artifacts from a postsubmit.""" @@ -309,7 +350,7 @@ def build_pr(args): util.MASTER_REPO_NAME, args.commit, branches=branches) - build_and_push_artifacts(go_dir, src_dir, args.registry) + build_and_push(go_dir, src_dir, args) def build_lastgreen(args): # pylint: disable=too-many-locals """Find the latest green postsubmit and build the artifacts. @@ -337,9 +378,7 @@ def build_lastgreen(args): # pylint: disable=too-many-locals _, sha = util.clone_repo(src_dir, util.MASTER_REPO_OWNER, util.MASTER_REPO_NAME, sha) - build_and_push_artifacts(go_dir, src_dir, registry=args.registry, - publish_path=args.releases_path, - gcb_project=args.project) + build_and_push(go_dir, src_dir, args) def add_common_args(parser): """Add a set of common parser arguments.""" @@ -357,25 +396,33 @@ def add_common_args(parser): help=("If specified use Google Container Builder and this project to " "build artifacts.")) + parser.add_argument( + "--releases_path", + default=None, + required=False, + type=str, + help="The GCS location where artifacts should be pushed.") + + parser.add_argument( + "--build_info_path", + default="", + type=str, + help="(Optional). The GCS location to write build info to.") + + parser.add_argument("--dryrun", dest="dryrun", action="store_true", + help="Do a dry run.") + parser.add_argument("--no-dryrun", dest="dryrun", action="store_false", + help="Don't do a dry run.") + parser.set_defaults(dryrun=False) + def main(): # pylint: disable=too-many-locals logging.getLogger().setLevel(logging.INFO) # pylint: disable=too-many-locals - this_dir = os.path.dirname(__file__) - version_file = os.path.join(this_dir, "version.json") - if os.path.exists(version_file): - # Print out version information so we know what container we ran in. - with open(version_file) as hf: - version = json.load(hf) - logging.info("Image info:\n%s", json.dumps(version, indent=2, - sort_keys=True)) - else: - logging.warn("Could not find file: %s", version_file) - # create the top-level parser parser = argparse.ArgumentParser( description="Build the release artifacts.") subparsers = parser.add_subparsers() - ############################################################################ + ############################################################################# # local # # Create the parser for the "local" mode. @@ -411,13 +458,6 @@ def main(): # pylint: disable=too-many-locals add_common_args(parser_lastgreen) - parser_lastgreen.add_argument( - "--releases_path", - default=None, - required=True, - type=str, - help="The GCS location where artifacts should be pushed.") - ############################################################################ # Pull Request parser_pr = subparsers.add_parser( @@ -440,13 +480,6 @@ def main(): # pylint: disable=too-many-locals help="Optional a particular commit to checkout and build.") parser_postsubmit.set_defaults(func=build_postsubmit) - parser_pr.add_argument( - "--releases_path", - default=None, - required=False, - type=str, - help="The GCS location where artifacts should be pushed.") - parser_pr.set_defaults(func=build_pr) # parse the args and call whatever function was selected diff --git a/py/test_util.py b/py/test_util.py new file mode 100644 index 0000000000..c84e3d94cb --- /dev/null +++ b/py/test_util.py @@ -0,0 +1,59 @@ +import logging +from xml.etree import ElementTree + +import six + +from py import util + +class TestCase(object): + def __init__(self): + self.class_name = None + self.name = None + self.time = None + # String describing the failure. + self.failure = None + + +def create_junit_xml_file(test_cases, output_path, gcs_client=None): + """Create a JUnit XML file. + + Args: + test_cases: List of test case objects. + output_path: Path to write the XML + gcs_client: GCS client to use if output is GCS. + """ + total_time = 0 + failures = 0 + for c in test_cases: + total_time += c.time + + if c.failure: + failures += 1 + attrib = {"failures": "{0}".format(failures), "tests": "{0}".format(len(test_cases)), + "time": "{0}".format(total_time)} + root = ElementTree.Element("testsuite", attrib) + + for c in test_cases: + attrib = { + "classname": c.class_name, + "name": c.name, + "time": "{0}".format(c.time), + } + if c.failure: + attrib["failure"] = c.failure + e = ElementTree.Element("testcase", attrib) + + root.append(e) + + t = ElementTree.ElementTree(root) + logging.info("Creationg %s", output_path) + if output_path.startswith("gs://"): + b = six.StringIO() + t.write(b) + + bucket_name, path = util.split_gcs_uri(output_path) + bucket = gcs_client.get_bucket(bucket_name) + blob = bucket.blob(path) + blob.upload_from_string(b.getvalue()) + else: + t.write(output_path) diff --git a/py/test_util_test.py b/py/test_util_test.py new file mode 100644 index 0000000000..e28072a9ac --- /dev/null +++ b/py/test_util_test.py @@ -0,0 +1,38 @@ +from __future__ import print_function + +import tempfile +import unittest + +from py import test_util + +class XMLTest(unittest.TestCase): + def test_write_xml(self): + with tempfile.NamedTemporaryFile(delete=False) as hf: + pass + + success = test_util.TestCase() + success.class_name = "some_test" + success.name = "first" + success.time = 10 + + failure = test_util.TestCase() + failure.class_name = "some_test" + failure.name = "first" + failure.time = 10 + failure.failure = "failed for some reason." + + test_util.create_junit_xml_file([success, failure], hf.name) + with open(hf.name) as hf: + output = hf.read() + print(output) + expected = ("""""" + """""" + """""") + + self.assertEquals(expected, output) + + +if __name__ == "__main__": + unittest.main() diff --git a/py/util.py b/py/util.py old mode 100644 new mode 100755 index e68b248689..a10ea34de0 --- a/py/util.py +++ b/py/util.py @@ -11,8 +11,14 @@ import urllib import yaml +import google.auth +import google.auth.transport +import google.auth.transport.requests + from googleapiclient import errors from kubernetes import client as k8s_client +from kubernetes.config import kube_config +from kubernetes.client import configuration from kubernetes.client import rest # Default name for the repo organization and name. @@ -21,7 +27,7 @@ MASTER_REPO_NAME = "k8s" -def run(command, cwd=None, env=None): +def run(command, cwd=None, env=None, use_print=False, dryrun=False): """Run a subprocess. Any subprocess output is emitted through the logging modules. @@ -32,11 +38,30 @@ def run(command, cwd=None, env=None): env = os.environ try: + if dryrun: + command_str = ("Dryrun: Command:\n{0}\nCWD:\n{1}\n" + "Environment:\n{2}").format(" ".join(command), cwd, env) + if use_print: + print(command_str) + else: + logging.info(command_str) + return output = subprocess.check_output(command, cwd=cwd, env=env, stderr=subprocess.STDOUT).decode("utf-8") - logging.info("Subprocess output:\n%s", output) + + if use_print: + # With Airflow use print to bypass logging module. + print("Subprocess output:\n") + print(output) + else: + logging.info("Subprocess output:\n%s", output) except subprocess.CalledProcessError as e: - logging.info("Subprocess output:\n%s", e.output) + if use_print: + # With Airflow use print to bypass logging module. + print("Subprocess output:\n") + print(e.output) + else: + logging.info("Subprocess output:\n%s", e.output) raise def run_and_output(command, cwd=None, env=None): @@ -354,3 +379,48 @@ def split_gcs_uri(gcs_uri): bucket = m.group(1) path = m.group(2) return bucket, path + +def _refresh_credentials(): + # I tried userinfo.email scope that was insufficient; got unauthorized errors. + credentials, _ = google.auth.default( + scopes=["https://www.googleapis.com/auth/cloud-platform"]) + request = google.auth.transport.requests.Request() + credentials.refresh(request) + return credentials + +# TODO(jlewi): This is a work around for +# https://github.com/kubernetes-incubator/client-python/issues/339. +# Consider getting rid of this and adopting the solution to that issue. +def load_kube_config(config_file=None, context=None, + client_configuration=configuration, + persist_config=True, + get_google_credentials=_refresh_credentials, + **kwargs): + """Loads authentication and cluster information from kube-config file + and stores them in kubernetes.client.configuration. + + :param config_file: Name of the kube-config file. + :param context: set the active context. If is set to None, current_context + from config file will be used. + :param client_configuration: The kubernetes.client.ConfigurationObject to + set configs to. + :param persist_config: If True, config file will be updated when changed + (e.g GCP token refresh). + """ + + if config_file is None: + config_file = os.path.expanduser(kube_config.KUBE_CONFIG_DEFAULT_LOCATION) + + config_persister = None + if persist_config: + def _save_kube_config(config_map): + with open(config_file, 'w') as f: + yaml.safe_dump(config_map, f, default_flow_style=False) + config_persister = _save_kube_config + + kube_config._get_kube_config_loader_for_yaml_file( # pylint: disable=protected-access + config_file, active_context=context, + client_configuration=client_configuration, + config_persister=config_persister, + get_google_credentials=get_google_credentials, + **kwargs).load_and_set()