diff --git a/cmd/suggestion/nasrl/Dockerfile b/cmd/suggestion/nasrl/Dockerfile index b5fee2f7d8f..87b2f26a2a1 100644 --- a/cmd/suggestion/nasrl/Dockerfile +++ b/cmd/suggestion/nasrl/Dockerfile @@ -1,7 +1,8 @@ -FROM python:3 +FROM python:3.6 ADD . /usr/src/app/github.com/kubeflow/katib WORKDIR /usr/src/app/github.com/kubeflow/katib/cmd/suggestion/nasrl +RUN pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.12.0-cp36-cp36m-linux_x86_64.whl RUN pip install --no-cache-dir -r requirements.txt ENV PYTHONPATH /usr/src/app/github.com/kubeflow/katib:/usr/src/app/github.com/kubeflow/katib/pkg/api/python diff --git a/cmd/suggestion/nasrl/requirements.txt b/cmd/suggestion/nasrl/requirements.txt index 8d2c9d4bda7..155e383b5f7 100644 --- a/cmd/suggestion/nasrl/requirements.txt +++ b/cmd/suggestion/nasrl/requirements.txt @@ -1,9 +1,3 @@ grpcio -duecredit -cloudpickle==0.5.6 -numpy>=1.13.3 -scikit-learn>=0.19.0 -scipy>=0.19.1 -forestci protobuf googleapis-common-protos diff --git a/pkg/suggestion/NAS_Reinforcement_Learning/Controller.py b/pkg/suggestion/NAS_Reinforcement_Learning/Controller.py new file mode 100755 index 00000000000..a4add0f7a5e --- /dev/null +++ b/pkg/suggestion/NAS_Reinforcement_Learning/Controller.py @@ -0,0 +1,220 @@ +import tensorflow as tf +from pkg.suggestion.NAS_Reinforcement_Learning.LSTM import stack_lstm +from pkg.suggestion.NAS_Reinforcement_Learning.Trainer import get_train_ops + + +class Controller(object): + def __init__(self, + num_layers=12, + num_operations=16, + lstm_size=64, + lstm_num_layers=1, + lstm_keep_prob=1.0, + tanh_constant=1.5, + temperature=None, + lr_init=1e-3, + lr_dec_start=0, + lr_dec_every=1000, + lr_dec_rate=0.9, + l2_reg=0, + entropy_weight=1e-4, + clip_mode=None, + grad_bound=None, + bl_dec=0.999, + optim_algo="adam", + sync_replicas=False, + num_aggregate=20, + num_replicas=1, + skip_target=0.4, + skip_weight=0.8, + name="controller"): + + print("-" * 80) + print("Building Controller") + + self.num_layers = num_layers + self.num_operations = num_operations + + self.lstm_size = lstm_size + self.lstm_num_layers = lstm_num_layers + self.lstm_keep_prob = lstm_keep_prob + self.tanh_constant = tanh_constant + self.temperature = temperature + self.lr_init = lr_init + self.lr_dec_start = lr_dec_start + self.lr_dec_every = lr_dec_every + self.lr_dec_rate = lr_dec_rate + self.l2_reg = l2_reg + self.entropy_weight = entropy_weight + self.clip_mode = clip_mode + self.grad_bound = grad_bound + self.bl_dec = bl_dec + + self.skip_target = skip_target + self.skip_weight = skip_weight + + self.optim_algo = optim_algo + self.sync_replicas = sync_replicas + self.num_aggregate = num_aggregate + self.num_replicas = num_replicas + self.name = name + + self._create_params() + self._build_sampler() + + def _create_params(self): + initializer = tf.random_uniform_initializer(minval=-0.1, maxval=0.1) + with tf.variable_scope(self.name, initializer=initializer): + with tf.variable_scope("lstm"): + self.w_lstm = [] + for layer_id in range(self.lstm_num_layers): + with tf.variable_scope("layer_{}".format(layer_id)): + w = tf.get_variable("w", [2 * self.lstm_size, 4 * self.lstm_size]) + self.w_lstm.append(w) + + self.g_emb = tf.get_variable("g_emb", [1, self.lstm_size]) + with tf.variable_scope("emb"): + self.w_emb = tf.get_variable("w", [self.num_operations, self.lstm_size]) + with tf.variable_scope("softmax"): + self.w_soft = tf.get_variable("w", [self.lstm_size, self.num_operations]) + + with tf.variable_scope("attention"): + self.w_attn_1 = tf.get_variable("w_1", [self.lstm_size, self.lstm_size]) + self.w_attn_2 = tf.get_variable("w_2", [self.lstm_size, self.lstm_size]) + self.v_attn = tf.get_variable("v", [self.lstm_size, 1]) + + def _build_sampler(self): + """Build the sampler ops and the log_prob ops.""" + + print("-" * 80) + print("Building Controller Sampler") + anchors = [] + anchors_w_1 = [] + + arc_seq = [] + entropys = [] + log_probs = [] + skip_count = [] + skip_penaltys = [] + + prev_c = [tf.zeros([1, self.lstm_size], tf.float32) for _ in range(self.lstm_num_layers)] + prev_h = [tf.zeros([1, self.lstm_size], tf.float32) for _ in range(self.lstm_num_layers)] + inputs = self.g_emb + skip_targets = tf.constant([1.0 - self.skip_target, self.skip_target], dtype=tf.float32) + for layer_id in range(self.num_layers): + next_c, next_h = stack_lstm(inputs, prev_c, prev_h, self.w_lstm) + prev_c, prev_h = next_c, next_h + logit = tf.matmul(next_h[-1], self.w_soft) + if self.temperature is not None: + logit /= self.temperature + if self.tanh_constant is not None: + logit = self.tanh_constant * tf.tanh(logit) + + operation_id = tf.multinomial(logit, 1) + operation_id = tf.to_int32(operation_id) + operation_id = tf.reshape(operation_id, [1]) + + arc_seq.append(operation_id) + log_prob = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logit, labels=operation_id) + log_probs.append(log_prob) + entropy = tf.stop_gradient(log_prob * tf.exp(-log_prob)) + entropys.append(entropy) + inputs = tf.nn.embedding_lookup(self.w_emb, operation_id) + + next_c, next_h = stack_lstm(inputs, prev_c, prev_h, self.w_lstm) + prev_c, prev_h = next_c, next_h + + if layer_id > 0: + query = tf.concat(anchors_w_1, axis=0) + query = tf.tanh(query + tf.matmul(next_h[-1], self.w_attn_2)) + query = tf.matmul(query, self.v_attn) + logit = tf.concat([-query, query], axis=1) + if self.temperature is not None: + logit /= self.temperature + if self.tanh_constant is not None: + logit = self.tanh_constant * tf.tanh(logit) + + skip = tf.multinomial(logit, 1) + skip = tf.to_int32(skip) + skip = tf.reshape(skip, [layer_id]) + arc_seq.append(skip) + + skip_prob = tf.sigmoid(logit) + kl = skip_prob * tf.log(skip_prob / skip_targets) + kl = tf.reduce_sum(kl) + skip_penaltys.append(kl) + + log_prob = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logit, labels=skip) + log_probs.append(tf.reduce_sum(log_prob, keepdims=True)) + + entropy = tf.stop_gradient( + tf.reduce_sum(log_prob * tf.exp(-log_prob), keepdims=True)) + entropys.append(entropy) + + skip = tf.to_float(skip) + skip = tf.reshape(skip, [1, layer_id]) + skip_count.append(tf.reduce_sum(skip)) + inputs = tf.matmul(skip, tf.concat(anchors, axis=0)) + inputs /= (1.0 + tf.reduce_sum(skip)) + else: + inputs = self.g_emb + + anchors.append(next_h[-1]) + anchors_w_1.append(tf.matmul(next_h[-1], self.w_attn_1)) + + arc_seq = tf.concat(arc_seq, axis=0) + self.sample_arc = tf.reshape(arc_seq, [-1]) + + entropys = tf.stack(entropys) + self.sample_entropy = tf.reduce_sum(entropys) + + log_probs = tf.stack(log_probs) + self.sample_log_prob = tf.reduce_sum(log_probs) + + skip_count = tf.stack(skip_count) + self.skip_count = tf.reduce_sum(skip_count) + + skip_penaltys = tf.stack(skip_penaltys) + self.skip_penaltys = tf.reduce_mean(skip_penaltys) + + def build_trainer(self): + self.reward = tf.placeholder(tf.float32, shape=()) + + normalize = tf.to_float(self.num_layers * (self.num_layers - 1) / 2) + self.skip_rate = tf.to_float(self.skip_count) / normalize + + if self.entropy_weight is not None: + self.reward += self.entropy_weight * self.sample_entropy + + self.sample_log_prob = tf.reduce_sum(self.sample_log_prob) + self.baseline = tf.Variable(0.0, dtype=tf.float32, trainable=False) + baseline_update = tf.assign_sub(self.baseline, (1 - self.bl_dec) * (self.baseline - self.reward)) + + with tf.control_dependencies([baseline_update]): + self.reward = tf.identity(self.reward) + + self.loss = self.sample_log_prob * (self.reward - self.baseline) + if self.skip_weight is not None: + self.loss += self.skip_weight * self.skip_penaltys + + self.train_step = tf.Variable(0, dtype=tf.int32, trainable=False, name=self.name + "_train_step") + tf_variables = [var for var in tf.trainable_variables() if var.name.startswith(self.name)] + print("-" * 80) + + self.train_op, self.lr, self.grad_norm, self.optimizer = get_train_ops( + self.loss, + tf_variables, + self.train_step, + clip_mode=self.clip_mode, + grad_bound=self.grad_bound, + l2_reg=self.l2_reg, + lr_init=self.lr_init, + lr_dec_start=self.lr_dec_start, + lr_dec_every=self.lr_dec_every, + lr_dec_rate=self.lr_dec_rate, + optim_algo=self.optim_algo, + sync_replicas=self.sync_replicas, + num_aggregate=self.num_aggregate, + num_replicas=self.num_replicas) diff --git a/pkg/suggestion/NAS_Reinforcement_Learning/LSTM.py b/pkg/suggestion/NAS_Reinforcement_Learning/LSTM.py new file mode 100755 index 00000000000..b3ad6008386 --- /dev/null +++ b/pkg/suggestion/NAS_Reinforcement_Learning/LSTM.py @@ -0,0 +1,28 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import tensorflow as tf + + +# TODO: will remove this function and use tf.nn.LSTMCell instead + +def lstm(x, prev_c, prev_h, w): + ifog = tf.matmul(tf.concat([x, prev_h], axis=1), w) + i, f, o, g = tf.split(ifog, 4, axis=1) + i = tf.sigmoid(i) + f = tf.sigmoid(f) + o = tf.sigmoid(o) + g = tf.tanh(g) + next_c = i * g + f * prev_c + next_h = o * tf.tanh(next_c) + return next_c, next_h + + +def stack_lstm(x, prev_c, prev_h, w): + next_c, next_h = [], [] + for layer_id, (_c, _h, _w) in enumerate(zip(prev_c, prev_h, w)): + inputs = x if layer_id == 0 else next_h[-1] + curr_c, curr_h = lstm(inputs, _c, _h, _w) + next_c.append(curr_c) + next_h.append(curr_h) + return next_c, next_h diff --git a/pkg/suggestion/NAS_Reinforcement_Learning/Operation.py b/pkg/suggestion/NAS_Reinforcement_Learning/Operation.py new file mode 100644 index 00000000000..0d6fc5395cb --- /dev/null +++ b/pkg/suggestion/NAS_Reinforcement_Learning/Operation.py @@ -0,0 +1,79 @@ +import itertools +import numpy as np +from pkg.api.python import api_pb2 + + +class Operation(object): + def __init__(self, opt_id, opt_type, opt_params): + self.opt_id = opt_id + self.opt_type = opt_type + self.opt_params = opt_params + + def get_dict(self): + opt_dict = dict() + opt_dict['opt_id'] = self.opt_id + opt_dict['opt_type'] = self.opt_type + opt_dict['opt_params'] = self.opt_params + return opt_dict + + def print_op(self, logger): + logger.info("Operation ID: \n\t{}".format(self.opt_id)) + logger.info("Operation Type: \n\t{}".format(self.opt_type)) + logger.info("Operations Parameters:") + for ikey in self.opt_params: + logger.info("\t{}: {}".format(ikey, self.opt_params[ikey])) + logger.info("") + + +class SearchSpace(object): + def __init__(self, operations): + self.operation_list = list(operations.operation) + self.search_space = list() + self._parse_operations() + print() + self.num_operations = len(self.search_space) + + def _parse_operations(self): + # search_sapce is a list of Operation class + + operation_id = 0 + + for operation_dict in self.operation_list: + opt_type = operation_dict.operationType + opt_spec = list(operation_dict.parameter_configs.configs) + # avail_space is dict with the format {"spec_nam": [spec feasible values]} + avail_space = dict() + num_spec = len(opt_spec) + + for ispec in opt_spec: + spec_name = ispec.name + if ispec.parameter_type == api_pb2.CATEGORICAL: + avail_space[spec_name] = list(ispec.feasible.list) + elif ispec.parameter_type == api_pb2.INT: + spec_min = int(ispec.feasible.min) + spec_max = int(ispec.feasible.max) + spec_step = int(ispec.feasible.step) + avail_space[spec_name] = range(spec_min, spec_max+1, spec_step) + elif ispec.parameter_type == api_pb2.DOUBLE: + spec_min = float(ispec.feasible.min) + spec_max = float(ispec.feasible.max) + spec_step = float(ispec.feasible.step) + if spec_step == 0: + print("Error, NAS Reinforcement Learning algorithm cannot accept continuous search space!") + exit(999) + double_list = np.arange(spec_min, spec_max+spec_step, spec_step) + if double_list[-1] > spec_max: + del double_list[-1] + avail_space[spec_name] = double_list + + # generate all the combinations of possible operations + key_avail_space = list(avail_space.keys()) + val_avail_space = list(avail_space.values()) + + for this_opt_vector in itertools.product(*val_avail_space): + opt_params = dict() + for i in range(num_spec): + opt_params[key_avail_space[i]] = this_opt_vector[i] + this_opt_class = Operation(operation_id, opt_type, opt_params) + self.search_space.append(this_opt_class) + operation_id += 1 diff --git a/pkg/suggestion/NAS_Reinforcement_Learning/README.md b/pkg/suggestion/NAS_Reinforcement_Learning/README.md new file mode 100644 index 00000000000..8cf44e40d2b --- /dev/null +++ b/pkg/suggestion/NAS_Reinforcement_Learning/README.md @@ -0,0 +1,128 @@ +# About the Nerual Architecture Search with Reinforcement Learning Suggestion + +The algorithm follows the idea proposed in *Neural Architecture Search with Reinforcement Learning* by Zoph & Le (https://arxiv.org/abs/1611.01578), and the implementation is based on the github of *Efficient Neural Architecture Search via Parameter Sharing* (https://github.com/melodyguan/enas). It uses a recurrent neural network with LSTM cells as controller to generate neural architecture candidates. And this controller network is updated by policy gradients. However, it currently does not support parameter sharing. + +## Definition of a Neural Architecture + +Define the number of layers is n, the number of possible operations is m. + +If n = 12, m = 6, the definition of an architecture will be like: + +``` +[2] +[0 0] +[1 1 0] +[5 1 0 1] +[1 1 1 0 1] +[5 0 0 1 0 1] +[1 1 1 0 0 1 0] +[2 0 0 0 1 1 0 1] +[0 0 0 1 1 1 1 1 0] +[2 0 1 0 1 1 1 0 0 0] +[3 1 1 1 1 1 1 0 0 1 1] +[0 1 1 1 1 0 0 1 1 1 1 0] +``` + +There are n rows, the ith row has i elements and describes the ith layer. Please notice that layer 0 is the input and is not included in this definition. + +In each row: +The first integer ranges from 0 to m-1, indicates the operation in this layer. +The next (i-1) integers is either 0 or 1. The kth (k>=2) integer indicates whether (k-2)th layer has a skip connection with this layer. (There will always be a connection from (k-1)th layer to kth layer) + +## Output of `GetSuggestion()` +The output of `GetSuggestion()` consists of two parts: `architecture` and `nn_config`. + +`architecture` is a json string of the definition of a neural architecture. The format is as stated above. One example is: +``` +[[27], [29, 0], [22, 1, 0], [13, 0, 0, 0], [26, 1, 1, 0, 0], [30, 1, 0, 1, 0, 0], [11, 0, 1, 1, 0, 1, 1], [9, 1, 0, 0, 1, 0, 0, 0]] +``` + +`nn_config` is a json string of the detailed description of what is the num of layers, input size, output size and what each operation index stands for. A nn_config corresponding to the architecuture above can be: +``` +{ + "num_layers": 8, + "input_size": [32, 32, 3], + "output_size": [10], + "embedding": { + "27": { + "opt_id": 27, + "opt_type": "convolution", + "opt_params": { + "filter_size": "7", + "num_filter": "96", + "stride": "2" + } + }, + "29": { + "opt_id": 29, + "opt_type": "convolution", + "opt_params": { + "filter_size": "7", + "num_filter": "128", + "stride": "2" + } + }, + "22": { + "opt_id": 22, + "opt_type": "convolution", + "opt_params": { + "filter_size": "7", + "num_filter": "48", + "stride": "1" + } + }, + "13": { + "opt_id": 13, + "opt_type": "convolution", + "opt_params": { + "filter_size": "5", + "num_filter": "48", + "stride": "2" + } + }, + "26": { + "opt_id": 26, + "opt_type": "convolution", + "opt_params": { + "filter_size": "7", + "num_filter": "96", + "stride": "1" + } + }, + "30": { + "opt_id": 30, + "opt_type": "reduction", + "opt_params": { + "reduction_type": "max_pooling", + "pool_size": 2 + } + }, + "11": { + "opt_id": 11, + "opt_type": "convolution", + "opt_params": { + "filter_size": "5", + "num_filter": "32", + "stride": "2" + } + }, + "9": { + "opt_id": 9, + "opt_type": "convolution", + "opt_params": { + "filter_size": "3", + "num_filter": "128", + "stride": "2" + } + } + } +} +``` +This neural architecture can be visualized as +![a neural netowrk architecure example](example.png) + +## To Do +1. Add support for multiple studyjobs +2. Add support for multiple trials +3. Change LSTM cell from self defined functions in LSTM.py to `tf.nn.rnn_cell.LSTMCell` +4. Store the suggestion checkpoint to PVC in case of nasrl service pod restarts \ No newline at end of file diff --git a/pkg/suggestion/NAS_Reinforcement_Learning/SuggestionParam.py b/pkg/suggestion/NAS_Reinforcement_Learning/SuggestionParam.py new file mode 100644 index 00000000000..ae9f1d19777 --- /dev/null +++ b/pkg/suggestion/NAS_Reinforcement_Learning/SuggestionParam.py @@ -0,0 +1,73 @@ +def parseSuggestionParam(params_raw): + param_standard = { + "lstm_num_cells": ['value', int, [1, 'inf']], + "lstm_num_layers": ['value', int, [1, 'inf']], + "lstm_keep_prob": ['value', float, [0.0, 1.0]], + "optimizer": ['categorical', str, ["adam", "momentum", "sgd"]], + "init_learning_rate": ['value', float, [1e-6, 1.0]], + "lr_decay_start": ['value', int, [0, 'inf']], + "lr_decay_every": ['value', int, [1, 'inf']], + "lr_decay_rate": ['value', float, [0.0, 1.0]], + "skip-target": ['value', float, [0.0, 1.0]], + "skip-weight": ['value', float, [0.0, 'inf']], + "l2_reg": ['value', float, [0.0, 'inf']], + "entropy_weight": ['value', float, [0.0, 'inf']], + "baseline_decay": ['value', float, [0.0, 1.0]], + } + + suggestion_params = { + "lstm_num_cells": 64, + "lstm_num_layers": 1, + "lstm_keep_prob": 1.0, + "optimizer": "adam", + "init_learning_rate": 1e-3, + "lr_decay_start": 0, + "lr_decay_every": 1000, + "lr_decay_rate": 0.9, + "skip-target": 0.4, + "skip-weight": 0.8, + "l2_reg": 0, + "entropy_weight": 1e-4, + "baseline_decay": 0.9999 + } + + def checktype(param_name, param_value, check_mode, supposed_type, supposed_range=None): + correct = True + + try: + converted_value = supposed_type(param_value) + except: + correct = False + print("Parameter {} is of wrong type. Set back to default value {}" + .format(param_name, suggestion_params[param_name])) + + if correct and check_mode == 'value': + if not ((supposed_range[0] == '-inf' or converted_value >= supposed_range[0]) and + (supposed_range[1] == 'inf' or converted_value <= supposed_range[1])): + correct = False + print("Parameter {} out of range. Set back to default value {}" + .format(param_name, suggestion_params[param_name])) + elif correct and check_mode == 'categorical': + if converted_value not in supposed_range: + correct = False + print("Parameter {} out of range. Set back to default value {}" + .format(param_name, suggestion_params[param_name])) + + if correct: + suggestion_params[param_name] = converted_value + + + for param in params_raw: + # SuggestionCount is automatically added by controller and not used currently + if param.name == "SuggestionCount": + continue + if param.name in suggestion_params.keys(): + checktype(param.name, + param.value, + param_standard[param.name][0], # mode + param_standard[param.name][1], # type + param_standard[param.name][2]) # range + else: + print("Unknown Parameter name: {}".format(param.name)) + + return suggestion_params diff --git a/pkg/suggestion/NAS_Reinforcement_Learning/Trainer.py b/pkg/suggestion/NAS_Reinforcement_Learning/Trainer.py new file mode 100644 index 00000000000..a8dde5115e5 --- /dev/null +++ b/pkg/suggestion/NAS_Reinforcement_Learning/Trainer.py @@ -0,0 +1,143 @@ +import tensorflow as tf + + +def get_train_ops(loss, + tf_variables, + train_step, + clip_mode=None, + grad_bound=None, + l2_reg=1e-4, + lr_warmup_val=None, + lr_warmup_steps=100, + lr_init=0.1, + lr_dec_start=0, + lr_dec_every=10000, + lr_dec_rate=0.1, + lr_dec_min=None, + lr_cosine=False, + lr_max=None, + lr_min=None, + lr_T_0=None, + lr_T_mul=None, + num_train_batches=None, + optim_algo=None, + sync_replicas=False, + num_aggregate=None, + num_replicas=None, + get_grad_norms=False, + moving_average=None): + """ + Args: + clip_mode: "global", "norm", or None. + moving_average: store the moving average of parameters + """ + + if l2_reg > 0: + l2_losses = [] + for var in tf_variables: + l2_losses.append(tf.reduce_sum(var ** 2)) + l2_loss = tf.add_n(l2_losses) + loss += l2_reg * l2_loss + + grads = tf.gradients(loss, tf_variables) + grad_norm = tf.global_norm(grads) + + grad_norms = {} + for v, g in zip(tf_variables, grads): + if v is None or g is None: + continue + if isinstance(g, tf.IndexedSlices): + grad_norms[v.name] = tf.sqrt(tf.reduce_sum(g.values ** 2)) + else: + grad_norms[v.name] = tf.sqrt(tf.reduce_sum(g ** 2)) + + if clip_mode is not None: + assert grad_bound is not None, "Need grad_bound to clip gradients." + if clip_mode == "global": + grads, _ = tf.clip_by_global_norm(grads, grad_bound) + elif clip_mode == "norm": + clipped = [] + for g in grads: + if isinstance(g, tf.IndexedSlices): + c_g = tf.clip_by_norm(g.values, grad_bound) + c_g = tf.IndexedSlices(g.indices, c_g) + else: + c_g = tf.clip_by_norm(g, grad_bound) + clipped.append(g) + grads = clipped + else: + raise NotImplementedError("Unknown clip_mode {}".format(clip_mode)) + + if lr_cosine: + assert lr_max is not None, "Need lr_max to use lr_cosine" + assert lr_min is not None, "Need lr_min to use lr_cosine" + assert lr_T_0 is not None, "Need lr_T_0 to use lr_cosine" + assert lr_T_mul is not None, "Need lr_T_mul to use lr_cosine" + assert num_train_batches is not None, ("Need num_train_batches to use" + " lr_cosine") + + curr_epoch = train_step // num_train_batches + + last_reset = tf.Variable(0, dtype=tf.int32, trainable=False, + name="last_reset") + T_i = tf.Variable(lr_T_0, dtype=tf.int32, trainable=False, name="T_i") + T_curr = curr_epoch - last_reset + + def _update(): + update_last_reset = tf.assign(last_reset, curr_epoch, use_locking=True) + update_T_i = tf.assign(T_i, T_i * lr_T_mul, use_locking=True) + with tf.control_dependencies([update_last_reset, update_T_i]): + rate = tf.to_float(T_curr) / tf.to_float(T_i) * 3.1415926 + lr = lr_min + 0.5 * (lr_max - lr_min) * (1.0 + tf.cos(rate)) + return lr + + def _no_update(): + rate = tf.to_float(T_curr) / tf.to_float(T_i) * 3.1415926 + lr = lr_min + 0.5 * (lr_max - lr_min) * (1.0 + tf.cos(rate)) + return lr + + learning_rate = tf.cond( + tf.greater_equal(T_curr, T_i), _update, _no_update) + else: + learning_rate = tf.train.exponential_decay( + lr_init, tf.maximum(train_step - lr_dec_start, 0), lr_dec_every, + lr_dec_rate, staircase=True) + if lr_dec_min is not None: + learning_rate = tf.maximum(learning_rate, lr_dec_min) + + if lr_warmup_val is not None: + learning_rate = tf.cond(tf.less(train_step, lr_warmup_steps), + lambda: lr_warmup_val, lambda: learning_rate) + + if optim_algo == "momentum": + opt = tf.train.MomentumOptimizer( + learning_rate, 0.9, use_locking=True, use_nesterov=True) + elif optim_algo == "sgd": + opt = tf.train.GradientDescentOptimizer(learning_rate, use_locking=True) + elif optim_algo == "adam": + opt = tf.train.AdamOptimizer(learning_rate, beta1=0.0, epsilon=1e-3, + use_locking=True) + else: + raise ValueError("Unknown optim_algo {}".format(optim_algo)) + + if sync_replicas: + assert num_aggregate is not None, "Need num_aggregate to sync." + assert num_replicas is not None, "Need num_replicas to sync." + + opt = tf.train.SyncReplicasOptimizer( + opt, + replicas_to_aggregate=num_aggregate, + total_num_replicas=num_replicas, + use_locking=True) + + if moving_average is not None: + opt = tf.contrib.opt.MovingAverageOptimizer( + opt, average_decay=moving_average) + + train_op = opt.apply_gradients( + zip(grads, tf_variables), global_step=train_step) + + if get_grad_norms: + return train_op, learning_rate, grad_norm, opt, grad_norms + else: + return train_op, learning_rate, grad_norm, opt diff --git a/pkg/suggestion/NAS_Reinforcement_Learning/example.png b/pkg/suggestion/NAS_Reinforcement_Learning/example.png new file mode 100644 index 00000000000..90ac2eb6190 Binary files /dev/null and b/pkg/suggestion/NAS_Reinforcement_Learning/example.png differ diff --git a/pkg/suggestion/nasrl_service.py b/pkg/suggestion/nasrl_service.py index f7e672a2276..0373bd122aa 100644 --- a/pkg/suggestion/nasrl_service.py +++ b/pkg/suggestion/nasrl_service.py @@ -1,22 +1,259 @@ -import random -import string - +from pkg.suggestion.NAS_Reinforcement_Learning.Controller import Controller +from pkg.suggestion.NAS_Reinforcement_Learning.Operation import SearchSpace +from pkg.suggestion.NAS_Reinforcement_Learning.SuggestionParam import parseSuggestionParam +import tensorflow as tf import grpc -import numpy as np - from pkg.api.python import api_pb2 from pkg.api.python import api_pb2_grpc import logging from logging import getLogger, StreamHandler, INFO, DEBUG +import json +import os +import time class NasrlService(api_pb2_grpc.SuggestionServicer): def __init__(self, logger=None): self.manager_addr = "vizier-core" self.manager_port = 6789 - self.current_trial_id = "" + self.registered_studies = list() + + self.ctrl_cache_file = "" + self.ctrl_step = 0 + self.is_first_run = True + + if not os.path.exists("ctrl_cache/"): + os.makedirs("ctrl_cache/") + + if logger == None: + self.logger = getLogger(__name__) + FORMAT = '%(asctime)-15s StudyID %(studyid)s %(message)s' + logging.basicConfig(format=FORMAT) + handler = StreamHandler() + handler.setLevel(INFO) + self.logger.setLevel(INFO) + self.logger.addHandler(handler) + self.logger.propagate = False + else: + self.logger = logger + + def setup_controller(self, request): + self.logger.info("-" * 80 + "\nSetting Up Suggestion for StudyJob {}\n".format(request.study_id) + "-" * 80) + self.tf_graph = tf.Graph() + self.ctrl_step = 0 + self.ctrl_cache_file = "ctrl_cache/{}.ckpt".format(request.study_id) + self._get_suggestion_param(request.param_id, request.study_id) + self._get_search_space(request.study_id) + + with self.tf_graph.as_default(): + ctrl_param = self.suggestion_config + self.controllers = Controller( + num_layers=self.num_layers, + num_operations=self.num_operations, + lstm_size=ctrl_param['lstm_num_cells'], + lstm_num_layers=ctrl_param['lstm_num_layers'], + lstm_keep_prob=ctrl_param['lstm_keep_prob'], + lr_init=ctrl_param['init_learning_rate'], + lr_dec_start=ctrl_param['lr_decay_start'], + lr_dec_every=ctrl_param['lr_decay_every'], + lr_dec_rate=ctrl_param['lr_decay_rate'], + l2_reg=ctrl_param['l2_reg'], + entropy_weight=ctrl_param['entropy_weight'], + bl_dec=ctrl_param['baseline_decay'], + optim_algo=ctrl_param['optimizer'], + skip_target=ctrl_param['skip-target'], + skip_weight=ctrl_param['skip-weight'], + name="Ctrl_"+request.study_id) + + self.controllers.build_trainer() + + self.logger.info("Suggestion for StudyJob {} has been initialized.".format(request.study_id)) def GetSuggestions(self, request, context): + if request.study_id not in self.registered_studies: + self.setup_controller(request) + self.is_first_run = True + self.registered_studies.append(request.study_id) + + self.logger.info("-" * 80 + "\nSuggestion Step {} for Study {}\n".format(self.ctrl_step, request.study_id) + "-" * 80) + + with self.tf_graph.as_default(): + + saver = tf.train.Saver() + ctrl = self.controllers + + controller_ops = { + "train_step": ctrl.train_step, + "loss": ctrl.loss, + "train_op": ctrl.train_op, + "lr": ctrl.lr, + "grad_norm": ctrl.grad_norm, + "optimizer": ctrl.optimizer, + "baseline": ctrl.baseline, + "entropy": ctrl.sample_entropy, + "sample_arc": ctrl.sample_arc, + "skip_rate": ctrl.skip_rate} + + run_ops = [ + controller_ops["loss"], + controller_ops["entropy"], + controller_ops["lr"], + controller_ops["grad_norm"], + controller_ops["baseline"], + controller_ops["skip_rate"], + controller_ops["train_op"]] + + if self.is_first_run: + self.logger.info("First time running suggestion for {}. Random architecture will be given.".format(request.study_id)) + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + arc = sess.run(controller_ops["sample_arc"]) + # TODO: will use PVC to store the checkpoint to protect against unexpected suggestion pod restart + saver.save(sess, self.ctrl_cache_file) + + self.is_first_run = False + + else: + with tf.Session() as sess: + saver.restore(sess, self.ctrl_cache_file) + + valid_acc = ctrl.reward + result = self.GetEvaluationResult(request.study_id) + + # This lstm cell is designed to maximize the metrics + # However, if the user want to minimize the metrics, we can take the negative of the result + if self.opt_direction == api_pb2.MINIMIZE: + result = -result + + loss, entropy, lr, gn, bl, skip, _ = sess.run( + fetches=run_ops, + feed_dict={valid_acc: result}) + self.logger.info("Suggetion updated. LSTM Controller Loss: {}".format(loss)) + arc = sess.run(controller_ops["sample_arc"]) + + saver.save(sess, self.ctrl_cache_file) + + arc = arc.tolist() + organized_arc = [0 for _ in range(self.num_layers)] + record = 0 + for l in range(self.num_layers): + organized_arc[l] = arc[record: record + l + 1] + record += l + 1 + + nn_config = dict() + nn_config['num_layers'] = self.num_layers + nn_config['input_size'] = self.input_size + nn_config['output_size'] = self.output_size + nn_config['embedding'] = dict() + for l in range(self.num_layers): + opt = organized_arc[l][0] + nn_config['embedding'][opt] = self.search_space[opt].get_dict() + + organized_arc_json = json.dumps(organized_arc) + nn_config_json = json.dumps(nn_config) + + organized_arc_str = str(organized_arc_json).replace('\"', '\'') + nn_config_str = str(nn_config_json).replace('\"', '\'') + + self.logger.info("\nNew Neural Network Architecture (internal representation):") + self.logger.info(organized_arc_json) + self.logger.info("\nCorresponding Seach Space Description:") + self.logger.info(nn_config_str) + self.logger.info("") + trials = [] + trials.append(api_pb2.Trial( + study_id=request.study_id, + parameter_set=[ + api_pb2.Parameter( + name="architecture", + value=organized_arc_str, + parameter_type= api_pb2.CATEGORICAL), + api_pb2.Parameter( + name="nn_config", + value=nn_config_str, + parameter_type= api_pb2.CATEGORICAL) + ], + ) + ) + channel = grpc.beta.implementations.insecure_channel(self.manager_addr, self.manager_port) + with api_pb2.beta_create_Manager_stub(channel) as client: + for i, t in enumerate(trials): + ctrep = client.CreateTrial(api_pb2.CreateTrialRequest(trial=t), 10) + trials[i].trial_id = ctrep.trial_id + self.logger.info("Trial {} Created\n".format(ctrep.trial_id)) + self.prev_trial_id = ctrep.trial_id + + self.ctrl_step += 1 return api_pb2.GetSuggestionsReply(trials=trials) + + def GetEvaluationResult(self, studyID): + worker_list = [] + channel = grpc.beta.implementations.insecure_channel(self.manager_addr, self.manager_port) + with api_pb2.beta_create_Manager_stub(channel) as client: + gwfrep = client.GetWorkerFullInfo(api_pb2.GetWorkerFullInfoRequest(study_id=studyID, trial_id=self.prev_trial_id, only_latest_log=True), 10) + worker_list = gwfrep.worker_full_infos + + for w in worker_list: + if w.Worker.status == api_pb2.COMPLETED: + for ml in w.metrics_logs: + if ml.name == self.objective_name: + self.logger.info("Evaluation result of previous candidate: {}".format(ml.values[-1].value)) + return float(ml.values[-1].value) + + # TODO: add support for multiple trials + + + def _get_search_space(self, studyID): + + # this function need to + # 1) get the number of layers + # 2) get the I/O size + # 3) get the available operations + # 4) get the optimization direction (i.e. minimize or maximize) + # 5) get the objective name + + channel = grpc.beta.implementations.insecure_channel(self.manager_addr, self.manager_port) + with api_pb2.beta_create_Manager_stub(channel) as client: + gsrep = client.GetStudy(api_pb2.GetStudyRequest(study_id=studyID), 10) + + self.opt_direction = gsrep.study_config.optimization_type + self.objective_name = gsrep.study_config.objective_value_name + + all_params = gsrep.study_config.nas_config + graph_config = all_params.graph_config + search_space_raw = all_params.operations + + self.num_layers = int(graph_config.num_layers) + self.input_size = list(map(int, graph_config.input_size)) + self.output_size = list(map(int, graph_config.output_size)) + search_space_object = SearchSpace(search_space_raw) + + self.logger.info("Search Space for Study {}:".format(studyID)) + + self.search_space = search_space_object.search_space + for opt in self.search_space: + opt.print_op(self.logger) + + self.num_operations = search_space_object.num_operations + self.logger.info("There are {} operations in total.\n".format(self.num_operations)) + + + def _get_suggestion_param(self, paramID, studyID): + channel = grpc.beta.implementations.insecure_channel(self.manager_addr, self.manager_port) + with api_pb2.beta_create_Manager_stub(channel) as client: + gsprep = client.GetSuggestionParameters(api_pb2.GetSuggestionParametersRequest(param_id=paramID), 10) + + params_raw = gsprep.suggestion_parameters + + suggestion_params = parseSuggestionParam(params_raw) + + self.logger.info("Parameters of LSTM Controller for Study {}:".format(studyID)) + for spec in suggestion_params: + if len(spec) > 13: + self.logger.info("{}: \t{}".format(spec, suggestion_params[spec])) + else: + self.logger.info("{}: \t\t{}".format(spec, suggestion_params[spec])) + + self.suggestion_config = suggestion_params