diff --git a/examples/README.md b/examples/README.md index a144627c5db..960279e7864 100644 --- a/examples/README.md +++ b/examples/README.md @@ -303,6 +303,32 @@ Intel® Neural Compressor validated examples with multiple compression technique +## Pruning + + + + + + + + + + + + + + + + + + + + + + + +
ModelDomainApproachExamples
ResNet V2Image RecognitionStructured (4x1, 2in4)keras
ViTImage RecognitionStructured (4x1, 2in4)keras
+ ## Model Export diff --git a/examples/tensorflow/image_recognition/ViT/pruning/magnitude/README.md b/examples/tensorflow/image_recognition/ViT/pruning/magnitude/README.md new file mode 100644 index 00000000000..3e3f419ad33 --- /dev/null +++ b/examples/tensorflow/image_recognition/ViT/pruning/magnitude/README.md @@ -0,0 +1,38 @@ +Step-by-Step +============ + +This document is used to list steps of reproducing Intel® Neural Compressor magnitude pruning feature on ViT model. + + +# Prerequisite + +## 1. Environment + +### Install Intel® Neural Compressor +```shell +pip install neural-compressor +``` +### Install requirements +```shell +pip install -r requirements.txt +``` + +## 2. Prepare Model +Run the script to save a baseline model to the directory './ViT_Model'. +```python +python prepare_model.py +``` + +# Run +Run the command to prune the baseline model and save it into a given path. +The CIFAR100 dataset will be automatically loaded. + +```shell +python main.py --output_model=/path/to/output_model/ +``` + +If you want to accelerate pruning with multi-node distributed training and evaluation, you only need to add twp arguments and use horovod to run main.py. Run the command to get pruned model with multi-node distributed training and evaluation. + +```shell +horovodrun -np -H python main.py --output_model=/path/to/output_model/ --train_distributed --evaluation_distributed +``` \ No newline at end of file diff --git a/examples/tensorflow/image_recognition/ViT/pruning/magnitude/main.py b/examples/tensorflow/image_recognition/ViT/pruning/magnitude/main.py new file mode 100644 index 00000000000..d8607593969 --- /dev/null +++ b/examples/tensorflow/image_recognition/ViT/pruning/magnitude/main.py @@ -0,0 +1,147 @@ +# +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +import tensorflow as tf +import tensorflow_addons as tfa +from neural_compressor.utils import logger +from neural_compressor.data import DataLoader +from neural_compressor.adaptor import FRAMEWORKS +from neural_compressor.conf.dotdict import DotDict +from neural_compressor.training import WeightPruningConfig +from neural_compressor.training import prepare_compression +from neural_compressor.utils import create_obj_from_config +from neural_compressor.conf.config import default_workspace + +flags = tf.compat.v1.flags +FLAGS = flags.FLAGS + +## Required parameters +flags.DEFINE_string( + 'output_model', None, 'The output pruned model.') + +flags.DEFINE_integer( + 'start_step', 0, 'The start step of pruning process.') + +flags.DEFINE_integer( + 'end_step', 9, 'The end step of pruning process.') + +flags.DEFINE_bool( + 'train_distributed', False, 'Whether to perform distributed training.') + +flags.DEFINE_bool( + 'evaluation_distributed', False, 'Whether to perform distributed evaluation.') + +# Prepare dataset +def prepare_dataset(): + (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data() + y_train = tf.keras.utils.to_categorical(y_train, 100) + y_test = tf.keras.utils.to_categorical(y_test, 100) + logger.info(f"Training set: x_shape-{x_train.shape}, y_shape-{y_train.shape}") + logger.info(f"Test set: x_shape-{x_test.shape}, y_shape-{y_test.shape}") + return TrainDataset(x_train, y_train), EvalDataset(x_test, y_test) + +# Build TrainDataset and EvalDataset +class TrainDataset(object): + def __init__(self, x_train, y_train): + self.x_train = x_train + self.y_train = y_train + + def __len__(self): + return len(self.x_train) + + def __getitem__(self, idx): + return self.x_train[idx], self.y_train[idx] + +class EvalDataset(object): + def __init__(self, x_test, y_test): + self.x_test = x_test + self.y_test = y_test + + def __len__(self): + return len(self.x_test) + + def __getitem__(self, idx): + return self.x_test[idx], self.y_test[idx] + +def train(model, adaptor, compression_manager, train_dataloader): + train_cfg = { + 'epoch': 15, + 'start_epoch': 0, + 'execution_mode': 'eager', + 'criterion': {'CrossEntropyLoss': {'reduction': 'sum_over_batch_size', 'from_logits': True}}, + 'optimizer': {'AdamW': {'learning_rate': 1e-03, 'weight_decay': 1e-04}}, + } + train_cfg = DotDict(train_cfg) + train_func = create_obj_from_config.create_train_func('tensorflow', \ + train_dataloader, \ + adaptor, \ + train_cfg, \ + hooks=compression_manager.callbacks.callbacks_list[0].hooks, \ + callbacks=compression_manager.callbacks.callbacks_list[0]) + train_func(model) + +def evaluate(model, adaptor, eval_dataloader): + eval_cfg = {'accuracy': {'metric': {'topk': 1}, + 'iteration': -1, + 'multi_metrics': None} + } + eval_cfg = DotDict(eval_cfg) + eval_func = create_obj_from_config.create_eval_func('tensorflow', \ + eval_dataloader, \ + adaptor, \ + eval_cfg.accuracy.metric, \ + eval_cfg.accuracy.postprocess, \ + fp32_baseline = False) + return eval_func(model) + +if __name__ == '__main__': + training_set, test_set = prepare_dataset() + train_dataloader = DataLoader(dataset=training_set, batch_size=128, + framework='tensorflow', distributed=FLAGS.train_distributed) + eval_dataloader = DataLoader(dataset=test_set, batch_size=256, + framework='tensorflow', distributed=FLAGS.evaluation_distributed) + + framework_specific_info = { + 'device': 'cpu', 'random_seed': 9527, + 'workspace_path': default_workspace, + 'q_dataloader': None, 'format': 'default', + 'backend': 'default', 'inputs': [], 'outputs': [] + } + adaptor = FRAMEWORKS['keras'](framework_specific_info) + + configs = WeightPruningConfig( + backend='itex', + pruning_type='magnitude', + target_sparsity=0.7, + start_step=FLAGS.start_step, + end_step=FLAGS.end_step, + pruning_op_types=['Conv', 'Dense'] + ) + compression_manager = prepare_compression(model='./ViT_Model', confs=configs) + compression_manager.callbacks.on_train_begin() + model = compression_manager.model + + train(model, adaptor, compression_manager, train_dataloader) + print("Pruned model score is ",evaluate(model, adaptor, eval_dataloader)) + + + compression_manager.callbacks.on_train_end() + compression_manager.save(FLAGS.output_model) + stats, sparsity = model.report_sparsity() + logger.info(stats) + logger.info(sparsity) \ No newline at end of file diff --git a/examples/tensorflow/image_recognition/ViT/pruning/magnitude/prepare_model.py b/examples/tensorflow/image_recognition/ViT/pruning/magnitude/prepare_model.py new file mode 100644 index 00000000000..c89a2e00bf8 --- /dev/null +++ b/examples/tensorflow/image_recognition/ViT/pruning/magnitude/prepare_model.py @@ -0,0 +1,180 @@ +# +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import layers +import tensorflow_addons as tfa + +num_classes = 100 +input_shape = (32, 32, 3) + +(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data() + +print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}") +print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}") + +learning_rate = 0.001 +weight_decay = 0.0001 +batch_size = 256 +num_epochs = 100 +image_size = 72 # We'll resize input images to this size +patch_size = 6 # Size of the patches to be extract from the input images +num_patches = (image_size // patch_size) ** 2 +projection_dim = 64 +num_heads = 4 +transformer_units = [ + projection_dim * 2, + projection_dim, +] # Size of the transformer layers +transformer_layers = 8 +mlp_head_units = [2048, 1024] # Size of the dense layers of the final classifier + +data_augmentation = keras.Sequential( + [ + layers.Normalization(), + layers.Resizing(image_size, image_size), + layers.RandomFlip("horizontal"), + layers.RandomRotation(factor=0.02), + layers.RandomZoom( + height_factor=0.2, width_factor=0.2 + ), + ], + name="data_augmentation", +) +# Compute the mean and the variance of the training data for normalization. +data_augmentation.layers[0].adapt(x_train) + + +def mlp(x, hidden_units, dropout_rate): + for units in hidden_units: + x = layers.Dense(units, activation=tf.nn.gelu)(x) + x = layers.Dropout(dropout_rate)(x) + return x + +class Patches(layers.Layer): + def __init__(self, patch_size): + super().__init__() + self.patch_size = patch_size + + def call(self, images): + batch_size = tf.shape(images)[0] + patches = tf.image.extract_patches( + images=images, + sizes=[1, self.patch_size, self.patch_size, 1], + strides=[1, self.patch_size, self.patch_size, 1], + rates=[1, 1, 1, 1], + padding="VALID", + ) + patch_dims = patches.shape[-1] + patches = tf.reshape(patches, [batch_size, -1, patch_dims]) + return patches + +class PatchEncoder(layers.Layer): + def __init__(self, num_patches, projection_dim): + super().__init__() + self.num_patches = num_patches + self.projection = layers.Dense(units=projection_dim) + self.position_embedding = layers.Embedding( + input_dim=num_patches, output_dim=projection_dim + ) + + def call(self, patch): + positions = tf.range(start=0, limit=self.num_patches, delta=1) + encoded = self.projection(patch) + self.position_embedding(positions) + return encoded + + +def create_vit_classifier(): + inputs = layers.Input(shape=input_shape) + # Augment data. + augmented = data_augmentation(inputs) + # Create patches. + patches = Patches(patch_size)(augmented) + # Encode patches. + encoded_patches = PatchEncoder(num_patches, projection_dim)(patches) + + # Create multiple layers of the Transformer block. + for _ in range(transformer_layers): + # Layer normalization 1. + x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches) + # Create a multi-head attention layer. + attention_output = layers.MultiHeadAttention( + num_heads=num_heads, key_dim=projection_dim, dropout=0.1 + )(x1, x1) + # Skip connection 1. + x2 = layers.Add()([attention_output, encoded_patches]) + # Layer normalization 2. + x3 = layers.LayerNormalization(epsilon=1e-6)(x2) + # MLP. + x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1) + # Skip connection 2. + encoded_patches = layers.Add()([x3, x2]) + + # Create a [batch_size, projection_dim] tensor. + representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches) + representation = layers.Flatten()(representation) + representation = layers.Dropout(0.5)(representation) + # Add MLP. + features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5) + # Classify outputs. + logits = layers.Dense(num_classes)(features) + # Create the Keras model. + model = keras.Model(inputs=inputs, outputs=logits) + return model + +def run_experiment(model): + optimizer = tfa.optimizers.AdamW( + learning_rate=learning_rate, weight_decay=weight_decay + ) + + model.compile( + optimizer=optimizer, + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=[ + keras.metrics.SparseCategoricalAccuracy(name="accuracy"), + keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"), + ], + ) + + checkpoint_filepath = "/tmp/checkpoint" + checkpoint_callback = keras.callbacks.ModelCheckpoint( + checkpoint_filepath, + monitor="val_accuracy", + save_best_only=True, + save_weights_only=True, + ) + + history = model.fit( + x=x_train, + y=y_train, + batch_size=batch_size, + epochs=num_epochs, + validation_split=0.1, + callbacks=[checkpoint_callback], + ) + + model.load_weights(checkpoint_filepath) + _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test) + print(f"Test accuracy: {round(accuracy * 100, 2)}%") + print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%") + model.save("./ViT_Model") # Add this line + return history + +vit_classifier = create_vit_classifier() +history = run_experiment(vit_classifier) \ No newline at end of file diff --git a/examples/tensorflow/image_recognition/ViT/pruning/magnitude/requirements.txt b/examples/tensorflow/image_recognition/ViT/pruning/magnitude/requirements.txt new file mode 100644 index 00000000000..e904ddad36e --- /dev/null +++ b/examples/tensorflow/image_recognition/ViT/pruning/magnitude/requirements.txt @@ -0,0 +1,5 @@ +tensorflow +keras +tensorflow-estimator +tensorflow-addons +horovod \ No newline at end of file diff --git a/examples/tensorflow/image_recognition/resnet_v2/pruning/magnitude/README.md b/examples/tensorflow/image_recognition/resnet_v2/pruning/magnitude/README.md new file mode 100644 index 00000000000..a99c087ef23 --- /dev/null +++ b/examples/tensorflow/image_recognition/resnet_v2/pruning/magnitude/README.md @@ -0,0 +1,35 @@ +Step-by-Step +============ + +This document is used to list steps of reproducing Intel® Neural Compressor magnitude pruning feature. + + +# Prerequisite + +## 1. Environment + +### Install Intel® Neural Compressor +```shell +pip install neural-compressor +``` +### Install TensorFlow +```shell +pip install tensorflow +``` + +# Run +Run the command to get pretrained baseline model which will be saved to './baseline_model'. Then, the model will be pruned and saved into a given path. +The CIFAR10 dataset will be automatically loaded. +```shell +python main.py --output_model=/path/to/output_model/ --prune +``` +If you want to accelerate pruning with multi-node distributed training and evaluation, you only need to add two arguments and use horovod to run main.py. +Use horovod to run main.py to get pruned model with multi-node distributed training and evaluation. +```shell +horovodrun -np -H python main.py --output_model=/path/to/output_model/ --train_distributed --evaluation_distributed --prune +``` + +Run the command to get pruned model performance. +```shell +python main.py --input_model=/path/to/input_model/ --benchmark +``` diff --git a/examples/tensorflow/image_recognition/resnet_v2/pruning/magnitude/main.py b/examples/tensorflow/image_recognition/resnet_v2/pruning/magnitude/main.py new file mode 100644 index 00000000000..450ba18aca7 --- /dev/null +++ b/examples/tensorflow/image_recognition/resnet_v2/pruning/magnitude/main.py @@ -0,0 +1,453 @@ +# +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import print_function +import yaml +import numpy as np +import tensorflow + +from tensorflow.keras.layers import Dense, Conv2D, BatchNormalization, Activation +from tensorflow.keras.layers import AveragePooling2D, Input, Flatten +from tensorflow.keras.callbacks import LearningRateScheduler +from tensorflow.keras.callbacks import ReduceLROnPlateau +from tensorflow.keras.regularizers import l2 +from tensorflow.keras.models import Model +from tensorflow.keras.datasets import cifar10 + + +from neural_compressor.utils import logger +from neural_compressor.data import DataLoader +from neural_compressor.adaptor import FRAMEWORKS +from neural_compressor.conf.dotdict import DotDict +from neural_compressor.training import WeightPruningConfig +from neural_compressor.training import prepare_compression +from neural_compressor.utils import create_obj_from_config +from neural_compressor.conf.config import default_workspace + +flags = tensorflow.compat.v1.flags +FLAGS = flags.FLAGS + +## Required parameters +flags.DEFINE_bool( + 'prune', False, 'Whether to perform distributed training.') + +flags.DEFINE_bool( + 'benchmark', False, 'Whether to perform distributed evaluation.') + +flags.DEFINE_string( + 'input_model', None, 'Run inference with specified model.') + +flags.DEFINE_string( + 'output_model', None, 'The output pruned model.') + +flags.DEFINE_integer( + 'start_step', 0, 'The start step of pruning process.') + +flags.DEFINE_integer( + 'end_step', 7, 'The end step of pruning process.') + +flags.DEFINE_integer( + 'iters', 100, 'The iteration of evaluating the performance.') + +flags.DEFINE_bool( + 'train_distributed', False, 'Whether to perform distributed training.') + +flags.DEFINE_bool( + 'evaluation_distributed', False, 'Whether to perform distributed evaluation.') + + +def lr_schedule(epoch): + """Learning Rate Schedule + + Learning rate is scheduled to be reduced after 80, 120, 160, 180 epochs. + Called automatically every epoch as part of callbacks during training. + + # Arguments + epoch (int): The number of epochs + + # Returns + lr (float32): learning rate + """ + lr = 1e-3 + if epoch > 180: + lr *= 0.5e-3 + elif epoch > 160: + lr *= 1e-3 + elif epoch > 120: + lr *= 1e-2 + elif epoch > 80: + lr *= 1e-1 + print('Learning rate: ', lr) + return lr + + +def resnet_layer(inputs, + num_filters=16, + kernel_size=3, + strides=1, + activation='relu', + batch_normalization=True, + conv_first=True): + """2D Convolution-Batch Normalization-Activation stack builder + + # Arguments + inputs (tensor): input tensor from input image or previous layer + num_filters (int): Conv2D number of filters + kernel_size (int): Conv2D square kernel dimensions + strides (int): Conv2D square stride dimensions + activation (string): activation name + batch_normalization (bool): whether to include batch normalization + conv_first (bool): conv-bn-activation (True) or + bn-activation-conv (False) + + # Returns + x (tensor): tensor as input to the next layer + """ + conv = Conv2D(num_filters, + kernel_size=kernel_size, + strides=strides, + padding='same', + use_bias=True, + kernel_initializer='he_normal', + kernel_regularizer=l2(1e-4)) + + x = inputs + if conv_first: + x = conv(x) + if batch_normalization: + x = BatchNormalization()(x) + if activation is not None: + x = Activation(activation)(x) + else: + if batch_normalization: + x = BatchNormalization()(x) + if activation is not None: + x = Activation(activation)(x) + x = conv(x) + return x + +def resnet_v2(input_shape, depth, num_classes=10): + """ResNet Version 2 Model builder [b] + + Stacks of (1 x 1)-(3 x 3)-(1 x 1) BN-ReLU-Conv2D or also known as + bottleneck layer + First shortcut connection per layer is 1 x 1 Conv2D. + Second and onwards shortcut connection is identity. + At the beginning of each stage, the feature map size is halved (downsampled) + by a convolutional layer with strides=2, while the number of filter maps is + doubled. Within each stage, the layers have the same number filters and the + same filter map sizes. + Features maps sizes: + conv1 : 32x32, 16 + stage 0: 32x32, 64 + stage 1: 16x16, 128 + stage 2: 8x8, 256 + + # Arguments + input_shape (tensor): shape of input image tensor + depth (int): number of core convolutional layers + num_classes (int): number of classes (CIFAR10 has 10) + + # Returns + model (Model): Keras model instance + """ + if (depth - 2) % 9 != 0: + raise ValueError('depth should be 9n+2 (eg 56 or 110 in [b])') + # Start model definition. + num_filters_in = 16 + num_res_blocks = int((depth - 2) / 9) + + inputs = Input(shape=input_shape) + # v2 performs Conv2D with BN-ReLU on input before splitting into 2 paths + x = resnet_layer(inputs=inputs, + num_filters=num_filters_in, + conv_first=True) + + # Instantiate the stack of residual units + for stage in range(3): + for res_block in range(num_res_blocks): + activation = 'relu' + batch_normalization = True + strides = 1 + if stage == 0: + num_filters_out = num_filters_in * 4 + if res_block == 0: # first layer and first stage + activation = None + batch_normalization = False + else: + num_filters_out = num_filters_in * 2 + if res_block == 0: # first layer but not first stage + strides = 2 # downsample + + # bottleneck residual unit + y = resnet_layer(inputs=x, + num_filters=num_filters_in, + kernel_size=1, + strides=strides, + activation=activation, + batch_normalization=batch_normalization, + conv_first=False) + y = resnet_layer(inputs=y, + num_filters=num_filters_in, + conv_first=False) + + y = resnet_layer(inputs=y, + num_filters=num_filters_out, + kernel_size=1, + conv_first=False) + if res_block == 0: + # linear projection residual shortcut connection to match + # changed dims + x = resnet_layer(inputs=x, + num_filters=num_filters_out, + kernel_size=1, + strides=strides, + activation=None, + batch_normalization=False) + x = tensorflow.keras.layers.add([x, y]) + + num_filters_in = num_filters_out + + # Add classifier on top. + # v2 has BN-ReLU before Pooling + x = BatchNormalization()(x) + x = Activation('relu')(x) + x = AveragePooling2D(pool_size=8)(x) + y = Flatten()(x) + outputs = Dense(num_classes, + activation='softmax', + kernel_initializer='he_normal')(y) + + # Instantiate model. + model = Model(inputs=inputs, outputs=outputs) + return model + +# Training parameters +batch_size = 32 # orig paper trained all networks with batch_size=128 +epochs = 2 +num_classes = 10 + +# Subtracting pixel mean improves accuracy +subtract_pixel_mean = True + +# Model parameter +# ---------------------------------------------------------------------------- +# | | 200-epoch | Orig Paper| 200-epoch | Orig Paper| sec/epoch +# Model | n | ResNet v1 | ResNet v1 | ResNet v2 | ResNet v2 | GTX1080Ti +# |v1(v2)| %Accuracy | %Accuracy | %Accuracy | %Accuracy | v1 (v2) +# ---------------------------------------------------------------------------- +# ResNet20 | 3 (2)| 92.16 | 91.25 | ----- | ----- | 35 (---) +# ResNet32 | 5(NA)| 92.46 | 92.49 | NA | NA | 50 ( NA) +# ResNet44 | 7(NA)| 92.50 | 92.83 | NA | NA | 70 ( NA) +# ResNet56 | 9 (6)| 92.71 | 93.03 | 93.01 | NA | 90 (100) +# ResNet110 |18(12)| 92.65 | 93.39+-.16| 93.15 | 93.63 | 165(180) +# ResNet164 |27(18)| ----- | 94.07 | ----- | 94.54 | ---(---) +# ResNet1001| (111)| ----- | 92.39 | ----- | 95.08+-.14| ---(---) +# --------------------------------------------------------------------------- +n = 3 +depth = n * 9 + 2 + +def generate_pretrained_model(): + # Load the CIFAR10 data. + (x_train, y_train), (x_test, y_test) = cifar10.load_data() + + # Input image dimensions. + input_shape = x_train.shape[1:] + # Normalize data. + x_train = x_train.astype('float32') / 255 + x_test = x_test.astype('float32') / 255 + + x_train_mean = np.mean(x_train, axis=0) + x_train -= x_train_mean + x_test -= x_train_mean + + print('x_train shape:', x_train.shape) + print(x_train.shape[0], 'train samples') + print(x_test.shape[0], 'test samples') + print('y_train shape:', y_train.shape) + + # Convert class vectors to binary class matrices. + y_train = tensorflow.keras.utils.to_categorical(y_train, num_classes) + y_test = tensorflow.keras.utils.to_categorical(y_test, num_classes) + + model = resnet_v2(input_shape=input_shape, depth=depth) + + model.compile(loss='categorical_crossentropy', + optimizer=tensorflow.keras.optimizers.Adam(learning_rate=0.01), + metrics=['accuracy']) + model.summary() + + lr_scheduler = LearningRateScheduler(lr_schedule) + + lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.1), + cooldown=0, + patience=5, + min_lr=0.5e-6) + + callbacks = [lr_reducer, lr_scheduler] + + # Run training, with or without data augmentation. + model.fit(x_train, y_train, + batch_size=batch_size, + epochs=epochs, + validation_data=(x_test, y_test), + shuffle=True, + callbacks=callbacks) + + + # Score trained model. + scores = model.evaluate(x_test, y_test, verbose=1) + print('Test loss:', scores[0]) + print('Test accuracy:', scores[1]) + model.save("baseline_model") + +class EvalDataset(object): + def __init__(self, batch_size=100): + (x_train, y_train), (x_test, y_test) = cifar10.load_data() + + x_train = x_train.astype('float32') / 255 + x_test = x_test.astype('float32') / 255 + + # If subtract pixel mean is enabled + x_train_mean = np.mean(x_train, axis=0) + x_train -= x_train_mean + x_test -= x_train_mean + + print('x_train shape:', x_train.shape) + print(x_train.shape[0], 'train samples') + print(x_test.shape[0], 'test samples') + print('y_train shape:', y_train.shape) + + # Convert class vectors to binary class matrices. + y_train = tensorflow.keras.utils.to_categorical(y_train, num_classes) + y_test = tensorflow.keras.utils.to_categorical(y_test, num_classes) + self.test_images = x_test + self.test_labels = y_test + + def __len__(self): + return len(self.test_images) + + def __getitem__(self, idx): + return self.test_images[idx], self.test_labels[idx] + +class TrainDataset(object): + def __init__(self, batch_size=100): + (x_train, y_train), (x_test, y_test) = cifar10.load_data() + + x_train = x_train.astype('float32') / 255 + x_test = x_test.astype('float32') / 255 + + # If subtract pixel mean is enabled + x_train_mean = np.mean(x_train, axis=0) + x_train -= x_train_mean + x_test -= x_train_mean + + print('x_train shape:', x_train.shape) + print(x_train.shape[0], 'train samples') + print(x_test.shape[0], 'test samples') + print('y_train shape:', y_train.shape) + + # Convert class vectors to binary class matrices. + y_train = tensorflow.keras.utils.to_categorical(y_train, num_classes) + y_test = tensorflow.keras.utils.to_categorical(y_test, num_classes) + self.test_images = x_test + self.test_labels = y_test + self.train_images = x_train + self.train_labels = y_train + + def __len__(self): + return len(self.train_images) + + def __getitem__(self, idx): + return self.train_images[idx], self.train_labels[idx] + +def train(model, adaptor, compression_manager, train_dataloader): + train_cfg = { + 'epoch': 8, + 'start_epoch': 0, + 'execution_mode': 'eager', + 'criterion': {'CrossEntropyLoss': {'reduction': 'sum_over_batch_size'}}, + 'optimizer': {'SGD': {'learning_rate': 1e-03, 'momentum': 0.9, 'nesterov': True}}, + } + train_cfg = DotDict(train_cfg) + train_func = create_obj_from_config.create_train_func( + 'tensorflow', \ + train_dataloader, \ + adaptor, \ + train_cfg, \ + hooks=compression_manager.callbacks.callbacks_list[0].hooks, \ + callbacks=compression_manager.callbacks.callbacks_list[0]) + train_func(model) + + +def evaluate(model, adaptor, eval_dataloader): + eval_cfg = {'accuracy': {'metric': {'topk': 1}, + 'iteration': -1, + 'multi_metrics': None} + } + eval_cfg = DotDict(eval_cfg) + eval_func = create_obj_from_config.create_eval_func('tensorflow', \ + eval_dataloader, \ + adaptor, \ + eval_cfg.accuracy.metric, \ + eval_cfg.accuracy.postprocess, \ + fp32_baseline = False) + return eval_func(model) + +if __name__ == '__main__': + train_dataloader = DataLoader(dataset=TrainDataset(), batch_size=32, + framework='tensorflow', distributed=FLAGS.train_distributed) + eval_dataloader = DataLoader(dataset=EvalDataset(), batch_size=32, + framework='tensorflow', distributed=FLAGS.evaluation_distributed) + + if FLAGS.prune: + generate_pretrained_model() + framework_specific_info = { + 'device': 'cpu', 'random_seed': 9527, + 'workspace_path': default_workspace, + 'q_dataloader': None, 'format': 'default', + 'backend': 'default', 'inputs': [], 'outputs': [] + } + adaptor = FRAMEWORKS['keras'](framework_specific_info) + + configs = WeightPruningConfig( + backend='itex', + pruning_type='magnitude', + pattern='3x1', + target_sparsity=0.25, + start_step=FLAGS.start_step, + end_step=FLAGS.end_step, + pruning_op_types=['Conv', 'Dense'] + ) + compression_manager = prepare_compression(model='./baseline_model', confs=configs) + compression_manager.callbacks.on_train_begin() + model = compression_manager.model + + train(model, adaptor, compression_manager, train_dataloader) + print("Pruned model score is ",evaluate(model, adaptor, eval_dataloader)) + + compression_manager.callbacks.on_train_end() + compression_manager.save(FLAGS.output_model) + stats, sparsity = model.report_sparsity() + logger.info(stats) + logger.info(sparsity) + + if FLAGS.benchmark: + from neural_compressor.benchmark import fit + from neural_compressor.config import BenchmarkConfig + conf = BenchmarkConfig(cores_per_instance=4, num_of_instance=1, iteration=FLAGS.iters) + fit(FLAGS.input_model, conf, b_dataloader=eval_dataloader) \ No newline at end of file diff --git a/neural_compressor/adaptor/keras.py b/neural_compressor/adaptor/keras.py index d106936c3f8..8924de8bb2e 100644 --- a/neural_compressor/adaptor/keras.py +++ b/neural_compressor/adaptor/keras.py @@ -748,6 +748,129 @@ def convert(self, model, source, destinatin): ''' pass + def _pre_hook_for_hvd(self, dataloader=None): + """Pre hook for Horovod.""" + import horovod.tensorflow as hvd + self.hvd = hvd + self.hvd.init() + + @dump_elapsed_time(customized_msg="Model training") + def train(self, model, dataloader, optimizer_tuple, + criterion_tuple, hooks, postprocess, **kwargs): + """Model training API. + + Args: + model ([Graph, GraphDef or Path String]): The model could be the graph, + graph_def object, the frozen pb or ckpt/savedmodel folder path. + dataloader (generator): generate the data and labels. + optimizer_tuple (tuple): optimizers for model training. + criterion_tuple (tuple): criterions for model training. + hooks (callback): on_epoch_begin hook on_epoch_end hook. + postprocess (object): process the result from the model. + + Returns: + None. + """ + # check model is savedmodel or not + import tensorflow as tf + from neural_compressor.model.tensorflow_model import get_model_type + tf.random.set_seed(1) + self.model_type = get_model_type(model._model) + optimizer = optimizer_tuple[0](**optimizer_tuple[1]) + criterion = criterion_tuple[0](**criterion_tuple[1]) + start_epochs = kwargs['kwargs'].get('start_epoch', None) + end_epochs = kwargs['kwargs'].get('end_epoch', None) + epochs = kwargs['kwargs'].get('epoch', None) + iters = kwargs['kwargs'].get('iteration', None) + callbacks = kwargs['kwargs'].get('callbacks', None) + execution_mode = kwargs['kwargs'].get('execution_mode', None) + distributed = getattr(dataloader, 'distributed', False) + + if isinstance(model._model, tf.keras.Model): + input_model = model._model + else: + input_model = tf.keras.models.load_model(model._model) + # hooks = callbacks['tf_pruning'](model, input_model, hooks) + hooks['on_train_begin']() # on_train_begin hook + train_loss_results = [] + if distributed: + try: + len_dataloader = len(dataloader) + except: + logger.info("The length of the distributed training dataloader is unknown." + "When the iteration of training dataloader in each process is " + "inconsistent, an error may occur.") + else: + list_len_dataloader = self.hvd.allgather_object(len_dataloader) + if self.hvd.rank() == 0: + for i in range(len(list_len_dataloader)-1): + if list_len_dataloader[i] != list_len_dataloader[i+1]: + raise AttributeError("The traning dataloader's iteration is" + "different between processes, please reset dataloader's batch_size.") + + def training_step(x, y, first_batch): + with tf.GradientTape() as tape: + tape.watch(input_model.trainable_variables) + y_ = input_model(x, training=True) + loss_value = criterion(y, y_) + + tape = self.hvd.DistributedGradientTape(tape) if distributed else tape + # Get gradient + grads = tape.gradient(loss_value, input_model.trainable_variables) # pylint: disable=no-member + # Optimize the model + optimizer.apply_gradients(zip(grads, input_model.trainable_variables)) # pylint: disable=no-member + if distributed and first_batch: + self.hvd.broadcast_variables(input_model.variables, root_rank=0) + self.hvd.broadcast_variables(optimizer.variables(), root_rank=0) + return loss_value + + training_step = training_step if execution_mode=='eager' else tf.function(training_step) + if start_epochs is not None and end_epochs is not None: + epochs = end_epochs - start_epochs + + for epoch in range(epochs): + cnt = 0 + epoch_loss_avg = tf.keras.metrics.Mean() + # Training loop + for iter, data in enumerate(dataloader): + x, y = postprocess(data) if postprocess is not None else data + hooks['on_step_begin'](iter) # on_step_begin hook + cnt += 1 + loss_value = training_step(x, y, iter==0) + # Track progress + epoch_loss_avg.update_state(loss_value) # Add current batch loss + hooks['on_before_optimizer_step']() + hooks['on_after_optimizer_step']() + if iters is not None and cnt >= iters: + break + model._sess = None + # End epoch + train_loss_results.append(epoch_loss_avg.result()) + if distributed: + logger.info("Epoch-{:03d} training on rank {!s} have been done." \ + .format(epoch+1, self.hvd.allgather_object(self.hvd.rank()))) + logger.info("Epoch {:03d}: Loss: {:.3f}".format(epoch+1, epoch_loss_avg.result())) + + hooks['on_train_end']() # on_train_end hook + model._sess = None + + if distributed: + if self.hvd.rank() == 0: + # Update the input model with pruned weights manually due to keras API limitation. + if isinstance(model._model, tf.keras.Model): + model._model = input_model + else: + input_model.save(model._model) + rank_list = self.hvd.allgather_object(self.hvd.rank()) + logger.info(f"rank 0 has saved the pruned model to '{model._model}'," + f"all ranks {rank_list} ready.") + else: + if isinstance(model._model, tf.keras.Model): + model._model = input_model + else: + input_model.save(model._model) + + class KerasQuery(QueryBackendCapability): def __init__(self, local_config_file=None): super().__init__() diff --git a/neural_compressor/compression/callbacks.py b/neural_compressor/compression/callbacks.py index 5c6b44fa0c4..94398b1c97f 100644 --- a/neural_compressor/compression/callbacks.py +++ b/neural_compressor/compression/callbacks.py @@ -27,13 +27,12 @@ from ..model import BaseModel, Model from ..model.model import MODELS from .pruner.utils import process_config, parse_to_prune, get_sparsity_ratio +from .pruner.utils import parse_to_prune_tf, get_sparsity_ratio_tf from .pruner.pruners import get_pruner, PRUNERS -# model auto slim related -from .pruner.model_slim.pattern_analyzer import SelfMHASearcher - LazyImport('torch.nn') torch = LazyImport('torch') +tf = LazyImport('tensorflow') class BaseCallbacks(object): """This is base class of Neural Compressor Callbacks. @@ -225,8 +224,10 @@ def on_train_end(self): """Be called after the end of training.""" for on_train_end_hook in self.hooks_dict['on_train_end']: on_train_end_hook() - if isinstance(self.model.model, torch.nn.Module): + if self.conf.framework == 'pytorch' and isinstance(self.model.model, torch.nn.Module): get_sparsity_ratio(self.pruners, self.model) + elif self.conf.framework == 'keras' and isinstance(self.model.model, tf.keras.Model): + get_sparsity_ratio_tf(self.pruners, self.model) def __repr__(self): """Return the class's string representation.""" @@ -241,7 +242,9 @@ def generate_hooks(self): def _generate_pruners(self): """Obtain Pruner objects.""" - if isinstance(self.model.model, torch.nn.Module): + if self.conf.framework == 'pytorch' and isinstance(self.model.model, torch.nn.Module): + # model auto slim related + from .pruner.model_slim.pattern_analyzer import SelfMHASearcher for info in self.pruners_info: if 'mha' in info['pattern']: # head pruning @@ -262,6 +265,19 @@ def _generate_pruners(self): info['modules'] = [key for key in modules.keys()] info['len_of_modules'] = len(info['modules']) logger.info(info) + elif self.conf.framework == 'keras' and isinstance(self.model.model, tf.keras.Model): + from tensorflow.python.ops.numpy_ops import np_config + np_config.enable_numpy_behavior() + for info in self.pruners_info: + # original pruning types, e.g NxM or N:M + modules = parse_to_prune_tf(info, self.model.model) + if modules == {}: + logger.warning("one pruner hooks no layers, please have a check") + + self.pruners.append(get_pruner(info, modules, 'keras')) + info['modules'] = [key for key in modules.keys()] + info['len_of_modules'] = len(info['modules']) + logger.info(info) else: assert False, 'now only support {}'.format(PRUNERS.keys()) diff --git a/neural_compressor/compression/pruner/criteria.py b/neural_compressor/compression/pruner/criteria.py index 4b69fe8bd38..25cfcee0693 100644 --- a/neural_compressor/compression/pruner/criteria.py +++ b/neural_compressor/compression/pruner/criteria.py @@ -15,7 +15,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .utils import torch + +import numpy as np +from ...utils.utility import LazyImport +torch = LazyImport('torch') +tf = LazyImport('tensorflow') CRITERIA = {} @@ -31,12 +35,12 @@ def register(criterion): return register -def get_criterion(config, modules): +def get_criterion(config, modules, framework='pytorch'): """Get registered criterion class.""" name = config["criterion_type"] if name not in CRITERIA.keys(): assert False, f"criteria does not support {name}, currently only support {CRITERIA.keys()}" - return CRITERIA[name](modules, config) + return CRITERIA[name](modules, config, framework) class PruningCriterion: @@ -50,11 +54,12 @@ class PruningCriterion: scores: A dict {"module_name": Tensor} that stores the scores of pruning modules. """ - def __init__(self, modules, config): + def __init__(self, modules, config, framework='pytorch'): """Initiliaze a pruning criterion.""" self.scores = {} self.modules = modules self.config = config + self.framework=framework def on_step_begin(self): """Calculate and store the pruning scores of pruning modules at the beginning of a step.""" @@ -84,17 +89,21 @@ class MagnitudeCriterion(PruningCriterion): scores: A dict {"module_name": Tensor} that stores the scores of pruning modules. """ - def __init__(self, modules, config): + def __init__(self, modules, config, framework='pytorch'): """Initiliaze a magnitude pruning criterion.""" - super(MagnitudeCriterion, self).__init__(modules, config) + super(MagnitudeCriterion, self).__init__(modules, config, framework) def on_step_begin(self): """Calculate and store the pruning scores based on a magnitude criterion.""" - with torch.no_grad(): + if self.framework == 'pytorch': + with torch.no_grad(): + for key in self.modules.keys(): + p = self.modules[key].weight.data + self.scores[key] = torch.abs(p) + elif self.framework == 'keras': for key in self.modules.keys(): - p = self.modules[key].weight.data - self.scores[key] = torch.abs(p) - + p = self.modules[key].get_weights()[0] + self.scores[key] = np.abs(p) @register_criterion('gradient') class GradientCriterion(PruningCriterion): @@ -111,13 +120,14 @@ class GradientCriterion(PruningCriterion): scores: A dict {"module_name": Tensor} that stores the scores of pruning modules. """ - def __init__(self, modules, config): + def __init__(self, modules, config, framework='pytorch'): """Initiliaze a gradient pruning criterion.""" - super(GradientCriterion, self).__init__(modules, config) + super(GradientCriterion, self).__init__(modules, config, framework) assert self.config.end_step > 0, "please set end_step > 0 for gradient based criterion" def on_before_optimizer_step(self): """Calculate and store the pruning scores based on gradient criterion.""" + assert self.framework != 'keras', "This pruning criterion is not supported by Keras now." with torch.no_grad(): for key in self.modules.keys(): p = self.modules[key].weight @@ -141,18 +151,20 @@ class SnipCriterion(PruningCriterion): scores: A dict {"module_name": Tensor} that stores the scores of pruning modules. """ - def __init__(self, modules, config): + def __init__(self, modules, config, framework='pytorch'): """Initiliaze a snip pruning criterion.""" - super(SnipCriterion, self).__init__(modules, config) + super(SnipCriterion, self).__init__(modules, config, framework) assert self.config.end_step > 0, "please set end_step > 0 for gradient based criterion" def on_before_optimizer_step(self): """Calculate and store the pruning scores based on snip criterion.""" ##self.mask_weights() + assert self.framework != 'keras', "This pruning criterion is not supported by Keras now." with torch.no_grad(): for key in self.modules.keys(): p = self.modules[key].weight self.scores[key] = torch.abs(p * p.grad) + @register_criterion('snip_momentum') @@ -172,9 +184,10 @@ class SnipMomentumCriterion(PruningCriterion): scores: A dict {"module_name": Tensor} that stores the scores of pruning modules. """ - def __init__(self, modules, config): + def __init__(self, modules, config, framework='pytorch'): """Initiliaze a snip_momentum pruning criterion.""" - super(SnipMomentumCriterion, self).__init__(modules, config) + super(SnipMomentumCriterion, self).__init__(modules, config, framework) + assert self.framework != 'keras', "This pruning criterion is not supported by Keras now." assert self.config.end_step > 0, "please set end_step > 0 for gradient based criterion" for key in modules.keys(): p = modules[key].weight @@ -209,9 +222,10 @@ class SnipMomentumBlockCriterion(PruningCriterion): scores: A dict {"module_name": Tensor} that stores the scores of pruning modules. """ - def __init__(self, modules, config): + def __init__(self, modules, config, framework='pytorch'): """Initiliaze a block_mask pruning criterion.""" - super(SnipMomentumBlockCriterion, self).__init__(modules, config) + super(SnipMomentumBlockCriterion, self).__init__(modules, config, framework) + assert self.framework != 'keras', "This pruning criterion is not supported by Keras now." assert self.config.end_step > 0, "please set end_step > 0 for gradient based criterion" for key in self.modules.keys(): if not hasattr(self.modules[key], 'block_mask'): @@ -248,9 +262,10 @@ class RetrainFreeCriterion(PruningCriterion): scores: A dict {"module_name": Tensor} that stores the scores of pruning modules. """ - def __init__(self, modules, config): + def __init__(self, modules, config, framework='pytorch'): """Initiliaze a block_mask pruning criterion.""" - super(RetrainFreeCriterion, self).__init__(modules, config) + super(RetrainFreeCriterion, self).__init__(modules, config, framework) + assert self.framework != 'keras', "This pruning criterion is not supported by Keras now." assert self.config.end_step > 0, "please set end_step > 0 for gradient based criterion" self.collected_grads = {} for key in self.modules.keys(): diff --git a/neural_compressor/compression/pruner/model_slim/auto_slim.py b/neural_compressor/compression/pruner/model_slim/auto_slim.py index f85487250f6..0e4b6ccd09f 100644 --- a/neural_compressor/compression/pruner/model_slim/auto_slim.py +++ b/neural_compressor/compression/pruner/model_slim/auto_slim.py @@ -16,8 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # model slim related -from .pattern_analyzer import Linear2LinearSearcher, RecipeSearcher, SelfMHASearcher -from .weight_slim import LinearCompressionIterator, MHACompression + from ..utils import logger def model_slim(model, dataloader=None, round_multiplier=32): @@ -39,6 +38,8 @@ def model_slim_ffn2(model, dataloader = None, round_multiplier=32): model: a sprase model. round_multiplier(int): the channel number after slimming should be multiple of this number. """ + from .pattern_analyzer import Linear2LinearSearcher + from .weight_slim import LinearCompressionIterator logger.warning(f"You are using model slim methods, some weight channels will be removed permanently.") pa_obj = Linear2LinearSearcher(model, dataloader) layers = pa_obj.search() @@ -53,6 +54,8 @@ def model_slim_mha(model, dataloader = None): Args: model: a sprase model. """ + from .weight_slim import MHACompression + from .pattern_analyzer import SelfMHASearcher logger.warning(f"You are using model slim methods, some attention heads will be removed permanently.") pa_obj = SelfMHASearcher(model, dataloader) layers, _ = pa_obj.search(split_qkv_ffn = False) @@ -75,6 +78,7 @@ def parse_auto_slim_config(model, dataloader = None, ffn2_sparsity = .0, mha_spa def generate_ffn2_pruning_config(model, dataloader, ffn2_sparsity, **kwargs): """Get consecutive linear layers pruning configs.""" + from .pattern_analyzer import Linear2LinearSearcher searcher = Linear2LinearSearcher(model, dataloader) layers = searcher.search() # extract the second linear layer @@ -106,6 +110,7 @@ def generate_mha_pruning_config(model, dataloader, mha_sparsity, **kwargs): return mha_pruning_config # method 2: apply experimental mha pruning + # from .pattern_analyzer import SelfMHASearcher # searcher = SelfMHASearcher(model, dataloader) # qkv_pattern, ffn_pattern = searcher.get_head_pattern() # qkv_layers, ffn_layers = searcher.search() diff --git a/neural_compressor/compression/pruner/model_slim/pattern_analyzer.py b/neural_compressor/compression/pruner/model_slim/pattern_analyzer.py index 91b673fab09..7df005fbbc6 100644 --- a/neural_compressor/compression/pruner/model_slim/pattern_analyzer.py +++ b/neural_compressor/compression/pruner/model_slim/pattern_analyzer.py @@ -16,8 +16,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..utils import torch, logger +from ..utils import logger import re +from ....utils.utility import LazyImport +torch = LazyImport('torch') +tf = LazyImport('tensorflow') JIT_SUPPORT_OPS = ['linear', 'dropout', 'gelu', 'silu', 'relu', 'mul', 'add'] @@ -792,3 +795,42 @@ def search(self, return_name=True): last_lc = all_lc_modules[-1] if last_lc == all_modules[-1]: return last_lc else: return None + +class ClassifierHeadSearcherTF(object): + """Static graph searcher for multi-head attention modules. + + Use the static graph to detect final classifier head in a module, there is no need for user to define layer name. + Automatically search multi-head attention modules which can be optimized. + + Args: + model (tf.keras.Model): The Keras model for searching. + + Attributes: + model: The Keras model for searching. + device: The model's current device type. + static_graph: The static graph of original model. + flatten_static_graph: A list of string with the model's static graph inference details. + """ + + def __init__(self, model): + """Initialize.""" + assert isinstance(model, tf.keras.Model) + super(ClassifierHeadSearcherTF, self).__init__() + self.model = model + self.pruning_ops = ["Dense", "Conv2d"] + self.excluded_ops = ["Dropout"] # to be extended + + def search(self, return_name=True): + all_modules = [] + all_lc_modules = [] + for layer in self.model.layers: + if layer.__class__.__name__ not in self.excluded_ops: + all_modules.append(layer.name) + if layer.__class__.__name__ in self.pruning_ops: + all_lc_modules.append(layer.name) + else: + continue + last_lc = all_lc_modules[-1] + if last_lc == all_modules[-1]: + return last_lc + return None \ No newline at end of file diff --git a/neural_compressor/compression/pruner/patterns.py b/neural_compressor/compression/pruner/patterns.py index 3b68f36b695..96e17a4f99a 100644 --- a/neural_compressor/compression/pruner/patterns.py +++ b/neural_compressor/compression/pruner/patterns.py @@ -16,8 +16,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .utils import torch, logger +import numpy as np +from .utils import logger from collections import namedtuple +from ...utils.utility import LazyImport +torch = LazyImport('torch') +tf = LazyImport('tensorflow') PATTERNS = {} @@ -42,7 +46,7 @@ def register(pattern): return register -def get_pattern(config, modules): +def get_pattern(config, modules, framework='pytorch'): """Get registered pattern class. Get a Pattern object from PATTERNS. @@ -60,7 +64,7 @@ def get_pattern(config, modules): name = config.pattern name = name.split('_')[-1] if "x" in name: - return PATTERNS["NxM"](config, modules) + return PATTERNS["NxM"](config, modules, framework) if ":" in name: return PATTERNS["N:M"](config, modules) if "mha" in name: @@ -318,7 +322,7 @@ class BasePattern: target_sparsity: A float representing the sparsity ratio of the modules after pruning. """ - def __init__(self, config, modules): + def __init__(self, config, modules, framework='pytorch'): """Initialize the basic pruning unit of a pattern.""" self.pattern = config.pattern self.is_global = config.pruning_scope == "global" @@ -326,12 +330,14 @@ def __init__(self, config, modules): self.invalid_layers = [] self.modules = modules self.config = config + self.framework = framework self.max_sparsity_ratio_per_op = self.config['max_sparsity_ratio_per_op'] self.min_sparsity_ratio_per_op = self.config['min_sparsity_ratio_per_op'] self.target_sparsity_ratio = self.config['target_sparsity'] self.block = bool('block' in self.config['pruning_type'] or 'free' in self.config['pruning_type']) # Not using deterministic_algorithms for all examples - torch.use_deterministic_algorithms(False) + if self.framework == 'pytorch': + torch.use_deterministic_algorithms(False) def reduce_tensor(self, data, dim): """Reduce the data along the given dimension. @@ -345,11 +351,20 @@ def reduce_tensor(self, data, dim): """ name = self.config['criterion_reduce_type'] if name == "mean": - return torch.mean(data, dim=dim) + if self.framework == 'pytorch': + return torch.mean(data, dim=dim) + elif self.framework == 'keras': + return tf.math.reduce_mean(data, dim) elif name == "sum": - return torch.sum(data, dim=dim) + if self.framework == 'pytorch': + return torch.sum(data, dim=dim) + elif self.framework == 'keras': + return tf.math.reduce_sum(data, dim) elif name == "max": - return torch.max(data, dim=dim)[0] + if self.framework == 'pytorch': + return torch.max(data, dim=dim)[0] + elif self.framework == 'keras': + return tf.math.reduce_max(data, dim) else: assert False, "currently only support mean, sum and max reduce type" @@ -406,15 +421,27 @@ def get_single_mask_per_target_ratio(self, score, exact_sparsity_ratio): Returns: A Tensor with the identical size as score. a new mask. """ - flattern_score = torch.flatten(score) - k = int(exact_sparsity_ratio * flattern_score.numel()) - threshold, _ = torch.kthvalue(flattern_score, k) - if not k < 1: - zero = torch.tensor([0.]).to(score.device) - one = torch.tensor([1.]).to(score.device) - mask = torch.where(score <= threshold, zero, one) - else: - mask = torch.ones(score.shape, device=score.device) + if self.framework == 'pytorch': + flattern_score = torch.flatten(score) + k = int(exact_sparsity_ratio * flattern_score.numel()) + threshold, _ = torch.kthvalue(flattern_score, k) + if not k < 1: + zero = torch.tensor([0.]).to(score.device) + one = torch.tensor([1.]).to(score.device) + mask = torch.where(score <= threshold, zero, one) + else: + mask = torch.ones(score.shape, device=score.device) + elif self.framework == 'keras': + flattern_score = tf.reshape(score, [-1]).numpy() + k = int(exact_sparsity_ratio * flattern_score.size) + threshold = np.partition(flattern_score, kth=k)[k] + if not k < 1: + zero = tf.convert_to_tensor([0.]) + one = tf.convert_to_tensor([1.]) + mask = tf.where(score <= threshold, zero, one) + else: + mask = tf.ones_like(score.shape) + return mask def get_block_size_dict(self, data): @@ -446,6 +473,7 @@ def get_sparsity_ratio(self, pre_masks, return_dict=False): pre_mask = pre_masks[key] zero_cnt += torch.sum(pre_mask == 0.0).data.item() total_cnt += pre_mask.numel() ##FIXME + if return_dict: return {"sparsity_ratio": float(zero_cnt) / total_cnt, "zero_cnt": zero_cnt, "total_cnt": total_cnt} else: @@ -469,6 +497,7 @@ def get_sparsity_ratio_progressive(self, pre_masks, return_dict=False): # progressive masks are unstructured, therefore directly find zeros zero_cnt += float(torch.sum(pre_masks[key] == 0).data.item()) total_cnt += float(pre_masks[key].numel()) + return (zero_cnt / total_cnt) def get_pattern_lock_masks(self, modules): @@ -487,6 +516,7 @@ def get_pattern_lock_masks(self, modules): mask = torch.ones(shape) mask[weight == 0] = 0.0 pattern_lock_masks[key] = mask.to(weight.device) + return pattern_lock_masks def check_layer_validity(self): @@ -529,17 +559,34 @@ def get_sparsity_ratio_each_layer(self, masks): infos = {} zero_cnts = 0 total_cnts = 0 - for key in masks.keys(): - if key in self.invalid_layers: - continue - reduced_mask = masks[key] if self.block else self.get_reduced_masks_from_data(masks[key], key) - zero_cnt = (int(torch.sum(reduced_mask == 0.0).data.item())) - total_cnt = int(reduced_mask.numel()) - sparsity_ratio = float(zero_cnt) / total_cnt - val = SparsityInfo(zero_cnt, total_cnt, sparsity_ratio) - infos[key] = val - zero_cnts += zero_cnt - total_cnts += total_cnt + + if self.framework == 'pytorch': + for key in masks.keys(): + if key in self.invalid_layers: + continue + reduced_mask = masks[key] if self.block else self.get_reduced_masks_from_data(masks[key], key) + zero_cnt = (int(torch.sum(reduced_mask == 0.0).data.item())) + total_cnt = int(reduced_mask.numel()) + sparsity_ratio = float(zero_cnt) / total_cnt + val = SparsityInfo(zero_cnt, total_cnt, sparsity_ratio) + infos[key] = val + zero_cnts += zero_cnt + total_cnts += total_cnt + elif self.framework == 'keras': + for key in masks.keys(): + if key in self.invalid_layers: + continue + if not isinstance(masks[key], np.ndarray): + masks[key] = masks[key].numpy() + reduced_mask = masks[key] if self.block else self.get_reduced_masks_from_data(masks[key], key) + zero_cnt = int(np.sum(reduced_mask == 0.0)) + total_cnt = int(reduced_mask.size) + sparsity_ratio = float(zero_cnt) / total_cnt + val = SparsityInfo(zero_cnt, total_cnt, sparsity_ratio) + infos[key] = val + zero_cnts += zero_cnt + total_cnts += total_cnt + sparsity_ratio = float(zero_cnts) / total_cnts return infos, SparsityInfo(zero_cnts, total_cnts, sparsity_ratio) @@ -629,9 +676,9 @@ class PatternNxM(BasePattern): because PyTorch's tensor matmul has a hidden transpose operation. """ - def __init__(self, config, modules): + def __init__(self, config, modules, framework='pytorch'): """Initialize the basic pruning unit of NXM pattern.""" - super(PatternNxM, self).__init__(config, modules) + super(PatternNxM, self).__init__(config, modules, framework) pattern = self.pattern.split('_')[-1] self.N = pattern.split('x')[0] self.M = pattern.split('x')[1] @@ -662,7 +709,7 @@ def get_block_size_dict(self): block_sizes_dict[key] = self.block_size if not (self.N == "channel" or self.M == "channel"): continue - if isinstance(datas[key], torch.nn.Module): + if self.framework == 'pytorch' and isinstance(datas[key], torch.nn.Module): shape = datas[key].weight.shape else: shape = datas[key].shape @@ -677,14 +724,24 @@ def check_layer_validity(self): """Check if a layer is valid for this block_size.""" block_sizes = self.block_size datas = self.modules - for key in datas.keys(): - data = datas[key].weight - data = self._reshape_orig_to_2dims(data) - shape = data.shape - block_size = block_sizes[key] - if shape[0] % block_size[0] != 0 or shape[1] % block_size[1] != 0: ## only consider input channel - self.invalid_layers.append(key) - logger.warning(f"{key} shape {data.shape} cannot be divided by {self.pattern}") + if self.framework == 'pytorch': + for key in datas.keys(): + data = datas[key].weight + data = self._reshape_orig_to_2dims(data) + shape = data.shape + block_size = block_sizes[key] + if shape[0] % block_size[0] != 0 or shape[1] % block_size[1] != 0: ## only consider input channel + self.invalid_layers.append(key) + logger.warning(f"{key} shape {data.shape} cannot be divided by {self.pattern}") + elif self.framework == 'keras': + for key in datas.keys(): + data = datas[key].get_weights()[0] + data = self._reshape_orig_to_2dims(data) + shape = data.shape + block_size = block_sizes[key] + if shape[0] % block_size[0] != 0 or shape[1] % block_size[1] != 0: ## only consider input channel + self.invalid_layers.append(key) + logger.warning(f"{key} shape {data.shape} cannot be divided by {self.pattern}") def get_reduced_masks_from_data(self, data, key): """Obtain the unpruned weights and reshape according to the block_size. @@ -718,12 +775,22 @@ def get_sparsity_ratio(self, pre_masks, return_dict=False): """ zero_cnt = 0 total_cnt = 0 - for key in pre_masks.keys(): - if key in self.invalid_layers: - continue - reduced_mask = pre_masks[key] if self.block else self.get_reduced_masks_from_data(pre_masks[key], key) - zero_cnt += (int(torch.sum(reduced_mask == 0.0).data.item())) - total_cnt += int(reduced_mask.numel()) + if self.framework == 'pytorch': + for key in pre_masks.keys(): + if key in self.invalid_layers: + continue + reduced_mask = pre_masks[key] if self.block else self.get_reduced_masks_from_data(pre_masks[key], key) + zero_cnt += (int(torch.sum(reduced_mask == 0.0).data.item())) + total_cnt += int(reduced_mask.numel()) + elif self.framework == 'keras': + for key in pre_masks.keys(): + if key in self.invalid_layers: + continue + if not isinstance(pre_masks[key], np.ndarray): + pre_masks[key] = pre_masks[key].numpy() + reduced_mask = pre_masks[key] if self.block else self.get_reduced_masks_from_data(pre_masks[key], key) + zero_cnt += int(np.sum(reduced_mask == 0.0)) + total_cnt += int(reduced_mask.size) if total_cnt == 0: sparsity_ratio = 0.0 else: @@ -744,7 +811,10 @@ def _reshape_orig_to_2dims(self, data): """ ##TODO need to verify whether it's ok for transposed conv if len(data.shape) == 4: - data = data.permute(0, 2, 3, 1) ##cout,k,k,cin + if isinstance(data, np.ndarray): + data = np.transpose(data, (0, 2, 3, 1)) + else: + data = data.permute(0, 2, 3, 1) ##cout,k,k,cin data = data.reshape(data.shape[0], -1) return data @@ -761,7 +831,10 @@ def _reshape_2dims_to_orig(self, data, orig_shape): if len(orig_shape) == 4: data = data.reshape(orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[1]) - data = data.permute(0, 3, 1, 2) + if isinstance(data, np.ndarray): + data = np.transpose(data, (0, 3, 1, 2)) + else: + data = data.permute(0, 3, 1, 2) return data def reshape_orig_to_pattern(self, data, key): @@ -823,11 +896,20 @@ def reduce_scores(self, scores): def get_mask_per_threshold(self, score, threshold, block_size): """Get the mask per threshold.""" - zero = torch.tensor([0.]).to(score.device) - one = torch.tensor([1.]).to(score.device) - mask = torch.where(score <= threshold, zero, one) - if not self.block: - mask = mask.repeat_interleave(block_size[0], dim=0).repeat_interleave(block_size[1], dim=-1) + if self.framework == 'pytorch': + zero = torch.tensor([0.]).to(score.device) + one = torch.tensor([1.]).to(score.device) + mask = torch.where(score <= threshold, zero, one) + if not self.block: + mask = mask.repeat_interleave(block_size[0], dim=0).repeat_interleave(block_size[1], dim=-1) + elif self.framework == 'keras': + zero = tf.convert_to_tensor([0.]) + one = tf.convert_to_tensor([1.]) + mask = tf.where(score <= threshold, zero, one) + if not self.block: + mask = tf.repeat(mask, repeats=block_size[0], axis=0) + mask = tf.repeat(mask, repeats=block_size[1], axis=-1) + mask = mask.numpy() return mask def get_masks_global(self, scores, cur_target_sparsity_ratio, pre_masks, @@ -837,6 +919,30 @@ def get_masks_global(self, scores, cur_target_sparsity_ratio, pre_masks, Gather all layer's scores together and calculate a common threshold. This threshold will be applied to all layers. + Args: + scores: A dict{"layer_name": Tensor} that stores the pruning scores of weights. + cur_target_sparsity_ratio: A float representing the model's sparsity after pruning. + pre_masks: A dict{"layer_name": Tensor} that stores the masks generated at the last pruning step. + max_sparsity_ratio_per_op: A float representing the maximum sparsity that one layer can reach. + keep_pre_masks: A bool representing if the masks should remain unchanged. + + Returns: + A dict with the identical size as pre_masks and its 0/1 values are updated. + 1 means unpruned and 0 means pruned. + """ + if self.framework == 'pytorch': + return self.get_masks_global_pytorch(scores, cur_target_sparsity_ratio, pre_masks, \ + keep_exact_sparsity_ratio) + elif self.framework == 'keras': + return self.get_masks_global_tf(scores, cur_target_sparsity_ratio, pre_masks, keep_exact_sparsity_ratio) + + def get_masks_global_pytorch(self, scores, cur_target_sparsity_ratio, pre_masks, + keep_exact_sparsity_ratio=True): + """Generate masks for layers. + + Gather all layer's scores together and calculate a common threshold. + This threshold will be applied to all layers. + Args: scores: A dict{"layer_name": Tensor} that stores the pruning scores of weights. cur_target_sparsity_ratio: A float representing the model's sparsity after pruning. @@ -907,6 +1013,84 @@ def get_masks_global(self, scores, cur_target_sparsity_ratio, pre_masks, logger.info(f'{key} sparsity is {layer_ratio}') return masks + def get_masks_global_tf(self, scores, cur_target_sparsity_ratio, pre_masks, + keep_exact_sparsity_ratio=True): + """Generate masks for layers. + + Gather all layer's scores together and calculate a common threshold. + This threshold will be applied to all layers. + + Args: + scores: A dict{"layer_name": Tensor} that stores the pruning scores of weights. + cur_target_sparsity_ratio: A float representing the model's sparsity after pruning. + pre_masks: A dict{"layer_name": Tensor} that stores the masks generated at the last pruning step. + max_sparsity_ratio_per_op: A float representing the maximum sparsity that one layer can reach. + keep_pre_masks: A bool representing if the masks should remain unchanged. + + Returns: + A dict with the identical size as pre_masks and its 0/1 values are updated. + 1 means unpruned and 0 means pruned. + """ + ##keep the masks if the layer exceed max sparsity ratio + + masks = pre_masks + k_blockwise = self.update_residual_cnt(masks, cur_target_sparsity_ratio) + if k_blockwise <= 0: + return masks + new_scores = scores if self.block else self.reduce_scores(scores) + not_exceed_layers = [] + residual_k = k_blockwise + if self.min_sparsity_ratio_per_op > 0: + sparsity_infos_perlayer, _ = self.get_sparsity_ratio_each_layer(masks) + + while True: + new_not_exceed_layers = [key for key in new_scores.keys() if not self.keep_mask_layers.get(key, False)] + if not_exceed_layers == new_not_exceed_layers or len(new_not_exceed_layers) == 0: + break + not_exceed_layers = new_not_exceed_layers + global_scores = np.concatenate([tf.reshape(new_scores[key], [-1]).numpy() for key in not_exceed_layers]) + threshold = np.partition(global_scores, kth=residual_k)[residual_k] + + for key in not_exceed_layers: + block_size = self.block_size[key] + score = new_scores[key] + mask = self.get_mask_per_threshold(score, threshold, block_size) + info = self.get_sparsity_ratio({key: mask}, return_dict=True) + zero_cnt = info["zero_cnt"] + total_cnt = info["total_cnt"] + current_sparsity_ratio = float(zero_cnt) / total_cnt + key_new_sparsity = SparsityInfo(zero_cnt, total_cnt, current_sparsity_ratio) + need_adjust, adjust_ratio = self.adjust_ratio(masks, key, key_new_sparsity, + self.max_sparsity_ratio_per_op, + self.min_sparsity_ratio_per_op, + self.target_sparsity_ratio) + if need_adjust: + # uptade status + self.keep_mask_layers[key] = True + masks[key] = self.get_single_mask_per_target_ratio(new_scores[key], adjust_ratio) + if not self.block: + masks[key] = tf.repeat(masks[key], repeats=block_size[0], axis=0) + masks[key] = tf.repeat(masks[key], repeats=block_size[1], axis=-1) + if keep_exact_sparsity_ratio: + zero_cnt = self.get_sparsity_ratio({key: masks[key]}, return_dict=True)["zero_cnt"] + residual_k -= zero_cnt + else: + masks[key] = mask + if not keep_exact_sparsity_ratio: + break + + for key in masks.keys(): + if key in self.invalid_layers: + continue + if len(scores[key].shape) == 4: ## need to permute + mask = masks[key] + orig_shape = scores[key].shape + mask = self._reshape_2dims_to_orig(mask, orig_shape) + masks[key] = mask + layer_ratio = np.sum(masks[key] == 0.0) / masks[key].size + logger.info(f'{key} sparsity is {layer_ratio}') + return masks + def get_pattern_lock_masks(self, modules): """Obtain masks from original weight map by masking the zero-valued weights. @@ -927,6 +1111,7 @@ def get_pattern_lock_masks(self, modules): reduced_mask = self.get_reduced_masks_from_data(weight, key) mask = self.reshape_reduced_to_orig(reduced_mask, key, ori_shape) pattern_lock_masks[key] = mask + return pattern_lock_masks def register_block_masks(self, modules): @@ -951,6 +1136,7 @@ def register_block_masks(self, modules): block_mask = torch.nn.Parameter(self.get_reduced_masks_from_data(weight, key).to(dtype=weight.dtype)) module.register_parameter("block_mask", block_mask) masks[key] = modules[key].block_mask.data + return masks def remove_block_masks(self): @@ -1011,9 +1197,9 @@ class PatternNInM(BasePattern): M: The size of the weight sequence. """ - def __init__(self, config, modules): + def __init__(self, config, modules, framework='pytorch'): """Initialize the basic pruning unit of N:M pattern.""" - super(PatternNInM, self).__init__(config, modules) + super(PatternNInM, self).__init__(config, modules, framework) pattern = self.pattern.split('_')[-1] self.N = int(pattern.split(':')[0]) self.M = int(pattern.split(':')[1]) ##m is bigger @@ -1345,7 +1531,8 @@ class PatternMHA(BasePattern): M: The size of the weight sequence. """ - def __init__(self, config, modules = None): + def __init__(self, config, modules = None, framework='pytorch'): + self.framework = framework self.is_global = config.pruning_scope == "global" # only implement three method: get_masks, get_masks_local, get_masks_global diff --git a/neural_compressor/compression/pruner/pruners.py b/neural_compressor/compression/pruner/pruners.py index 36e53ace1b0..5ac68944de9 100644 --- a/neural_compressor/compression/pruner/pruners.py +++ b/neural_compressor/compression/pruner/pruners.py @@ -16,18 +16,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import re -from .utils import torch, F +import copy +import numpy as np from functools import partial from .patterns import get_pattern from .schedulers import get_scheduler from .criteria import get_criterion, CRITERIA from .regs import get_reg from .utils import logger -# auto slim related: head pruning objects -from .model_slim.pattern_analyzer import SelfMHASearcher -from .model_slim.weight_slim import MHACompression + +from ...utils.utility import LazyImport +torch = LazyImport('torch') +tf = LazyImport('tensorflow') +F = LazyImport('torch.nn.functional') PRUNERS = {} @@ -62,7 +64,7 @@ def parse_valid_pruner_types(): valid_pruner_types.append("pattern_lock") return valid_pruner_types -def get_pruner(config, modules): +def get_pruner(config, modules, framework='pytorch'): """Get registered pruner class. Get a Pruner object from PRUNERS. @@ -102,7 +104,7 @@ def get_pruner(config, modules): if name not in PRUNERS.keys(): assert False, f"does not support {name}, currently only support {parse_valid_pruner_types()}" - return PRUNERS[name](config, modules) + return PRUNERS[name](config, modules, framework) class BasePruner: """Pruning Pruner. @@ -131,10 +133,11 @@ class BasePruner: max_sparsity_ratio_per_op: A float showing the maximum sparsity ratio for every module. """ - def __init__(self, config, modules): + def __init__(self, config, modules, framework='pytorch'): """Initialize.""" self.modules = modules self.config = config + self.framework = framework self.masks = {} self.global_step = 0 self.handled_global_step = -1 @@ -150,10 +153,15 @@ def __init__(self, config, modules): self.total_prune_cnt = 1 self.completed_pruned_cnt = 1 - for key in self.modules.keys(): - module = self.modules[key] - self.masks[key] = torch.ones(module.weight.shape).to(module.weight.device) ##TODO support bias or others - + if self.framework == 'pytorch': + for key in self.modules.keys(): + module = self.modules[key] + ##TODO support bias or others + self.masks[key] = torch.ones(module.weight.shape).to(module.weight.device) + elif self.framework == 'keras': + for key in self.modules.keys(): + module = self.modules[key] + self.masks[key] = np.ones(module.get_weights()[0].shape) self.target_sparsity_ratio = self.config['target_sparsity'] self.current_sparsity_ratio = 0.0 self.init_sparsity_ratio = 0.0 @@ -172,10 +180,15 @@ def mask_weights(self): Weights are multipled with masks. This is the formal pruning process. """ - with torch.no_grad(): + if self.framework == 'pytorch': + with torch.no_grad(): + for key in self.modules.keys(): + module = self.modules[key] + module.weight.data = module.weight.data * self.masks[key] + elif self.framework == 'keras': for key in self.modules.keys(): module = self.modules[key] - module.weight.data = module.weight.data * self.masks[key] + module.set_weights([module.get_weights()[0] * self.masks[key]] + module.get_weights()[1:]) def mask_weights_general(self, input_masks): """Apply input masks to corresponding modules' weights. @@ -185,10 +198,15 @@ def mask_weights_general(self, input_masks): Args: input_masks: A dict {"module_name": Tensor} that stores the masks for modules' weights. """ - with torch.no_grad(): + if self.framework == 'pytorch': + with torch.no_grad(): + for key in self.modules.keys(): + module = self.modules[key] + module.weight.data = module.weight.data * input_masks[key] + elif self.framework == 'keras': for key in self.modules.keys(): module = self.modules[key] - module.weight.data = module.weight.data * input_masks[key] + module.set_weights([module.get_weights()[0] * input_masks[key]] + module.get_weights()[1:]) def on_step_begin(self, local_step): """Implement at the start of each step.""" @@ -254,6 +272,7 @@ def check_is_pruned_step(self, step): def rewrite_forward(self): """Rewrite forward to implement block mask operation""" + assert self.framework != 'keras', "This pruning method is not supported by Keras now." def forward(self, input): block_size = [self.weight.shape[0]//self.block_mask.shape[0], \ self.weight.shape[1]//self.block_mask.shape[1]] @@ -269,6 +288,7 @@ def forward(self, input): def recover_forward(self): """Restore the forward format at the end of pruning""" + assert self.framework != 'keras', "This pruning method is not supported by Keras now." with torch.no_grad(): for key in self.modules.keys(): if not hasattr(self.modules[key], 'block_mask'): @@ -296,18 +316,18 @@ class BasicPruner(BasePruner): reg: A Reg object that defines regulization terms. """ - def __init__(self, config, modules): + def __init__(self, config, modules, framework='pytorch'): """Initialize.""" # self.modules = modules # self.config = config # self.masks = {} - super(BasicPruner, self).__init__(config, modules) + super(BasicPruner, self).__init__(config, modules, framework) def _init(self): """Auxiliary function for initializing.""" - self.pattern = get_pattern(self.config, self.modules) + self.pattern = get_pattern(self.config, self.modules, self.framework) self.scheduler = get_scheduler(self.config) - self.criterion = get_criterion(self.config, self.modules) + self.criterion = get_criterion(self.config, self.modules, self.framework) self.reg = get_reg(self.config, self.modules, self.pattern) # if switch off progressive but use per-channel pruning, give a warn if "channel" in self.pattern.pattern: @@ -388,9 +408,9 @@ class PatternLockPruner(BasePruner): Inherit from parent class Pruner. """ - def __init__(self, config, modules): + def __init__(self, config, modules, framework='pytorch'): """Initialize.""" - super(PatternLockPruner, self).__init__(config, modules) + super(PatternLockPruner, self).__init__(config, modules, framework) self.pattern = get_pattern(self.config, modules) assert self.config.end_step == self.config.start_step, "pattern_lock pruner only supports one shot mode" @@ -428,17 +448,17 @@ class BlockMaskPruner(BasePruner): scheduler: A Scheduler object that defines how the model's sparsity changes as training/pruning proceeds. reg: A Reg object that defines regulization terms. """ - def __init__(self, config, modules): + def __init__(self, config, modules, framework='pytorch'): """Initialize.""" - super(BlockMaskPruner, self).__init__(config, modules) + super(BlockMaskPruner, self).__init__(config, modules, framework) def _init(self): """Initialize.""" - self.pattern = get_pattern(self.config, self.modules) + self.pattern = get_pattern(self.config, self.modules, self.framework) self.masks = self.pattern.register_block_masks(self.modules) self.rewrite_forward() self.scheduler = get_scheduler(self.config) - self.criterion = get_criterion(self.config, self.modules) + self.criterion = get_criterion(self.config, self.modules, self.framework) self.reg = get_reg(self.config, self.modules, self.pattern) if "channel" not in self.pattern.pattern: @@ -553,17 +573,17 @@ class RetrainFreePruner(BasePruner): scheduler: A Scheduler object that defines how the model's sparsity changes as training/pruning proceeds. reg: A Reg object that defines regulization terms. """ - def __init__(self, config, modules): + def __init__(self, config, modules, framework='pytorch'): """Initialize.""" - super(RetrainFreePruner, self).__init__(config, modules) + super(RetrainFreePruner, self).__init__(config, modules, framework) def _init(self): """Initialize.""" - self.pattern = get_pattern(self.config, self.modules) + self.pattern = get_pattern(self.config, self.modules, self.framework) self.masks = self.pattern.register_block_masks(self.modules) self.rewrite_forward() self.scheduler = get_scheduler(self.config) - self.criterion = get_criterion(self.config, self.modules) + self.criterion = get_criterion(self.config, self.modules, self.framework) self.reg = get_reg(self.config, self.modules, self.pattern) logger.warning("Retrain-free pruner fixed the weights, please DO NOT turn on gradient update.") @@ -703,15 +723,15 @@ class ProgressivePruner(BasicPruner): Inherit from parent class Pruner. """ - def __init__(self, config, modules): + def __init__(self, config, modules, framework='pytorch'): """Initialize.""" - super(ProgressivePruner, self).__init__(config, modules) + super(ProgressivePruner, self).__init__(config, modules, framework) def _init(self): """Auxiliary function for initialization.""" - self.pattern = get_pattern(self.config, self.modules) + self.pattern = get_pattern(self.config, self.modules, self.framework) self.scheduler = get_scheduler(self.config) - self.criterion = get_criterion(self.config, self.modules) + self.criterion = get_criterion(self.config, self.modules, self.framework) self.reg = get_reg(self.config, self.modules, self.pattern) # progressive pruning set up, including check up paramters. self.use_progressive = self.config["progressive"] @@ -992,6 +1012,8 @@ def _init_mha_attrs(self): # similar to original mha slim process, but only hook mha modules and their attributes, # do not call slim main functions. """ + # auto slim related: head pruning objects + from .model_slim.weight_slim import MHACompression for mha_module in self.mha_modules: # initialize self.mha_compressions mha_comp = MHACompression(mha_module) diff --git a/neural_compressor/compression/pruner/utils.py b/neural_compressor/compression/pruner/utils.py index 34595ab0260..8945998393f 100644 --- a/neural_compressor/compression/pruner/utils.py +++ b/neural_compressor/compression/pruner/utils.py @@ -18,6 +18,7 @@ import re import yaml +import numpy as np from ...config import WeightPruningConfig as WeightPruningConf try: @@ -29,6 +30,7 @@ from neural_compressor.conf.config import Pruner LazyImport('torch.nn') torch = LazyImport('torch') + tf = LazyImport('tensorflow') F = LazyImport('torch.nn.functional') except: import torch @@ -141,6 +143,62 @@ def get_sparsity_ratio(pruners, model): return elementwise_over_matmul_gemm_conv, elementwise_over_all, blockwise_over_matmul_gemm_conv +def get_sparsity_ratio_tf(pruners, model): + """Calculate sparsity ratio of a module/layer. + + Returns: + Three floats. + elementwise_over_matmul_gemm_conv refers to zero elements' ratio in pruning layers. + elementwise_over_all refers to zero elements' ratio in all layers in the model. + blockwise_over_matmul_gemm_conv refers to all-zero blocks' ratio in pruning layers. + """ + pattern_sparsity_cnt = 0 + element_sparsity_cnt = 0 + if hasattr(model, 'model'): + model = model.model + for pruner in pruners: + modules = pruner.modules + sparsity_ratio = pruner.pattern.get_sparsity_ratio(pruner.masks) + cnt = 0 + for key in modules.keys(): + cnt += modules[key].get_weights()[0].size + pattern_sparsity_cnt += int(cnt * sparsity_ratio) + for key in pruner.masks.keys(): + block_num = 1 + if pruner.pattern.block: + block_size = pruner.pattern.block_size[key] + block_num = block_size[0] * block_size[1] + element_sparsity_cnt += np.sum(pruner.masks[key] == 0) * block_num + + linear_conv_cnt = 0 + param_cnt = 0 + for layer in model.layers: + if layer.__class__.__name__ in ["Dense"] or re.search(r'Conv.d', layer.__class__.__name__) != None: + linear_conv_cnt += layer.get_weights()[0].size + + for layer in model.layers: + if bool(layer.weights): + weights = layer.get_weights()[0] + param_cnt += weights.size + if linear_conv_cnt == 0: + blockwise_over_matmul_gemm_conv = 0 + elementwise_over_matmul_gemm_conv = 0 + else: + blockwise_over_matmul_gemm_conv = float(pattern_sparsity_cnt) / linear_conv_cnt + elementwise_over_matmul_gemm_conv = float(element_sparsity_cnt) / linear_conv_cnt + if param_cnt == 0: + elementwise_over_all = 0 + else: + elementwise_over_all = float( + element_sparsity_cnt) / param_cnt + + logger.info( + f"elementwise_over_matmul_gemm_conv:{elementwise_over_matmul_gemm_conv}," + f" elementwise_over_all:{elementwise_over_all}," + f"blockwise_over_matmul_gemm_conv:{blockwise_over_matmul_gemm_conv}") + + return elementwise_over_matmul_gemm_conv, elementwise_over_all, blockwise_over_matmul_gemm_conv + def check_config(prune_config): """Check if the configuration dict is valid for running Pruning object. @@ -424,6 +482,19 @@ def parse_last_linear(model): layer = searcher.search(return_name=True) return layer +def parse_last_linear_tf(model): + """Locate the last linear layers of the model. + While pruning, the final linear often acts like classifier head, which might cause + accuracy drop. + + Args: + model(tf.keras.Model): The model to be pruned. + """ + from .model_slim.pattern_analyzer import ClassifierHeadSearcherTF + searcher = ClassifierHeadSearcherTF(model) + layer = searcher.search(return_name=True) + return layer + def parse_to_prune(config, model): """Keep target pruned layers. @@ -462,6 +533,40 @@ def parse_to_prune(config, model): new_modules[name] = modules[name] return new_modules +def parse_to_prune_tf(config, model): + """Keep target pruned layers. + + Args: + config(string): A string representing the path to the configuration file. + model(tf.keras.Model): The model to be pruned. + """ + modules = {} + # additional function: exclude last layer (often a classifier head and not suitable to be pruned) + classifier_head_name = parse_last_linear_tf(model) + if classifier_head_name != None: + config["excluded_op_names"].append(classifier_head_name) + # locate target layers + if config["op_names"] == None or config["op_names"] == []: + config["op_names"] = [".*"] + + for layer in model.layers: + for layer_type in config["pruning_op_types"]: + if layer_type in layer.__class__.__name__ and bool(layer.weights): + modules[layer.name] = layer + + ##remove not to prune layers + """Drop non-pruned layers.""" + exclude_names = config["excluded_op_names"] + patterns = [re.compile(s) for s in exclude_names] + if len(patterns) <= 0: + return modules + new_modules = {} + for name in modules.keys(): + if any([p.search(name) for p in patterns]): + continue + new_modules[name] = modules[name] + return new_modules + def generate_pruner_config(info): """Generate pruner config object from prune information. diff --git a/neural_compressor/config.py b/neural_compressor/config.py index 935aff1a526..7e05f6e8ca0 100644 --- a/neural_compressor/config.py +++ b/neural_compressor/config.py @@ -1372,12 +1372,13 @@ class WeightPruningConfig: def __init__(self, pruning_configs=[{}], ##empty dict will use global values target_sparsity=0.9, pruning_type="snip_momentum", pattern="4x1", op_names=[], - excluded_op_names=[], + excluded_op_names=[], backend=None, start_step=0, end_step=0, pruning_scope="global", pruning_frequency=1, min_sparsity_ratio_per_op=0.0, max_sparsity_ratio_per_op=0.98, sparsity_decay_type="exp", pruning_op_types=['Conv', 'Linear'], **kwargs): """Init a WeightPruningConfig object.""" + self.backend = backend self.pruning_configs = pruning_configs self._weight_compression = DotDict({ 'target_sparsity': target_sparsity, diff --git a/neural_compressor/model/model.py b/neural_compressor/model/model.py index 0d0aa2b374a..e73efbb4feb 100644 --- a/neural_compressor/model/model.py +++ b/neural_compressor/model/model.py @@ -167,10 +167,10 @@ def __new__(cls, root, **kwargs): if isinstance(root, BaseModel): if conf != "NA" and conf.framework is None: conf.framework = list(MODELS.keys())[list(MODELS.values()).index(type(root))] - if conf.backend == "ipex": + if hasattr(conf, 'backend') and conf.backend == "ipex": assert conf.framework == "pytorch_ipex",\ "Please wrap the model with correct Model class!" - if conf.backend == "itex": + if hasattr(conf, 'backend') and conf.backend == "itex": if get_model_type(root.model) == 'keras': assert conf.framework == "keras",\ "Please wrap the model with KerasModel class!" @@ -213,10 +213,10 @@ def __new__(cls, root, **kwargs): return MODELS[framework](root, **kwargs) else: conf.framework = framework - if conf.backend == "default": + if hasattr(conf, 'backend') and conf.backend == "default": if framework == "pytorch": conf.framework = "pytorch_fx" - elif conf.backend == "ipex": + elif hasattr(conf, 'backend') and conf.backend == "ipex": conf.framework = "pytorch_ipex" if 'tensorflow' in conf.framework: @@ -227,7 +227,7 @@ def __new__(cls, root, **kwargs): model_type = kwargs['modelType'] else: model_type = get_model_type(root) - if conf.backend == "itex": + if hasattr(conf, 'backend') and conf.backend == "itex": if model_type == 'keras': conf.framework = "keras" model = MODELS[conf.framework](root, **kwargs) @@ -243,7 +243,7 @@ def __new__(cls, root, **kwargs): root, **kwargs) else: model = MODELS[conf.framework](root, **kwargs) - if 'tensorflow' in conf.framework: + if 'tensorflow' in conf.framework and hasattr(conf, 'model_name'): model.name = conf.model_name model.output_tensor_names = conf.outputs model.input_tensor_names = conf.inputs diff --git a/neural_compressor/training.py b/neural_compressor/training.py index 5d6d90e5e63..28905f1bd78 100644 --- a/neural_compressor/training.py +++ b/neural_compressor/training.py @@ -80,7 +80,7 @@ def __init__(self, model: Callable, confs: Union[Callable, List], **kwargs): if isinstance(confs, List) and len(confs) > 1: for conf in confs: - if isinstance(conf, QuantizationAwareTrainingConfig): + if isinstance(conf, QuantizationAwareTrainingConfig) or isinstance(conf, WeightPruningConfig): self.model = Model(model, conf=conf) if self.model is None: self.model = Model(model) @@ -150,7 +150,7 @@ def __init__(self, model: Callable, confs: Union[Callable, List], **kwargs): callbacks_list.append(QuantizationAwareTrainingCallbacks(confs, adaptor=self.adaptor)) self.conf = _Config(quantization=confs, benchmark=None, pruning=None, distillation=None, nas=None) elif isinstance(confs, WeightPruningConfig): - self.model = Model(model) + self.model = Model(model, conf=confs) callbacks_list.append(PruningCallbacks(confs, model=self.model)) self.conf = _Config(quantization=None, benchmark=None, pruning=confs, distillation=None, nas=None) elif isinstance(confs, DistillationConfig): diff --git a/test/pruning_2.x/test_pruning.py b/test/pruning_2.x/test_pruning.py index c401a4cd269..2398e57c99d 100644 --- a/test/pruning_2.x/test_pruning.py +++ b/test/pruning_2.x/test_pruning.py @@ -9,7 +9,11 @@ from neural_compressor.data.dataloaders.pytorch_dataloader import PyTorchDataLoader from neural_compressor import WeightPruningConfig from neural_compressor.training import prepare_compression - +from neural_compressor.data import DataLoader +from neural_compressor.adaptor import FRAMEWORKS +from neural_compressor.conf.dotdict import DotDict +from neural_compressor.utils import create_obj_from_config +from neural_compressor.conf.config import default_workspace class TestPruning(unittest.TestCase): model = torchvision.models.resnet18() @@ -72,6 +76,58 @@ def test_pruning_basic(self): compression_manager.callbacks.on_before_eval() compression_manager.callbacks.on_after_eval() + def test_pruning_keras(self): + import tensorflow as tf + model = tf.keras.applications.ResNet50V2(weights='imagenet') + def train(model, adaptor, compression_manager, train_dataloader): + train_cfg = { + 'epoch': 1, + 'start_epoch': 0, + 'execution_mode': 'eager', + 'criterion': {'SparseCategoricalCrossentropy': {'reduction': 'sum_over_batch_size'}}, + 'optimizer': {'SGD': {'learning_rate': 1e-03, 'momentum': 0.9, 'nesterov': True}}, + } + train_cfg = DotDict(train_cfg) + train_func = create_obj_from_config.create_train_func( + 'tensorflow', \ + train_dataloader, \ + adaptor, \ + train_cfg, \ + hooks=compression_manager.callbacks.callbacks_list[0].hooks, \ + callbacks=compression_manager.callbacks.callbacks_list[0]) + train_func(model) + + tf_datasets = Datasets('tensorflow') + dummy_dataset = tf_datasets['dummy'](shape=(100, 224, 224, 3), low=0., high=1., label=True) + train_dataloader = DataLoader(dataset=dummy_dataset, batch_size=32, + framework='tensorflow', distributed=False) + + framework_specific_info = { + 'device': 'cpu', 'random_seed': 9527, + 'workspace_path': default_workspace, + 'q_dataloader': None, 'format': 'default', + 'backend': 'default', 'inputs': [], 'outputs': [] + } + adaptor = FRAMEWORKS['keras'](framework_specific_info) + + configs = WeightPruningConfig( + backend='itex', + pruning_type='magnitude', + pattern='3x1', + target_sparsity=0.5, + start_step=1, + end_step=10, + pruning_op_types=['Conv', 'Dense'] + ) + compression_manager = prepare_compression(model, confs=configs) + compression_manager.callbacks.on_train_begin() + model = compression_manager.model + + train(model, adaptor, compression_manager, train_dataloader) + + compression_manager.callbacks.on_train_end() + stats, sparsity = model.report_sparsity() + if __name__ == "__main__": unittest.main()