Skip to content

Commit

Permalink
add neural_coder for neural_coder INC integration in INC 1.13 rls (in…
Browse files Browse the repository at this point in the history
  • Loading branch information
kaikaiyao authored Jul 11, 2022
1 parent f5669f5 commit b3a409f
Show file tree
Hide file tree
Showing 37 changed files with 5,264 additions and 0 deletions.
78 changes: 78 additions & 0 deletions neural_coder/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
Neural Coder
===========================
## What do we offer?

Neural Coder is a novel deployment toolkit for one-click acceleration on Deep Learning scripts via performing automated code insertions of CUDA to CPU platform conversions and Deep Learning optimization APIs. Subsequently, Neural Coder can perform automated benchmark on all applicable optimization sets acquired from the automated enabling, and evaluate for the best out-of-box performance.

Neural Coder leverages static program analysis techniques and heuristic optimization rules to simplify the usage of various Deep Learning optimization APIs for increasing computation efficiency of AI models and improving user experience for general AI customers. We demonstrate great improvement of developer productivity and aim to facilitate enhanced Deep Learning acceleration adoption via this toolkit.

Neural Coder helps you code Deep Learning optimizations automatically into your scripts. For example, to apply
- Automatic Mixed Precision (torch.cpu.amp.autocast)
- JIT Script computation graph transformation (torch.jit.script)
- Channels Last memory format transformation (torch.channels_last)

simultaneously on below PyTorch evaluation code, we generate the optimized code in one-click by detecting the correct position to insert the correct API code lines:
```diff
import torch
import torchvision.models as models
my_model = models.resnet50(pretrained=True)
+ import torch
+ with torch.no_grad():
+ my_model = my_model.to(memory_format=torch.channels_last)
+ import torch
+ with torch.no_grad():
+ my_model.eval()
+ my_model = torch.jit.script(my_model)
+ my_model = torch.jit.freeze(my_model)
my_model.eval()
batch_size = 112
input = torch.rand(batch_size, 3, 224, 224)
with torch.no_grad():
+ import torch
+ with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
my_model(input)
```

## Getting Started!
We currently provide 3 user-facing APIs: enable, bench and superbench.
#### Enable
Users can use ```enable()``` to enable specific features into DL scripts:
```
from neural_coder import enable
enable(code="examples/vision/resnet50.py",
features=["pytorch_jit_script", "pytorch_channels_last"])
```
To run benchmark directly on the optimization together with the enabling:
```
from neural_coder import enable
enable(code="examples/vision/resnet50.py",
features=["pytorch_jit_script", "pytorch_channels_last"],
run_bench=True,
mode="throughput")
```
#### Bench
To run benchmark on your code with an existing patch:
```
from neural_coder import bench
bench(code="examples/vision/resnet50.py",
patch_path="${your_patch_path}",
mode="throughput")
```
#### SuperBench
To sweep on optimization sets with a fixed benchmark configuration:
```
from neural_coder import superbench
superbench(code="examples/vision/resnet50.py",
sweep_objective="feature",
mode="throughput")
```
To sweep on benchmark configurations for a fixed optimization set:
```
from neural_coder import superbench
superbench(code="examples/vision/resnet50.py",
sweep_objective="bench_config",
bench_feature=["pytorch_jit_script","pytorch_channels_last"])
```

## Contact
Please contact us at [[email protected]](mailto:[email protected]) for any Neural Coder related question.
17 changes: 17 additions & 0 deletions neural_coder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .interface import enable
from .interface import bench
from .interface import superbench
15 changes: 15 additions & 0 deletions neural_coder/coders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


15 changes: 15 additions & 0 deletions neural_coder/coders/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


136 changes: 136 additions & 0 deletions neural_coder/coders/pytorch/amp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (c) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from ... import globals
from ...utils.line_operation import get_line_indent_level, is_eval_func_model_name

import logging

logging.basicConfig(level=globals.logging_level,
format='%(asctime)s %(levelname)s %(message)s',
datefmt='%a, %d %b %Y %H:%M:%S +0000')
logger = logging.getLogger(__name__)


class PTAMP(object):
def __init__(self, mode):
self.mode = mode

# collect file transformation info and register (store) in globals
# (i.e. which file to add which lines at which location)
def register_transformation(self):
for file_path in globals.list_code_path:
code = open(file_path, 'r').read()
lines = code.split('\n')
line_idx = 0
for i in range(len(lines)):
line = lines[i]
for model_name in globals.list_model_name:
if is_eval_func_model_name(model_name, line) and "# Neural Coder appended" not in line:
indent_level = get_line_indent_level(line)

# 1. indenting
# indenting can have multiple location, so is a list of numbers
trans_indenting_location = []
trans_indenting_level = []

if ")" in line: # e.g. model(xxx)
trans_indenting_location.append(line_idx)
trans_indenting_level.append(1)
else: # e.g. model(xxx,
# xxx,
# xxx
# )
trans_indenting_location.append(line_idx)
trans_indenting_level.append(1)
do_search = True
i_search = 1
while do_search:
trans_indenting_location.append(
line_idx + i_search)
trans_indenting_level.append(1)
following_line = lines[line_idx + i_search]
if ")" in following_line:
do_search = False
i_search += 1

# 2. insert
trans_insert_location = line_idx # insert only has 1 location, so is a number

lines_to_insert = ""
if self.mode == "cpu":
lines_to_insert += " " * indent_level + "import torch" + "\n"
lines_to_insert += " " * indent_level + \
"with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):"
if self.mode == "cuda":
lines_to_insert += " " * indent_level + "import torch" + "\n"
lines_to_insert += " " * indent_level + \
"with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16):"

# 1. indenting: transform "model(input)" to " model(input)"
if file_path not in globals.list_trans_indenting_modified_file:
globals.list_trans_indenting_modified_file.append(
file_path)
globals.list_trans_indenting_location_idxs.append(
trans_indenting_location)
globals.list_trans_indenting_level.append(
trans_indenting_level)
else:
idx = globals.list_trans_indenting_modified_file.index(
file_path)
for i in trans_indenting_location:
globals.list_trans_indenting_location_idxs[idx].append(
i)
for i in trans_indenting_level:
globals.list_trans_indenting_level[idx].append(
i)

# 2. insert: add "with autocast()" line
if file_path not in globals.list_trans_insert_modified_file:
globals.list_trans_insert_modified_file.append(
file_path)
globals.list_trans_insert_location_idxs.append(
[trans_insert_location])
globals.list_trans_insert_number_insert_lines.append(
[lines_to_insert.count("\n") + 1])
globals.list_trans_insert_lines_to_insert.append(
[lines_to_insert])
else:
idx = globals.list_trans_insert_modified_file.index(
file_path)
globals.list_trans_insert_location_idxs[idx].append(
trans_insert_location)
globals.list_trans_insert_number_insert_lines[idx].append(
lines_to_insert.count("\n") + 1)
globals.list_trans_insert_lines_to_insert[idx].append(
lines_to_insert)

line_idx += 1

logger.debug(
f"globals.list_trans_indenting_modified_file: {globals.list_trans_indenting_modified_file}")
logger.debug(
f"globals.list_trans_indenting_location_idxs: {globals.list_trans_indenting_location_idxs}")
logger.debug(
f"globals.list_trans_indenting_level: {globals.list_trans_indenting_level}")

logger.debug(
f"globals.list_trans_insert_modified_file: {globals.list_trans_insert_modified_file}")
logger.debug(
f"globals.list_trans_insert_location_idxs: {globals.list_trans_insert_location_idxs}")
logger.debug(
f"globals.list_trans_insert_number_insert_lines: {globals.list_trans_insert_number_insert_lines}")
logger.debug(
f"globals.list_trans_insert_lines_to_insert: {globals.list_trans_insert_lines_to_insert}")
87 changes: 87 additions & 0 deletions neural_coder/coders/pytorch/batch_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from ... import globals


class BatchSizeCoder(object):
def __init__(self, file) -> None:
self.file = file
self.result = []

def transform(self):
lines = self.file.split('\n')
for line in lines:
if self.not_modify(line):
new_line = self.modify(line)
self.result.append(new_line)
else:
if line == '' and self.result[-1] == '':
continue
self.result.append(line)
for index, line in enumerate(self.result):
if index != len(self.result)-1:
self.result[index] += '\n'
return ''.join(self.result)

def not_modify(self, s):
if 'batch_size' in s and '=' in s:
return True
return False

def modify(self, s):
idx = s.find('batch_size')
s_right = s[idx:]
if ' = ' in s_right:
index = s.find(' = ')
s_left = s[:index]
if 'batch_size' in s_left:
if ',' in s_left:
index1 = s_left.find(',')
index2 = s_left.find('batch_size')
if index1 > index2:
slice1 = s_left[:index1]
else:
s_left1 = s_left[:index2]
s_right = s_left[index2:]
index3 = s_left1.rfind(',')
if ',' in s_right:
index4 = s_right.find(',') + len(s_left1)
slice1 = s_left[index3+2:index4]
else:
slice1 = s_left[index3+2:index]
s1 = slice1 + ' = ' + globals.target_batch_size
s = s[:] + '\n' + s1
else:
s_right = s[index+3:]
s_right = s_right.replace(
s_right, globals.target_batch_size)
s = s_left + ' = ' + s_right
elif 'batch_size=' in s:
idx = s.find('batch_size=')
s_right = s[idx:]
idx2 = s_right.find('batch_size')
if ',' in s_right:
index2 = s_right.find(',')
old = s_right[idx2:index2]
s = s.replace(old, "batch_size=" + globals.target_batch_size)
elif ')' in s_right:
index2 = s_right.find(')')
old = s_right[idx2:index2]
s = s.replace(old, "batch_size=" + globals.target_batch_size)
else:
old = s_right[idx2:]
s = s.replace(old, "batch_size=" + globals.target_batch_size)
return s
Loading

0 comments on commit b3a409f

Please sign in to comment.