Skip to content

Commit

Permalink
Integrates get_genotype in paddleslim (PaddlePaddle#228)
Browse files Browse the repository at this point in the history
  • Loading branch information
baiyfbupt authored Apr 21, 2020
1 parent cdb98e2 commit 388211f
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 114 deletions.
56 changes: 0 additions & 56 deletions demo/darts/genotypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,62 +21,6 @@
'sep_conv_5x5', 'dil_conv_3x3', 'dil_conv_5x5'
]

NASNet = Genotype(
normal=[
('sep_conv_5x5', 1),
('sep_conv_3x3', 0),
('sep_conv_5x5', 0),
('sep_conv_3x3', 0),
('avg_pool_3x3', 1),
('skip_connect', 0),
('avg_pool_3x3', 0),
('avg_pool_3x3', 0),
('sep_conv_3x3', 1),
('skip_connect', 1),
],
normal_concat=[2, 3, 4, 5, 6],
reduce=[
('sep_conv_5x5', 1),
('sep_conv_7x7', 0),
('max_pool_3x3', 1),
('sep_conv_7x7', 0),
('avg_pool_3x3', 1),
('sep_conv_5x5', 0),
('skip_connect', 3),
('avg_pool_3x3', 2),
('sep_conv_3x3', 2),
('max_pool_3x3', 1),
],
reduce_concat=[4, 5, 6], )

AmoebaNet = Genotype(
normal=[
('avg_pool_3x3', 0),
('max_pool_3x3', 1),
('sep_conv_3x3', 0),
('sep_conv_5x5', 2),
('sep_conv_3x3', 0),
('avg_pool_3x3', 3),
('sep_conv_3x3', 1),
('skip_connect', 1),
('skip_connect', 0),
('avg_pool_3x3', 1),
],
normal_concat=[4, 5, 6],
reduce=[
('avg_pool_3x3', 0),
('sep_conv_3x3', 1),
('max_pool_3x3', 0),
('sep_conv_7x7', 2),
('sep_conv_7x7', 0),
('avg_pool_3x3', 1),
('max_pool_3x3', 0),
('max_pool_3x3', 1),
('conv_7x1_1x7', 0),
('sep_conv_3x3', 5),
],
reduce_concat=[3, 4, 6])

DARTS_V1 = Genotype(
normal=[('sep_conv_5x5', 0), ('dil_conv_3x3', 1), ('sep_conv_3x3', 2),
('sep_conv_5x5', 0), ('sep_conv_5x5', 0), ('dil_conv_3x3', 3),
Expand Down
59 changes: 2 additions & 57 deletions demo/darts/model_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.base import to_variable
from genotypes import PRIMITIVES
from genotypes import Genotype
from operations import *


Expand Down Expand Up @@ -147,6 +146,7 @@ def __init__(self,
self._layers = layers
self._steps = steps
self._multiplier = multiplier
self._primitives = PRIMITIVES
self._method = method

c_cur = stem_multiplier * c_in
Expand Down Expand Up @@ -238,7 +238,7 @@ def new(self):

def _initialize_alphas(self):
k = sum(1 for i in range(self._steps) for n in range(2 + i))
num_ops = len(PRIMITIVES)
num_ops = len(self._primitives)
self.alphas_normal = fluid.layers.create_parameter(
shape=[k, num_ops],
dtype="float32",
Expand Down Expand Up @@ -268,58 +268,3 @@ def _initialize_alphas(self):

def arch_parameters(self):
return self._arch_parameters

def genotype(self):
def _parse(weights, weights2=None):
gene = []
n = 2
start = 0
for i in range(self._steps):
end = start + n
W = weights[start:end].copy()
if self._method == "PC-DARTS":
W2 = weights2[start:end].copy()
for j in range(n):
W[j, :] = W[j, :] * W2[j]
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2]
for j in edges:
k_best = None
for k in range(len(W[j])):
if k != PRIMITIVES.index('none'):
if k_best is None or W[j][k] > W[j][k_best]:
k_best = k
gene.append((PRIMITIVES[k_best], j))
start = end
n += 1
return gene

weightsr2 = None
weightsn2 = None
if self._method == "PC-DARTS":
n = 3
start = 2
weightsr2 = fluid.layers.softmax(self.betas_reduce[0:2])
weightsn2 = fluid.layers.softmax(self.betas_normal[0:2])
for i in range(self._steps - 1):
end = start + n
tw2 = fluid.layers.softmax(self.betas_reduce[start:end])
tn2 = fluid.layers.softmax(self.betas_normal[start:end])
start = end
n += 1
weightsr2 = fluid.layers.concat([weightsr2, tw2])
weightsn2 = fluid.layers.concat([weightsn2, tn2])
weightsr2 = weightsr2.numpy()
weightsn2 = weightsn2.numpy()

gene_normal = _parse(
fluid.layers.softmax(self.alphas_normal).numpy(), weightsn2)
gene_reduce = _parse(
fluid.layers.softmax(self.alphas_reduce).numpy(), weightsr2)

concat = range(2 + self._steps - self._multiplier, self._steps + 2)
genotype = Genotype(
normal=gene_normal,
normal_concat=concat,
reduce=gene_reduce,
reduce_concat=concat)
return genotype
78 changes: 78 additions & 0 deletions paddleslim/nas/darts/get_genotype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import paddle.fluid as fluid
from collections import namedtuple

Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')


def get_genotype(model):
def _parse(weights, weights2=None):
gene = []
n = 2
start = 0
for i in range(model._steps):
end = start + n
W = weights[start:end].copy()
if model._method == "PC-DARTS":
W2 = weights2[start:end].copy()
for j in range(n):
W[j, :] = W[j, :] * W2[j]
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != model._primitives.index('none')))[:2]
for j in edges:
k_best = None
for k in range(len(W[j])):
if k != model._primitives.index('none'):
if k_best is None or W[j][k] > W[j][k_best]:
k_best = k
gene.append((model._primitives[k_best], j))
start = end
n += 1
return gene

weightsr2 = None
weightsn2 = None
if model._method == "PC-DARTS":
n = 3
start = 2
weightsr2 = fluid.layers.softmax(model.betas_reduce[0:2])
weightsn2 = fluid.layers.softmax(model.betas_normal[0:2])
for i in range(model._steps - 1):
end = start + n
tw2 = fluid.layers.softmax(model.betas_reduce[start:end])
tn2 = fluid.layers.softmax(model.betas_normal[start:end])
start = end
n += 1
weightsr2 = fluid.layers.concat([weightsr2, tw2])
weightsn2 = fluid.layers.concat([weightsn2, tn2])
weightsr2 = weightsr2.numpy()
weightsn2 = weightsn2.numpy()

gene_normal = _parse(
fluid.layers.softmax(model.alphas_normal).numpy(), weightsn2)
gene_reduce = _parse(
fluid.layers.softmax(model.alphas_reduce).numpy(), weightsr2)

concat = range(2 + model._steps - model._multiplier, model._steps + 2)
genotype = Genotype(
normal=gene_normal,
normal_concat=concat,
reduce=gene_reduce,
reduce_concat=concat)
return genotype
3 changes: 2 additions & 1 deletion paddleslim/nas/darts/train_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from paddle.fluid.dygraph.base import to_variable
from ...common import AvgrageMeter, get_logger
from .architect import Architect
from .get_genotype import get_genotype
logger = get_logger(__name__, level=logging.INFO)


Expand Down Expand Up @@ -201,7 +202,7 @@ def train(self):
for epoch in range(self.num_epochs):
logger.info('Epoch {}, lr {:.6f}'.format(
epoch, optimizer.current_step_lr()))
genotype = self.model.genotype()
genotype = get_genotype(self.model)
logger.info('genotype = %s', genotype)

train_top1 = self.train_one_epoch(train_loader, valid_loader,
Expand Down

0 comments on commit 388211f

Please sign in to comment.