From f57edc0b2d60e941d5b7f6aa63f39f794bd6498e Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Wed, 28 Feb 2018 10:30:20 +0100 Subject: [PATCH] Initial public release --- .gitignore | 1 + LICENSE | 21 + README.md | 439 +++++++++++++++++++ client/launcher.py | 309 +++++++++++++ requirements.txt | 7 + server/config/templates/default.json | 55 +++ server/config/templates/ec2-template.json | 29 ++ server/config/templates/ssh-template.json | 25 ++ server/config/templates/torque-template.json | 23 + server/main.py | 178 ++++++++ server/nmtwizard/__init__.py | 0 server/nmtwizard/common.py | 317 +++++++++++++ server/nmtwizard/config.py | 86 ++++ server/nmtwizard/redis_database.py | 65 +++ server/nmtwizard/service.py | 105 +++++ server/nmtwizard/task.py | 146 ++++++ server/nmtwizard/worker.py | 166 +++++++ server/services/__init__.py | 0 server/services/ec2.py | 149 +++++++ server/services/ssh.py | 181 ++++++++ server/services/torque.py | 220 ++++++++++ server/settings.ini | 13 + server/worker.py | 36 ++ 23 files changed, 2571 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 README.md create mode 100644 client/launcher.py create mode 100644 requirements.txt create mode 100644 server/config/templates/default.json create mode 100644 server/config/templates/ec2-template.json create mode 100644 server/config/templates/ssh-template.json create mode 100644 server/config/templates/torque-template.json create mode 100644 server/main.py create mode 100644 server/nmtwizard/__init__.py create mode 100644 server/nmtwizard/common.py create mode 100644 server/nmtwizard/config.py create mode 100644 server/nmtwizard/redis_database.py create mode 100644 server/nmtwizard/service.py create mode 100644 server/nmtwizard/task.py create mode 100644 server/nmtwizard/worker.py create mode 100644 server/services/__init__.py create mode 100644 server/services/ec2.py create mode 100644 server/services/ssh.py create mode 100644 server/services/torque.py create mode 100644 server/settings.ini create mode 100644 server/worker.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..0d20b648 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +*.pyc diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..1f2f7a9c --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2017-present The OpenNMT Authors. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 00000000..6cae7b66 --- /dev/null +++ b/README.md @@ -0,0 +1,439 @@ +# nmt-wizard + +`nmt-wizard` is a Docker-based task launcher and monitor on a variety of remote platforms (called *services*) such as SSH servers, Torque clusters, or EC2 instances. Each *service* is providing access to compute *resources*. The launcher is meant to be used with `nmt-wizard-docker` images, but without a strong dependency. + +The project provides: + +* a RESTful server that queues incoming requests in a Redis database; +* a client to the REST server providing a simple textual visualization interface; +* workers that launch and manage tasks on the requested service and updates their status. + +Once launched, the tasks are sending progress and updated information to the launcher. + +## Services configuration + +Service configurations are provided by the system administrators. A service is declared and configured by a JSON file in the `config` directory. The REST server and worker automatically discover existing configuration files, provided that their filename ends with `.json`. A special `default.json` file defines parameters shared by all services. + +The configuration file has the following structure: + +``` +{ + "name": "my-service", // The short name the user will select. + "description": "My service", // Display name of the service. + "module": "services.XXX", // Name of the Python module managing the service. + "variables": { + "key1": [ "value1", "value2" ], + ... + }, + "docker": { + "registries": { // Docker registries: ECS, Docker Hub. + "aws": { + "type": "aws", + "credentials": { + "AWS_ACCESS_KEY_ID": "XXXXX", + "AWS_SECRET_ACCESS_KEY": "XXXXX" + }, + "uri": "XXXXX.dkr.ecr.eu-west-3.amazonaws.com", + "region": "eu-west-3" + }, + "dockerhub": { + "type": "dockerhub", + "uri": "" + }, + "mydockerprivate": { + "type": "dockerprivate", + "uri": "", + "credentials": { + "password": "XXXXX", + "username": "XXXXX" + } + } + }, + "mount": [ // Volumes to mount when running the Docker image. + "/home/devling/corpus:/root/corpus", + "/home/devling/models:/root/models" + ], + "envvar": { // Environment variables to set when running the Docker image. + } + }, + "skey1": "svalue1", // Service specific configurations. + ..., + "disabled": [01], // Boolean field to disable/enable the service. + "storages": { // Storage configuration as described in single-training-docker. + }, + "callback_url": "http://LAUNCHER_URL", + "callback_interval": 60 +} +``` + +where `variables` is a list of possible options for the service. The structure of these options is specific to each service. These options are transformed into simple key/LIST,FIELDS by the `describe` route to enable simple and generic UI selection of multiple variants. + +Template files are provided in `config/templates` and can be used as a basis for configuring services. + +## Server configuration + +* The Redis database must be configured to enable keyspace event as followed: + +```bash +redis-cli config set notify-keyspace-events Klgx +``` + +* The REST server and worker are configured by `settings.ini`. The `LAUNCHER_MODE` environment variable (defaulting to `Production`) can be set to select different set of options in development or production. + +## Using the launcher + +### Worker + +The first component to launch is the worker that should always be running. It handles: + +* the launch of tasks +* the termination of tasks +* the update of active resources + +```bash +cd server && python worker.py +``` + +For performance, multiple workers might be running simultaneously. In that case, a longer refresh should be defined. + +### Server + +The server has the following HTTP routes: + +* `list_services`: returns available services +* `describe`: returns user selectable options for the service +* `check`: checks availability of a given service with provided user options +* `launch`: launches a task on a given service with provided user options +* `status`: checks the status of a task +* `list_tasks`: returns the list of tasks in the database +* `del`: delete a task from the database +* `terminate`: terminates the process and/or instance associated with a task +* `beat`: provides a `beat` back to the launcher to notify the task activity and announce the next beat to expect +* `file`: sets or returns a file associated to a task + +The server uses Flask. See the [Flask documentation](http://flask.pocoo.org/docs/0.12/deploying/) to deploy it for production. For development, it can be run as follows (single thread): + +```bash +cd app && FLASK_APP=main.py flask run [--host=0.0.0.0] +``` + +Here are the are the available routes. Also see the next section + +#### `GET /list_services` + +Lists available services. + +* **Arguments:** None +* **Input:** None +* **Output:** A dictionary of service name to description (JSON) +* **Example:** + +``` +$ curl -X GET 'http://127.0.0.1:5000/list_services' +{ + "demogpu02": "OVH extra training server", + "ec2": "Instance on AWS EC2", + "localhost": "test local environment", + "ssaling04": "GPU training server" +} +``` + +#### `GET /describe/` + +Returns possible options for a service as a [JSON Form](https://github.com/joshfire/jsonform). This can be used to easily implement a GUI to select options the target service. + +* **Arguments:** + * `service_name`: the service name +* **Input:** None +* **Output:** A JSON form (or an empty dictionary if the service has no possible options). +* **Example:** + +``` +$ curl -X GET 'http://127.0.0.1:5000/describe/ec2' +{ + "launchTemplate": { + "description": "The name of the EC2 launch template to use", + "enum": [ + "SingleTrainingDev" + ], + "title": "EC2 Launch Template", + "type": "string" + } +} +``` + +#### `GET /check/` + +Checks if the service is available and can be used with the provided options. In case of success, it returns information about the service and the corresponding resource. + +* **Arguments:** + * `service_name`: the service name +* **Input:** The selected service options (see `describe/`) (JSON) +* **Output:** + * On invalid option, a HTTP 400 code with the error message (JSON) + * On server error, a HTTP 500 code with the error message (JSON) + * On success, an optional message with details about the service (JSON) +* **Example:** + +``` +$ curl -X GET http://127.0.0.1:5000/check/ec2 +{ + "message": "missing launchTemplateName option", +} +$ curl -X GET -d '{"launchTemplateName": "InvalidLaunchTemplate"}' \ + -H "Content-Type: application/json" 'http://127.0.0.1:5000/check/ec2' +{ + "message": "An error occurred (InvalidLaunchTemplateId.NotFound) when calling the RunInstances operation: LaunchTemplate null not found" +} +$ curl -X GET -d '{"launchTemplateName": "SingleTrainingDev"}' \ + -H "Content-Type: application/json" 'http://127.0.0.1:5000/check/ec2' +{ + "message": "" +} +``` + +#### `POST /launch/` + +Launches a Docker-based task on the specified service. In case of success, it returns a task identifier that can be used to monitor the task using the `status` or `terminate` routes. + +* **Arguments:** + * `service_name`: the service name +* **Input:** the input is either a simple json body or a multi-part request with `content` field containing JSON task configuration. The other fields of the multi-part requests are binary files to be uploaded on the remote service at task-launch time. + +The task configuration (JSON) + +``` +$ cat body.json +{ + "docker": { + "registry": "dockerhub" + "image": "opennmt/opennmt-lua", + "tag": "latest", + "command": [ + ... + ] + }, + "wait_after_launch": 2, + "trainer_id": "OpenNMT", + "options": { + "launchTemplateName": "SingleTrainingDev" + } +} +``` + +`docker.tag` and `wait_after_launch` are optional. +* **Output:** + * On invalid task configuration, a HTTP 400 code with an error message (JSON) + * On success, a task identifier (string) +* **Example:** + +``` +$ curl -X POST -d @invalid_body.json -H "Content-Type: application/json" \ + http://127.0.0.1:5000/launch/ec2 +{ + "message": "missing trainer_id field" +} +$ curl -X POST -d @body.json -H "Content-Type: application/json" \ + 'http://127.0.0.1:5000/launch/ec2' +"130d4400-9aad-4654-b124-d258cbe4b1e3" +$ curl -X POST -d content=@body.json -F input.txt=@input.txt 'http://127.0.0.1:5000/launch/ec2' +"1f877e53-5a25-44de-b115-7f6d3e386e70" +``` + +#### `GET /list_tasks/` + +Lists available services. + +* **Arguments:** + * `pattern`: pattern for the tasks to match. See [KEYS pattern](https://redis.io/commands/keys) for syntax. +* **Input:** None +* **Output:** A list of tasks matching the pattern with minimal information (`task_id`, `queued_time`, `status`, `service`, `message`) +* **Example:** + +``` +$ curl -X GET 'http://127.0.0.1:5000/list_tasks/jean_*' +[ + { + "message": "completed", + "queued_time": "1519652594.957615", + "status": "stopped", + "service": "ec2", + "task_id": "jean_5af69495-3304-4118-bd6c-37d0e6" + }, + { + "message": "error", + "queued_time": "1519652097.672299", + "status": "stopped", + "service": "mysshgpu", + "task_id": "jean_99b822bc-51ac-4049-ba39-980541" + } +] +``` + +#### `GET /del_tasks/` + +Lists available services. + +* **Arguments:** + * `pattern`: pattern for the tasks to match - only stopped tasks will be deleted. See [KEYS pattern](https://redis.io/commands/keys) for syntax. +* **Input:** None +* **Output:** list of deleted tasks +* **Example:** + +``` +$ curl -X GET 'http://127.0.0.1:5000/del_tasks/jean_*' +[ + "jean_5af69495-3304-4118-bd6c-37d0e6", + "jean_99b822bc-51ac-4049-ba39-980541" +] +``` + +#### `GET /status/` + +Returns the status of a task. + +* **Arguments:** + * `task_id`: the task ID returned by `/launch/` +* **Input:** None +* **Output:** + * On invalid `task_id`, a HTTP 404 code dictionary with an error message (JSON) + * On success, a dictionary with the task status (JSON) +* **Example:** + +``` +curl -X GET http://127.0.0.1:5000/status/unknwon-task-id +{ + "message": "task unknwon-task-id unknown" +} +curl -X GET http://127.0.0.1:5000/status/130d4400-9aad-4654-b124-d258cbe4b1e3 +{ + "allocated_time": "1519148201.9924579", + "content": "{\"docker\": {\"command\": [], \"registry\": \"dockerhub\", \"image\": \"opennmt/opennmt-lua\", \"tag\": \"latest\"}, \"service\": \"ec2\", \"wait_after_launch\": 2, \"trainer_id\": \"OpenNMT\", \"options\": {\"launchTemplateName\": \"SingleTrainingDev\"}}", + "message": "unknown registry", + "queued_time": "1519148144.483396", + "resource": "SingleTrainingDev", + "service": "ec2", + "status": "stopped", + "stopped_time": "1519148201.9977396", + "ttl": null +} +``` + +(Here the task was quickly stopped due to an incorrect Docker registry.) + +The main fields are: + +* `status`: (timestamp for each status can be found in `_time`) + * `queued`, + * `allocated`, + * `running`, + * `terminating`, + * `stopped` (additional information can be found in `message` field); +* `service`: name of the service the task is running on; +* `resource`: name of the resource the task is using; +* `content`: the actual task definition; +* `update_time`: if the task is sending beat requests; +* `ttl` if a time to live was passed in the beat request. + +#### `GET /terminate/(?phase=status)` + +Terminates a task. If the task is already stopped, it does nothing. Otherwise, it changes the status of the task to `terminating` (actual termination is asynchronous) and returns a success message. + +* **Arguments:** + * `task_id`: the task identifier returned by `/launch/` + * (optionnal) `phase`: indicate if the termination command is corresponding to an error or natural completion (`completed`) +* **Input**: None +* **Output**: + * On invalid `task_id`, a HTTP 404 code with an error message (JSON) + * On success, a HTTP 200 code with a message (JSON) + +``` +curl -X GET http://127.0.0.1:5000/terminate/130d4400-9aad-4654-b124-d258cbe4b1e3 +{ + "message": "130d4400-9aad-4654-b124-d258cbe4b1e3 already stopped" +} +``` + +#### `GET /del/` + +Deletes a task. If the task is not stopped, it does nothing. + +* **Arguments:** + * `task_id`: the task identifier returned by `/launch/` +* **Input**: None +* **Output**: + * On invalid `task_id`, a HTTP 404 code with an error message (JSON) + * On success, a HTTP 200 code with a message (JSON) + +#### `GET /beat/(?duration=XXX&container_id=CID)` + +Notifies a *beat* back to the launcher. Tasks should invoke this route wih a specific interval to notify that they are still alive and working. This makes it easier for the launcher to identify and handle dead tasks. + +* **Arguments** + * `task_id`: the task identifier returned by `/launch/` + * (optional) `duration`: if no beat is received for this task after this duration the task is assumed to be dead + * (optional) `container_id`: the ID of the Docker container +* **Input:** None +* **Output:** + * On invalid `duration`, a HTTP 400 code with an error message (JSON) + * On invalid `task_id`, a HTTP 404 code with an error message (JSON) + * On success, a HTTP 200 code + +#### `POST /file//` + +Registers a file for a task - typically used for log, or posting translation output using http storage. + +* **Arguments** + * `task_id`: the task identifier returned by `/launch/` + * `filename`: a filename +* **Input:** None +* **Output:** + * On invalid `task_id`, a HTTP 404 code with an error message (JSON) + * On success, a HTTP 200 code + +#### `GET /file//` + +Retrieves file attached to a tasK + +* **Arguments** + * `task_id`: the task identifier returned by `/launch/` + * `filename`: a filename +* **Input:** None +* **Output:** + * On invalid `task_id`, a HTTP 404 code with an error message (JSON) + * On missing files, a HTTP 404 code with an error message (JSON) + * On success, the actual file + +### Launcher + +The launcher is a simple client to the REST server. See: + +```bash +python client/launcher.py -h +``` + +**Notes:** + +* The address of the launcher REST service is provided either by the environment variable `LAUNCHER_URL` or the command line parameter `-u URL`. +* By default, the request response are formatted in text-table for better readibility, the option `-j` displays raw JSON response +* The `trainer_id` field to the `launch` command is either coming from `--trainer_id` option or using `LAUNCHER_TID` environment variable. Also, by default, the same environment variable is used as a default value of the `prefix` parameter of the `lt` command. +* By default, the command parameter are expected as inline values, but can also be obtained from a file, in that case, the corresponding option will take the value `@FILEPATH`. +* File identified as local files, are transfered to the launcher using `TMP_DIR` on the remote server + +## Development + +### Redis database + +The Redis database contains the following fields: + +| Field | Type | Description | +| --- | --- | --- | +| `active` | list | Active tasks | +| `beat:` | int | Specific ttl-key for a given task | +| `lock:` | value | Temporary lock on a resource or task | +| `queued:` | list | Tasks waiting for a resource | +| `resource::` | list | Tasks using this resource | +| `task:` | dict |
  • status: [queued, allocated, running, terminating, stopped]
  • job: json of jobid (if status>=waiting)
  • service:the name of the service
  • resource: the name of the resource - or auto before allocating one message: error message (if any), ‘completed’ if successfully finished
  • container_id: container in which the task run send back by docker notifier
  • (queued|allocated|running|updated|stopped)_time: time for each event
| +| `files:` | dict | files associated to a task, "log" is generated when training is complete | +| `queue:` | str | expirable timestamp on the task - is used to regularily check status | +| `work` | list | Tasks to process | diff --git a/client/launcher.py b/client/launcher.py new file mode 100644 index 00000000..0c8731bd --- /dev/null +++ b/client/launcher.py @@ -0,0 +1,309 @@ +from __future__ import print_function + +import argparse +import json +import sys +import os +import logging +import requests +from datetime import datetime + +def getjson(config): + if config is None: + return None + if not config.startswith('@'): + return json.loads(config) + with open(config[1:]) as data: + return json.load(data) + +def find_files_parameters(config, files): + for k in config: + v = config[k] + if isinstance(v, unicode) and v.startswith('/') and os.path.exists(v): + basename = os.path.basename(v) + files[basename] = (basename, open(v, 'rb')) + config[k] = "${TMP_DIR}/%s" % basename + logger.debug('found local file: %s -> ${TMP_DIR}/%s', v, basename) + elif isinstance(v, dict): + find_files_parameters(v, files) + +def confirm(prompt=None, resp=False): + """prompts for yes or no response from the user. Returns True for yes and + False for no. + """ + + if prompt is None: + prompt = 'Confirm' + + if resp: + prompt = '%s [%s]|%s: ' % (prompt, 'y', 'n') + else: + prompt = '%s [%s]|%s: ' % (prompt, 'n', 'y') + + while True: + ans = raw_input(prompt) + if not ans: + return resp + if ans not in ['y', 'Y', 'n', 'N']: + print('please enter y or n.') + continue + if ans == 'y' or ans == 'Y': + return True + if ans == 'n' or ans == 'N': + return False + +parser = argparse.ArgumentParser() +parser.add_argument('-u', '--url', + help="url to the launcher") +parser.add_argument('-l', '--log-level', default='INFO', + help="log-level (INFO|WARN|DEBUG|FATAL|ERROR)") +parser.add_argument('-j', '--json', action='store_true', + help="display output in json format from rest server (default text)") +subparsers = parser.add_subparsers(help='command help', dest='cmd') +parser_list_services = subparsers.add_parser('ls', + help='list available services') +parser_describe = subparsers.add_parser('describe', + help='list available options for the service') +parser_describe.add_argument('-s', '--service', help="service name") +parser_check = subparsers.add_parser('check', + help='check that service associated to provided options is operational') +parser_check.add_argument('-s', '--service', help="service name") +parser_check.add_argument('-o', '--options', default='{}', + help="options selected to run the service") +parser_launch = subparsers.add_parser('launch', + help='launch a task on the service associated to provided options') +parser_launch.add_argument('-s', '--service', help="service name") +parser_launch.add_argument('-o', '--options', default='{}', + help="options selected to run the service") +parser_launch.add_argument('-w', '--wait_after_launch', default=2, type=int, + help=('if not 0, wait for this number of seconds after launch ' + 'to check that launch is ok - by default wait for 2 seconds')) +parser_launch.add_argument('-r', '--docker_registry', default='dockerhub', + help='docker registry (as configured on server side) - default is `dockerhub`') +parser_launch.add_argument('-i', '--docker_image', required=True, + help='Docker image') +parser_launch.add_argument('-t', '--docker_tag', default="latest", + help='Docker image tag (default is latest)') +parser_launch.add_argument('-T', '--trainer_id', default=os.getenv('LAUNCHER_TID', None), + help='trainer id, used as a prefix to generated models (default ENV[LAUNCHER_TID])') +parser_launch.add_argument('docker_command', type=str, nargs='*', + help='Docker command') +parser_list_tasks = subparsers.add_parser('lt', + help='list tasks matching prefix pattern') +parser_list_tasks.add_argument('-p', '--prefix', default=os.getenv('LAUNCHER_TID', ''), + help='prefix for the tasks to list (default ENV[LAUNCHER_TID])') +parser_del_tasks = subparsers.add_parser('dt', + help='delete tasks matching prefix pattern') +parser_del_tasks.add_argument('-p', '--prefix', required=True, + help='prefix for the tasks to delete') +parser_status = subparsers.add_parser('status', help='get status of a task') +parser_status.add_argument('-k', '--task_id', + help="task identifier", required=True) +parser_terminate = subparsers.add_parser('terminate', help='terminate a task') +parser_terminate.add_argument('-k', '--task_id', + help="task identifier", required=True) +parser_file = subparsers.add_parser('file', help='get file associated to a task') +parser_file.add_argument('-k', '--task_id', + help="task identifier", required=True) +parser_file.add_argument('-f', '--filename', + help="filename to retrieve - for instance log", required=True) + +args = parser.parse_args() + +logging.basicConfig(stream=sys.stdout, level=args.log_level) +logger = logging.getLogger() + +if args.url is None: + args.url = os.getenv('LAUNCHER_URL') + if args.url is None: + logger.error('missing launcher_url') + sys.exit(1) + +r = requests.get(os.path.join(args.url, "list_services")) +if r.status_code != 200: + logger.error('incorrect result from \'list_services\' service: %s', r.text) +serviceList = r.json() + +if args.cmd == "ls": + result = serviceList + if not args.json: + print("%-20s\t%s" % ("SERVICE NAME", "DESCRIPTION")) + for k in result: + print("%-20s\t%s" % (k, result[k])) + sys.exit(0) +elif args.cmd == "lt": + r = requests.get(os.path.join(args.url, "list_tasks", args.prefix + '*')) + if r.status_code != 200: + logger.error('incorrect result from \'list_tasks\' service: %s', r.text) + sys.exit(1) + result = r.json() + if not args.json: + print("%-32s\t%-20s\t%-16s\t%-10s\t%s" % ("TASK_ID", "LAUNCH DATE", "IMAGE", "STATUS", "MESSAGE")) + for k in sorted(result, key=lambda k: int(k["queued_time"])): + date = datetime.fromtimestamp(int(k["queued_time"])).isoformat(' ') + print("%-32s\t%-20s\t%-16s\t%-10s\t%s" % ( + k["task_id"], date, k["image"], k["status"], k.get("message"))) + sys.exit(0) +elif args.cmd == "describe": + if args.service not in serviceList: + logger.fatal("ERROR: service '%s' not defined", args.service) + sys.exit(1) + r = requests.get(os.path.join(args.url, "describe", args.service)) + if r.status_code != 200: + logger.error('incorrect result from \'describe\' service: %s', r.text) + sys.exit(1) + result = r.json() +elif args.cmd == "check": + if args.service not in serviceList: + logger.fatal("ERROR: service '%s' not defined", args.service) + sys.exit(1) + r = requests.get(os.path.join(args.url, "check", args.service), json=getjson(args.options)) + if r.status_code != 200: + logger.error('incorrect result from \'check\' service: %s', r.text) + sys.exit(1) + result = r.json() + if not args.json: + print(result["message"]) + sys.exit(0) +elif args.cmd == "launch": + if args.service not in serviceList: + logger.fatal("ERROR: service '%s' not defined", args.service) + sys.exit(1) + + # for multi-part file sending + files = {} + docker_command = [] + + for c in args.docker_command: + orgc = c + if c.startswith("@"): + with open(c[1:], "rt") as f: + c = f.read() + if os.path.exists(c): + basename = os.path.basename(c) + files[basename] = (basename, open(c, 'rb')) + c = "${TMP_DIR}/%s" % basename + # if json, explore for values to check local path values + if c.startswith('{'): + try: + cjson = json.loads(c) + except ValueError as err: + logger.fatal("Invalid JSON parameter in %s: %s", orgc, str(err)) + sys.exit(1) + find_files_parameters(cjson, files) + c = json.dumps(cjson) + docker_command.append(c) + + if args.service not in serviceList: + logger.fatal("ERROR: service '%s' not defined", args.service) + sys.exit(1) + + content = { + "docker": { + "registry": args.docker_registry, + "image": args.docker_image, + "tag": args.docker_tag, + "command": docker_command + }, + "wait_after_launch": args.wait_after_launch, + "trainer_id": args.trainer_id, + "options": getjson(args.options) + } + + launch_url = os.path.join(args.url, "launch", args.service) + r = None + if len(files) > 0: + r = requests.post(launch_url, files=files, data = {'content': json.dumps(content)}) + else: + r = requests.post(launch_url, json=content) + if r.status_code != 200: + logger.error('incorrect result from \'launch\' service: %s', r.text) + sys.exit(1) + result = r.json() + if not args.json: + print(result) + sys.exit(0) +elif args.cmd == "status": + r = requests.get(os.path.join(args.url, "status", args.task_id)) + if r.status_code != 200: + logger.error('incorrect result from \'status\' service: %s', r.text) + sys.exit(1) + result = r.json() + if not args.json: + times = [] + current_time = int(result["current_time"]) + result.pop("current_time", None) + for k in result: + if k.endswith('_time'): + times.append(k) + sorted_times = sorted(times, key=lambda k: int(result[k])) + last_update = '' + if sorted_times: + upd = current_time - int(result[sorted_times[-1]]) + last_update = " - updated %d seconds ago" % upd + print("TASK %s - status %s (%s)%s" % ( + args.task_id, result.get('status'), result.get('message'), last_update)) + if "service" in result: + print("SERVICE %s - RESOURCE %s - CONTAINER %s" % ( + result['service'], result.get('resource'), result.get('container_id'))) + print("ATTACHED FILES: %s" % ', '.join(result['files'])) + print("TIMELINE:") + last = -1 + delay = [] + for k in sorted_times: + if k != "updated_time": + current = int(result[k]) + delta = current-last if last != -1 else 0 + delay.append("(%ds)" % delta) + last = current + delay.append('') + idx = 1 + for k in sorted_times: + if k != "updated_time": + current = int(result[k]) + date = datetime.fromtimestamp(current).isoformat(' ') + print("\t%-12s\t%s\t%s" % (k[:-5], date, delay[idx])) + idx += 1 + content = result["content"] + content = json.loads(content) + print("CONTENT") + print(json.dumps(content, indent=True)) + sys.exit(0) +elif args.cmd == "dt": + r = requests.get(os.path.join(args.url, "list_tasks", args.prefix + '*')) + if r.status_code != 200: + logger.error('incorrect result from \'list_tasks\' service: %s', r.text) + sys.exit(1) + result = r.json() + if not args.json: + print('Delete %d tasks:' % len(result)) + print("\t%-32s\t%-20s\t%-16s\t%-10s\t%s" % ("TASK_ID", "LAUNCH DATE", "IMAGE", "STATUS", "MESSAGE")) + for k in sorted(result, key=lambda k: int(k["queued_time"])): + date = datetime.fromtimestamp(int(k["queued_time"])).isoformat(' ') + print("\t%-32s\t%-20s\t%-16s\t%-10s\t%s" % ( + k["task_id"], date, k["image"], k["status"], k.get("message"))) + if confirm(): + for k in result: + r = requests.get(os.path.join(args.url, "del", k["task_id"])) + if r.status_code != 200: + logger.error('incorrect result from \'delete_task\' service: %s', r.text) + sys.exit(1) + sys.exit(0) +elif args.cmd == "terminate": + r = requests.get(os.path.join(args.url, "terminate", args.task_id)) + if r.status_code != 200: + logger.error('incorrect result from \'terminate\' service: %s', r.text) + sys.exit(1) + result = r.json() + if not args.json: + print(result["message"]) + sys.exit(0) +elif args.cmd == "file": + r = requests.get(os.path.join(args.url, "file", args.task_id, args.filename)) + if r.status_code != 200: + logger.error('incorrect result from \'log\' service: %s', r.text) + sys.exit(1) + print(r.text) + sys.exit(0) + +print(json.dumps(result)) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..4c19a27e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +boto3 +enum34 +flask +ovh +paramiko +redis +requests \ No newline at end of file diff --git a/server/config/templates/default.json b/server/config/templates/default.json new file mode 100644 index 00000000..2acc9ed1 --- /dev/null +++ b/server/config/templates/default.json @@ -0,0 +1,55 @@ +{ + "docker": { + "registries": { + "aws": { + "type": "aws", + "credentials": { + "AWS_ACCESS_KEY_ID": "XXXXX", + "AWS_SECRET_ACCESS_KEY": "XXXXX" + }, + "uri": "XXXXX.dkr.ecr.eu-west-3.amazonaws.com", + "region": "eu-west-3" + }, + "dockerhub": { + "type": "dockerhub", + "uri": "" + }, + "mydockerprivate": { + "type": "dockerprivate", + "uri": "", + "credentials": { + "password": "XXXXX", + "username": "XXXXX" + } + } + } + }, + "storages" : { + "s3_models": { + "type": "s3", + "bucket": "model-catalog", + "aws_credentials": { + "access_key_id": "XXXXX", + "secret_access_key": "XXXXX", + "region_name": "eu-west-3" + } + }, + "s3_test": { + "type": "s3", + "bucket": "pn9-testing", + "aws_credentials": { + "access_key_id": "XXXXX", + "secret_access_key": "XXXXX", + "region_name": "eu-west-3" + } + }, + "myserver": { + "type": "ssh", + "server": "myserver_url", + "user": "XXXXX", + "password": "XXXXX" + } + }, + "callback_url": "http://XXXXX:5000", + "callback_interval": 60 +} \ No newline at end of file diff --git a/server/config/templates/ec2-template.json b/server/config/templates/ec2-template.json new file mode 100644 index 00000000..5775ec6b --- /dev/null +++ b/server/config/templates/ec2-template.json @@ -0,0 +1,29 @@ +{ + "name": "ec2", + "description": "Instance on AWS EC2", + "module": "services.ec2", + "awsAccessKeyId": "XXXXX", + "awsSecretAccessKey": "XXXXX", + "awsRegion": "eu-west-3", + "privateKeysDirectory": "credentials", + "amiUsername": "ec2-user", + "logDir": "/home/ec2-user", + "sshConnectionDelay": 10, + "maxSshConnectionRetry": 3, + "maxInstancePerTemplate": 5, + "corpus": { + "bucket": "pn9-training", + "mount": "/home/ec2-user/corpus", + "credentials": { + "AWS_ACCESS_KEY_ID": "XXXXX", + "AWS_SECRET_ACCESS_KEY": "XXXXX" + }, + "region": "eu-west-3" + }, + "docker": { + "mount": [ + "/home/ec2-user/corpus:/root/corpus" + ] + }, + "disabled": 1 +} diff --git a/server/config/templates/ssh-template.json b/server/config/templates/ssh-template.json new file mode 100644 index 00000000..fc7f172f --- /dev/null +++ b/server/config/templates/ssh-template.json @@ -0,0 +1,25 @@ +{ + "name": "ssaling04", + "description": "GPU training server", + "visibility": [ "JS" ], + "module": "services.ssh", + "variables": { + "server_pool": [ + { + "host": "ssaling04", + "gpus": [1,2,3,4], + "login": "devling", + "log_dir": "inftraining_logs" + } + ] + }, + "privateKey": "credentials/id_rsa", + "docker": { + "mount": [ + "/mnt/nmt-corpus-pn9:/root/corpus" + ], + "envvar": { + } + }, + "disabled": 1 +} \ No newline at end of file diff --git a/server/config/templates/torque-template.json b/server/config/templates/torque-template.json new file mode 100644 index 00000000..e093cd92 --- /dev/null +++ b/server/config/templates/torque-template.json @@ -0,0 +1,23 @@ +{ + "name": "ssaling-cluster", + "description": "ssaling GPU clusters", + "module": "services.torque", + "maxInstance": 100, + "variables": { + "master_node": "ssaling01", + "torque_install_path": "/usr/local/torque/bin", + "log_dir": "inftraining_logs" + }, + "privateKey": "credentials/id_rsa", + "docker": { + "mount": [ + "/mnt/nmt-corpus-pn9:/root/corpus" + ], + "envvar": { + "AWS_ACCESS_KEY_ID": "XXXXX", + "AWS_SECRET_ACCESS_KEY": "XXXXX", + "AWS_DEFAULT_REGION": "eu-west-3" + } + }, + "disabled": 0 +} \ No newline at end of file diff --git a/server/main.py b/server/main.py new file mode 100644 index 00000000..995c1b58 --- /dev/null +++ b/server/main.py @@ -0,0 +1,178 @@ +import uuid +import os +import logging +import flask +import json + +from six.moves import configparser + +from nmtwizard import common, config, task +from nmtwizard.redis_database import RedisDatabase + +ch = logging.StreamHandler() +ch.setLevel(logging.ERROR) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +ch.setFormatter(formatter) + +config.add_log_handler(ch) +common.add_log_handler(ch) + +cfg = configparser.ConfigParser() +cfg.read('settings.ini') +MODE = os.getenv('LAUNCHER_MODE', 'Production') + +redis_password = None +if cfg.has_option(MODE, 'redis_password'): + redis_password = cfg.get(MODE, 'redis_password') + +redis = RedisDatabase(cfg.get(MODE, 'redis_host'), + cfg.getint(MODE, 'redis_port'), + cfg.get(MODE, 'redis_db'), + redis_password) +services = config.load_services(cfg.get(MODE, 'config_dir')) + + +def _get_service(service): + """Wrapper to fail on invalid service.""" + if service not in services: + response = flask.jsonify(message="invalid service name: %s" % service) + flask.abort(flask.make_response(response, 404)) + return services[service] + + +app = flask.Flask(__name__) + +@app.route("/list_services", methods=["GET"]) +def list_services(): + return flask.jsonify({k: services[k].display_name for k in services}) + +@app.route("/describe/", methods=["GET"]) +def describe(service): + service_module = _get_service(service) + return flask.jsonify(service_module.describe()) + +@app.route("/check/", methods=["GET"]) +def check(service): + service_options = flask.request.get_json() if flask.request.is_json else None + if service_options is None: + service_options = {} + service_module = _get_service(service) + try: + details = service_module.check(service_options) + except ValueError as e: + flask.abort(flask.make_response(flask.jsonify(message=str(e)), 400)) + except Exception as e: + flask.abort(flask.make_response(flask.jsonify(message=str(e)), 500)) + else: + return flask.jsonify(message=details) + +@app.route("/launch/", methods=["POST"]) +def launch(service): + content = None + files = {} + if flask.request.is_json: + content = flask.request.get_json() + else: + content = flask.request.form.get('content') + if content is not None: + content = json.loads(content) + for k in flask.request.files: + files[k] = flask.request.files[k].read() + if content is None: + flask.abort(flask.make_response(flask.jsonify(message="missing content in request"), 400)) + service_module = _get_service(service) + content["service"] = service + task_id = str(uuid.uuid4()) + if 'trainer_id' in content and content['trainer_id']: + task_id = (content['trainer_id']+'_'+task_id)[0:35] + # Sanity check on content. + if 'options' not in content or not isinstance(content['options'], dict): + flask.abort(flask.make_response(flask.jsonify(message="invalid options field"), 400)) + if 'docker' not in content: + flask.abort(flask.make_response(flask.jsonify(message="missing docker field"), 400)) + resource = service_module.get_resource_from_options(content["options"]) + task.create(redis, task_id, resource, service, content, files) + return flask.jsonify(task_id) + +@app.route("/status/", methods=["GET"]) +def status(task_id): + if not task.exists(redis, task_id): + flask.abort(flask.make_response(flask.jsonify(message="task %s unknown" % task_id), 404)) + response = task.info(redis, task_id, []) + return flask.jsonify(response) + +@app.route("/del/", methods=["GET"]) +def del_task(task_id): + response = task.delete(redis, task_id) + if isinstance(response, list) and not response[0]: + flask.abort(flask.make_response(flask.jsonify(message=response[1]), 400)) + return flask.jsonify(message="deleted %s" % task_id) + +@app.route("/list_tasks/", methods=["GET"]) +def list_tasks(pattern): + ltask = [] + for task_key in task.scan_iter(redis, pattern): + task_id = task.id(task_key) + info = task.info(redis, task_id, ["queued_time", "service", "content", "status", "message"]) + content = json.loads(info["content"]) + info["image"] = content['docker']['image'] + del info['content'] + info['task_id'] = task_id + ltask.append(info) + return flask.jsonify(ltask) + +@app.route("/terminate/", methods=["GET"]) +def terminate(task_id): + with redis.acquire_lock(task_id): + current_status = task.info(redis, task_id, "status") + if current_status is None: + flask.abort(flask.make_response(flask.jsonify(message="task %s unknown" % task_id), 404)) + elif current_status == "stopped": + return flask.jsonify(message="%s already stopped" % task_id) + phase = flask.request.args.get('phase') + task.terminate(redis, task_id, phase=phase) + return flask.jsonify(message="terminating %s" % task_id) + +@app.route("/beat/", methods=["GET"]) +def beat(task_id): + duration = flask.request.args.get('duration') + try: + if duration is not None: + duration = int(duration) + except ValueError: + flask.abort(flask.make_response(flask.jsonify(message="invalid duration value"), 400)) + container_id = flask.request.args.get('container_id') + if not task.exists(redis, task_id): + flask.abort(flask.make_response(flask.jsonify(message="task %s unknown" % task_id), 404)) + task.beat(redis, task_id, duration, container_id) + return flask.jsonify(200) + +@app.route("/log/", methods=["GET"]) +def get_log(task_id): + if not task.exists(redis, task_id): + flask.abort(flask.make_response(flask.jsonify(message="task %s unknown" % task_id), 404)) + content = task.get_log(redis, task_id) + if content is None: + flask.abort(flask.make_response(flask.jsonify(message="no logs for task %s" % task_id), 404)) + response = flask.make_response(content) + response.mimetype = 'text/plain' + return response + +@app.route("/file//", methods=["GET"]) +def get_file(task_id, filename): + if not task.exists(redis, task_id): + flask.abort(flask.make_response(flask.jsonify(message="task %s unknown" % task_id), 404)) + content = task.get_file(redis, task_id, filename) + if content is None: + flask.abort(flask.make_response( + flask.jsonify(message="cannot find file %s for task %s" % (filename, task_id)), 404)) + response = flask.make_response(content) + return response + +@app.route("/file//", methods=["POST"]) +def post_file(task_id, filename): + if not task.exists(redis, task_id): + flask.abort(flask.make_response(flask.jsonify(message="task %s unknown" % task_id), 404)) + content = flask.request.get_data() + task.set_file(redis, task_id, content, filename) + return flask.jsonify(200) diff --git a/server/nmtwizard/__init__.py b/server/nmtwizard/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/nmtwizard/common.py b/server/nmtwizard/common.py new file mode 100644 index 00000000..d0181af6 --- /dev/null +++ b/server/nmtwizard/common.py @@ -0,0 +1,317 @@ +import time +import json +import re +import logging +import six + +import paramiko + +logger = logging.getLogger(__name__) + +def add_log_handler(fh): + logger.addHandler(fh) + +def run_command(client, cmd, stdin_content=None, sudo=False): + if sudo: + cmd = "sudo " + cmd + logger.debug("RUN %s", cmd) + stdin, stdout, stderr = client.exec_command(cmd) + if stdin_content is not None: + stdin.write(stdin_content) + stdin.flush() + exit_status = stdout.channel.recv_exit_status() + return exit_status, stdout, stderr + +def run_and_check_command(client, cmd, stdin_content=None, sudo=False): + exit_status, _, _ = run_command( + client, cmd, stdin_content=stdin_content, sudo=sudo) + return exit_status == 0 + +def program_exists(client, program): + return run_and_check_command(client, "command -v %s" % program) + +def has_gpu_support(client): + return run_and_check_command(client, "nvidia-smi") + +def ssh_connect_with_retry(client, + hostname, + username, + key_path, + delay=0, + retry=0, + login_cmd=None): + """Wrap the SSH connect method with a delay and retry mechanism. This is + useful when connecting to an instance that was freshly started. + """ + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + while True: + if delay > 0: + time.sleep(delay) + try: + client.load_system_host_keys() + logger.info("Connecting to %s via SSH...", hostname) + client.connect( + hostname, + username=username, + key_filename=key_path, + look_for_keys=False) + if login_cmd is not None: + if not run_and_check_command(client, login_cmd): + raise RuntimeError("failed to run login command") + return client + except Exception as e: + retry -= 1 + if retry < 0: + raise e + else: + logger.warning("Failed to connect to %s via SSH (%s), retrying...", + hostname, str(e)) + +def fuse_s3_bucket(client, corpus): + if not program_exists(client, "s3fs"): + raise EnvironmentError("s3fs is not installed") + if not run_and_check_command( + client, + "sed 's/# *user_allow_other/user_allow_other/' -i /etc/fuse.conf", + sudo=True): + raise RuntimeError("failed to configure s3fs") + status, _, stderr = run_command(client, "mkdir -p %s && chmod -R 775 %s" % ( + corpus["mount"], corpus["mount"])) + if status != 0: + return RuntimeError('failed to created mount directory: %s' % stderr.read()) + status, _, stderr = run_command(client, "echo %s:%s > s3_passwd && chmod 600 s3_passwd" % ( + corpus["credentials"]["AWS_ACCESS_KEY_ID"], + corpus["credentials"]["AWS_SECRET_ACCESS_KEY"])) + if status != 0: + raise RuntimeError('failed to store S3 credentials: %s' % stderr.read()) + status, _, stderr = run_command( + client, "s3fs %s %s -o allow_other -o passwd_file=s3_passwd" % ( + corpus["bucket"], corpus["mount"])) + if status != 0: + raise RuntimeError('failed to fuse S3 bucket: %s' % stderr.read()) + +def check_environment(client, gpu_id, log_dir, docker_registries): + """Check that the environment contains all the tools necessary to launch a task + """ + for registry in six.itervalues(docker_registries): + if registry['type'] == 'aws' and not program_exists(client, 'aws'): + raise EnvironmentError("missing aws client") + + # check log_dir + if not run_and_check_command(client, "test -d '%s'" % log_dir): + raise EnvironmentError("incorrect log directory: %s" % log_dir) + + if gpu_id == 0: + if not program_exists(client, "docker"): + raise EnvironmentError("docker not available") + return '' + else: + if not program_exists(client, "nvidia-docker"): + raise EnvironmentError("nvidia-docker not available") + + exit_status, stdout, stderr = run_command( + client, 'nvidia-smi -q -i %d -d UTILIZATION,MEMORY' % (gpu_id - 1)) + if exit_status != 0: + raise EnvironmentError("nvidia-smi exited with status %d: %s" % ( + exit_status, stderr.read())) + + out = stdout.read() + gpu = '?' + mem = '?' + m = re.search(b'Gpu *: (.*) *\n', out) + if m: + gpu = m.group(1).decode('utf-8') + m = re.search(b'Free *: (.*) *\n', out) + if m: + mem = m.group(1).decode('utf-8') + + return 'gpu usage: %s, free mem: %s' % (gpu, mem) + +def cmd_connect_private_registry(docker_registry): + if docker_registry['type'] == "aws": + return ('$(AWS_ACCESS_KEY_ID=%s AWS_SECRET_ACCESS_KEY=%s ' + 'aws ecr get-login --no-include-email --region %s)') % ( + docker_registry['credentials']['AWS_ACCESS_KEY_ID'], + docker_registry['credentials']['AWS_SECRET_ACCESS_KEY'], + docker_registry['region']) + username = docker_registry['credentials']['username'] + password = docker_registry['credentials']['password'] + return ('docker login --username %s --password %s') % (username, password) + +def cmd_docker_pull(image_ref, docker_path=None): + path = "" + if docker_path is not None: + path = docker_path + "/" + return '%sdocker pull %s' % (path, image_ref) + +def _protect_arg(arg): + return "'" + re.sub(r"(')", r"\\\1", arg.strip()) + "'" + +def cmd_docker_run(gpu_id, docker_options, task_id, + image_ref, callback_url, callback_interval, + storages, docker_command, log_dir=None): + if docker_options.get('dev') == 1: + return "sleep 35" + else: + docker_cmd = 'docker' if gpu_id == 0 else 'nvidia-docker' + + # launch the task + cmd = '%s run -i --rm' % docker_cmd + if 'mount' in docker_options: + for k in docker_options['mount']: + cmd += ' -v %s' % k + if 'envvar' in docker_options: + for k in docker_options['envvar']: + cmd += ' -e %s=%s' % (k, docker_options['envvar'][k]) + + # mount TMP_DIR used to store potential transfered files + cmd += ' -e TMP_DIR=/root/tmp/%s' % task_id + + cmd += ' %s' % image_ref + + if storages is not None and storages != {}: + v = json.dumps(storages) + v = v.replace("", task_id) + v = v.replace("", callback_url) + cmd += ' -s \'%s\'' % v + + cmd += ' -g %s' % gpu_id + cmd += ' -t %s' % task_id + if callback_url is not None and callback_url != '': + cmd += ' -b \'%s\'' % callback_url + if callback_interval is not None: + cmd += ' -bi %d' % callback_interval + + cmd += ' -i %s' % image_ref + + for arg in docker_command: + if arg.startswith('${TMP_DIR}'): + arg = '/root/tmp/%s%s' % (task_id, arg[10:]) + cmd += ' ' + _protect_arg(arg) + + if log_dir is not None: + cmd += ' > %s/\"%s.log\" 2>&1' % (log_dir, task_id) + + return cmd + +def launch_task(task_id, + client, + gpu_id, + log_dir, + docker_options, + the_docker_registry, + docker_image, + docker_tag, + docker_command, + docker_files, + wait_for_immediate_failure=2, + storages=None, + callback_url=None, + callback_interval=None): + """Launch a task: + * `task_id`: assigned id for the task and used for logging + * `client`: ssh client + * `gpu_id`: if > 0, the id (eventually ids) of the GPU to use on the host + * `log_dir`: host docker log + * `docker_options`: environment and mounting points + * `docker_image`: image name of the docker + * `docker_tag`: tag - by default latest + * `docker_command`: the actual command to launch + * `wait_for_immediate_failure`: time to wait and monitor launched process + * `storages`: dictionary of storage to use in the docker execution + * `callback_url`: server to callback for beat of activity + * `callback_interval`: time between 2 beats + """ + image_ref = "" + if docker_options.get('dev') != 1: + docker_registry = docker_options['registries'][the_docker_registry] + + registry_uri = docker_registry['uri'] + + # connect to a registry + if docker_registry['type'] != 'dockerhub': + exit_status, stdout, stderr = run_command( + client, + cmd_connect_private_registry(docker_registry) + ) + if exit_status != 0: + raise RuntimeError("cannot connect to private registry: %s" % stderr.read()) + + # pull the docker image + registry_urip = '' if registry_uri == '' else registry_uri + '/' + image_ref = '%s%s:%s' % (registry_urip, docker_image, docker_tag) + logger.debug("pulling docker image: %s - %s", docker_registry['type'], docker_image) + docker_cmd = cmd_docker_pull(image_ref, docker_path=docker_options.get('path')) + exit_status, stdout, stderr = run_command(client, docker_cmd) + if exit_status != 0: + raise RuntimeError("error pulling the image %s: %s" % (image_ref, stderr.read())) + + if len(docker_files): + # we have files to synchronize locally + assert 'mount' in docker_options, "mount point should be defined for passing files" + assert callback_url is not None, "callback_url needed for passing files" + mount_tmpdir = None + for m in docker_options['mount']: + if m.endswith('/root/tmp'): + mount_tmpdir = m[:-10] + break + assert mount_tmpdir is not None, "mount points need to include /root/tmp for passing files" + cmd_mkdir = "mkdir -p %s/%s" % (mount_tmpdir, task_id) + exit_status, stdout, stderr = run_command(client, cmd_mkdir) + if exit_status != 0: + raise RuntimeError("error build task tmp dir: %s, %s" % (cmd_mkdir, stderr.read())) + for f in docker_files: + logger.info("retrieve file %s -> %s/%s", f, mount_tmpdir, task_id) + cmd_get_files = 'curl "%s/file/%s/%s" > %s/%s/%s' % ( + callback_url, + task_id, + f, + mount_tmpdir, + task_id, + f) + exit_status, stdout, stderr = run_command(client, cmd_get_files) + if exit_status != 0: + raise RuntimeError("error retrieving files: %s, %s" % (cmd_get_files, stderr.read())) + + cmd = 'nohup ' + cmd_docker_run(gpu_id, docker_options, task_id, + image_ref, callback_url, callback_interval, + storages, docker_command, log_dir) + log_file = "%s/%s.log" % (log_dir, task_id) + if callback_url is not None: + cmd = '(%s ; status=$?' % cmd + if log_dir is not None and log_dir != '': + cmd = '%s ; curl -X POST "%s/file/%s/log" --data-binary "@%s"' % ( + cmd, + callback_url, + task_id, + log_file) + cmd = ('%s ; if [[ $status = 0 ]]; then curl -X GET "%s/terminate/%s?phase=completed";' + + ' else curl -X GET "%s/terminate/%s?phase=error"; fi )') % ( + cmd, callback_url, task_id, callback_url, task_id) + + # get the process group id + cmd += ' & ps -o pgid -p $!' + + exit_status, stdout, stderr = run_command(client, cmd) + if exit_status != 0: + raise RuntimeError("%s run failed: %s" % (cmd, stderr.read())) + + # read ps header + outpgid = stdout.readline() + # read pgid + outpgid = stdout.readline() + m = re.search(r'(\d+)', outpgid) + if not m: + raise RuntimeError("cannot get PGID") + pgid = int(m.group(1)) + logger.info("Process launched with pgid %d.", pgid) + + # check what is happening 1s later - just to check immediate failure + if wait_for_immediate_failure > 0: + logger.info("Wait for %d seconds and check process status.", wait_for_immediate_failure) + time.sleep(wait_for_immediate_failure) + if not run_and_check_command(client, 'kill -0 -%d' % pgid): + _, stdout, stderr = run_command(client, 'cat %s' % log_file) + raise RuntimeError("process exited early: %s" % stdout.read()) + + return {"model": task_id, "pgid": pgid} diff --git a/server/nmtwizard/config.py b/server/nmtwizard/config.py new file mode 100644 index 00000000..9ca60cb8 --- /dev/null +++ b/server/nmtwizard/config.py @@ -0,0 +1,86 @@ +import os +import json +import logging +import importlib +import six + +logger = logging.getLogger(__name__) + +_BASE_CONFIG_NAME = "default.json" + + +def add_log_handler(fh): + logger.addHandler(fh) + +def merge_config(a, b, name): + assert type(a) == type(b), "default and %s config file are not compatible" % name + if isinstance(a, dict): + for k in six.iterkeys(b): + if k not in a: + a[k] = b[k] + elif isinstance(a[k], dict): + merge_config(a[k], b[k], name) + +def load_service(config_path, base_config=None): + """Loads a service configuration. + + Args: + config_path: Path the service configuration to load. + base_config: The shared configuration to include in this service. + + Returns: + name: The service name + service: The service manager. + """ + with open(config_path) as config_file: + config = json.load(config_file) + if base_config is not None: + merge_config(config, base_config, config_path) + name = config["name"] + if config.get("disabled") == 1: + return name, None + if "module" not in config or "docker" not in config or "description" not in config: + raise ValueError("invalid service definition in %s" % config_path) + service = importlib.import_module(config["module"]).init(config) + return name, service + +def load_services(directory): + """Loads configured services. + + Each service is configured by a JSON file and a optional shared configuration + named "default.json". + + Args: + directory: The directory to load services from. + + Returns: + A map of service name to service module. + """ + if not os.path.isdir(directory): + raise ValueError("invalid service directory %s" % os.path.abspath(directory)) + + base_config = {} + base_config_path = os.path.join(directory, _BASE_CONFIG_NAME) + if os.path.exists(base_config_path): + logger.info("Reading base configuration %s", base_config_path) + with open(base_config_path) as base_config_file: + base_config = json.load(base_config_file) + + logger.info("Loading services from %s", directory) + services = {} + for filename in os.listdir(directory): + config_path = os.path.join(directory, filename) + if (not os.path.isfile(config_path) + or not filename.endswith(".json") + or filename == _BASE_CONFIG_NAME): + continue + logger.info("Loading service configuration %s", config_path) + name, service = load_service(config_path, base_config=base_config) + if service is None: + logger.info("Skipping disabled service %s", name) + continue + if name in services: + raise RuntimeError("%s duplicates service %s definition" % (filename, name)) + services[name] = service + logger.info("Loaded service %s (total capacity: %s)", name, service.total_capacity) + return services diff --git a/server/nmtwizard/redis_database.py b/server/nmtwizard/redis_database.py new file mode 100644 index 00000000..f8859c62 --- /dev/null +++ b/server/nmtwizard/redis_database.py @@ -0,0 +1,65 @@ +import uuid +import time +import logging +import redis + +logger = logging.getLogger(__name__) + + +class RedisDatabase(redis.Redis): + """Extension to redis.Redis.""" + + def __init__(self, host, port, db, password): + """Creates a new database instance.""" + super(RedisDatabase, self).__init__( + host=host, + port=port, + db=db, + password=password, + decode_responses=True) + + def acquire_lock(self, name, acquire_timeout=10, expire_time=60): + return RedisLock(self, name, acquire_timeout=acquire_timeout, expire_time=expire_time) + + +class RedisLock(object): + + def __init__(self, redis, name, acquire_timeout=10, expire_time=60): + self._redis = redis + self._name = name + self._acquire_timeout = acquire_timeout + self._expire_time = expire_time + self._identifier = None + + def __enter__(self): + """Adds a lock for a specific name and expires the lock after some delay.""" + logger.debug('Acquire lock for %s', self._name) + self._identifier = str(uuid.uuid4()) + end = time.time() + self._acquire_timeout + lock = 'lock:%s' % self._name + while time.time() < end: + if self._redis.setnx(lock, self._identifier): + self._redis.expire(lock, self._expire_time) + return self + time.sleep(.01) + raise RuntimeWarning("failed to acquire lock on %s" % self._name) + + def __exit__(self, exc_type, exc_val, exc_tb): + """Releases a lock given some identifier and makes sure it is the one we set + (could have been destroyed in the meantime). + """ + logger.debug('Release lock for %s', self._name) + pipe = self._redis.pipeline(True) + lock = 'lock:%s' % self._name + while True: + try: + pipe.watch(lock) + if pipe.get(lock) == self._identifier: + pipe.multi() + pipe.delete(lock) + pipe.execute() + pipe.unwatch() + break + except redis.exceptions.WatchError: + pass + return False diff --git a/server/nmtwizard/service.py b/server/nmtwizard/service.py new file mode 100644 index 00000000..2799a4b1 --- /dev/null +++ b/server/nmtwizard/service.py @@ -0,0 +1,105 @@ +"""Base class for services: objects that can start, monitor, and terminate +Docker-based tasks. +""" + +import logging +import abc +import six + + +@six.add_metaclass(abc.ABCMeta) +class Service(object): + """Base class for services.""" + + def __init__(self, config): + self._config = config + + @property + def name(self): + """Returns the name of the service.""" + return self._config['name'] + + @property + def display_name(self): + """Returns the detailed name of the service.""" + return self._config['description'] + + @property + def is_notifying_activity(self): + return self._config.get('callback_url') + + @property + def total_capacity(self): + """Total capacity of the service (i.e. the total number of tasks that + can run at the same time). + """ + return sum(six.itervalues(self.list_resources())) + + @abc.abstractmethod + def list_resources(self): + """Lists resources covered by the service. + + Returns: + A dictionary of resource name to their maximum capacity (-1 for unbounded). + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_resource_from_options(self, options): + """Returns the selected resource. + + Args: + options: The options provided by the user. + + Returns: + The name of the selected resource. + + See Also: + describe(). + """ + raise NotImplementedError() + + def describe(self): + """Describe the service options. + + Returns: + A (possibly empty) dictionary following the JSON form structure. + """ + return {} + + @abc.abstractmethod + def check(self, options): + """Checks if a task can be launched on the service. + + Args: + options: The user options to use for the launch. + + Returns: + A (possibly empty) string with details on the target service and + resource. + """ + raise NotImplementedError() + + @abc.abstractmethod + def launch(self, + task_id, + options, + resource, + docker_registry, + docker_image, + docker_tag, + docker_command, + docker_files, + wait_after_launch): + """Launches a new task.""" + raise NotImplementedError() + + @abc.abstractmethod + def status(self, params): + """Returns the status of a task as string.""" + raise NotImplementedError() + + @abc.abstractmethod + def terminate(self, params): + """Terminates a (possibly) running task.""" + raise NotImplementedError() diff --git a/server/nmtwizard/task.py b/server/nmtwizard/task.py new file mode 100644 index 00000000..0cb3a73d --- /dev/null +++ b/server/nmtwizard/task.py @@ -0,0 +1,146 @@ +import time +import json + +def set_status(redis, keyt, status): + """Sets the status and save the time of change.""" + redis.hset(keyt, "status", status) + redis.hset(keyt, status + "_time", int(time.time())) + +def exists(redis, task_id): + """Checks if a task exist.""" + return redis.exists("task:" + task_id) + +def create(redis, task_id, resource, service, content, files): + """Creates a new task and enables it.""" + keyt = "task:" + task_id + redis.hset(keyt, "resource", resource) + redis.hset(keyt, "service", service) + redis.hset(keyt, "content", json.dumps(content)) + set_status(redis, keyt, "queued") + for k in files: + redis.hset("files:" + task_id, k, files[k]) + enable(redis, task_id) + queue(redis, task_id) + +def terminate(redis, task_id, phase): + """Requests task termination (assume it is locked).""" + if phase is None: + phase = "aborted" + keyt = "task:" + task_id + if redis.hget(keyt, "status") in ("terminating", "stopped"): + return + redis.hset(keyt, "message", phase) + set_status(redis, keyt, "terminating") + queue(redis, task_id) + +def queue(redis, task_id, delay=0): + """Queues the task in the work queue with a delay.""" + if delay == 0: + redis.lpush('work', task_id) + redis.delete('queue:'+task_id) + else: + redis.set('queue:'+task_id, delay) + redis.expire('queue:'+task_id, int(delay)) + +def unqueue(redis): + """Pop a task from the work queue.""" + return redis.rpop('work') + +def enable(redis, task_id): + """Marks a task as enabled.""" + redis.sadd("active", task_id) + +def disable(redis, task_id): + """Marks a task as disabled.""" + redis.srem("active", task_id) + +def list_active(redis): + """Returns all active tasks (i.e. non stopped).""" + return redis.smembers("active") + +def file_list(redis, task_id): + """Returns the list of files attached to a task""" + keyf = "files:" + task_id + return redis.hkeys(keyf) + +def info(redis, task_id, fields): + """Gets information on a task.""" + keyt = "task:" + task_id + field = None + if not isinstance(fields, list): + field = fields + fields = [field] + if not fields: + # only if we want all information - add a lock on the resource + with redis.acquire_lock(keyt): + fields = redis.hkeys(keyt) + fields.append("ttl") + r=info(redis, task_id, fields) + r['files'] = file_list(redis, task_id) + return r + r = {} + for f in fields: + if f != "ttl": + r[f] = redis.hget(keyt, f) + else: + r[f] = redis.ttl("beat:" + task_id) + if field: + return r[field] + r["current_time"] = int(time.time()) + return r + +def delete(redis, task_id): + """Delete a given task.""" + keyt = "task:" + task_id + status = redis.hget(keyt, "status") + if status is None: + return (False, "task does not exists") + if status != "stopped": + return (False, "status is not stopped") + with redis.acquire_lock(keyt): + redis.delete(keyt) + redis.delete("queue:" + task_id) + redis.delete("files:" + task_id) + return True + +# TODO: create iterator returning directly task_id +def scan_iter(redis, pattern): + return redis.scan_iter('task:' + pattern) + +def id(task_key): + return task_key[5:] + +def beat(redis, task_id, duration, container_id): + """Sends an update event to the task and add an expiration time + (set duration to 0 to disable expiration). The task must be running. + """ + keyt = "task:" + task_id + with redis.acquire_lock(keyt): + # a beat can only be sent in running mode except if in between, the task stopped + # or in development mode, no need to raise an alert + if redis.hget(keyt, "status") != "running": + return + if duration is not None: + if duration == 0: + redis.delete("beat:" + task_id) + else: + redis.set("beat:" + task_id, duration) + redis.expire("beat:" + task_id, duration) + queue = redis.get("queue:" + task_id) + # renew ttl of queue + if queue is not None: + redis.expire("queue:" + task_id, int(queue)) + redis.hset(keyt, "updated_time", int(time.time())) + if container_id is not None: + redis.hset(keyt, "container_id", container_id) + +def set_file(redis, task_id, content, filename): + keyf = "files:" + task_id + redis.hset(keyf, filename, content) + +def get_file(redis, task_id, filename): + keyf = "files:" + task_id + return redis.hget(keyf, filename) + +def get_log(redis, task_id): + return get_file(redis, task_id, "log") diff --git a/server/nmtwizard/worker.py b/server/nmtwizard/worker.py new file mode 100644 index 00000000..d7a26c79 --- /dev/null +++ b/server/nmtwizard/worker.py @@ -0,0 +1,166 @@ +import time +import json +import logging +import six + +from nmtwizard import task + + +class Worker(object): + + def __init__(self, redis, services, index=0): + self._redis = redis + self._services = services + self._logger = logging.getLogger('worker%d' % index) + + def run(self): + self._logger.info('Starting worker') + + # Subscribe to beat expiration. + pubsub = self._redis.pubsub() + pubsub.psubscribe('__keyspace@0__:beat:*') + pubsub.psubscribe('__keyspace@0__:queue:*') + + while True: + message = pubsub.get_message() + if message: + channel = message['channel'] + data = message['data'] + if data == 'expired': + if channel.startswith('__keyspace@0__:beat:'): + task_id = channel[20:] + self._logger.info('%s: task expired', task_id) + with self._redis.acquire_lock(task_id): + task.terminate(self._redis, task_id, phase='expired') + elif channel.startswith('__keyspace@0__:queue:'): + task_id = channel[21:] + task.queue(self._redis, task_id) + else: + task_id = task.unqueue(self._redis) + if task_id is not None: + try: + self._advance_task(task_id) + except RuntimeWarning: + self._logger.warning( + '%s: failed to acquire a lock, retrying', task_id) + task.queue(self._redis, task_id) + except Exception as e: + self._logger.error('%s: %s', task_id, str(e)) + with self._redis.acquire_lock(task_id): + task.terminate(self._redis, task_id, phase="launch_error") + time.sleep(0.1) + + def _advance_task(self, task_id): + """Tries to advance the task to the next status. If it can, re-queue it immediately + to process the next stage. Otherwise, re-queue it after some delay to try again. + """ + keyt = 'task:%s' % task_id + with self._redis.acquire_lock(keyt, acquire_timeout=1, expire_time=600): + status = self._redis.hget(keyt, 'status') + if status == 'stopped': + return + + service_name = self._redis.hget(keyt, 'service') + if service_name not in self._services: + raise ValueError('unknown service %s' % service_name) + service = self._services[service_name] + + self._logger.info('%s: trying to advance from status %s', task_id, status) + + if status == 'queued': + resource = self._redis.hget(keyt, 'resource') + resource = self._allocate_resource(task_id, resource, service) + if resource is not None: + self._logger.info('%s: resource %s reserved', task_id, resource) + self._redis.hset(keyt, 'resource', resource) + task.set_status(self._redis, keyt, 'allocated') + task.queue(self._redis, task_id) + else: + self._logger.warning('%s: no resources available, waiting', task_id) + self._wait_for_resource(service, task_id) + + elif status == 'allocated': + content = json.loads(self._redis.hget(keyt, 'content')) + resource = self._redis.hget(keyt, 'resource') + self._logger.info('%s: launching on %s', task_id, service.name) + data = service.launch( + task_id, + content['options'], + resource, + content['docker']['registry'], + content['docker']['image'], + content['docker']['tag'], + content['docker']['command'], + task.file_list(self._redis, task_id), + content['wait_after_launch']) + self._logger.info('%s: task started on %s', task_id, service.name) + self._redis.hset(keyt, 'job', json.dumps(data)) + task.set_status(self._redis, keyt, 'running') + # For services that do not notify their activity, we should + # poll the task status more regularly. + task.queue(self._redis, task_id, delay=service.is_notifying_activity and 120 or 30) + + elif status == 'running': + data = json.loads(self._redis.hget(keyt, 'job')) + status = service.status(data) + if status == 'dead': + self._logger.info('%s: task no longer running on %s, request termination', + task_id, service.name) + task.terminate(self._redis, task_id, phase='exited') + else: + task.queue(self._redis, task_id, delay=service.is_notifying_activity and 120 or 30) + + elif status == 'terminating': + data = self._redis.hget(keyt, 'job') + if data is not None: + data = json.loads(data) + self._logger.info('%s: terminating task', task_id) + try: + service.terminate(data) + self._logger.info('%s: terminated', task_id) + except Exception: + self._logger.warning('%s: failed to terminate', task_id) + resource = self._redis.hget(keyt, 'resource') + self._release_resource(service, resource, task_id) + task.set_status(self._redis, keyt, 'stopped') + task.disable(self._redis, task_id) + + def _allocate_resource(self, task_id, resource, service): + """Allocates a resource for task_id and returns the name of the resource + (or None if none where allocated). + """ + resources = service.list_resources() + if resource == 'auto': + for name, capacity in six.iteritems(resources): + if self._reserve_resource(service, name, capacity, task_id): + return name + elif resource not in resources: + raise ValueError('resource %s does not exist for service %s' % (resource, service.name)) + elif self._reserve_resource(service, resource, resources[resource], task_id): + return resource + return None + + def _reserve_resource(self, service, resource, capacity, task_id): + """Reserves the resource for task_id, if possible. The resource is locked + while we try to reserve it. + """ + keyr = 'resource:%s:%s' % (service.name, resource) + with self._redis.acquire_lock(keyr): + current_usage = self._redis.llen(keyr) + if current_usage < capacity: + self._redis.rpush(keyr, task_id) + return True + else: + return False + + def _release_resource(self, service, resource, task_id): + keyr = 'resource:%s:%s' % (service.name, resource) + with self._redis.acquire_lock(keyr): + self._redis.lrem(keyr, task_id) + # Pop a task waiting for a resource on this service and queue it for a retry. + next_task = self._redis.rpop('queued:%s' % service.name) + if next_task is not None: + task.queue(self._redis, next_task) + + def _wait_for_resource(self, service, task_id): + self._redis.lpush('queued:%s' % service.name, task_id) diff --git a/server/services/__init__.py b/server/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/services/ec2.py b/server/services/ec2.py new file mode 100644 index 00000000..9054e568 --- /dev/null +++ b/server/services/ec2.py @@ -0,0 +1,149 @@ +import os +import logging +import boto3 +import paramiko + +from botocore.exceptions import ClientError +from nmtwizard import common +from nmtwizard.service import Service + +logger = logging.getLogger(__name__) + + +def _run_instance(client, launch_template_name, dry_run=False): + return client.run_instances( + MaxCount=1, + MinCount=1, + DryRun=dry_run, + LaunchTemplate={ + "LaunchTemplateName": launch_template_name}) + + +class EC2Service(Service): + + def __init__(self, config): + super(EC2Service, self).__init__(config) + self._session = boto3.Session( + aws_access_key_id=config["awsAccessKeyId"], + aws_secret_access_key=config["awsSecretAccessKey"], + region_name=config["awsRegion"]) + self._launch_template_names = self._get_launch_template_names() + + def _get_launch_template_names(self): + ec2_client = self._session.client("ec2") + response = ec2_client.describe_launch_templates() + if not response or not response["LaunchTemplates"]: + raise ValueError('no EC2 launch templates are available to use') + return [template["LaunchTemplateName"] for template in response["LaunchTemplates"]] + + def list_resources(self): + return { + name:self._config['maxInstancePerTemplate'] + for name in self._launch_template_names} + + def get_resource_from_options(self, options): + return options["launchTemplateName"] + + def describe(self): + return { + "launchTemplateName": { + "title": "EC2 Launch Template", + "type": "string", + "description": "The name of the EC2 launch template to use", + "enum": self._launch_template_names}} + + def check(self, options): + if "launchTemplateName" not in options: + raise ValueError("missing launchTemplateName option") + try: + ec2_client = self._session.client("ec2") + _ = _run_instance(ec2_client, options["launchTemplateName"], dry_run=True) + except ClientError as e: + if e.response["Error"]["Code"] == "DryRunOperation": + pass + elif e.response["Error"]["Code"] == "UnauthorizedOperation": + raise RuntimeError("not authorized to run instances") + else: + raise e + return "" + + def launch(self, + task_id, + options, + resource, + docker_registry, + docker_image, + docker_tag, + docker_command, + docker_files, + wait_after_launch): + ec2_client = self._session.client("ec2") + response = _run_instance(ec2_client, resource) + if response is None: + raise RuntimeError("empty response from boto3.run_instances") + if not response["Instances"]: + raise RuntimeError("no instances were created") + instance_id = response["Instances"][0]["InstanceId"] + ec2 = self._session.resource("ec2") + instance = ec2.Instance(instance_id) + instance.wait_until_running() + logger.info("Instance %s is running.", instance.id) + + key_path = os.path.join( + self._config["privateKeysDirectory"], "%s.pem" % instance.key_pair.name) + client = paramiko.SSHClient() + try: + common.ssh_connect_with_retry( + client, + instance.public_dns_name, + self._config["amiUsername"], + key_path, + delay=self._config["sshConnectionDelay"], + retry=self._config["maxSshConnectionRetry"]) + common.fuse_s3_bucket(client, self._config["corpus"]) + gpu_id = 1 if common.has_gpu_support(client) else 0 + task = common.launch_task( + task_id, + client, + gpu_id, + self._config["logDir"], + self._config["docker"], + docker_registry, + docker_image, + docker_tag, + docker_command, + docker_files, + wait_after_launch, + self._config.get('storages'), + self._config.get('callback_url'), + self._config.get('callback_interval')) + except Exception as e: + if self._config.get("terminateOnError", True): + instance.terminate() + client.close() + raise e + finally: + client.close() + task["instance_id"] = instance.id + return task + + def status(self, params): + instance_id = params["instance_id"] if isinstance(params, dict) else params + ec2_client = self._session.client("ec2") + status = ec2_client.describe_instance_status( + InstanceIds=[instance_id], IncludeAllInstances=True) + + # TODO: actually check if the task is running, not only the instance. + + return status["InstanceStatuses"][0]["InstanceState"]["Name"] + + def terminate(self, params): + instance_id = params["instance_id"] if isinstance(params, dict) else params + ec2 = self._session.resource("ec2") + instance = ec2.Instance(instance_id) + instance.terminate() + logger.info("Instance %s is terminated.", instance.id) + + +def init(config): + return EC2Service(config) diff --git a/server/services/ssh.py b/server/services/ssh.py new file mode 100644 index 00000000..c9bf7392 --- /dev/null +++ b/server/services/ssh.py @@ -0,0 +1,181 @@ +import logging +import paramiko + +from nmtwizard import common +from nmtwizard.service import Service + +logger = logging.getLogger(__name__) + +def _get_params(config, options): + params = {} + if 'server' not in options: + server_pool = config['variables']['server_pool'] + if len(server_pool) > 1 or len(server_pool[0]['gpus']) > 1: + raise ValueError('server option is required to select a server and a resource') + service = config['variables']['server_pool'][0]['host'] + resource = str(config['variables']['server_pool'][0]['gpus'][0]) + options['server'] = service + ':' + resource + fields = options['server'].split(':') + if len(fields) != 2: + raise ValueError( + "invalid server option '%s', should be 'server:gpu'" % options['server']) + + params['server'] = fields[0] + params['gpu'] = int(fields[1]) + + servers = {server['host']:server for server in config['variables']['server_pool']} + + if params['server'] not in servers: + raise ValueError('server %s not in server_pool list' % params['server']) + + server_cfg = servers[params['server']] + + if params['gpu'] not in server_cfg['gpus']: + raise ValueError("GPU %d not in server gpus list" % params['gpu']) + if 'login' not in server_cfg and 'login' not in options: + raise ValueError('login not found in server configuration or user options') + if 'log_dir' not in server_cfg: + raise ValueError('missing log_dir in the configuration') + + params['login'] = server_cfg.get('login', options.get('login')) + params['log_dir'] = server_cfg['log_dir'] + params['login_cmd'] = server_cfg.get('login_cmd') + + return params + + +class SSHService(Service): + + def __init__(self, config): + super(SSHService, self).__init__(config) + self._resources = self._list_all_gpus() + + def _list_all_gpus(self): + gpus = [] + for server in self._config['variables']['server_pool']: + for gpu in server['gpus']: + gpus.append('%s:%d' % (server['host'], gpu)) + return gpus + + def list_resources(self): + return {gpu:1 for gpu in self._resources} + + def get_resource_from_options(self, options): + if len(self._resources) == 1: + return self._resources[0] + else: + return options["server"] + + def describe(self): + has_login = False + for server in self._config['variables']['server_pool']: + if 'login' in server: + has_login = True + break + desc = {} + if len(self._resources) > 1: + desc['server'] = { + "title": "server", + "type": "string", + "description": "server:gpu", + "enum": self._resources + ["auto"], + "default": "auto" + } + if not has_login: + desc['login'] = { + "type": "string", + "title": "login", + "description": "login to use to access the server" + } + return desc + + def check(self, options): + params = _get_params(self._config, options) + client = paramiko.client.SSHClient() + common.ssh_connect_with_retry( + client, + params['server'], + params['login'], + self._config['privateKey'], + login_cmd=params['login_cmd']) + try: + details = common.check_environment( + client, + params['gpu'], + params['log_dir'], + self._config['docker']['registries']) + finally: + client.close() + return details + + def launch(self, + task_id, + options, + resource, + docker_registry, + docker_image, + docker_tag, + docker_command, + docker_files, + wait_after_launch): + if len(self._resources) > 1: + options['server'] = resource + + params = _get_params(self._config, options) + client = paramiko.client.SSHClient() + common.ssh_connect_with_retry( + client, + params['server'], + params['login'], + self._config['privateKey'], + login_cmd=params['login_cmd']) + try: + task = common.launch_task( + task_id, + client, + params['gpu'], + params['log_dir'], + self._config['docker'], + docker_registry, + docker_image, + docker_tag, + docker_command, + docker_files, + wait_after_launch, + self._config.get('storages'), + self._config.get('callback_url'), + self._config.get('callback_interval')) + finally: + client.close() + params['model'] = task['model'] + params['pgid'] = task['pgid'] + return params + + def status(self, params): + client = paramiko.client.SSHClient() + client.load_system_host_keys() + client.connect(params['server'], username=params['login']) + _, stdout, _ = client.exec_command('kill -0 -%d ; echo $?' % params['pgid']) + outstatus = stdout.readline() + client.close() + if outstatus.strip() != '0': + return "dead" + return "running" + + def terminate(self, params): + client = paramiko.client.SSHClient() + client.load_system_host_keys() + client.connect(params['server'], username=params['login']) + _, stdout, stderr = client.exec_command('kill -0 -%d ; echo $?' % params['pgid']) + outstatus = stdout.readline() + if outstatus.strip() != '0': + client.close() + return + _, stdout, stderr = client.exec_command('kill -9 -%d' % params['pgid']) + outstatus = stdout.readline() + stderr = stderr.readline() + client.close() + + +def init(config): + return SSHService(config) diff --git a/server/services/torque.py b/server/services/torque.py new file mode 100644 index 00000000..caf5105c --- /dev/null +++ b/server/services/torque.py @@ -0,0 +1,220 @@ +import logging +import os +import re +import paramiko + +from nmtwizard import common +from nmtwizard.service import Service + +logger = logging.getLogger(__name__) + +def _get_params(config, options): + server_cfg = config['variables'] + + if 'master_node' not in server_cfg: + raise ValueError('missing master_node in configuration') + if 'log_dir' not in server_cfg: + raise ValueError('missing log_dir in configuration') + if 'torque_install_path' not in server_cfg: + raise ValueError('missing torque_install_path in configuration') + if 'login' not in server_cfg and 'login' not in options: + raise ValueError('missing login in one of configuration and user options') + if 'mem' not in options: + raise ValueError('missing mem in user options') + if 'priority' not in options: + raise ValueError('missing priority in user options') + + params = {} + params['master_node'] = server_cfg['master_node'] + params['login'] = server_cfg.get('login', options.get('login')) + params['mem'] = options['mem'] + params['priority'] = options['priority'] + params['log_dir'] = server_cfg['log_dir'] + params['torque_install_path'] = server_cfg['torque_install_path'] + return params + +class TorqueService(Service): + + def list_resources(self): + return {'torque': self._config['maxInstance']} + + def get_resource_from_options(self, options): + return "torque" + + def describe(self): + desc = {} + if 'login' not in self._config['variables']: + desc['login'] = { + "type": "string", + "title": "login", + "description": "login to use to access the server" + } + desc['mem'] = { + "type": "integer", + "default": 4, + "title": "required memory (Gb)", + "minimum": 1 + + } + desc['priority'] = { + "type": "integer", + "default": 0, + "title": "Priority of the job", + "minimum": -1024, + "maximum": 1023 + } + return desc + + def check(self, options): + params = _get_params(self._config, options) + client = paramiko.client.SSHClient() + + common.ssh_connect_with_retry( + client, + params['master_node'], + params['login'], + self._config['privateKey']) + + # check log_dir + if not common.run_and_check_command(client, "test -d '%s'" % params['log_dir']): + client.close() + raise ValueError("incorrect log directory: %s" % params['log_dir']) + + status, stdout, _ = common.run_command( + client, os.path.join(params['torque_install_path'], "qstat")) + + client.close() + if status != 0: + raise RuntimeError('qstat exited with code %s' % status) + return "%s jobs(s) in the queue" % (len(stdout.read().split('\n')) - 2) + + def launch(self, + task_id, + options, + resource, + docker_registry, + docker_image, + docker_tag, + docker_command, + docker_files, + wait_after_launch): + params = _get_params(self._config, options) + client = paramiko.client.SSHClient() + + common.ssh_connect_with_retry( + client, + params['master_node'], + params['login'], + self._config['privateKey']) + + cmd = "cat <<-'EOF'\n" + cmd += "#!/bin/bash\n" + cmd += "#PBS -l nodes=1:ppn=2:gpus=1,mem=%sG,walltime=10000:00:00\n" % params['mem'] + cmd += "#PBS -p %d\n" % params['priority'] + cmd += "#PBS -N infTraining\n" + cmd += "#PBS -o %s/%s.log -j oe\n" % (params['log_dir'], task_id) + + cmd += "guessdevice(){\n" + cmd += " if [ -e \"${PBS_GPUFILE}\" ]\n" + cmd += " then\n" + cmd += " GPUS=`cat ${PBS_GPUFILE} | perl -pe 's/[^-]+-gpu//g' |" + cmd += " perl -pe 's/\s+/ /g' | perl -pe 's/,$//g'`\n" + cmd += " GPUS=`echo \"${GPUS}+1\" | bc `\n" + cmd += " echo $GPUS;\n" + cmd += " else\n" + cmd += " echo \"error: No available GPU\"\n" + cmd += " fi\n" + cmd += "}\n" + + cmd += "DEVICE=$(guessdevice)\n" + cmd += "echo \"RUN ON GPU ${DEVICE}\"\n" + registry = self._config['docker']['registries'][docker_registry] + registry_uri = registry['uri'] + registry_urip = '' if registry_uri == '' else registry_uri + '/' + image_ref = '%s%s:%s' % (registry_urip, docker_image, docker_tag) + + if registry['type'] != 'dockerhub': + cmd_connect = common.cmd_connect_private_registry(registry) + cmd += "echo '=> " + cmd_connect + "'\n" + cmd += cmd_connect + '\n' + + cmd_docker_pull = common.cmd_docker_pull(image_ref) + cmd += "echo '=> " + cmd_docker_pull + "'\n" + cmd += cmd_docker_pull + '\n' + docker_cmd = "echo | " + common.cmd_docker_run( + "$DEVICE", + self._config['docker'], + task_id, + image_ref, + self._config['storage'], + self._config['callback_url'], + self._config['callback_interval'], + docker_command) + + cmd += "echo \"=> " + docker_cmd.replace("\"", "\"") + "\"\n" + cmd += docker_cmd + '\n' + + if self._config['callback_url']: + callback_cmd = '' + if params['log_dir'] is not None and params['log_dir'] != '': + callback_cmd = 'curl -X POST "%s/log/%s" --data-binary "@%s/%s.log" ; ' % ( + self._config['callback_url'], + task_id, + params['log_dir'], + task_id) + + callback_cmd += 'curl "%s/terminate/%s?phase=completed"' % ( + self._config['callback_url'], task_id) + cmd += "echo \"=> " + callback_cmd.replace("\"", "\\\"") + "\"\n" + cmd += callback_cmd + '\n' + + cmd += "EOF\n" + + qsub_cmd = "echo \"$(%s)\" | %s" % ( + cmd, os.path.join(params['torque_install_path'], "qsub -V")) + + exit_status, stdout, stderr = common.run_command(client, qsub_cmd) + if exit_status != 0: + client.close() + raise RuntimeError('run exited with code %d: %s' % (exit_status, stderr.read())) + + client.close() + params['model'] = task_id + params['qsub_id'] = stdout.read().strip() + return params + + def status(self, params): + logger.info("Check status of process with qsub id %s.", params['qsub_id']) + + client = paramiko.client.SSHClient() + client.load_system_host_keys() + client.connect(params['master_node'], username=params['login']) + _, stdout, _ = client.exec_command( + '%s -f %s' % (os.path.join(params['torque_install_path'], "qstat"), params['qsub_id'])) + outstatus = stdout.read() + client.close() + + m = re.search(r'job_state = (.)\n', outstatus) + + if m is None or m.group(1) == "C": + return "dead" + + status = m.group(1) + m = re.search(r'exec_gpus = (.*?)\n', outstatus) + host = '?' + + if m is not None: + host = m.group(1) + + return "%s (%s)" % (status, host) + + def terminate(self, params): + client = paramiko.client.SSHClient() + client.load_system_host_keys() + client.connect(params['master_node'], username=params['login']) + client.exec_command( + '%s %s' % (os.path.join(params['torque_install_path'], "qdel"), params['qsub_id'])) + client.close() + +def init(config): + return TorqueService(config) diff --git a/server/settings.ini b/server/settings.ini new file mode 100644 index 00000000..eb25e9bc --- /dev/null +++ b/server/settings.ini @@ -0,0 +1,13 @@ +[DEFAULT] +# config_dir with service configuration +config_dir = ./config +# logging level +log_level = INFO +# refresh rate +refresh = 60 + +[Production] +redis_host = localhost +redis_port = 6379 +redis_db = 0 +#redis_password=xxx diff --git a/server/worker.py b/server/worker.py new file mode 100644 index 00000000..e107ee7c --- /dev/null +++ b/server/worker.py @@ -0,0 +1,36 @@ +import logging +import os +import sys + +from six.moves import configparser + +from nmtwizard import config, task +from nmtwizard.redis_database import RedisDatabase +from nmtwizard.worker import Worker + +cfg = configparser.ConfigParser() +cfg.read('settings.ini') +MODE = os.getenv('LAUNCHER_MODE', 'Production') + +logging.basicConfig(stream=sys.stdout, level=cfg.get(MODE, 'log_level')) +logger = logging.getLogger() + +redis_password = None +if cfg.has_option(MODE, 'redis_password'): + redis_password = cfg.get(MODE, 'redis_password') + +redis = RedisDatabase(cfg.get(MODE, 'redis_host'), + cfg.getint(MODE, 'redis_port'), + cfg.get(MODE, 'redis_db'), + redis_password) + + +services = config.load_services(cfg.get(MODE, 'config_dir')) + +# On startup, add all active tasks in the work queue. +for task_id in task.list_active(redis): + task.queue(redis, task_id) + +# TODO: start multiple workers here? +worker = Worker(redis, services) +worker.run()