Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DARTS Suggestion #1175

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
.gitignore
docs
examples
!examples/v1alpha3/nas/enas-cnn-cifar10
!examples/v1alpha3/nas
manifests
pkg/ui/*/frontend/node_modules
pkg/ui/*/frontend/build
26 changes: 26 additions & 0 deletions cmd/suggestion/nas/darts/v1alpha3/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
FROM python:3.6

RUN if [ "$(uname -m)" = "ppc64le" ] || [ "$(uname -m)" = "aarch64" ]; then \
apt-get -y update && \
apt-get -y install gfortran libopenblas-dev liblapack-dev && \
pip install cython; \
fi

RUN GRPC_HEALTH_PROBE_VERSION=v0.3.1 && \
if [ "$(uname -m)" = "ppc64le" ]; then \
wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-ppc64le; \
elif [ "$(uname -m)" = "aarch64" ]; then \
wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-arm64; \
else \
wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-amd64; \
fi && \
chmod +x /bin/grpc_health_probe

ADD . /usr/src/app/github.com/kubeflow/katib
WORKDIR /usr/src/app/github.com/kubeflow/katib/cmd/suggestion/nas/darts/v1alpha3
RUN pip install --no-cache-dir -r requirements.txt

ENV PYTHONPATH /usr/src/app/github.com/kubeflow/katib:/usr/src/app/github.com/kubeflow/katib/pkg/apis/manager/v1alpha3/python:/usr/src/app/github.com/kubeflow/katib/pkg/apis/manager/health/python

ENTRYPOINT ["python", "main.py"]

30 changes: 30 additions & 0 deletions cmd/suggestion/nas/darts/v1alpha3/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import grpc
from concurrent import futures
import time
from pkg.apis.manager.v1alpha3.python import api_pb2_grpc
from pkg.apis.manager.health.python import health_pb2_grpc
from pkg.suggestion.v1alpha3.nas.darts.service import DartsService


_ONE_DAY_IN_SECONDS = 60 * 60 * 24
DEFAULT_PORT = "0.0.0.0:6789"


def serve():
print("Darts Suggestion Service")
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
service = DartsService()
api_pb2_grpc.add_SuggestionServicer_to_server(service, server)
health_pb2_grpc.add_HealthServicer_to_server(service, server)
server.add_insecure_port(DEFAULT_PORT)
print("Listening...")
server.start()
try:
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
server.stop(0)


if __name__ == "__main__":
serve()
3 changes: 3 additions & 0 deletions cmd/suggestion/nas/darts/v1alpha3/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
grpcio==1.23.0
protobuf==3.9.1
googleapis-common-protos==1.6.0
9 changes: 9 additions & 0 deletions examples/v1alpha3/nas/darts-cnn-cifar10/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
ARG cuda_version=10.0
ARG cudnn_version=7
FROM pytorch/pytorch:1.0-cuda${cuda_version}-cudnn${cudnn_version}-runtime


ADD . /usr/src/app/github.com/kubeflow/katib
WORKDIR /usr/src/app/github.com/kubeflow/katib/examples/v1alpha3/nas/darts-cnn-cifar10

ENTRYPOINT ["python3", "-u", "run_trial.py"]
113 changes: 113 additions & 0 deletions examples/v1alpha3/nas/darts-cnn-cifar10/architect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import torch
import copy


class Architect():
"""" Architect controls architecture of cell by computing gradients of alphas
"""

def __init__(self, model, w_momentum, w_weight_decay):
self.model = model
self.v_model = copy.deepcopy(model)
self.w_momentum = w_momentum
self.w_weight_decay = w_weight_decay

def virtual_step(self, train_x, train_y, xi, w_optim):
"""
Compute unrolled weight w' (virtual step)
Step process:
1) forward
2) calculate loss
3) compute gradient (by backprop)
4) update gradient

Args:
xi: learning rate for virtual gradient step (same as weights lr)
w_optim: weights optimizer
"""

# Forward and calculate loss
# Loss for train with w. L_train(w)
loss = self.model.loss(train_x, train_y)
# Compute gradient
gradients = torch.autograd.grad(loss, self.model.getWeights())

# Do virtual step (Update gradient)
# Bellow opeartions do not need gradient tracking
with torch.no_grad():
# dict key is not the value, but the pointer. So original network weight have to
# be iterated also.
for w, vw, g in zip(self.model.getWeights(), self.v_model.getWeights(), gradients):
m = w_optim.state[w].get("momentum_buffer", 0.) * self.w_momentum
vw.copy_(w - xi * (m + g + self.w_weight_decay * w))

# Sync alphas
for a, va in zip(self.model.getAlphas(), self.v_model.getAlphas()):
va.copy_(a)

def unrolled_backward(self, train_x, train_y, valid_x, valid_y, xi, w_optim):
""" Compute unrolled loss and backward its gradients
Args:
xi: learning rate for virtual gradient step (same as model lr)
w_optim: weights optimizer - for virtual step
"""
# Do virtual step (calc w')
self.virtual_step(train_x, train_y, xi, w_optim)

# Calculate unrolled loss
# Loss for validation with w'. L_valid(w')
loss = self.v_model.loss(valid_x, valid_y)

# Calculate gradient
v_alphas = tuple(self.v_model.getAlphas())
v_weights = tuple(self.v_model.getWeights())
v_grads = torch.autograd.grad(loss, v_alphas + v_weights)

dalpha = v_grads[:len(v_alphas)]
dws = v_grads[len(v_alphas):]

hessian = self.compute_hessian(dws, train_x, train_y)

# Update final gradient = dalpha - xi * hessian
with torch.no_grad():
for alpha, da, h in zip(self.model.getAlphas(), dalpha, hessian):
alpha.grad = da - xi * h

def compute_hessian(self, dws, train_x, train_y):
"""
dw = dw' { L_valid(w', alpha) }
w+ = w + eps * dw
w- = w - eps * dw
hessian = (dalpha{ L_train(w+, alpha) } - dalpha{ L_train(w-, alpha) }) / (2*eps)
eps = 0.01 / ||dw||
"""

norm = torch.cat([dw.view(-1) for dw in dws]).norm()
eps = 0.01 / norm

# w+ = w + eps * dw
with torch.no_grad():
for p, dw in zip(self.model.getWeights(), dws):
p += eps * dw

loss = self.model.loss(train_x, train_y)
# dalpha { L_train(w+, alpha) }
dalpha_positive = torch.autograd.grad(loss, self.model.getAlphas())

# w- = w - eps * dw
with torch.no_grad():
for p, dw in zip(self.model.getWeights(), dws):
# TODO (andreyvelich): Do we need this * 2.0 ?
p -= 2. * eps * dw

loss = self.model.loss(train_x, train_y)
# dalpha { L_train(w-, alpha) }
dalpha_negative = torch.autograd.grad(loss, self.model.getAlphas())

# recover w
with torch.no_grad():
for p, dw in zip(self.model.getWeights(), dws):
p += eps * dw

hessian = [(p-n) / (2. * eps) for p, n in zip(dalpha_positive, dalpha_negative)]
return hessian
172 changes: 172 additions & 0 deletions examples/v1alpha3/nas/darts-cnn-cifar10/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from operations import FactorizedReduce, StdConv, MixedOp


class Cell(nn.Module):
""" Cell for search
Each edge is mixed and continuous relaxed.
"""

def __init__(self, num_nodes, c_prev_prev, c_prev, c_cur, reduction_prev, reduction_cur, search_space):
"""
Args:
num_nodes: Number of intermediate cell nodes
c_prev_prev: channels_out[k-2]
c_prev : Channels_out[k-1]
c_cur : Channels_in[k] (current)
reduction_prev: flag for whether the previous cell is reduction cell or not
reduction_cur: flag for whether the current cell is reduction cell or not
"""

super(Cell, self).__init__()
self.reduction_cur = reduction_cur
self.num_nodes = num_nodes

# If previous cell is reduction cell, current input size does not match with
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing
if reduction_prev:
self.preprocess0 = FactorizedReduce(c_prev_prev, c_cur)
else:
self.preprocess0 = StdConv(c_prev_prev, c_cur, kernel_size=1, stride=1, padding=0)
self.preprocess1 = StdConv(c_prev, c_cur, kernel_size=1, stride=1, padding=0)

# Generate dag from mixed operations
self.dag_ops = nn.ModuleList()

for i in range(self.num_nodes):
self.dag_ops.append(nn.ModuleList())
# Include 2 input nodes
for j in range(2+i):
# Reduction with stride = 2 must be only for the input node
stride = 2 if reduction_cur and j < 2 else 1
op = MixedOp(c_cur, stride, search_space)
self.dag_ops[i].append(op)

def forward(self, s0, s1, w_dag):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)

states = [s0, s1]
for edges, w_list in zip(self.dag_ops, w_dag):
state_cur = sum(edges[i](s, w) for i, (s, w) in enumerate((zip(states, w_list))))
states.append(state_cur)

state_out = torch.cat(states[2:], dim=1)
return state_out


class NetworkCNN(nn.Module):

def __init__(self, init_channels, input_channels, num_classes, num_layers, criterion, search_space):
super(NetworkCNN, self).__init__()

self.init_channels = init_channels
self.num_classes = num_classes
self.num_layers = num_layers
self.criterion = criterion

# TODO: Algorithm settings?
self.num_nodes = 4
self.stem_multiplier = 3

c_cur = self.stem_multiplier*self.init_channels

self.stem = nn.Sequential(
nn.Conv2d(input_channels, c_cur, 3, padding=1, bias=False),
nn.BatchNorm2d(c_cur)
)

# In first Cell stem is used for s0 and s1
# c_prev_prev and c_prev - output channels size
# c_cur - init channels size
c_prev_prev, c_prev, c_cur = c_cur, c_cur, self.init_channels

self.cells = nn.ModuleList()

reduction_prev = False
for i in range(self.num_layers):
# For [1/3, 2/3] Layers - Reduction cell with double channels
# Others - Normal cell
if i in [self.num_layers//3, 2*self.num_layers//3]:
c_cur *= 2
reduction_cur = True
else:
reduction_cur = False

cell = Cell(self.num_nodes, c_prev_prev, c_prev, c_cur, reduction_prev, reduction_cur, search_space)
reduction_prev = reduction_cur
self.cells.append(cell)

c_cur_out = c_cur * self.num_nodes
c_prev_prev, c_prev = c_prev, c_cur_out

self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(c_prev, self.num_classes)

# Initialize alphas parameters
num_ops = len(search_space.primitives)

self.alpha_normal = nn.ParameterList()
self.alpha_reduce = nn.ParameterList()

for i in range(self.num_nodes):
self.alpha_normal.append(nn.Parameter(1e-3*torch.randn(i+2, num_ops)))
self.alpha_reduce.append(nn.Parameter(1e-3*torch.randn(i+2, num_ops)))

# Setup alphas list
self.alphas = []
for name, parameter in self.named_parameters():
if "alpha" in name:
self.alphas.append((name, parameter))

def forward(self, x):

weights_normal = [F.softmax(alpha, dim=-1) for alpha in self.alpha_normal]
weights_reduce = [F.softmax(alpha, dim=-1) for alpha in self.alpha_reduce]

s0 = s1 = self.stem(x)

for cell in self.cells:
weights = weights_reduce if cell.reduction_cur else weights_normal
s0, s1 = s1, cell(s0, s1, weights)

out = self.global_pooling(s1)

# Make out flatten
out = out.view(out.size(0), -1)

logits = self.classifier(out)
return logits

def print_alphas(self):

print("\n>>> Alphas Normal <<<")
for alpha in self.alpha_normal:
print(F.softmax(alpha, dim=-1))

print("\n>>> Alpha Reduce <<<")
for alpha in self.alpha_reduce:
print(F.softmax(alpha, dim=-1))
print("\n")

def getWeights(self):
return self.parameters()

def getAlphas(self):
for _, parameter in self.alphas:
yield parameter

def loss(self, x, y):
logits = self.forward(x)
return self.criterion(logits, y)

def genotype(self, search_space):
gene_normal = search_space.parse(self.alpha_normal, k=2)
gene_reduce = search_space.parse(self.alpha_reduce, k=2)
# concat all intermediate nodes
concat = range(2, 2 + self.num_nodes)

return search_space.genotype(normal=gene_normal, normal_concat=concat,
reduce=gene_reduce, reduce_concat=concat)
Loading