Skip to content

Commit

Permalink
Refactor Nano Training to make easy upgrade (#4723)
Browse files Browse the repository at this point in the history
* Refactor Nano Training to make easy upgrade

* add utils

* add graph module back

* add import back

* add copy back

* fix ipex multiprocessing

* fix trainer
  • Loading branch information
yangw1234 authored May 30, 2022
1 parent ad0efc9 commit a4e334e
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 64 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/nano_inc_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,4 @@ jobs:
source $CONDA/bin/deactivate
$CONDA/bin/conda remove -n inc-tf --all
env:
ANALYTICS_ZOO_ROOT: ${{ github.workspace }}
ANALYTICS_ZOO_ROOT: ${{ github.workspace }}
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependency-reduced-pom.xml
.worksheet
*.iml
.idea/
.vscode/

# macOS specific
.DS_Store
Expand Down
2 changes: 1 addition & 1 deletion python/nano/src/bigdl/nano/deps/ray/ray_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ def create_ray_envbase(world_size):


def distributed_ray(*args, **kwargs):
from ray_distributed import RayPlugin
from .ray_distributed import RayPlugin
return RayPlugin(*args, **kwargs)
30 changes: 15 additions & 15 deletions python/nano/src/bigdl/nano/pytorch/plugins/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,23 @@
from typing import Any, List, Optional, Callable

import multiprocessing
from bigdl.nano.pytorch.utils import TORCH_VERSION_LESS_1_10
import torch
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.multiprocessing.spawn import _wrap, ProcessContext

import pytorch_lightning as pl
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.environments import LightningEnvironment
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.seed import reset_seed

from bigdl.nano.common.cpu_schedule import schedule_workers
from bigdl.nano.deps.ipex.ipex_api import ipex_device
import logging

import warnings
log = logging.getLogger(__name__)


Expand Down Expand Up @@ -99,27 +104,22 @@ class DDPSpawnPlugin(pl.plugins.DDPSpawnPlugin):

def __init__(
self,
parallel_devices: Optional[List[torch.device]] = None,
num_nodes: int = 1,
cluster_environment: ClusterEnvironment = None,
sync_batchnorm: bool = False,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[Callable] = None,
ddp_comm_wrapper: Optional[Callable] = None,
num_processes: int = 1,
cpu_for_each_process: Optional[List[List[int]]] = None,
**kwargs: Any,
use_ipex=False,
enable_bf16=False,
):
"""Create a DDPSpawnPlugin, adding a cpu_for_each_process parameter."""
device = ipex_device() if use_ipex and TORCH_VERSION_LESS_1_10 else 'cpu'
parallel_devices = [torch.device(device) for _ in range(num_processes)]
cluster_environment = LightningEnvironment()

super().__init__(parallel_devices,
num_nodes,
cluster_environment,
sync_batchnorm,
ddp_comm_state,
ddp_comm_hook,
ddp_comm_wrapper,
**kwargs)
cluster_environment=cluster_environment)
self.cpu_for_each_process = cpu_for_each_process
self.is_distributed = True
self.use_ipex = use_ipex
self.enable_bf16 = enable_bf16

@property
def mp_spawn_kwargs(self):
Expand Down
63 changes: 26 additions & 37 deletions python/nano/src/bigdl/nano/pytorch/trainer/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,21 @@
import pytorch_lightning as pl
import torch
from pytorch_lightning import LightningModule
from pytorch_lightning.plugins.environments import LightningEnvironment
from torch import nn
from torch.fx.graph_module import GraphModule
from torch.nn.modules.loss import _Loss
from torch.utils.data import DataLoader
from torchmetrics.metric import Metric
from torch.optim.lr_scheduler import _LRScheduler
import yaml
from bigdl.nano.common import check_avx512
from bigdl.nano.pytorch.utils import TORCH_VERSION_LESS_1_10
from bigdl.nano.pytorch.lightning import LightningModuleFromTorch
from bigdl.nano.pytorch.plugins.ddp_spawn import DDPSpawnPlugin
from bigdl.nano.pytorch.plugins.ddp_subprocess import DDPSubprocessPlugin

from bigdl.nano.deps.automl.hpo_api import create_hpo_searcher, check_hpo_status
from bigdl.nano.deps.ray.ray_api import distributed_ray
from bigdl.nano.deps.ipex.ipex_api import create_IPEXAccelerator, ipex_device
from bigdl.nano.deps.ipex.ipex_api import create_IPEXAccelerator
from bigdl.nano.deps.openvino.openvino_api import PytorchOpenVINOModel, load_openvino_model
from bigdl.nano.deps.onnxruntime.onnxruntime_api import bind_onnxrt_methods,\
PytorchONNXRuntimeModel, load_onnxruntime_model
Expand Down Expand Up @@ -92,56 +93,42 @@ def __init__(self, num_processes: int = 1,
else:
self.hposearcher = None

# Initialize trainer
if use_ipex and not check_avx512():
warning("Enable ipex in a cpu instruction set"
" without avx512 may cause some random error."
"Fall back to cpu device.")
use_ipex = False

accelerator = None
if num_processes == 1:
accelerator = None
if use_ipex:
accelerator = create_IPEXAccelerator(enable_bf16=enable_bf16)
if TORCH_VERSION_LESS_1_10:
accelerator = create_IPEXAccelerator(enable_bf16=enable_bf16)
else:
invalidInputError("We currently do not support ipex above 1.9.0")
super().__init__(accelerator=accelerator, *args, **kwargs)
else:
plugin = None
invalidInputError(distributed_backend in distributed_backends,
f"Distributed backends supported now are subprocess, spawn and ray,"
f"Distributed backends supported now are {distributed_backends},"
f" but get {distributed_backend}.")
if distributed_backend == "spawn":
if use_ipex:
device = ipex_device()
else:
device = "cpu"
plugin = DDPSpawnPlugin(parallel_devices=[
torch.device(device) for _ in range(num_processes)],
cpu_for_each_process=cpu_for_each_process,
cluster_environment=LightningEnvironment())
plugin = DDPSpawnPlugin(num_processes=num_processes,
cpu_for_each_process=cpu_for_each_process,
use_ipex=use_ipex,
enable_bf16=enable_bf16)
elif distributed_backend == "subprocess":
from bigdl.nano.pytorch.plugins.ddp_subprocess import DDPSubprocessPlugin
if use_ipex:
import intel_pytorch_extension as ipex
device = ipex.DEVICE
else:
device = "cpu"
plugin = DDPSubprocessPlugin(parallel_devices=[
torch.device(device) for _ in range(num_processes)],
cpu_for_each_process=cpu_for_each_process,
cluster_environment=LightningEnvironment())
plugin = DDPSubprocessPlugin(num_processes=num_processes,
cpu_for_each_process=cpu_for_each_process,
use_ipex=use_ipex,
enable_bf16=enable_bf16)
elif distributed_backend == "ray":
# Import RayPlugins may entangle with openmp even if it has not been used,
# which leads to an unacceptably low performance.
# So we import when we need.
plugin = distributed_ray(num_workers=num_processes, # type: ignore
use_ipex=use_ipex,
device=ipex_device())

accelerator = None
enable_bf16=enable_bf16)
if use_ipex:
accelerator = create_IPEXAccelerator(training_type_plugin=plugin, # type: ignore
enable_bf16=enable_bf16)

if TORCH_VERSION_LESS_1_10:
accelerator = create_IPEXAccelerator(training_type_plugin=plugin,
enable_bf16=enable_bf16)
else:
invalidInputError("We currently do not support ipex above 1.9.0")
super().__init__(accelerator=accelerator,
plugins=[plugin], *args, **kwargs)

Expand Down Expand Up @@ -281,6 +268,8 @@ def quantize(self, model, # remove the type requirement for type checking
# check if dataloader is of legal format
check_pytorch_dataloaders(model, [calib_dataloader, val_dataloader])

model.eval()

if approach not in ['static', 'dynamic']:
invalidInputError(False,
"Approach should be 'static' or 'dynamic', "
Expand Down
20 changes: 20 additions & 0 deletions python/nano/src/bigdl/nano/pytorch/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import operator
from pytorch_lightning.utilities.imports import _compare_version

TORCH_VERSION_LESS_1_10 = _compare_version("torch", operator.lt, "1.10")
24 changes: 14 additions & 10 deletions python/nano/test/pytorch/utils/_train_torch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def train_with_linear_top_layer(model_without_top, batch_size, num_workers, data

def train_torch_lightning(model, batch_size, num_workers, data_dir, accelerator=None,
use_orca_lite_trainer=False):
orig_parameters = deepcopy(list(model.named_parameters()))
orig_parameters = deepcopy(model.state_dict())
# list to store the right key of dict
orig_parameters_list = deepcopy(list(model.named_parameters()))

train_loader = create_data_loader(
data_dir, batch_size, num_workers, data_transform)
Expand All @@ -104,20 +106,22 @@ def train_torch_lightning(model, batch_size, num_workers, data_dir, accelerator=

trainer.fit(model, train_loader)

trained_parameters = list(model.named_parameters())
trained_parameters = model.state_dict()
trained_parameters_list = list(model.named_parameters())

# Check if the training and the freeze operation is successful
for i in range(len(orig_parameters)):
name1, para1 = orig_parameters[i]
name2, para2 = trained_parameters[i]
if name1 == "model.1.bias" or name1 == "model.1.weight" or \
name1 == "new_classifier.1.bias" or name1 == "new_classifier.1.weight":
for i in range(len(orig_parameters_list)):
name, para = orig_parameters_list[i]
para1 = orig_parameters[name]
para2 = trained_parameters[name]

if name == "model.1.bias" or name == "model.1.weight" or \
name == "new_classifier.1.bias" or name == "new_classifier.1.weight":
# Top layer is trained
if torch.all(torch.eq(para1, para2)):
raise Exception("Parameter " + name1 +
" remains the same after training.")
raise Exception("Parameter " + name + " remains the same after training.")
else:
# Frozen parameters should not change
if not torch.all(torch.eq(para1, para2)):
raise Exception(name1 + " freeze failed.")
raise Exception(name + " freeze failed.")
print("pass")

0 comments on commit a4e334e

Please sign in to comment.