diff --git a/cmd/suggestion/nasrl/Dockerfile b/cmd/suggestion/nasrl/Dockerfile
index b5fee2f7d8f..87b2f26a2a1 100644
--- a/cmd/suggestion/nasrl/Dockerfile
+++ b/cmd/suggestion/nasrl/Dockerfile
@@ -1,7 +1,8 @@
-FROM python:3
+FROM python:3.6
ADD . /usr/src/app/github.com/kubeflow/katib
WORKDIR /usr/src/app/github.com/kubeflow/katib/cmd/suggestion/nasrl
+RUN pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.12.0-cp36-cp36m-linux_x86_64.whl
RUN pip install --no-cache-dir -r requirements.txt
ENV PYTHONPATH /usr/src/app/github.com/kubeflow/katib:/usr/src/app/github.com/kubeflow/katib/pkg/api/python
diff --git a/cmd/suggestion/nasrl/requirements.txt b/cmd/suggestion/nasrl/requirements.txt
index 8d2c9d4bda7..155e383b5f7 100644
--- a/cmd/suggestion/nasrl/requirements.txt
+++ b/cmd/suggestion/nasrl/requirements.txt
@@ -1,9 +1,3 @@
grpcio
-duecredit
-cloudpickle==0.5.6
-numpy>=1.13.3
-scikit-learn>=0.19.0
-scipy>=0.19.1
-forestci
protobuf
googleapis-common-protos
diff --git a/pkg/suggestion/NAS_Reinforcement_Learning/Controller.py b/pkg/suggestion/NAS_Reinforcement_Learning/Controller.py
new file mode 100755
index 00000000000..a4add0f7a5e
--- /dev/null
+++ b/pkg/suggestion/NAS_Reinforcement_Learning/Controller.py
@@ -0,0 +1,220 @@
+import tensorflow as tf
+from pkg.suggestion.NAS_Reinforcement_Learning.LSTM import stack_lstm
+from pkg.suggestion.NAS_Reinforcement_Learning.Trainer import get_train_ops
+
+
+class Controller(object):
+ def __init__(self,
+ num_layers=12,
+ num_operations=16,
+ lstm_size=64,
+ lstm_num_layers=1,
+ lstm_keep_prob=1.0,
+ tanh_constant=1.5,
+ temperature=None,
+ lr_init=1e-3,
+ lr_dec_start=0,
+ lr_dec_every=1000,
+ lr_dec_rate=0.9,
+ l2_reg=0,
+ entropy_weight=1e-4,
+ clip_mode=None,
+ grad_bound=None,
+ bl_dec=0.999,
+ optim_algo="adam",
+ sync_replicas=False,
+ num_aggregate=20,
+ num_replicas=1,
+ skip_target=0.4,
+ skip_weight=0.8,
+ name="controller"):
+
+ print("-" * 80)
+ print("Building Controller")
+
+ self.num_layers = num_layers
+ self.num_operations = num_operations
+
+ self.lstm_size = lstm_size
+ self.lstm_num_layers = lstm_num_layers
+ self.lstm_keep_prob = lstm_keep_prob
+ self.tanh_constant = tanh_constant
+ self.temperature = temperature
+ self.lr_init = lr_init
+ self.lr_dec_start = lr_dec_start
+ self.lr_dec_every = lr_dec_every
+ self.lr_dec_rate = lr_dec_rate
+ self.l2_reg = l2_reg
+ self.entropy_weight = entropy_weight
+ self.clip_mode = clip_mode
+ self.grad_bound = grad_bound
+ self.bl_dec = bl_dec
+
+ self.skip_target = skip_target
+ self.skip_weight = skip_weight
+
+ self.optim_algo = optim_algo
+ self.sync_replicas = sync_replicas
+ self.num_aggregate = num_aggregate
+ self.num_replicas = num_replicas
+ self.name = name
+
+ self._create_params()
+ self._build_sampler()
+
+ def _create_params(self):
+ initializer = tf.random_uniform_initializer(minval=-0.1, maxval=0.1)
+ with tf.variable_scope(self.name, initializer=initializer):
+ with tf.variable_scope("lstm"):
+ self.w_lstm = []
+ for layer_id in range(self.lstm_num_layers):
+ with tf.variable_scope("layer_{}".format(layer_id)):
+ w = tf.get_variable("w", [2 * self.lstm_size, 4 * self.lstm_size])
+ self.w_lstm.append(w)
+
+ self.g_emb = tf.get_variable("g_emb", [1, self.lstm_size])
+ with tf.variable_scope("emb"):
+ self.w_emb = tf.get_variable("w", [self.num_operations, self.lstm_size])
+ with tf.variable_scope("softmax"):
+ self.w_soft = tf.get_variable("w", [self.lstm_size, self.num_operations])
+
+ with tf.variable_scope("attention"):
+ self.w_attn_1 = tf.get_variable("w_1", [self.lstm_size, self.lstm_size])
+ self.w_attn_2 = tf.get_variable("w_2", [self.lstm_size, self.lstm_size])
+ self.v_attn = tf.get_variable("v", [self.lstm_size, 1])
+
+ def _build_sampler(self):
+ """Build the sampler ops and the log_prob ops."""
+
+ print("-" * 80)
+ print("Building Controller Sampler")
+ anchors = []
+ anchors_w_1 = []
+
+ arc_seq = []
+ entropys = []
+ log_probs = []
+ skip_count = []
+ skip_penaltys = []
+
+ prev_c = [tf.zeros([1, self.lstm_size], tf.float32) for _ in range(self.lstm_num_layers)]
+ prev_h = [tf.zeros([1, self.lstm_size], tf.float32) for _ in range(self.lstm_num_layers)]
+ inputs = self.g_emb
+ skip_targets = tf.constant([1.0 - self.skip_target, self.skip_target], dtype=tf.float32)
+ for layer_id in range(self.num_layers):
+ next_c, next_h = stack_lstm(inputs, prev_c, prev_h, self.w_lstm)
+ prev_c, prev_h = next_c, next_h
+ logit = tf.matmul(next_h[-1], self.w_soft)
+ if self.temperature is not None:
+ logit /= self.temperature
+ if self.tanh_constant is not None:
+ logit = self.tanh_constant * tf.tanh(logit)
+
+ operation_id = tf.multinomial(logit, 1)
+ operation_id = tf.to_int32(operation_id)
+ operation_id = tf.reshape(operation_id, [1])
+
+ arc_seq.append(operation_id)
+ log_prob = tf.nn.sparse_softmax_cross_entropy_with_logits(
+ logits=logit, labels=operation_id)
+ log_probs.append(log_prob)
+ entropy = tf.stop_gradient(log_prob * tf.exp(-log_prob))
+ entropys.append(entropy)
+ inputs = tf.nn.embedding_lookup(self.w_emb, operation_id)
+
+ next_c, next_h = stack_lstm(inputs, prev_c, prev_h, self.w_lstm)
+ prev_c, prev_h = next_c, next_h
+
+ if layer_id > 0:
+ query = tf.concat(anchors_w_1, axis=0)
+ query = tf.tanh(query + tf.matmul(next_h[-1], self.w_attn_2))
+ query = tf.matmul(query, self.v_attn)
+ logit = tf.concat([-query, query], axis=1)
+ if self.temperature is not None:
+ logit /= self.temperature
+ if self.tanh_constant is not None:
+ logit = self.tanh_constant * tf.tanh(logit)
+
+ skip = tf.multinomial(logit, 1)
+ skip = tf.to_int32(skip)
+ skip = tf.reshape(skip, [layer_id])
+ arc_seq.append(skip)
+
+ skip_prob = tf.sigmoid(logit)
+ kl = skip_prob * tf.log(skip_prob / skip_targets)
+ kl = tf.reduce_sum(kl)
+ skip_penaltys.append(kl)
+
+ log_prob = tf.nn.sparse_softmax_cross_entropy_with_logits(
+ logits=logit, labels=skip)
+ log_probs.append(tf.reduce_sum(log_prob, keepdims=True))
+
+ entropy = tf.stop_gradient(
+ tf.reduce_sum(log_prob * tf.exp(-log_prob), keepdims=True))
+ entropys.append(entropy)
+
+ skip = tf.to_float(skip)
+ skip = tf.reshape(skip, [1, layer_id])
+ skip_count.append(tf.reduce_sum(skip))
+ inputs = tf.matmul(skip, tf.concat(anchors, axis=0))
+ inputs /= (1.0 + tf.reduce_sum(skip))
+ else:
+ inputs = self.g_emb
+
+ anchors.append(next_h[-1])
+ anchors_w_1.append(tf.matmul(next_h[-1], self.w_attn_1))
+
+ arc_seq = tf.concat(arc_seq, axis=0)
+ self.sample_arc = tf.reshape(arc_seq, [-1])
+
+ entropys = tf.stack(entropys)
+ self.sample_entropy = tf.reduce_sum(entropys)
+
+ log_probs = tf.stack(log_probs)
+ self.sample_log_prob = tf.reduce_sum(log_probs)
+
+ skip_count = tf.stack(skip_count)
+ self.skip_count = tf.reduce_sum(skip_count)
+
+ skip_penaltys = tf.stack(skip_penaltys)
+ self.skip_penaltys = tf.reduce_mean(skip_penaltys)
+
+ def build_trainer(self):
+ self.reward = tf.placeholder(tf.float32, shape=())
+
+ normalize = tf.to_float(self.num_layers * (self.num_layers - 1) / 2)
+ self.skip_rate = tf.to_float(self.skip_count) / normalize
+
+ if self.entropy_weight is not None:
+ self.reward += self.entropy_weight * self.sample_entropy
+
+ self.sample_log_prob = tf.reduce_sum(self.sample_log_prob)
+ self.baseline = tf.Variable(0.0, dtype=tf.float32, trainable=False)
+ baseline_update = tf.assign_sub(self.baseline, (1 - self.bl_dec) * (self.baseline - self.reward))
+
+ with tf.control_dependencies([baseline_update]):
+ self.reward = tf.identity(self.reward)
+
+ self.loss = self.sample_log_prob * (self.reward - self.baseline)
+ if self.skip_weight is not None:
+ self.loss += self.skip_weight * self.skip_penaltys
+
+ self.train_step = tf.Variable(0, dtype=tf.int32, trainable=False, name=self.name + "_train_step")
+ tf_variables = [var for var in tf.trainable_variables() if var.name.startswith(self.name)]
+ print("-" * 80)
+
+ self.train_op, self.lr, self.grad_norm, self.optimizer = get_train_ops(
+ self.loss,
+ tf_variables,
+ self.train_step,
+ clip_mode=self.clip_mode,
+ grad_bound=self.grad_bound,
+ l2_reg=self.l2_reg,
+ lr_init=self.lr_init,
+ lr_dec_start=self.lr_dec_start,
+ lr_dec_every=self.lr_dec_every,
+ lr_dec_rate=self.lr_dec_rate,
+ optim_algo=self.optim_algo,
+ sync_replicas=self.sync_replicas,
+ num_aggregate=self.num_aggregate,
+ num_replicas=self.num_replicas)
diff --git a/pkg/suggestion/NAS_Reinforcement_Learning/LSTM.py b/pkg/suggestion/NAS_Reinforcement_Learning/LSTM.py
new file mode 100755
index 00000000000..b3ad6008386
--- /dev/null
+++ b/pkg/suggestion/NAS_Reinforcement_Learning/LSTM.py
@@ -0,0 +1,28 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import tensorflow as tf
+
+
+# TODO: will remove this function and use tf.nn.LSTMCell instead
+
+def lstm(x, prev_c, prev_h, w):
+ ifog = tf.matmul(tf.concat([x, prev_h], axis=1), w)
+ i, f, o, g = tf.split(ifog, 4, axis=1)
+ i = tf.sigmoid(i)
+ f = tf.sigmoid(f)
+ o = tf.sigmoid(o)
+ g = tf.tanh(g)
+ next_c = i * g + f * prev_c
+ next_h = o * tf.tanh(next_c)
+ return next_c, next_h
+
+
+def stack_lstm(x, prev_c, prev_h, w):
+ next_c, next_h = [], []
+ for layer_id, (_c, _h, _w) in enumerate(zip(prev_c, prev_h, w)):
+ inputs = x if layer_id == 0 else next_h[-1]
+ curr_c, curr_h = lstm(inputs, _c, _h, _w)
+ next_c.append(curr_c)
+ next_h.append(curr_h)
+ return next_c, next_h
diff --git a/pkg/suggestion/NAS_Reinforcement_Learning/Operation.py b/pkg/suggestion/NAS_Reinforcement_Learning/Operation.py
new file mode 100644
index 00000000000..0d6fc5395cb
--- /dev/null
+++ b/pkg/suggestion/NAS_Reinforcement_Learning/Operation.py
@@ -0,0 +1,79 @@
+import itertools
+import numpy as np
+from pkg.api.python import api_pb2
+
+
+class Operation(object):
+ def __init__(self, opt_id, opt_type, opt_params):
+ self.opt_id = opt_id
+ self.opt_type = opt_type
+ self.opt_params = opt_params
+
+ def get_dict(self):
+ opt_dict = dict()
+ opt_dict['opt_id'] = self.opt_id
+ opt_dict['opt_type'] = self.opt_type
+ opt_dict['opt_params'] = self.opt_params
+ return opt_dict
+
+ def print_op(self, logger):
+ logger.info("Operation ID: \n\t{}".format(self.opt_id))
+ logger.info("Operation Type: \n\t{}".format(self.opt_type))
+ logger.info("Operations Parameters:")
+ for ikey in self.opt_params:
+ logger.info("\t{}: {}".format(ikey, self.opt_params[ikey]))
+ logger.info("")
+
+
+class SearchSpace(object):
+ def __init__(self, operations):
+ self.operation_list = list(operations.operation)
+ self.search_space = list()
+ self._parse_operations()
+ print()
+ self.num_operations = len(self.search_space)
+
+ def _parse_operations(self):
+ # search_sapce is a list of Operation class
+
+ operation_id = 0
+
+ for operation_dict in self.operation_list:
+ opt_type = operation_dict.operationType
+ opt_spec = list(operation_dict.parameter_configs.configs)
+ # avail_space is dict with the format {"spec_nam": [spec feasible values]}
+ avail_space = dict()
+ num_spec = len(opt_spec)
+
+ for ispec in opt_spec:
+ spec_name = ispec.name
+ if ispec.parameter_type == api_pb2.CATEGORICAL:
+ avail_space[spec_name] = list(ispec.feasible.list)
+ elif ispec.parameter_type == api_pb2.INT:
+ spec_min = int(ispec.feasible.min)
+ spec_max = int(ispec.feasible.max)
+ spec_step = int(ispec.feasible.step)
+ avail_space[spec_name] = range(spec_min, spec_max+1, spec_step)
+ elif ispec.parameter_type == api_pb2.DOUBLE:
+ spec_min = float(ispec.feasible.min)
+ spec_max = float(ispec.feasible.max)
+ spec_step = float(ispec.feasible.step)
+ if spec_step == 0:
+ print("Error, NAS Reinforcement Learning algorithm cannot accept continuous search space!")
+ exit(999)
+ double_list = np.arange(spec_min, spec_max+spec_step, spec_step)
+ if double_list[-1] > spec_max:
+ del double_list[-1]
+ avail_space[spec_name] = double_list
+
+ # generate all the combinations of possible operations
+ key_avail_space = list(avail_space.keys())
+ val_avail_space = list(avail_space.values())
+
+ for this_opt_vector in itertools.product(*val_avail_space):
+ opt_params = dict()
+ for i in range(num_spec):
+ opt_params[key_avail_space[i]] = this_opt_vector[i]
+ this_opt_class = Operation(operation_id, opt_type, opt_params)
+ self.search_space.append(this_opt_class)
+ operation_id += 1
diff --git a/pkg/suggestion/NAS_Reinforcement_Learning/README.md b/pkg/suggestion/NAS_Reinforcement_Learning/README.md
new file mode 100644
index 00000000000..8cf44e40d2b
--- /dev/null
+++ b/pkg/suggestion/NAS_Reinforcement_Learning/README.md
@@ -0,0 +1,128 @@
+# About the Nerual Architecture Search with Reinforcement Learning Suggestion
+
+The algorithm follows the idea proposed in *Neural Architecture Search with Reinforcement Learning* by Zoph & Le (https://arxiv.org/abs/1611.01578), and the implementation is based on the github of *Efficient Neural Architecture Search via Parameter Sharing* (https://github.com/melodyguan/enas). It uses a recurrent neural network with LSTM cells as controller to generate neural architecture candidates. And this controller network is updated by policy gradients. However, it currently does not support parameter sharing.
+
+## Definition of a Neural Architecture
+
+Define the number of layers is n, the number of possible operations is m.
+
+If n = 12, m = 6, the definition of an architecture will be like:
+
+```
+[2]
+[0 0]
+[1 1 0]
+[5 1 0 1]
+[1 1 1 0 1]
+[5 0 0 1 0 1]
+[1 1 1 0 0 1 0]
+[2 0 0 0 1 1 0 1]
+[0 0 0 1 1 1 1 1 0]
+[2 0 1 0 1 1 1 0 0 0]
+[3 1 1 1 1 1 1 0 0 1 1]
+[0 1 1 1 1 0 0 1 1 1 1 0]
+```
+
+There are n rows, the ith row has i elements and describes the ith layer. Please notice that layer 0 is the input and is not included in this definition.
+
+In each row:
+The first integer ranges from 0 to m-1, indicates the operation in this layer.
+The next (i-1) integers is either 0 or 1. The kth (k>=2) integer indicates whether (k-2)th layer has a skip connection with this layer. (There will always be a connection from (k-1)th layer to kth layer)
+
+## Output of `GetSuggestion()`
+The output of `GetSuggestion()` consists of two parts: `architecture` and `nn_config`.
+
+`architecture` is a json string of the definition of a neural architecture. The format is as stated above. One example is:
+```
+[[27], [29, 0], [22, 1, 0], [13, 0, 0, 0], [26, 1, 1, 0, 0], [30, 1, 0, 1, 0, 0], [11, 0, 1, 1, 0, 1, 1], [9, 1, 0, 0, 1, 0, 0, 0]]
+```
+
+`nn_config` is a json string of the detailed description of what is the num of layers, input size, output size and what each operation index stands for. A nn_config corresponding to the architecuture above can be:
+```
+{
+ "num_layers": 8,
+ "input_size": [32, 32, 3],
+ "output_size": [10],
+ "embedding": {
+ "27": {
+ "opt_id": 27,
+ "opt_type": "convolution",
+ "opt_params": {
+ "filter_size": "7",
+ "num_filter": "96",
+ "stride": "2"
+ }
+ },
+ "29": {
+ "opt_id": 29,
+ "opt_type": "convolution",
+ "opt_params": {
+ "filter_size": "7",
+ "num_filter": "128",
+ "stride": "2"
+ }
+ },
+ "22": {
+ "opt_id": 22,
+ "opt_type": "convolution",
+ "opt_params": {
+ "filter_size": "7",
+ "num_filter": "48",
+ "stride": "1"
+ }
+ },
+ "13": {
+ "opt_id": 13,
+ "opt_type": "convolution",
+ "opt_params": {
+ "filter_size": "5",
+ "num_filter": "48",
+ "stride": "2"
+ }
+ },
+ "26": {
+ "opt_id": 26,
+ "opt_type": "convolution",
+ "opt_params": {
+ "filter_size": "7",
+ "num_filter": "96",
+ "stride": "1"
+ }
+ },
+ "30": {
+ "opt_id": 30,
+ "opt_type": "reduction",
+ "opt_params": {
+ "reduction_type": "max_pooling",
+ "pool_size": 2
+ }
+ },
+ "11": {
+ "opt_id": 11,
+ "opt_type": "convolution",
+ "opt_params": {
+ "filter_size": "5",
+ "num_filter": "32",
+ "stride": "2"
+ }
+ },
+ "9": {
+ "opt_id": 9,
+ "opt_type": "convolution",
+ "opt_params": {
+ "filter_size": "3",
+ "num_filter": "128",
+ "stride": "2"
+ }
+ }
+ }
+}
+```
+This neural architecture can be visualized as
+![a neural netowrk architecure example](example.png)
+
+## To Do
+1. Add support for multiple studyjobs
+2. Add support for multiple trials
+3. Change LSTM cell from self defined functions in LSTM.py to `tf.nn.rnn_cell.LSTMCell`
+4. Store the suggestion checkpoint to PVC in case of nasrl service pod restarts
\ No newline at end of file
diff --git a/pkg/suggestion/NAS_Reinforcement_Learning/SuggestionParam.py b/pkg/suggestion/NAS_Reinforcement_Learning/SuggestionParam.py
new file mode 100644
index 00000000000..ae9f1d19777
--- /dev/null
+++ b/pkg/suggestion/NAS_Reinforcement_Learning/SuggestionParam.py
@@ -0,0 +1,73 @@
+def parseSuggestionParam(params_raw):
+ param_standard = {
+ "lstm_num_cells": ['value', int, [1, 'inf']],
+ "lstm_num_layers": ['value', int, [1, 'inf']],
+ "lstm_keep_prob": ['value', float, [0.0, 1.0]],
+ "optimizer": ['categorical', str, ["adam", "momentum", "sgd"]],
+ "init_learning_rate": ['value', float, [1e-6, 1.0]],
+ "lr_decay_start": ['value', int, [0, 'inf']],
+ "lr_decay_every": ['value', int, [1, 'inf']],
+ "lr_decay_rate": ['value', float, [0.0, 1.0]],
+ "skip-target": ['value', float, [0.0, 1.0]],
+ "skip-weight": ['value', float, [0.0, 'inf']],
+ "l2_reg": ['value', float, [0.0, 'inf']],
+ "entropy_weight": ['value', float, [0.0, 'inf']],
+ "baseline_decay": ['value', float, [0.0, 1.0]],
+ }
+
+ suggestion_params = {
+ "lstm_num_cells": 64,
+ "lstm_num_layers": 1,
+ "lstm_keep_prob": 1.0,
+ "optimizer": "adam",
+ "init_learning_rate": 1e-3,
+ "lr_decay_start": 0,
+ "lr_decay_every": 1000,
+ "lr_decay_rate": 0.9,
+ "skip-target": 0.4,
+ "skip-weight": 0.8,
+ "l2_reg": 0,
+ "entropy_weight": 1e-4,
+ "baseline_decay": 0.9999
+ }
+
+ def checktype(param_name, param_value, check_mode, supposed_type, supposed_range=None):
+ correct = True
+
+ try:
+ converted_value = supposed_type(param_value)
+ except:
+ correct = False
+ print("Parameter {} is of wrong type. Set back to default value {}"
+ .format(param_name, suggestion_params[param_name]))
+
+ if correct and check_mode == 'value':
+ if not ((supposed_range[0] == '-inf' or converted_value >= supposed_range[0]) and
+ (supposed_range[1] == 'inf' or converted_value <= supposed_range[1])):
+ correct = False
+ print("Parameter {} out of range. Set back to default value {}"
+ .format(param_name, suggestion_params[param_name]))
+ elif correct and check_mode == 'categorical':
+ if converted_value not in supposed_range:
+ correct = False
+ print("Parameter {} out of range. Set back to default value {}"
+ .format(param_name, suggestion_params[param_name]))
+
+ if correct:
+ suggestion_params[param_name] = converted_value
+
+
+ for param in params_raw:
+ # SuggestionCount is automatically added by controller and not used currently
+ if param.name == "SuggestionCount":
+ continue
+ if param.name in suggestion_params.keys():
+ checktype(param.name,
+ param.value,
+ param_standard[param.name][0], # mode
+ param_standard[param.name][1], # type
+ param_standard[param.name][2]) # range
+ else:
+ print("Unknown Parameter name: {}".format(param.name))
+
+ return suggestion_params
diff --git a/pkg/suggestion/NAS_Reinforcement_Learning/Trainer.py b/pkg/suggestion/NAS_Reinforcement_Learning/Trainer.py
new file mode 100644
index 00000000000..a8dde5115e5
--- /dev/null
+++ b/pkg/suggestion/NAS_Reinforcement_Learning/Trainer.py
@@ -0,0 +1,143 @@
+import tensorflow as tf
+
+
+def get_train_ops(loss,
+ tf_variables,
+ train_step,
+ clip_mode=None,
+ grad_bound=None,
+ l2_reg=1e-4,
+ lr_warmup_val=None,
+ lr_warmup_steps=100,
+ lr_init=0.1,
+ lr_dec_start=0,
+ lr_dec_every=10000,
+ lr_dec_rate=0.1,
+ lr_dec_min=None,
+ lr_cosine=False,
+ lr_max=None,
+ lr_min=None,
+ lr_T_0=None,
+ lr_T_mul=None,
+ num_train_batches=None,
+ optim_algo=None,
+ sync_replicas=False,
+ num_aggregate=None,
+ num_replicas=None,
+ get_grad_norms=False,
+ moving_average=None):
+ """
+ Args:
+ clip_mode: "global", "norm", or None.
+ moving_average: store the moving average of parameters
+ """
+
+ if l2_reg > 0:
+ l2_losses = []
+ for var in tf_variables:
+ l2_losses.append(tf.reduce_sum(var ** 2))
+ l2_loss = tf.add_n(l2_losses)
+ loss += l2_reg * l2_loss
+
+ grads = tf.gradients(loss, tf_variables)
+ grad_norm = tf.global_norm(grads)
+
+ grad_norms = {}
+ for v, g in zip(tf_variables, grads):
+ if v is None or g is None:
+ continue
+ if isinstance(g, tf.IndexedSlices):
+ grad_norms[v.name] = tf.sqrt(tf.reduce_sum(g.values ** 2))
+ else:
+ grad_norms[v.name] = tf.sqrt(tf.reduce_sum(g ** 2))
+
+ if clip_mode is not None:
+ assert grad_bound is not None, "Need grad_bound to clip gradients."
+ if clip_mode == "global":
+ grads, _ = tf.clip_by_global_norm(grads, grad_bound)
+ elif clip_mode == "norm":
+ clipped = []
+ for g in grads:
+ if isinstance(g, tf.IndexedSlices):
+ c_g = tf.clip_by_norm(g.values, grad_bound)
+ c_g = tf.IndexedSlices(g.indices, c_g)
+ else:
+ c_g = tf.clip_by_norm(g, grad_bound)
+ clipped.append(g)
+ grads = clipped
+ else:
+ raise NotImplementedError("Unknown clip_mode {}".format(clip_mode))
+
+ if lr_cosine:
+ assert lr_max is not None, "Need lr_max to use lr_cosine"
+ assert lr_min is not None, "Need lr_min to use lr_cosine"
+ assert lr_T_0 is not None, "Need lr_T_0 to use lr_cosine"
+ assert lr_T_mul is not None, "Need lr_T_mul to use lr_cosine"
+ assert num_train_batches is not None, ("Need num_train_batches to use"
+ " lr_cosine")
+
+ curr_epoch = train_step // num_train_batches
+
+ last_reset = tf.Variable(0, dtype=tf.int32, trainable=False,
+ name="last_reset")
+ T_i = tf.Variable(lr_T_0, dtype=tf.int32, trainable=False, name="T_i")
+ T_curr = curr_epoch - last_reset
+
+ def _update():
+ update_last_reset = tf.assign(last_reset, curr_epoch, use_locking=True)
+ update_T_i = tf.assign(T_i, T_i * lr_T_mul, use_locking=True)
+ with tf.control_dependencies([update_last_reset, update_T_i]):
+ rate = tf.to_float(T_curr) / tf.to_float(T_i) * 3.1415926
+ lr = lr_min + 0.5 * (lr_max - lr_min) * (1.0 + tf.cos(rate))
+ return lr
+
+ def _no_update():
+ rate = tf.to_float(T_curr) / tf.to_float(T_i) * 3.1415926
+ lr = lr_min + 0.5 * (lr_max - lr_min) * (1.0 + tf.cos(rate))
+ return lr
+
+ learning_rate = tf.cond(
+ tf.greater_equal(T_curr, T_i), _update, _no_update)
+ else:
+ learning_rate = tf.train.exponential_decay(
+ lr_init, tf.maximum(train_step - lr_dec_start, 0), lr_dec_every,
+ lr_dec_rate, staircase=True)
+ if lr_dec_min is not None:
+ learning_rate = tf.maximum(learning_rate, lr_dec_min)
+
+ if lr_warmup_val is not None:
+ learning_rate = tf.cond(tf.less(train_step, lr_warmup_steps),
+ lambda: lr_warmup_val, lambda: learning_rate)
+
+ if optim_algo == "momentum":
+ opt = tf.train.MomentumOptimizer(
+ learning_rate, 0.9, use_locking=True, use_nesterov=True)
+ elif optim_algo == "sgd":
+ opt = tf.train.GradientDescentOptimizer(learning_rate, use_locking=True)
+ elif optim_algo == "adam":
+ opt = tf.train.AdamOptimizer(learning_rate, beta1=0.0, epsilon=1e-3,
+ use_locking=True)
+ else:
+ raise ValueError("Unknown optim_algo {}".format(optim_algo))
+
+ if sync_replicas:
+ assert num_aggregate is not None, "Need num_aggregate to sync."
+ assert num_replicas is not None, "Need num_replicas to sync."
+
+ opt = tf.train.SyncReplicasOptimizer(
+ opt,
+ replicas_to_aggregate=num_aggregate,
+ total_num_replicas=num_replicas,
+ use_locking=True)
+
+ if moving_average is not None:
+ opt = tf.contrib.opt.MovingAverageOptimizer(
+ opt, average_decay=moving_average)
+
+ train_op = opt.apply_gradients(
+ zip(grads, tf_variables), global_step=train_step)
+
+ if get_grad_norms:
+ return train_op, learning_rate, grad_norm, opt, grad_norms
+ else:
+ return train_op, learning_rate, grad_norm, opt
diff --git a/pkg/suggestion/NAS_Reinforcement_Learning/example.png b/pkg/suggestion/NAS_Reinforcement_Learning/example.png
new file mode 100644
index 00000000000..90ac2eb6190
Binary files /dev/null and b/pkg/suggestion/NAS_Reinforcement_Learning/example.png differ
diff --git a/pkg/suggestion/nasrl_service.py b/pkg/suggestion/nasrl_service.py
index f7e672a2276..0373bd122aa 100644
--- a/pkg/suggestion/nasrl_service.py
+++ b/pkg/suggestion/nasrl_service.py
@@ -1,22 +1,259 @@
-import random
-import string
-
+from pkg.suggestion.NAS_Reinforcement_Learning.Controller import Controller
+from pkg.suggestion.NAS_Reinforcement_Learning.Operation import SearchSpace
+from pkg.suggestion.NAS_Reinforcement_Learning.SuggestionParam import parseSuggestionParam
+import tensorflow as tf
import grpc
-import numpy as np
-
from pkg.api.python import api_pb2
from pkg.api.python import api_pb2_grpc
import logging
from logging import getLogger, StreamHandler, INFO, DEBUG
+import json
+import os
+import time
class NasrlService(api_pb2_grpc.SuggestionServicer):
def __init__(self, logger=None):
self.manager_addr = "vizier-core"
self.manager_port = 6789
- self.current_trial_id = ""
+ self.registered_studies = list()
+
+ self.ctrl_cache_file = ""
+ self.ctrl_step = 0
+ self.is_first_run = True
+
+ if not os.path.exists("ctrl_cache/"):
+ os.makedirs("ctrl_cache/")
+
+ if logger == None:
+ self.logger = getLogger(__name__)
+ FORMAT = '%(asctime)-15s StudyID %(studyid)s %(message)s'
+ logging.basicConfig(format=FORMAT)
+ handler = StreamHandler()
+ handler.setLevel(INFO)
+ self.logger.setLevel(INFO)
+ self.logger.addHandler(handler)
+ self.logger.propagate = False
+ else:
+ self.logger = logger
+
+ def setup_controller(self, request):
+ self.logger.info("-" * 80 + "\nSetting Up Suggestion for StudyJob {}\n".format(request.study_id) + "-" * 80)
+ self.tf_graph = tf.Graph()
+ self.ctrl_step = 0
+ self.ctrl_cache_file = "ctrl_cache/{}.ckpt".format(request.study_id)
+ self._get_suggestion_param(request.param_id, request.study_id)
+ self._get_search_space(request.study_id)
+
+ with self.tf_graph.as_default():
+ ctrl_param = self.suggestion_config
+ self.controllers = Controller(
+ num_layers=self.num_layers,
+ num_operations=self.num_operations,
+ lstm_size=ctrl_param['lstm_num_cells'],
+ lstm_num_layers=ctrl_param['lstm_num_layers'],
+ lstm_keep_prob=ctrl_param['lstm_keep_prob'],
+ lr_init=ctrl_param['init_learning_rate'],
+ lr_dec_start=ctrl_param['lr_decay_start'],
+ lr_dec_every=ctrl_param['lr_decay_every'],
+ lr_dec_rate=ctrl_param['lr_decay_rate'],
+ l2_reg=ctrl_param['l2_reg'],
+ entropy_weight=ctrl_param['entropy_weight'],
+ bl_dec=ctrl_param['baseline_decay'],
+ optim_algo=ctrl_param['optimizer'],
+ skip_target=ctrl_param['skip-target'],
+ skip_weight=ctrl_param['skip-weight'],
+ name="Ctrl_"+request.study_id)
+
+ self.controllers.build_trainer()
+
+ self.logger.info("Suggestion for StudyJob {} has been initialized.".format(request.study_id))
def GetSuggestions(self, request, context):
+ if request.study_id not in self.registered_studies:
+ self.setup_controller(request)
+ self.is_first_run = True
+ self.registered_studies.append(request.study_id)
+
+ self.logger.info("-" * 80 + "\nSuggestion Step {} for Study {}\n".format(self.ctrl_step, request.study_id) + "-" * 80)
+
+ with self.tf_graph.as_default():
+
+ saver = tf.train.Saver()
+ ctrl = self.controllers
+
+ controller_ops = {
+ "train_step": ctrl.train_step,
+ "loss": ctrl.loss,
+ "train_op": ctrl.train_op,
+ "lr": ctrl.lr,
+ "grad_norm": ctrl.grad_norm,
+ "optimizer": ctrl.optimizer,
+ "baseline": ctrl.baseline,
+ "entropy": ctrl.sample_entropy,
+ "sample_arc": ctrl.sample_arc,
+ "skip_rate": ctrl.skip_rate}
+
+ run_ops = [
+ controller_ops["loss"],
+ controller_ops["entropy"],
+ controller_ops["lr"],
+ controller_ops["grad_norm"],
+ controller_ops["baseline"],
+ controller_ops["skip_rate"],
+ controller_ops["train_op"]]
+
+ if self.is_first_run:
+ self.logger.info("First time running suggestion for {}. Random architecture will be given.".format(request.study_id))
+ with tf.Session() as sess:
+ sess.run(tf.global_variables_initializer())
+ arc = sess.run(controller_ops["sample_arc"])
+ # TODO: will use PVC to store the checkpoint to protect against unexpected suggestion pod restart
+ saver.save(sess, self.ctrl_cache_file)
+
+ self.is_first_run = False
+
+ else:
+ with tf.Session() as sess:
+ saver.restore(sess, self.ctrl_cache_file)
+
+ valid_acc = ctrl.reward
+ result = self.GetEvaluationResult(request.study_id)
+
+ # This lstm cell is designed to maximize the metrics
+ # However, if the user want to minimize the metrics, we can take the negative of the result
+ if self.opt_direction == api_pb2.MINIMIZE:
+ result = -result
+
+ loss, entropy, lr, gn, bl, skip, _ = sess.run(
+ fetches=run_ops,
+ feed_dict={valid_acc: result})
+ self.logger.info("Suggetion updated. LSTM Controller Loss: {}".format(loss))
+ arc = sess.run(controller_ops["sample_arc"])
+
+ saver.save(sess, self.ctrl_cache_file)
+
+ arc = arc.tolist()
+ organized_arc = [0 for _ in range(self.num_layers)]
+ record = 0
+ for l in range(self.num_layers):
+ organized_arc[l] = arc[record: record + l + 1]
+ record += l + 1
+
+ nn_config = dict()
+ nn_config['num_layers'] = self.num_layers
+ nn_config['input_size'] = self.input_size
+ nn_config['output_size'] = self.output_size
+ nn_config['embedding'] = dict()
+ for l in range(self.num_layers):
+ opt = organized_arc[l][0]
+ nn_config['embedding'][opt] = self.search_space[opt].get_dict()
+
+ organized_arc_json = json.dumps(organized_arc)
+ nn_config_json = json.dumps(nn_config)
+
+ organized_arc_str = str(organized_arc_json).replace('\"', '\'')
+ nn_config_str = str(nn_config_json).replace('\"', '\'')
+
+ self.logger.info("\nNew Neural Network Architecture (internal representation):")
+ self.logger.info(organized_arc_json)
+ self.logger.info("\nCorresponding Seach Space Description:")
+ self.logger.info(nn_config_str)
+ self.logger.info("")
+
trials = []
+ trials.append(api_pb2.Trial(
+ study_id=request.study_id,
+ parameter_set=[
+ api_pb2.Parameter(
+ name="architecture",
+ value=organized_arc_str,
+ parameter_type= api_pb2.CATEGORICAL),
+ api_pb2.Parameter(
+ name="nn_config",
+ value=nn_config_str,
+ parameter_type= api_pb2.CATEGORICAL)
+ ],
+ )
+ )
+ channel = grpc.beta.implementations.insecure_channel(self.manager_addr, self.manager_port)
+ with api_pb2.beta_create_Manager_stub(channel) as client:
+ for i, t in enumerate(trials):
+ ctrep = client.CreateTrial(api_pb2.CreateTrialRequest(trial=t), 10)
+ trials[i].trial_id = ctrep.trial_id
+ self.logger.info("Trial {} Created\n".format(ctrep.trial_id))
+ self.prev_trial_id = ctrep.trial_id
+
+ self.ctrl_step += 1
return api_pb2.GetSuggestionsReply(trials=trials)
+
+ def GetEvaluationResult(self, studyID):
+ worker_list = []
+ channel = grpc.beta.implementations.insecure_channel(self.manager_addr, self.manager_port)
+ with api_pb2.beta_create_Manager_stub(channel) as client:
+ gwfrep = client.GetWorkerFullInfo(api_pb2.GetWorkerFullInfoRequest(study_id=studyID, trial_id=self.prev_trial_id, only_latest_log=True), 10)
+ worker_list = gwfrep.worker_full_infos
+
+ for w in worker_list:
+ if w.Worker.status == api_pb2.COMPLETED:
+ for ml in w.metrics_logs:
+ if ml.name == self.objective_name:
+ self.logger.info("Evaluation result of previous candidate: {}".format(ml.values[-1].value))
+ return float(ml.values[-1].value)
+
+ # TODO: add support for multiple trials
+
+
+ def _get_search_space(self, studyID):
+
+ # this function need to
+ # 1) get the number of layers
+ # 2) get the I/O size
+ # 3) get the available operations
+ # 4) get the optimization direction (i.e. minimize or maximize)
+ # 5) get the objective name
+
+ channel = grpc.beta.implementations.insecure_channel(self.manager_addr, self.manager_port)
+ with api_pb2.beta_create_Manager_stub(channel) as client:
+ gsrep = client.GetStudy(api_pb2.GetStudyRequest(study_id=studyID), 10)
+
+ self.opt_direction = gsrep.study_config.optimization_type
+ self.objective_name = gsrep.study_config.objective_value_name
+
+ all_params = gsrep.study_config.nas_config
+ graph_config = all_params.graph_config
+ search_space_raw = all_params.operations
+
+ self.num_layers = int(graph_config.num_layers)
+ self.input_size = list(map(int, graph_config.input_size))
+ self.output_size = list(map(int, graph_config.output_size))
+ search_space_object = SearchSpace(search_space_raw)
+
+ self.logger.info("Search Space for Study {}:".format(studyID))
+
+ self.search_space = search_space_object.search_space
+ for opt in self.search_space:
+ opt.print_op(self.logger)
+
+ self.num_operations = search_space_object.num_operations
+ self.logger.info("There are {} operations in total.\n".format(self.num_operations))
+
+
+ def _get_suggestion_param(self, paramID, studyID):
+ channel = grpc.beta.implementations.insecure_channel(self.manager_addr, self.manager_port)
+ with api_pb2.beta_create_Manager_stub(channel) as client:
+ gsprep = client.GetSuggestionParameters(api_pb2.GetSuggestionParametersRequest(param_id=paramID), 10)
+
+ params_raw = gsprep.suggestion_parameters
+
+ suggestion_params = parseSuggestionParam(params_raw)
+
+ self.logger.info("Parameters of LSTM Controller for Study {}:".format(studyID))
+ for spec in suggestion_params:
+ if len(spec) > 13:
+ self.logger.info("{}: \t{}".format(spec, suggestion_params[spec]))
+ else:
+ self.logger.info("{}: \t\t{}".format(spec, suggestion_params[spec]))
+
+ self.suggestion_config = suggestion_params