Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Add support for @nni.training_update in codegen #1564

Merged
merged 2 commits into from
Sep 30, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions tools/nni_annotation/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import ast
import astor


# pylint: disable=unidiomatic-typecheck

def parse_annotation_mutable_layers(code, lineno, nas_mode):
Expand Down Expand Up @@ -79,7 +80,8 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
fields['optional_inputs'] = True
elif k.id == 'optional_input_size':
assert not fields['optional_input_size'], 'Duplicated field: optional_input_size'
assert type(value) is ast.Num or type(value) is ast.List, 'Value of optional_input_size should be a number or list'
assert type(value) is ast.Num or type(
value) is ast.List, 'Value of optional_input_size should be a number or list'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest switching to a new line after ast.List,

optional_input_size = value
fields['optional_input_size'] = True
elif k.id == 'layer_output':
Expand Down Expand Up @@ -118,6 +120,7 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
nodes.append(node)
return nodes


def parse_annotation(code):
"""Parse an annotation string.
Return an AST Expr node.
Expand Down Expand Up @@ -198,7 +201,7 @@ def convert_args_to_dict(call, with_lambda=False):
if type(arg) in [ast.Str, ast.Num]:
arg_value = arg
else:
# if arg is not a string or a number, we use its source code as the key
# if arg is not a string or a number, we use its source code as the key
arg_value = astor.to_source(arg).strip('\n"')
arg_value = ast.Str(str(arg_value))
arg = make_lambda(arg) if with_lambda else arg
Expand Down Expand Up @@ -311,7 +314,6 @@ def visit(self, node):

return self._visit_children(node)


def _visit_string(self, node):
string = node.value.s
if string.startswith('@nni.'):
Expand All @@ -325,7 +327,7 @@ def _visit_string(self, node):
call_node.args.insert(0, ast.Str(s=self.nas_mode))
return expr

if string.startswith('@nni.report_intermediate_result') \
if string.startswith('@nni.report_intermediate_result') \
or string.startswith('@nni.report_final_result') \
or string.startswith('@nni.get_next_parameter'):
return parse_annotation(string[1:]) # expand annotation string to code
Expand All @@ -341,7 +343,6 @@ def _visit_string(self, node):

raise AssertionError('Unexpected annotation function')


def _visit_children(self, node):
self.stack.append(None)
self.generic_visit(node)
Expand Down
3 changes: 1 addition & 2 deletions tools/nni_annotation/search_space_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def generate_mutable_layer_search_space(self, args):
'optional_input_size': args[6].n if isinstance(args[6], ast.Num) else [args[6].elts[0].n, args[6].elts[1].n]
}


def visit_Call(self, node): # pylint: disable=invalid-name
self.generic_visit(node)

Expand Down Expand Up @@ -108,7 +107,7 @@ def visit_Call(self, node): # pylint: disable=invalid-name
else:
# arguments of other functions must be literal number
assert all(isinstance(ast.literal_eval(astor.to_source(arg)), numbers.Real) for arg in node.args), \
'Smart parameter\'s arguments must be number literals'
'Smart parameter\'s arguments must be number literals'
args = [ast.literal_eval(astor.to_source(arg)) for arg in node.args]

key = self.module_name + '/' + name + '/' + func
Expand Down
24 changes: 16 additions & 8 deletions tools/nni_annotation/specific_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
para_cfg = None
prefix_name = None


def parse_annotation_mutable_layers(code, lineno):
"""Parse the string of mutable layers in annotation.
Return a list of AST Expr nodes
Expand Down Expand Up @@ -102,6 +103,7 @@ def parse_annotation_mutable_layers(code, lineno):
nodes.append(node)
return nodes


def parse_annotation(code):
"""Parse an annotation string.
Return an AST Expr node.
Expand Down Expand Up @@ -182,7 +184,7 @@ def convert_args_to_dict(call, with_lambda=False):
if type(arg) in [ast.Str, ast.Num]:
arg_value = arg
else:
# if arg is not a string or a number, we use its source code as the key
# if arg is not a string or a number, we use its source code as the key
arg_value = astor.to_source(arg).strip('\n"')
arg_value = ast.Str(str(arg_value))
arg = make_lambda(arg) if with_lambda else arg
Expand Down Expand Up @@ -217,7 +219,7 @@ def test_variable_equal(node1, node2):
if len(node1) != len(node2):
return False
return all(test_variable_equal(n1, n2) for n1, n2 in zip(node1, node2))

return node1 == node2


Expand Down Expand Up @@ -294,7 +296,6 @@ def visit(self, node):

return self._visit_children(node)


def _visit_string(self, node):
string = node.value.s
if string.startswith('@nni.'):
Expand All @@ -303,19 +304,27 @@ def _visit_string(self, node):
return node # not an annotation, ignore it

if string.startswith('@nni.get_next_parameter'):
deprecated_message = "'@nni.get_next_parameter' is deprecated in annotation due to inconvenience. Please remove this line in the trial code."
deprecated_message = "'@nni.get_next_parameter' is deprecated in annotation due to inconvenience. " \
"Please remove this line in the trial code."
print_warning(deprecated_message)
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), args=[ast.Str(s='Get next parameter here...')], keywords=[]))
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='Get next parameter here...')], keywords=[]))

if string.startswith('@nni.training_update'):
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='Training update here...')], keywords=[]))

if string.startswith('@nni.report_intermediate_result'):
module = ast.parse(string[1:])
arg = module.body[0].value.args[0]
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), args=[ast.Str(s='nni.report_intermediate_result: '), arg], keywords=[]))
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='nni.report_intermediate_result: '), arg], keywords=[]))

if string.startswith('@nni.report_final_result'):
module = ast.parse(string[1:])
arg = module.body[0].value.args[0]
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), args=[ast.Str(s='nni.report_final_result: '), arg], keywords=[]))
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='nni.report_final_result: '), arg], keywords=[]))

if string.startswith('@nni.mutable_layers'):
return parse_annotation_mutable_layers(string[1:], node.lineno)
Expand All @@ -327,7 +336,6 @@ def _visit_string(self, node):

raise AssertionError('Unexpected annotation function')


def _visit_children(self, node):
self.stack.append(None)
self.generic_visit(node)
Expand Down