diff --git a/examples/README.md b/examples/README.md
index a144627c5db..960279e7864 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -303,6 +303,32 @@ Intel® Neural Compressor validated examples with multiple compression technique
+## Pruning
+
+
+
+ Model |
+ Domain |
+ Approach |
+ Examples |
+
+
+
+
+ ResNet V2 |
+ Image Recognition |
+ Structured (4x1, 2in4) |
+ keras |
+
+
+ ViT |
+ Image Recognition |
+ Structured (4x1, 2in4) |
+ keras |
+
+
+
+
## Model Export
diff --git a/examples/tensorflow/image_recognition/ViT/pruning/magnitude/README.md b/examples/tensorflow/image_recognition/ViT/pruning/magnitude/README.md
new file mode 100644
index 00000000000..3e3f419ad33
--- /dev/null
+++ b/examples/tensorflow/image_recognition/ViT/pruning/magnitude/README.md
@@ -0,0 +1,38 @@
+Step-by-Step
+============
+
+This document is used to list steps of reproducing Intel® Neural Compressor magnitude pruning feature on ViT model.
+
+
+# Prerequisite
+
+## 1. Environment
+
+### Install Intel® Neural Compressor
+```shell
+pip install neural-compressor
+```
+### Install requirements
+```shell
+pip install -r requirements.txt
+```
+
+## 2. Prepare Model
+Run the script to save a baseline model to the directory './ViT_Model'.
+```python
+python prepare_model.py
+```
+
+# Run
+Run the command to prune the baseline model and save it into a given path.
+The CIFAR100 dataset will be automatically loaded.
+
+```shell
+python main.py --output_model=/path/to/output_model/
+```
+
+If you want to accelerate pruning with multi-node distributed training and evaluation, you only need to add twp arguments and use horovod to run main.py. Run the command to get pruned model with multi-node distributed training and evaluation.
+
+```shell
+horovodrun -np -H python main.py --output_model=/path/to/output_model/ --train_distributed --evaluation_distributed
+```
\ No newline at end of file
diff --git a/examples/tensorflow/image_recognition/ViT/pruning/magnitude/main.py b/examples/tensorflow/image_recognition/ViT/pruning/magnitude/main.py
new file mode 100644
index 00000000000..d8607593969
--- /dev/null
+++ b/examples/tensorflow/image_recognition/ViT/pruning/magnitude/main.py
@@ -0,0 +1,147 @@
+#
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2023 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import numpy as np
+import tensorflow as tf
+import tensorflow_addons as tfa
+from neural_compressor.utils import logger
+from neural_compressor.data import DataLoader
+from neural_compressor.adaptor import FRAMEWORKS
+from neural_compressor.conf.dotdict import DotDict
+from neural_compressor.training import WeightPruningConfig
+from neural_compressor.training import prepare_compression
+from neural_compressor.utils import create_obj_from_config
+from neural_compressor.conf.config import default_workspace
+
+flags = tf.compat.v1.flags
+FLAGS = flags.FLAGS
+
+## Required parameters
+flags.DEFINE_string(
+ 'output_model', None, 'The output pruned model.')
+
+flags.DEFINE_integer(
+ 'start_step', 0, 'The start step of pruning process.')
+
+flags.DEFINE_integer(
+ 'end_step', 9, 'The end step of pruning process.')
+
+flags.DEFINE_bool(
+ 'train_distributed', False, 'Whether to perform distributed training.')
+
+flags.DEFINE_bool(
+ 'evaluation_distributed', False, 'Whether to perform distributed evaluation.')
+
+# Prepare dataset
+def prepare_dataset():
+ (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()
+ y_train = tf.keras.utils.to_categorical(y_train, 100)
+ y_test = tf.keras.utils.to_categorical(y_test, 100)
+ logger.info(f"Training set: x_shape-{x_train.shape}, y_shape-{y_train.shape}")
+ logger.info(f"Test set: x_shape-{x_test.shape}, y_shape-{y_test.shape}")
+ return TrainDataset(x_train, y_train), EvalDataset(x_test, y_test)
+
+# Build TrainDataset and EvalDataset
+class TrainDataset(object):
+ def __init__(self, x_train, y_train):
+ self.x_train = x_train
+ self.y_train = y_train
+
+ def __len__(self):
+ return len(self.x_train)
+
+ def __getitem__(self, idx):
+ return self.x_train[idx], self.y_train[idx]
+
+class EvalDataset(object):
+ def __init__(self, x_test, y_test):
+ self.x_test = x_test
+ self.y_test = y_test
+
+ def __len__(self):
+ return len(self.x_test)
+
+ def __getitem__(self, idx):
+ return self.x_test[idx], self.y_test[idx]
+
+def train(model, adaptor, compression_manager, train_dataloader):
+ train_cfg = {
+ 'epoch': 15,
+ 'start_epoch': 0,
+ 'execution_mode': 'eager',
+ 'criterion': {'CrossEntropyLoss': {'reduction': 'sum_over_batch_size', 'from_logits': True}},
+ 'optimizer': {'AdamW': {'learning_rate': 1e-03, 'weight_decay': 1e-04}},
+ }
+ train_cfg = DotDict(train_cfg)
+ train_func = create_obj_from_config.create_train_func('tensorflow', \
+ train_dataloader, \
+ adaptor, \
+ train_cfg, \
+ hooks=compression_manager.callbacks.callbacks_list[0].hooks, \
+ callbacks=compression_manager.callbacks.callbacks_list[0])
+ train_func(model)
+
+def evaluate(model, adaptor, eval_dataloader):
+ eval_cfg = {'accuracy': {'metric': {'topk': 1},
+ 'iteration': -1,
+ 'multi_metrics': None}
+ }
+ eval_cfg = DotDict(eval_cfg)
+ eval_func = create_obj_from_config.create_eval_func('tensorflow', \
+ eval_dataloader, \
+ adaptor, \
+ eval_cfg.accuracy.metric, \
+ eval_cfg.accuracy.postprocess, \
+ fp32_baseline = False)
+ return eval_func(model)
+
+if __name__ == '__main__':
+ training_set, test_set = prepare_dataset()
+ train_dataloader = DataLoader(dataset=training_set, batch_size=128,
+ framework='tensorflow', distributed=FLAGS.train_distributed)
+ eval_dataloader = DataLoader(dataset=test_set, batch_size=256,
+ framework='tensorflow', distributed=FLAGS.evaluation_distributed)
+
+ framework_specific_info = {
+ 'device': 'cpu', 'random_seed': 9527,
+ 'workspace_path': default_workspace,
+ 'q_dataloader': None, 'format': 'default',
+ 'backend': 'default', 'inputs': [], 'outputs': []
+ }
+ adaptor = FRAMEWORKS['keras'](framework_specific_info)
+
+ configs = WeightPruningConfig(
+ backend='itex',
+ pruning_type='magnitude',
+ target_sparsity=0.7,
+ start_step=FLAGS.start_step,
+ end_step=FLAGS.end_step,
+ pruning_op_types=['Conv', 'Dense']
+ )
+ compression_manager = prepare_compression(model='./ViT_Model', confs=configs)
+ compression_manager.callbacks.on_train_begin()
+ model = compression_manager.model
+
+ train(model, adaptor, compression_manager, train_dataloader)
+ print("Pruned model score is ",evaluate(model, adaptor, eval_dataloader))
+
+
+ compression_manager.callbacks.on_train_end()
+ compression_manager.save(FLAGS.output_model)
+ stats, sparsity = model.report_sparsity()
+ logger.info(stats)
+ logger.info(sparsity)
\ No newline at end of file
diff --git a/examples/tensorflow/image_recognition/ViT/pruning/magnitude/prepare_model.py b/examples/tensorflow/image_recognition/ViT/pruning/magnitude/prepare_model.py
new file mode 100644
index 00000000000..c89a2e00bf8
--- /dev/null
+++ b/examples/tensorflow/image_recognition/ViT/pruning/magnitude/prepare_model.py
@@ -0,0 +1,180 @@
+#
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2023 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import numpy as np
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.keras import layers
+import tensorflow_addons as tfa
+
+num_classes = 100
+input_shape = (32, 32, 3)
+
+(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
+
+print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
+print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
+
+learning_rate = 0.001
+weight_decay = 0.0001
+batch_size = 256
+num_epochs = 100
+image_size = 72 # We'll resize input images to this size
+patch_size = 6 # Size of the patches to be extract from the input images
+num_patches = (image_size // patch_size) ** 2
+projection_dim = 64
+num_heads = 4
+transformer_units = [
+ projection_dim * 2,
+ projection_dim,
+] # Size of the transformer layers
+transformer_layers = 8
+mlp_head_units = [2048, 1024] # Size of the dense layers of the final classifier
+
+data_augmentation = keras.Sequential(
+ [
+ layers.Normalization(),
+ layers.Resizing(image_size, image_size),
+ layers.RandomFlip("horizontal"),
+ layers.RandomRotation(factor=0.02),
+ layers.RandomZoom(
+ height_factor=0.2, width_factor=0.2
+ ),
+ ],
+ name="data_augmentation",
+)
+# Compute the mean and the variance of the training data for normalization.
+data_augmentation.layers[0].adapt(x_train)
+
+
+def mlp(x, hidden_units, dropout_rate):
+ for units in hidden_units:
+ x = layers.Dense(units, activation=tf.nn.gelu)(x)
+ x = layers.Dropout(dropout_rate)(x)
+ return x
+
+class Patches(layers.Layer):
+ def __init__(self, patch_size):
+ super().__init__()
+ self.patch_size = patch_size
+
+ def call(self, images):
+ batch_size = tf.shape(images)[0]
+ patches = tf.image.extract_patches(
+ images=images,
+ sizes=[1, self.patch_size, self.patch_size, 1],
+ strides=[1, self.patch_size, self.patch_size, 1],
+ rates=[1, 1, 1, 1],
+ padding="VALID",
+ )
+ patch_dims = patches.shape[-1]
+ patches = tf.reshape(patches, [batch_size, -1, patch_dims])
+ return patches
+
+class PatchEncoder(layers.Layer):
+ def __init__(self, num_patches, projection_dim):
+ super().__init__()
+ self.num_patches = num_patches
+ self.projection = layers.Dense(units=projection_dim)
+ self.position_embedding = layers.Embedding(
+ input_dim=num_patches, output_dim=projection_dim
+ )
+
+ def call(self, patch):
+ positions = tf.range(start=0, limit=self.num_patches, delta=1)
+ encoded = self.projection(patch) + self.position_embedding(positions)
+ return encoded
+
+
+def create_vit_classifier():
+ inputs = layers.Input(shape=input_shape)
+ # Augment data.
+ augmented = data_augmentation(inputs)
+ # Create patches.
+ patches = Patches(patch_size)(augmented)
+ # Encode patches.
+ encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
+
+ # Create multiple layers of the Transformer block.
+ for _ in range(transformer_layers):
+ # Layer normalization 1.
+ x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
+ # Create a multi-head attention layer.
+ attention_output = layers.MultiHeadAttention(
+ num_heads=num_heads, key_dim=projection_dim, dropout=0.1
+ )(x1, x1)
+ # Skip connection 1.
+ x2 = layers.Add()([attention_output, encoded_patches])
+ # Layer normalization 2.
+ x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
+ # MLP.
+ x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
+ # Skip connection 2.
+ encoded_patches = layers.Add()([x3, x2])
+
+ # Create a [batch_size, projection_dim] tensor.
+ representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
+ representation = layers.Flatten()(representation)
+ representation = layers.Dropout(0.5)(representation)
+ # Add MLP.
+ features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
+ # Classify outputs.
+ logits = layers.Dense(num_classes)(features)
+ # Create the Keras model.
+ model = keras.Model(inputs=inputs, outputs=logits)
+ return model
+
+def run_experiment(model):
+ optimizer = tfa.optimizers.AdamW(
+ learning_rate=learning_rate, weight_decay=weight_decay
+ )
+
+ model.compile(
+ optimizer=optimizer,
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ metrics=[
+ keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
+ keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
+ ],
+ )
+
+ checkpoint_filepath = "/tmp/checkpoint"
+ checkpoint_callback = keras.callbacks.ModelCheckpoint(
+ checkpoint_filepath,
+ monitor="val_accuracy",
+ save_best_only=True,
+ save_weights_only=True,
+ )
+
+ history = model.fit(
+ x=x_train,
+ y=y_train,
+ batch_size=batch_size,
+ epochs=num_epochs,
+ validation_split=0.1,
+ callbacks=[checkpoint_callback],
+ )
+
+ model.load_weights(checkpoint_filepath)
+ _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
+ print(f"Test accuracy: {round(accuracy * 100, 2)}%")
+ print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
+ model.save("./ViT_Model") # Add this line
+ return history
+
+vit_classifier = create_vit_classifier()
+history = run_experiment(vit_classifier)
\ No newline at end of file
diff --git a/examples/tensorflow/image_recognition/ViT/pruning/magnitude/requirements.txt b/examples/tensorflow/image_recognition/ViT/pruning/magnitude/requirements.txt
new file mode 100644
index 00000000000..e904ddad36e
--- /dev/null
+++ b/examples/tensorflow/image_recognition/ViT/pruning/magnitude/requirements.txt
@@ -0,0 +1,5 @@
+tensorflow
+keras
+tensorflow-estimator
+tensorflow-addons
+horovod
\ No newline at end of file
diff --git a/examples/tensorflow/image_recognition/resnet_v2/pruning/magnitude/README.md b/examples/tensorflow/image_recognition/resnet_v2/pruning/magnitude/README.md
new file mode 100644
index 00000000000..a99c087ef23
--- /dev/null
+++ b/examples/tensorflow/image_recognition/resnet_v2/pruning/magnitude/README.md
@@ -0,0 +1,35 @@
+Step-by-Step
+============
+
+This document is used to list steps of reproducing Intel® Neural Compressor magnitude pruning feature.
+
+
+# Prerequisite
+
+## 1. Environment
+
+### Install Intel® Neural Compressor
+```shell
+pip install neural-compressor
+```
+### Install TensorFlow
+```shell
+pip install tensorflow
+```
+
+# Run
+Run the command to get pretrained baseline model which will be saved to './baseline_model'. Then, the model will be pruned and saved into a given path.
+The CIFAR10 dataset will be automatically loaded.
+```shell
+python main.py --output_model=/path/to/output_model/ --prune
+```
+If you want to accelerate pruning with multi-node distributed training and evaluation, you only need to add two arguments and use horovod to run main.py.
+Use horovod to run main.py to get pruned model with multi-node distributed training and evaluation.
+```shell
+horovodrun -np -H python main.py --output_model=/path/to/output_model/ --train_distributed --evaluation_distributed --prune
+```
+
+Run the command to get pruned model performance.
+```shell
+python main.py --input_model=/path/to/input_model/ --benchmark
+```
diff --git a/examples/tensorflow/image_recognition/resnet_v2/pruning/magnitude/main.py b/examples/tensorflow/image_recognition/resnet_v2/pruning/magnitude/main.py
new file mode 100644
index 00000000000..450ba18aca7
--- /dev/null
+++ b/examples/tensorflow/image_recognition/resnet_v2/pruning/magnitude/main.py
@@ -0,0 +1,453 @@
+#
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2023 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from __future__ import print_function
+import yaml
+import numpy as np
+import tensorflow
+
+from tensorflow.keras.layers import Dense, Conv2D, BatchNormalization, Activation
+from tensorflow.keras.layers import AveragePooling2D, Input, Flatten
+from tensorflow.keras.callbacks import LearningRateScheduler
+from tensorflow.keras.callbacks import ReduceLROnPlateau
+from tensorflow.keras.regularizers import l2
+from tensorflow.keras.models import Model
+from tensorflow.keras.datasets import cifar10
+
+
+from neural_compressor.utils import logger
+from neural_compressor.data import DataLoader
+from neural_compressor.adaptor import FRAMEWORKS
+from neural_compressor.conf.dotdict import DotDict
+from neural_compressor.training import WeightPruningConfig
+from neural_compressor.training import prepare_compression
+from neural_compressor.utils import create_obj_from_config
+from neural_compressor.conf.config import default_workspace
+
+flags = tensorflow.compat.v1.flags
+FLAGS = flags.FLAGS
+
+## Required parameters
+flags.DEFINE_bool(
+ 'prune', False, 'Whether to perform distributed training.')
+
+flags.DEFINE_bool(
+ 'benchmark', False, 'Whether to perform distributed evaluation.')
+
+flags.DEFINE_string(
+ 'input_model', None, 'Run inference with specified model.')
+
+flags.DEFINE_string(
+ 'output_model', None, 'The output pruned model.')
+
+flags.DEFINE_integer(
+ 'start_step', 0, 'The start step of pruning process.')
+
+flags.DEFINE_integer(
+ 'end_step', 7, 'The end step of pruning process.')
+
+flags.DEFINE_integer(
+ 'iters', 100, 'The iteration of evaluating the performance.')
+
+flags.DEFINE_bool(
+ 'train_distributed', False, 'Whether to perform distributed training.')
+
+flags.DEFINE_bool(
+ 'evaluation_distributed', False, 'Whether to perform distributed evaluation.')
+
+
+def lr_schedule(epoch):
+ """Learning Rate Schedule
+
+ Learning rate is scheduled to be reduced after 80, 120, 160, 180 epochs.
+ Called automatically every epoch as part of callbacks during training.
+
+ # Arguments
+ epoch (int): The number of epochs
+
+ # Returns
+ lr (float32): learning rate
+ """
+ lr = 1e-3
+ if epoch > 180:
+ lr *= 0.5e-3
+ elif epoch > 160:
+ lr *= 1e-3
+ elif epoch > 120:
+ lr *= 1e-2
+ elif epoch > 80:
+ lr *= 1e-1
+ print('Learning rate: ', lr)
+ return lr
+
+
+def resnet_layer(inputs,
+ num_filters=16,
+ kernel_size=3,
+ strides=1,
+ activation='relu',
+ batch_normalization=True,
+ conv_first=True):
+ """2D Convolution-Batch Normalization-Activation stack builder
+
+ # Arguments
+ inputs (tensor): input tensor from input image or previous layer
+ num_filters (int): Conv2D number of filters
+ kernel_size (int): Conv2D square kernel dimensions
+ strides (int): Conv2D square stride dimensions
+ activation (string): activation name
+ batch_normalization (bool): whether to include batch normalization
+ conv_first (bool): conv-bn-activation (True) or
+ bn-activation-conv (False)
+
+ # Returns
+ x (tensor): tensor as input to the next layer
+ """
+ conv = Conv2D(num_filters,
+ kernel_size=kernel_size,
+ strides=strides,
+ padding='same',
+ use_bias=True,
+ kernel_initializer='he_normal',
+ kernel_regularizer=l2(1e-4))
+
+ x = inputs
+ if conv_first:
+ x = conv(x)
+ if batch_normalization:
+ x = BatchNormalization()(x)
+ if activation is not None:
+ x = Activation(activation)(x)
+ else:
+ if batch_normalization:
+ x = BatchNormalization()(x)
+ if activation is not None:
+ x = Activation(activation)(x)
+ x = conv(x)
+ return x
+
+def resnet_v2(input_shape, depth, num_classes=10):
+ """ResNet Version 2 Model builder [b]
+
+ Stacks of (1 x 1)-(3 x 3)-(1 x 1) BN-ReLU-Conv2D or also known as
+ bottleneck layer
+ First shortcut connection per layer is 1 x 1 Conv2D.
+ Second and onwards shortcut connection is identity.
+ At the beginning of each stage, the feature map size is halved (downsampled)
+ by a convolutional layer with strides=2, while the number of filter maps is
+ doubled. Within each stage, the layers have the same number filters and the
+ same filter map sizes.
+ Features maps sizes:
+ conv1 : 32x32, 16
+ stage 0: 32x32, 64
+ stage 1: 16x16, 128
+ stage 2: 8x8, 256
+
+ # Arguments
+ input_shape (tensor): shape of input image tensor
+ depth (int): number of core convolutional layers
+ num_classes (int): number of classes (CIFAR10 has 10)
+
+ # Returns
+ model (Model): Keras model instance
+ """
+ if (depth - 2) % 9 != 0:
+ raise ValueError('depth should be 9n+2 (eg 56 or 110 in [b])')
+ # Start model definition.
+ num_filters_in = 16
+ num_res_blocks = int((depth - 2) / 9)
+
+ inputs = Input(shape=input_shape)
+ # v2 performs Conv2D with BN-ReLU on input before splitting into 2 paths
+ x = resnet_layer(inputs=inputs,
+ num_filters=num_filters_in,
+ conv_first=True)
+
+ # Instantiate the stack of residual units
+ for stage in range(3):
+ for res_block in range(num_res_blocks):
+ activation = 'relu'
+ batch_normalization = True
+ strides = 1
+ if stage == 0:
+ num_filters_out = num_filters_in * 4
+ if res_block == 0: # first layer and first stage
+ activation = None
+ batch_normalization = False
+ else:
+ num_filters_out = num_filters_in * 2
+ if res_block == 0: # first layer but not first stage
+ strides = 2 # downsample
+
+ # bottleneck residual unit
+ y = resnet_layer(inputs=x,
+ num_filters=num_filters_in,
+ kernel_size=1,
+ strides=strides,
+ activation=activation,
+ batch_normalization=batch_normalization,
+ conv_first=False)
+ y = resnet_layer(inputs=y,
+ num_filters=num_filters_in,
+ conv_first=False)
+
+ y = resnet_layer(inputs=y,
+ num_filters=num_filters_out,
+ kernel_size=1,
+ conv_first=False)
+ if res_block == 0:
+ # linear projection residual shortcut connection to match
+ # changed dims
+ x = resnet_layer(inputs=x,
+ num_filters=num_filters_out,
+ kernel_size=1,
+ strides=strides,
+ activation=None,
+ batch_normalization=False)
+ x = tensorflow.keras.layers.add([x, y])
+
+ num_filters_in = num_filters_out
+
+ # Add classifier on top.
+ # v2 has BN-ReLU before Pooling
+ x = BatchNormalization()(x)
+ x = Activation('relu')(x)
+ x = AveragePooling2D(pool_size=8)(x)
+ y = Flatten()(x)
+ outputs = Dense(num_classes,
+ activation='softmax',
+ kernel_initializer='he_normal')(y)
+
+ # Instantiate model.
+ model = Model(inputs=inputs, outputs=outputs)
+ return model
+
+# Training parameters
+batch_size = 32 # orig paper trained all networks with batch_size=128
+epochs = 2
+num_classes = 10
+
+# Subtracting pixel mean improves accuracy
+subtract_pixel_mean = True
+
+# Model parameter
+# ----------------------------------------------------------------------------
+# | | 200-epoch | Orig Paper| 200-epoch | Orig Paper| sec/epoch
+# Model | n | ResNet v1 | ResNet v1 | ResNet v2 | ResNet v2 | GTX1080Ti
+# |v1(v2)| %Accuracy | %Accuracy | %Accuracy | %Accuracy | v1 (v2)
+# ----------------------------------------------------------------------------
+# ResNet20 | 3 (2)| 92.16 | 91.25 | ----- | ----- | 35 (---)
+# ResNet32 | 5(NA)| 92.46 | 92.49 | NA | NA | 50 ( NA)
+# ResNet44 | 7(NA)| 92.50 | 92.83 | NA | NA | 70 ( NA)
+# ResNet56 | 9 (6)| 92.71 | 93.03 | 93.01 | NA | 90 (100)
+# ResNet110 |18(12)| 92.65 | 93.39+-.16| 93.15 | 93.63 | 165(180)
+# ResNet164 |27(18)| ----- | 94.07 | ----- | 94.54 | ---(---)
+# ResNet1001| (111)| ----- | 92.39 | ----- | 95.08+-.14| ---(---)
+# ---------------------------------------------------------------------------
+n = 3
+depth = n * 9 + 2
+
+def generate_pretrained_model():
+ # Load the CIFAR10 data.
+ (x_train, y_train), (x_test, y_test) = cifar10.load_data()
+
+ # Input image dimensions.
+ input_shape = x_train.shape[1:]
+ # Normalize data.
+ x_train = x_train.astype('float32') / 255
+ x_test = x_test.astype('float32') / 255
+
+ x_train_mean = np.mean(x_train, axis=0)
+ x_train -= x_train_mean
+ x_test -= x_train_mean
+
+ print('x_train shape:', x_train.shape)
+ print(x_train.shape[0], 'train samples')
+ print(x_test.shape[0], 'test samples')
+ print('y_train shape:', y_train.shape)
+
+ # Convert class vectors to binary class matrices.
+ y_train = tensorflow.keras.utils.to_categorical(y_train, num_classes)
+ y_test = tensorflow.keras.utils.to_categorical(y_test, num_classes)
+
+ model = resnet_v2(input_shape=input_shape, depth=depth)
+
+ model.compile(loss='categorical_crossentropy',
+ optimizer=tensorflow.keras.optimizers.Adam(learning_rate=0.01),
+ metrics=['accuracy'])
+ model.summary()
+
+ lr_scheduler = LearningRateScheduler(lr_schedule)
+
+ lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.1),
+ cooldown=0,
+ patience=5,
+ min_lr=0.5e-6)
+
+ callbacks = [lr_reducer, lr_scheduler]
+
+ # Run training, with or without data augmentation.
+ model.fit(x_train, y_train,
+ batch_size=batch_size,
+ epochs=epochs,
+ validation_data=(x_test, y_test),
+ shuffle=True,
+ callbacks=callbacks)
+
+
+ # Score trained model.
+ scores = model.evaluate(x_test, y_test, verbose=1)
+ print('Test loss:', scores[0])
+ print('Test accuracy:', scores[1])
+ model.save("baseline_model")
+
+class EvalDataset(object):
+ def __init__(self, batch_size=100):
+ (x_train, y_train), (x_test, y_test) = cifar10.load_data()
+
+ x_train = x_train.astype('float32') / 255
+ x_test = x_test.astype('float32') / 255
+
+ # If subtract pixel mean is enabled
+ x_train_mean = np.mean(x_train, axis=0)
+ x_train -= x_train_mean
+ x_test -= x_train_mean
+
+ print('x_train shape:', x_train.shape)
+ print(x_train.shape[0], 'train samples')
+ print(x_test.shape[0], 'test samples')
+ print('y_train shape:', y_train.shape)
+
+ # Convert class vectors to binary class matrices.
+ y_train = tensorflow.keras.utils.to_categorical(y_train, num_classes)
+ y_test = tensorflow.keras.utils.to_categorical(y_test, num_classes)
+ self.test_images = x_test
+ self.test_labels = y_test
+
+ def __len__(self):
+ return len(self.test_images)
+
+ def __getitem__(self, idx):
+ return self.test_images[idx], self.test_labels[idx]
+
+class TrainDataset(object):
+ def __init__(self, batch_size=100):
+ (x_train, y_train), (x_test, y_test) = cifar10.load_data()
+
+ x_train = x_train.astype('float32') / 255
+ x_test = x_test.astype('float32') / 255
+
+ # If subtract pixel mean is enabled
+ x_train_mean = np.mean(x_train, axis=0)
+ x_train -= x_train_mean
+ x_test -= x_train_mean
+
+ print('x_train shape:', x_train.shape)
+ print(x_train.shape[0], 'train samples')
+ print(x_test.shape[0], 'test samples')
+ print('y_train shape:', y_train.shape)
+
+ # Convert class vectors to binary class matrices.
+ y_train = tensorflow.keras.utils.to_categorical(y_train, num_classes)
+ y_test = tensorflow.keras.utils.to_categorical(y_test, num_classes)
+ self.test_images = x_test
+ self.test_labels = y_test
+ self.train_images = x_train
+ self.train_labels = y_train
+
+ def __len__(self):
+ return len(self.train_images)
+
+ def __getitem__(self, idx):
+ return self.train_images[idx], self.train_labels[idx]
+
+def train(model, adaptor, compression_manager, train_dataloader):
+ train_cfg = {
+ 'epoch': 8,
+ 'start_epoch': 0,
+ 'execution_mode': 'eager',
+ 'criterion': {'CrossEntropyLoss': {'reduction': 'sum_over_batch_size'}},
+ 'optimizer': {'SGD': {'learning_rate': 1e-03, 'momentum': 0.9, 'nesterov': True}},
+ }
+ train_cfg = DotDict(train_cfg)
+ train_func = create_obj_from_config.create_train_func(
+ 'tensorflow', \
+ train_dataloader, \
+ adaptor, \
+ train_cfg, \
+ hooks=compression_manager.callbacks.callbacks_list[0].hooks, \
+ callbacks=compression_manager.callbacks.callbacks_list[0])
+ train_func(model)
+
+
+def evaluate(model, adaptor, eval_dataloader):
+ eval_cfg = {'accuracy': {'metric': {'topk': 1},
+ 'iteration': -1,
+ 'multi_metrics': None}
+ }
+ eval_cfg = DotDict(eval_cfg)
+ eval_func = create_obj_from_config.create_eval_func('tensorflow', \
+ eval_dataloader, \
+ adaptor, \
+ eval_cfg.accuracy.metric, \
+ eval_cfg.accuracy.postprocess, \
+ fp32_baseline = False)
+ return eval_func(model)
+
+if __name__ == '__main__':
+ train_dataloader = DataLoader(dataset=TrainDataset(), batch_size=32,
+ framework='tensorflow', distributed=FLAGS.train_distributed)
+ eval_dataloader = DataLoader(dataset=EvalDataset(), batch_size=32,
+ framework='tensorflow', distributed=FLAGS.evaluation_distributed)
+
+ if FLAGS.prune:
+ generate_pretrained_model()
+ framework_specific_info = {
+ 'device': 'cpu', 'random_seed': 9527,
+ 'workspace_path': default_workspace,
+ 'q_dataloader': None, 'format': 'default',
+ 'backend': 'default', 'inputs': [], 'outputs': []
+ }
+ adaptor = FRAMEWORKS['keras'](framework_specific_info)
+
+ configs = WeightPruningConfig(
+ backend='itex',
+ pruning_type='magnitude',
+ pattern='3x1',
+ target_sparsity=0.25,
+ start_step=FLAGS.start_step,
+ end_step=FLAGS.end_step,
+ pruning_op_types=['Conv', 'Dense']
+ )
+ compression_manager = prepare_compression(model='./baseline_model', confs=configs)
+ compression_manager.callbacks.on_train_begin()
+ model = compression_manager.model
+
+ train(model, adaptor, compression_manager, train_dataloader)
+ print("Pruned model score is ",evaluate(model, adaptor, eval_dataloader))
+
+ compression_manager.callbacks.on_train_end()
+ compression_manager.save(FLAGS.output_model)
+ stats, sparsity = model.report_sparsity()
+ logger.info(stats)
+ logger.info(sparsity)
+
+ if FLAGS.benchmark:
+ from neural_compressor.benchmark import fit
+ from neural_compressor.config import BenchmarkConfig
+ conf = BenchmarkConfig(cores_per_instance=4, num_of_instance=1, iteration=FLAGS.iters)
+ fit(FLAGS.input_model, conf, b_dataloader=eval_dataloader)
\ No newline at end of file
diff --git a/neural_compressor/adaptor/keras.py b/neural_compressor/adaptor/keras.py
index d106936c3f8..8924de8bb2e 100644
--- a/neural_compressor/adaptor/keras.py
+++ b/neural_compressor/adaptor/keras.py
@@ -748,6 +748,129 @@ def convert(self, model, source, destinatin):
'''
pass
+ def _pre_hook_for_hvd(self, dataloader=None):
+ """Pre hook for Horovod."""
+ import horovod.tensorflow as hvd
+ self.hvd = hvd
+ self.hvd.init()
+
+ @dump_elapsed_time(customized_msg="Model training")
+ def train(self, model, dataloader, optimizer_tuple,
+ criterion_tuple, hooks, postprocess, **kwargs):
+ """Model training API.
+
+ Args:
+ model ([Graph, GraphDef or Path String]): The model could be the graph,
+ graph_def object, the frozen pb or ckpt/savedmodel folder path.
+ dataloader (generator): generate the data and labels.
+ optimizer_tuple (tuple): optimizers for model training.
+ criterion_tuple (tuple): criterions for model training.
+ hooks (callback): on_epoch_begin hook on_epoch_end hook.
+ postprocess (object): process the result from the model.
+
+ Returns:
+ None.
+ """
+ # check model is savedmodel or not
+ import tensorflow as tf
+ from neural_compressor.model.tensorflow_model import get_model_type
+ tf.random.set_seed(1)
+ self.model_type = get_model_type(model._model)
+ optimizer = optimizer_tuple[0](**optimizer_tuple[1])
+ criterion = criterion_tuple[0](**criterion_tuple[1])
+ start_epochs = kwargs['kwargs'].get('start_epoch', None)
+ end_epochs = kwargs['kwargs'].get('end_epoch', None)
+ epochs = kwargs['kwargs'].get('epoch', None)
+ iters = kwargs['kwargs'].get('iteration', None)
+ callbacks = kwargs['kwargs'].get('callbacks', None)
+ execution_mode = kwargs['kwargs'].get('execution_mode', None)
+ distributed = getattr(dataloader, 'distributed', False)
+
+ if isinstance(model._model, tf.keras.Model):
+ input_model = model._model
+ else:
+ input_model = tf.keras.models.load_model(model._model)
+ # hooks = callbacks['tf_pruning'](model, input_model, hooks)
+ hooks['on_train_begin']() # on_train_begin hook
+ train_loss_results = []
+ if distributed:
+ try:
+ len_dataloader = len(dataloader)
+ except:
+ logger.info("The length of the distributed training dataloader is unknown."
+ "When the iteration of training dataloader in each process is "
+ "inconsistent, an error may occur.")
+ else:
+ list_len_dataloader = self.hvd.allgather_object(len_dataloader)
+ if self.hvd.rank() == 0:
+ for i in range(len(list_len_dataloader)-1):
+ if list_len_dataloader[i] != list_len_dataloader[i+1]:
+ raise AttributeError("The traning dataloader's iteration is"
+ "different between processes, please reset dataloader's batch_size.")
+
+ def training_step(x, y, first_batch):
+ with tf.GradientTape() as tape:
+ tape.watch(input_model.trainable_variables)
+ y_ = input_model(x, training=True)
+ loss_value = criterion(y, y_)
+
+ tape = self.hvd.DistributedGradientTape(tape) if distributed else tape
+ # Get gradient
+ grads = tape.gradient(loss_value, input_model.trainable_variables) # pylint: disable=no-member
+ # Optimize the model
+ optimizer.apply_gradients(zip(grads, input_model.trainable_variables)) # pylint: disable=no-member
+ if distributed and first_batch:
+ self.hvd.broadcast_variables(input_model.variables, root_rank=0)
+ self.hvd.broadcast_variables(optimizer.variables(), root_rank=0)
+ return loss_value
+
+ training_step = training_step if execution_mode=='eager' else tf.function(training_step)
+ if start_epochs is not None and end_epochs is not None:
+ epochs = end_epochs - start_epochs
+
+ for epoch in range(epochs):
+ cnt = 0
+ epoch_loss_avg = tf.keras.metrics.Mean()
+ # Training loop
+ for iter, data in enumerate(dataloader):
+ x, y = postprocess(data) if postprocess is not None else data
+ hooks['on_step_begin'](iter) # on_step_begin hook
+ cnt += 1
+ loss_value = training_step(x, y, iter==0)
+ # Track progress
+ epoch_loss_avg.update_state(loss_value) # Add current batch loss
+ hooks['on_before_optimizer_step']()
+ hooks['on_after_optimizer_step']()
+ if iters is not None and cnt >= iters:
+ break
+ model._sess = None
+ # End epoch
+ train_loss_results.append(epoch_loss_avg.result())
+ if distributed:
+ logger.info("Epoch-{:03d} training on rank {!s} have been done." \
+ .format(epoch+1, self.hvd.allgather_object(self.hvd.rank())))
+ logger.info("Epoch {:03d}: Loss: {:.3f}".format(epoch+1, epoch_loss_avg.result()))
+
+ hooks['on_train_end']() # on_train_end hook
+ model._sess = None
+
+ if distributed:
+ if self.hvd.rank() == 0:
+ # Update the input model with pruned weights manually due to keras API limitation.
+ if isinstance(model._model, tf.keras.Model):
+ model._model = input_model
+ else:
+ input_model.save(model._model)
+ rank_list = self.hvd.allgather_object(self.hvd.rank())
+ logger.info(f"rank 0 has saved the pruned model to '{model._model}',"
+ f"all ranks {rank_list} ready.")
+ else:
+ if isinstance(model._model, tf.keras.Model):
+ model._model = input_model
+ else:
+ input_model.save(model._model)
+
+
class KerasQuery(QueryBackendCapability):
def __init__(self, local_config_file=None):
super().__init__()
diff --git a/neural_compressor/compression/callbacks.py b/neural_compressor/compression/callbacks.py
index 5c6b44fa0c4..94398b1c97f 100644
--- a/neural_compressor/compression/callbacks.py
+++ b/neural_compressor/compression/callbacks.py
@@ -27,13 +27,12 @@
from ..model import BaseModel, Model
from ..model.model import MODELS
from .pruner.utils import process_config, parse_to_prune, get_sparsity_ratio
+from .pruner.utils import parse_to_prune_tf, get_sparsity_ratio_tf
from .pruner.pruners import get_pruner, PRUNERS
-# model auto slim related
-from .pruner.model_slim.pattern_analyzer import SelfMHASearcher
-
LazyImport('torch.nn')
torch = LazyImport('torch')
+tf = LazyImport('tensorflow')
class BaseCallbacks(object):
"""This is base class of Neural Compressor Callbacks.
@@ -225,8 +224,10 @@ def on_train_end(self):
"""Be called after the end of training."""
for on_train_end_hook in self.hooks_dict['on_train_end']:
on_train_end_hook()
- if isinstance(self.model.model, torch.nn.Module):
+ if self.conf.framework == 'pytorch' and isinstance(self.model.model, torch.nn.Module):
get_sparsity_ratio(self.pruners, self.model)
+ elif self.conf.framework == 'keras' and isinstance(self.model.model, tf.keras.Model):
+ get_sparsity_ratio_tf(self.pruners, self.model)
def __repr__(self):
"""Return the class's string representation."""
@@ -241,7 +242,9 @@ def generate_hooks(self):
def _generate_pruners(self):
"""Obtain Pruner objects."""
- if isinstance(self.model.model, torch.nn.Module):
+ if self.conf.framework == 'pytorch' and isinstance(self.model.model, torch.nn.Module):
+ # model auto slim related
+ from .pruner.model_slim.pattern_analyzer import SelfMHASearcher
for info in self.pruners_info:
if 'mha' in info['pattern']:
# head pruning
@@ -262,6 +265,19 @@ def _generate_pruners(self):
info['modules'] = [key for key in modules.keys()]
info['len_of_modules'] = len(info['modules'])
logger.info(info)
+ elif self.conf.framework == 'keras' and isinstance(self.model.model, tf.keras.Model):
+ from tensorflow.python.ops.numpy_ops import np_config
+ np_config.enable_numpy_behavior()
+ for info in self.pruners_info:
+ # original pruning types, e.g NxM or N:M
+ modules = parse_to_prune_tf(info, self.model.model)
+ if modules == {}:
+ logger.warning("one pruner hooks no layers, please have a check")
+
+ self.pruners.append(get_pruner(info, modules, 'keras'))
+ info['modules'] = [key for key in modules.keys()]
+ info['len_of_modules'] = len(info['modules'])
+ logger.info(info)
else:
assert False, 'now only support {}'.format(PRUNERS.keys())
diff --git a/neural_compressor/compression/pruner/criteria.py b/neural_compressor/compression/pruner/criteria.py
index 4b69fe8bd38..25cfcee0693 100644
--- a/neural_compressor/compression/pruner/criteria.py
+++ b/neural_compressor/compression/pruner/criteria.py
@@ -15,7 +15,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from .utils import torch
+
+import numpy as np
+from ...utils.utility import LazyImport
+torch = LazyImport('torch')
+tf = LazyImport('tensorflow')
CRITERIA = {}
@@ -31,12 +35,12 @@ def register(criterion):
return register
-def get_criterion(config, modules):
+def get_criterion(config, modules, framework='pytorch'):
"""Get registered criterion class."""
name = config["criterion_type"]
if name not in CRITERIA.keys():
assert False, f"criteria does not support {name}, currently only support {CRITERIA.keys()}"
- return CRITERIA[name](modules, config)
+ return CRITERIA[name](modules, config, framework)
class PruningCriterion:
@@ -50,11 +54,12 @@ class PruningCriterion:
scores: A dict {"module_name": Tensor} that stores the scores of pruning modules.
"""
- def __init__(self, modules, config):
+ def __init__(self, modules, config, framework='pytorch'):
"""Initiliaze a pruning criterion."""
self.scores = {}
self.modules = modules
self.config = config
+ self.framework=framework
def on_step_begin(self):
"""Calculate and store the pruning scores of pruning modules at the beginning of a step."""
@@ -84,17 +89,21 @@ class MagnitudeCriterion(PruningCriterion):
scores: A dict {"module_name": Tensor} that stores the scores of pruning modules.
"""
- def __init__(self, modules, config):
+ def __init__(self, modules, config, framework='pytorch'):
"""Initiliaze a magnitude pruning criterion."""
- super(MagnitudeCriterion, self).__init__(modules, config)
+ super(MagnitudeCriterion, self).__init__(modules, config, framework)
def on_step_begin(self):
"""Calculate and store the pruning scores based on a magnitude criterion."""
- with torch.no_grad():
+ if self.framework == 'pytorch':
+ with torch.no_grad():
+ for key in self.modules.keys():
+ p = self.modules[key].weight.data
+ self.scores[key] = torch.abs(p)
+ elif self.framework == 'keras':
for key in self.modules.keys():
- p = self.modules[key].weight.data
- self.scores[key] = torch.abs(p)
-
+ p = self.modules[key].get_weights()[0]
+ self.scores[key] = np.abs(p)
@register_criterion('gradient')
class GradientCriterion(PruningCriterion):
@@ -111,13 +120,14 @@ class GradientCriterion(PruningCriterion):
scores: A dict {"module_name": Tensor} that stores the scores of pruning modules.
"""
- def __init__(self, modules, config):
+ def __init__(self, modules, config, framework='pytorch'):
"""Initiliaze a gradient pruning criterion."""
- super(GradientCriterion, self).__init__(modules, config)
+ super(GradientCriterion, self).__init__(modules, config, framework)
assert self.config.end_step > 0, "please set end_step > 0 for gradient based criterion"
def on_before_optimizer_step(self):
"""Calculate and store the pruning scores based on gradient criterion."""
+ assert self.framework != 'keras', "This pruning criterion is not supported by Keras now."
with torch.no_grad():
for key in self.modules.keys():
p = self.modules[key].weight
@@ -141,18 +151,20 @@ class SnipCriterion(PruningCriterion):
scores: A dict {"module_name": Tensor} that stores the scores of pruning modules.
"""
- def __init__(self, modules, config):
+ def __init__(self, modules, config, framework='pytorch'):
"""Initiliaze a snip pruning criterion."""
- super(SnipCriterion, self).__init__(modules, config)
+ super(SnipCriterion, self).__init__(modules, config, framework)
assert self.config.end_step > 0, "please set end_step > 0 for gradient based criterion"
def on_before_optimizer_step(self):
"""Calculate and store the pruning scores based on snip criterion."""
##self.mask_weights()
+ assert self.framework != 'keras', "This pruning criterion is not supported by Keras now."
with torch.no_grad():
for key in self.modules.keys():
p = self.modules[key].weight
self.scores[key] = torch.abs(p * p.grad)
+
@register_criterion('snip_momentum')
@@ -172,9 +184,10 @@ class SnipMomentumCriterion(PruningCriterion):
scores: A dict {"module_name": Tensor} that stores the scores of pruning modules.
"""
- def __init__(self, modules, config):
+ def __init__(self, modules, config, framework='pytorch'):
"""Initiliaze a snip_momentum pruning criterion."""
- super(SnipMomentumCriterion, self).__init__(modules, config)
+ super(SnipMomentumCriterion, self).__init__(modules, config, framework)
+ assert self.framework != 'keras', "This pruning criterion is not supported by Keras now."
assert self.config.end_step > 0, "please set end_step > 0 for gradient based criterion"
for key in modules.keys():
p = modules[key].weight
@@ -209,9 +222,10 @@ class SnipMomentumBlockCriterion(PruningCriterion):
scores: A dict {"module_name": Tensor} that stores the scores of pruning modules.
"""
- def __init__(self, modules, config):
+ def __init__(self, modules, config, framework='pytorch'):
"""Initiliaze a block_mask pruning criterion."""
- super(SnipMomentumBlockCriterion, self).__init__(modules, config)
+ super(SnipMomentumBlockCriterion, self).__init__(modules, config, framework)
+ assert self.framework != 'keras', "This pruning criterion is not supported by Keras now."
assert self.config.end_step > 0, "please set end_step > 0 for gradient based criterion"
for key in self.modules.keys():
if not hasattr(self.modules[key], 'block_mask'):
@@ -248,9 +262,10 @@ class RetrainFreeCriterion(PruningCriterion):
scores: A dict {"module_name": Tensor} that stores the scores of pruning modules.
"""
- def __init__(self, modules, config):
+ def __init__(self, modules, config, framework='pytorch'):
"""Initiliaze a block_mask pruning criterion."""
- super(RetrainFreeCriterion, self).__init__(modules, config)
+ super(RetrainFreeCriterion, self).__init__(modules, config, framework)
+ assert self.framework != 'keras', "This pruning criterion is not supported by Keras now."
assert self.config.end_step > 0, "please set end_step > 0 for gradient based criterion"
self.collected_grads = {}
for key in self.modules.keys():
diff --git a/neural_compressor/compression/pruner/model_slim/auto_slim.py b/neural_compressor/compression/pruner/model_slim/auto_slim.py
index f85487250f6..0e4b6ccd09f 100644
--- a/neural_compressor/compression/pruner/model_slim/auto_slim.py
+++ b/neural_compressor/compression/pruner/model_slim/auto_slim.py
@@ -16,8 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# model slim related
-from .pattern_analyzer import Linear2LinearSearcher, RecipeSearcher, SelfMHASearcher
-from .weight_slim import LinearCompressionIterator, MHACompression
+
from ..utils import logger
def model_slim(model, dataloader=None, round_multiplier=32):
@@ -39,6 +38,8 @@ def model_slim_ffn2(model, dataloader = None, round_multiplier=32):
model: a sprase model.
round_multiplier(int): the channel number after slimming should be multiple of this number.
"""
+ from .pattern_analyzer import Linear2LinearSearcher
+ from .weight_slim import LinearCompressionIterator
logger.warning(f"You are using model slim methods, some weight channels will be removed permanently.")
pa_obj = Linear2LinearSearcher(model, dataloader)
layers = pa_obj.search()
@@ -53,6 +54,8 @@ def model_slim_mha(model, dataloader = None):
Args:
model: a sprase model.
"""
+ from .weight_slim import MHACompression
+ from .pattern_analyzer import SelfMHASearcher
logger.warning(f"You are using model slim methods, some attention heads will be removed permanently.")
pa_obj = SelfMHASearcher(model, dataloader)
layers, _ = pa_obj.search(split_qkv_ffn = False)
@@ -75,6 +78,7 @@ def parse_auto_slim_config(model, dataloader = None, ffn2_sparsity = .0, mha_spa
def generate_ffn2_pruning_config(model, dataloader, ffn2_sparsity, **kwargs):
"""Get consecutive linear layers pruning configs."""
+ from .pattern_analyzer import Linear2LinearSearcher
searcher = Linear2LinearSearcher(model, dataloader)
layers = searcher.search()
# extract the second linear layer
@@ -106,6 +110,7 @@ def generate_mha_pruning_config(model, dataloader, mha_sparsity, **kwargs):
return mha_pruning_config
# method 2: apply experimental mha pruning
+ # from .pattern_analyzer import SelfMHASearcher
# searcher = SelfMHASearcher(model, dataloader)
# qkv_pattern, ffn_pattern = searcher.get_head_pattern()
# qkv_layers, ffn_layers = searcher.search()
diff --git a/neural_compressor/compression/pruner/model_slim/pattern_analyzer.py b/neural_compressor/compression/pruner/model_slim/pattern_analyzer.py
index 91b673fab09..7df005fbbc6 100644
--- a/neural_compressor/compression/pruner/model_slim/pattern_analyzer.py
+++ b/neural_compressor/compression/pruner/model_slim/pattern_analyzer.py
@@ -16,8 +16,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ..utils import torch, logger
+from ..utils import logger
import re
+from ....utils.utility import LazyImport
+torch = LazyImport('torch')
+tf = LazyImport('tensorflow')
JIT_SUPPORT_OPS = ['linear', 'dropout', 'gelu', 'silu', 'relu', 'mul', 'add']
@@ -792,3 +795,42 @@ def search(self, return_name=True):
last_lc = all_lc_modules[-1]
if last_lc == all_modules[-1]: return last_lc
else: return None
+
+class ClassifierHeadSearcherTF(object):
+ """Static graph searcher for multi-head attention modules.
+
+ Use the static graph to detect final classifier head in a module, there is no need for user to define layer name.
+ Automatically search multi-head attention modules which can be optimized.
+
+ Args:
+ model (tf.keras.Model): The Keras model for searching.
+
+ Attributes:
+ model: The Keras model for searching.
+ device: The model's current device type.
+ static_graph: The static graph of original model.
+ flatten_static_graph: A list of string with the model's static graph inference details.
+ """
+
+ def __init__(self, model):
+ """Initialize."""
+ assert isinstance(model, tf.keras.Model)
+ super(ClassifierHeadSearcherTF, self).__init__()
+ self.model = model
+ self.pruning_ops = ["Dense", "Conv2d"]
+ self.excluded_ops = ["Dropout"] # to be extended
+
+ def search(self, return_name=True):
+ all_modules = []
+ all_lc_modules = []
+ for layer in self.model.layers:
+ if layer.__class__.__name__ not in self.excluded_ops:
+ all_modules.append(layer.name)
+ if layer.__class__.__name__ in self.pruning_ops:
+ all_lc_modules.append(layer.name)
+ else:
+ continue
+ last_lc = all_lc_modules[-1]
+ if last_lc == all_modules[-1]:
+ return last_lc
+ return None
\ No newline at end of file
diff --git a/neural_compressor/compression/pruner/patterns.py b/neural_compressor/compression/pruner/patterns.py
index 3b68f36b695..96e17a4f99a 100644
--- a/neural_compressor/compression/pruner/patterns.py
+++ b/neural_compressor/compression/pruner/patterns.py
@@ -16,8 +16,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from .utils import torch, logger
+import numpy as np
+from .utils import logger
from collections import namedtuple
+from ...utils.utility import LazyImport
+torch = LazyImport('torch')
+tf = LazyImport('tensorflow')
PATTERNS = {}
@@ -42,7 +46,7 @@ def register(pattern):
return register
-def get_pattern(config, modules):
+def get_pattern(config, modules, framework='pytorch'):
"""Get registered pattern class.
Get a Pattern object from PATTERNS.
@@ -60,7 +64,7 @@ def get_pattern(config, modules):
name = config.pattern
name = name.split('_')[-1]
if "x" in name:
- return PATTERNS["NxM"](config, modules)
+ return PATTERNS["NxM"](config, modules, framework)
if ":" in name:
return PATTERNS["N:M"](config, modules)
if "mha" in name:
@@ -318,7 +322,7 @@ class BasePattern:
target_sparsity: A float representing the sparsity ratio of the modules after pruning.
"""
- def __init__(self, config, modules):
+ def __init__(self, config, modules, framework='pytorch'):
"""Initialize the basic pruning unit of a pattern."""
self.pattern = config.pattern
self.is_global = config.pruning_scope == "global"
@@ -326,12 +330,14 @@ def __init__(self, config, modules):
self.invalid_layers = []
self.modules = modules
self.config = config
+ self.framework = framework
self.max_sparsity_ratio_per_op = self.config['max_sparsity_ratio_per_op']
self.min_sparsity_ratio_per_op = self.config['min_sparsity_ratio_per_op']
self.target_sparsity_ratio = self.config['target_sparsity']
self.block = bool('block' in self.config['pruning_type'] or 'free' in self.config['pruning_type'])
# Not using deterministic_algorithms for all examples
- torch.use_deterministic_algorithms(False)
+ if self.framework == 'pytorch':
+ torch.use_deterministic_algorithms(False)
def reduce_tensor(self, data, dim):
"""Reduce the data along the given dimension.
@@ -345,11 +351,20 @@ def reduce_tensor(self, data, dim):
"""
name = self.config['criterion_reduce_type']
if name == "mean":
- return torch.mean(data, dim=dim)
+ if self.framework == 'pytorch':
+ return torch.mean(data, dim=dim)
+ elif self.framework == 'keras':
+ return tf.math.reduce_mean(data, dim)
elif name == "sum":
- return torch.sum(data, dim=dim)
+ if self.framework == 'pytorch':
+ return torch.sum(data, dim=dim)
+ elif self.framework == 'keras':
+ return tf.math.reduce_sum(data, dim)
elif name == "max":
- return torch.max(data, dim=dim)[0]
+ if self.framework == 'pytorch':
+ return torch.max(data, dim=dim)[0]
+ elif self.framework == 'keras':
+ return tf.math.reduce_max(data, dim)
else:
assert False, "currently only support mean, sum and max reduce type"
@@ -406,15 +421,27 @@ def get_single_mask_per_target_ratio(self, score, exact_sparsity_ratio):
Returns:
A Tensor with the identical size as score. a new mask.
"""
- flattern_score = torch.flatten(score)
- k = int(exact_sparsity_ratio * flattern_score.numel())
- threshold, _ = torch.kthvalue(flattern_score, k)
- if not k < 1:
- zero = torch.tensor([0.]).to(score.device)
- one = torch.tensor([1.]).to(score.device)
- mask = torch.where(score <= threshold, zero, one)
- else:
- mask = torch.ones(score.shape, device=score.device)
+ if self.framework == 'pytorch':
+ flattern_score = torch.flatten(score)
+ k = int(exact_sparsity_ratio * flattern_score.numel())
+ threshold, _ = torch.kthvalue(flattern_score, k)
+ if not k < 1:
+ zero = torch.tensor([0.]).to(score.device)
+ one = torch.tensor([1.]).to(score.device)
+ mask = torch.where(score <= threshold, zero, one)
+ else:
+ mask = torch.ones(score.shape, device=score.device)
+ elif self.framework == 'keras':
+ flattern_score = tf.reshape(score, [-1]).numpy()
+ k = int(exact_sparsity_ratio * flattern_score.size)
+ threshold = np.partition(flattern_score, kth=k)[k]
+ if not k < 1:
+ zero = tf.convert_to_tensor([0.])
+ one = tf.convert_to_tensor([1.])
+ mask = tf.where(score <= threshold, zero, one)
+ else:
+ mask = tf.ones_like(score.shape)
+
return mask
def get_block_size_dict(self, data):
@@ -446,6 +473,7 @@ def get_sparsity_ratio(self, pre_masks, return_dict=False):
pre_mask = pre_masks[key]
zero_cnt += torch.sum(pre_mask == 0.0).data.item()
total_cnt += pre_mask.numel() ##FIXME
+
if return_dict:
return {"sparsity_ratio": float(zero_cnt) / total_cnt, "zero_cnt": zero_cnt, "total_cnt": total_cnt}
else:
@@ -469,6 +497,7 @@ def get_sparsity_ratio_progressive(self, pre_masks, return_dict=False):
# progressive masks are unstructured, therefore directly find zeros
zero_cnt += float(torch.sum(pre_masks[key] == 0).data.item())
total_cnt += float(pre_masks[key].numel())
+
return (zero_cnt / total_cnt)
def get_pattern_lock_masks(self, modules):
@@ -487,6 +516,7 @@ def get_pattern_lock_masks(self, modules):
mask = torch.ones(shape)
mask[weight == 0] = 0.0
pattern_lock_masks[key] = mask.to(weight.device)
+
return pattern_lock_masks
def check_layer_validity(self):
@@ -529,17 +559,34 @@ def get_sparsity_ratio_each_layer(self, masks):
infos = {}
zero_cnts = 0
total_cnts = 0
- for key in masks.keys():
- if key in self.invalid_layers:
- continue
- reduced_mask = masks[key] if self.block else self.get_reduced_masks_from_data(masks[key], key)
- zero_cnt = (int(torch.sum(reduced_mask == 0.0).data.item()))
- total_cnt = int(reduced_mask.numel())
- sparsity_ratio = float(zero_cnt) / total_cnt
- val = SparsityInfo(zero_cnt, total_cnt, sparsity_ratio)
- infos[key] = val
- zero_cnts += zero_cnt
- total_cnts += total_cnt
+
+ if self.framework == 'pytorch':
+ for key in masks.keys():
+ if key in self.invalid_layers:
+ continue
+ reduced_mask = masks[key] if self.block else self.get_reduced_masks_from_data(masks[key], key)
+ zero_cnt = (int(torch.sum(reduced_mask == 0.0).data.item()))
+ total_cnt = int(reduced_mask.numel())
+ sparsity_ratio = float(zero_cnt) / total_cnt
+ val = SparsityInfo(zero_cnt, total_cnt, sparsity_ratio)
+ infos[key] = val
+ zero_cnts += zero_cnt
+ total_cnts += total_cnt
+ elif self.framework == 'keras':
+ for key in masks.keys():
+ if key in self.invalid_layers:
+ continue
+ if not isinstance(masks[key], np.ndarray):
+ masks[key] = masks[key].numpy()
+ reduced_mask = masks[key] if self.block else self.get_reduced_masks_from_data(masks[key], key)
+ zero_cnt = int(np.sum(reduced_mask == 0.0))
+ total_cnt = int(reduced_mask.size)
+ sparsity_ratio = float(zero_cnt) / total_cnt
+ val = SparsityInfo(zero_cnt, total_cnt, sparsity_ratio)
+ infos[key] = val
+ zero_cnts += zero_cnt
+ total_cnts += total_cnt
+
sparsity_ratio = float(zero_cnts) / total_cnts
return infos, SparsityInfo(zero_cnts, total_cnts, sparsity_ratio)
@@ -629,9 +676,9 @@ class PatternNxM(BasePattern):
because PyTorch's tensor matmul has a hidden transpose operation.
"""
- def __init__(self, config, modules):
+ def __init__(self, config, modules, framework='pytorch'):
"""Initialize the basic pruning unit of NXM pattern."""
- super(PatternNxM, self).__init__(config, modules)
+ super(PatternNxM, self).__init__(config, modules, framework)
pattern = self.pattern.split('_')[-1]
self.N = pattern.split('x')[0]
self.M = pattern.split('x')[1]
@@ -662,7 +709,7 @@ def get_block_size_dict(self):
block_sizes_dict[key] = self.block_size
if not (self.N == "channel" or self.M == "channel"):
continue
- if isinstance(datas[key], torch.nn.Module):
+ if self.framework == 'pytorch' and isinstance(datas[key], torch.nn.Module):
shape = datas[key].weight.shape
else:
shape = datas[key].shape
@@ -677,14 +724,24 @@ def check_layer_validity(self):
"""Check if a layer is valid for this block_size."""
block_sizes = self.block_size
datas = self.modules
- for key in datas.keys():
- data = datas[key].weight
- data = self._reshape_orig_to_2dims(data)
- shape = data.shape
- block_size = block_sizes[key]
- if shape[0] % block_size[0] != 0 or shape[1] % block_size[1] != 0: ## only consider input channel
- self.invalid_layers.append(key)
- logger.warning(f"{key} shape {data.shape} cannot be divided by {self.pattern}")
+ if self.framework == 'pytorch':
+ for key in datas.keys():
+ data = datas[key].weight
+ data = self._reshape_orig_to_2dims(data)
+ shape = data.shape
+ block_size = block_sizes[key]
+ if shape[0] % block_size[0] != 0 or shape[1] % block_size[1] != 0: ## only consider input channel
+ self.invalid_layers.append(key)
+ logger.warning(f"{key} shape {data.shape} cannot be divided by {self.pattern}")
+ elif self.framework == 'keras':
+ for key in datas.keys():
+ data = datas[key].get_weights()[0]
+ data = self._reshape_orig_to_2dims(data)
+ shape = data.shape
+ block_size = block_sizes[key]
+ if shape[0] % block_size[0] != 0 or shape[1] % block_size[1] != 0: ## only consider input channel
+ self.invalid_layers.append(key)
+ logger.warning(f"{key} shape {data.shape} cannot be divided by {self.pattern}")
def get_reduced_masks_from_data(self, data, key):
"""Obtain the unpruned weights and reshape according to the block_size.
@@ -718,12 +775,22 @@ def get_sparsity_ratio(self, pre_masks, return_dict=False):
"""
zero_cnt = 0
total_cnt = 0
- for key in pre_masks.keys():
- if key in self.invalid_layers:
- continue
- reduced_mask = pre_masks[key] if self.block else self.get_reduced_masks_from_data(pre_masks[key], key)
- zero_cnt += (int(torch.sum(reduced_mask == 0.0).data.item()))
- total_cnt += int(reduced_mask.numel())
+ if self.framework == 'pytorch':
+ for key in pre_masks.keys():
+ if key in self.invalid_layers:
+ continue
+ reduced_mask = pre_masks[key] if self.block else self.get_reduced_masks_from_data(pre_masks[key], key)
+ zero_cnt += (int(torch.sum(reduced_mask == 0.0).data.item()))
+ total_cnt += int(reduced_mask.numel())
+ elif self.framework == 'keras':
+ for key in pre_masks.keys():
+ if key in self.invalid_layers:
+ continue
+ if not isinstance(pre_masks[key], np.ndarray):
+ pre_masks[key] = pre_masks[key].numpy()
+ reduced_mask = pre_masks[key] if self.block else self.get_reduced_masks_from_data(pre_masks[key], key)
+ zero_cnt += int(np.sum(reduced_mask == 0.0))
+ total_cnt += int(reduced_mask.size)
if total_cnt == 0:
sparsity_ratio = 0.0
else:
@@ -744,7 +811,10 @@ def _reshape_orig_to_2dims(self, data):
"""
##TODO need to verify whether it's ok for transposed conv
if len(data.shape) == 4:
- data = data.permute(0, 2, 3, 1) ##cout,k,k,cin
+ if isinstance(data, np.ndarray):
+ data = np.transpose(data, (0, 2, 3, 1))
+ else:
+ data = data.permute(0, 2, 3, 1) ##cout,k,k,cin
data = data.reshape(data.shape[0], -1)
return data
@@ -761,7 +831,10 @@ def _reshape_2dims_to_orig(self, data, orig_shape):
if len(orig_shape) == 4:
data = data.reshape(orig_shape[0], orig_shape[2], orig_shape[3],
orig_shape[1])
- data = data.permute(0, 3, 1, 2)
+ if isinstance(data, np.ndarray):
+ data = np.transpose(data, (0, 3, 1, 2))
+ else:
+ data = data.permute(0, 3, 1, 2)
return data
def reshape_orig_to_pattern(self, data, key):
@@ -823,11 +896,20 @@ def reduce_scores(self, scores):
def get_mask_per_threshold(self, score, threshold, block_size):
"""Get the mask per threshold."""
- zero = torch.tensor([0.]).to(score.device)
- one = torch.tensor([1.]).to(score.device)
- mask = torch.where(score <= threshold, zero, one)
- if not self.block:
- mask = mask.repeat_interleave(block_size[0], dim=0).repeat_interleave(block_size[1], dim=-1)
+ if self.framework == 'pytorch':
+ zero = torch.tensor([0.]).to(score.device)
+ one = torch.tensor([1.]).to(score.device)
+ mask = torch.where(score <= threshold, zero, one)
+ if not self.block:
+ mask = mask.repeat_interleave(block_size[0], dim=0).repeat_interleave(block_size[1], dim=-1)
+ elif self.framework == 'keras':
+ zero = tf.convert_to_tensor([0.])
+ one = tf.convert_to_tensor([1.])
+ mask = tf.where(score <= threshold, zero, one)
+ if not self.block:
+ mask = tf.repeat(mask, repeats=block_size[0], axis=0)
+ mask = tf.repeat(mask, repeats=block_size[1], axis=-1)
+ mask = mask.numpy()
return mask
def get_masks_global(self, scores, cur_target_sparsity_ratio, pre_masks,
@@ -837,6 +919,30 @@ def get_masks_global(self, scores, cur_target_sparsity_ratio, pre_masks,
Gather all layer's scores together and calculate a common threshold.
This threshold will be applied to all layers.
+ Args:
+ scores: A dict{"layer_name": Tensor} that stores the pruning scores of weights.
+ cur_target_sparsity_ratio: A float representing the model's sparsity after pruning.
+ pre_masks: A dict{"layer_name": Tensor} that stores the masks generated at the last pruning step.
+ max_sparsity_ratio_per_op: A float representing the maximum sparsity that one layer can reach.
+ keep_pre_masks: A bool representing if the masks should remain unchanged.
+
+ Returns:
+ A dict with the identical size as pre_masks and its 0/1 values are updated.
+ 1 means unpruned and 0 means pruned.
+ """
+ if self.framework == 'pytorch':
+ return self.get_masks_global_pytorch(scores, cur_target_sparsity_ratio, pre_masks, \
+ keep_exact_sparsity_ratio)
+ elif self.framework == 'keras':
+ return self.get_masks_global_tf(scores, cur_target_sparsity_ratio, pre_masks, keep_exact_sparsity_ratio)
+
+ def get_masks_global_pytorch(self, scores, cur_target_sparsity_ratio, pre_masks,
+ keep_exact_sparsity_ratio=True):
+ """Generate masks for layers.
+
+ Gather all layer's scores together and calculate a common threshold.
+ This threshold will be applied to all layers.
+
Args:
scores: A dict{"layer_name": Tensor} that stores the pruning scores of weights.
cur_target_sparsity_ratio: A float representing the model's sparsity after pruning.
@@ -907,6 +1013,84 @@ def get_masks_global(self, scores, cur_target_sparsity_ratio, pre_masks,
logger.info(f'{key} sparsity is {layer_ratio}')
return masks
+ def get_masks_global_tf(self, scores, cur_target_sparsity_ratio, pre_masks,
+ keep_exact_sparsity_ratio=True):
+ """Generate masks for layers.
+
+ Gather all layer's scores together and calculate a common threshold.
+ This threshold will be applied to all layers.
+
+ Args:
+ scores: A dict{"layer_name": Tensor} that stores the pruning scores of weights.
+ cur_target_sparsity_ratio: A float representing the model's sparsity after pruning.
+ pre_masks: A dict{"layer_name": Tensor} that stores the masks generated at the last pruning step.
+ max_sparsity_ratio_per_op: A float representing the maximum sparsity that one layer can reach.
+ keep_pre_masks: A bool representing if the masks should remain unchanged.
+
+ Returns:
+ A dict with the identical size as pre_masks and its 0/1 values are updated.
+ 1 means unpruned and 0 means pruned.
+ """
+ ##keep the masks if the layer exceed max sparsity ratio
+
+ masks = pre_masks
+ k_blockwise = self.update_residual_cnt(masks, cur_target_sparsity_ratio)
+ if k_blockwise <= 0:
+ return masks
+ new_scores = scores if self.block else self.reduce_scores(scores)
+ not_exceed_layers = []
+ residual_k = k_blockwise
+ if self.min_sparsity_ratio_per_op > 0:
+ sparsity_infos_perlayer, _ = self.get_sparsity_ratio_each_layer(masks)
+
+ while True:
+ new_not_exceed_layers = [key for key in new_scores.keys() if not self.keep_mask_layers.get(key, False)]
+ if not_exceed_layers == new_not_exceed_layers or len(new_not_exceed_layers) == 0:
+ break
+ not_exceed_layers = new_not_exceed_layers
+ global_scores = np.concatenate([tf.reshape(new_scores[key], [-1]).numpy() for key in not_exceed_layers])
+ threshold = np.partition(global_scores, kth=residual_k)[residual_k]
+
+ for key in not_exceed_layers:
+ block_size = self.block_size[key]
+ score = new_scores[key]
+ mask = self.get_mask_per_threshold(score, threshold, block_size)
+ info = self.get_sparsity_ratio({key: mask}, return_dict=True)
+ zero_cnt = info["zero_cnt"]
+ total_cnt = info["total_cnt"]
+ current_sparsity_ratio = float(zero_cnt) / total_cnt
+ key_new_sparsity = SparsityInfo(zero_cnt, total_cnt, current_sparsity_ratio)
+ need_adjust, adjust_ratio = self.adjust_ratio(masks, key, key_new_sparsity,
+ self.max_sparsity_ratio_per_op,
+ self.min_sparsity_ratio_per_op,
+ self.target_sparsity_ratio)
+ if need_adjust:
+ # uptade status
+ self.keep_mask_layers[key] = True
+ masks[key] = self.get_single_mask_per_target_ratio(new_scores[key], adjust_ratio)
+ if not self.block:
+ masks[key] = tf.repeat(masks[key], repeats=block_size[0], axis=0)
+ masks[key] = tf.repeat(masks[key], repeats=block_size[1], axis=-1)
+ if keep_exact_sparsity_ratio:
+ zero_cnt = self.get_sparsity_ratio({key: masks[key]}, return_dict=True)["zero_cnt"]
+ residual_k -= zero_cnt
+ else:
+ masks[key] = mask
+ if not keep_exact_sparsity_ratio:
+ break
+
+ for key in masks.keys():
+ if key in self.invalid_layers:
+ continue
+ if len(scores[key].shape) == 4: ## need to permute
+ mask = masks[key]
+ orig_shape = scores[key].shape
+ mask = self._reshape_2dims_to_orig(mask, orig_shape)
+ masks[key] = mask
+ layer_ratio = np.sum(masks[key] == 0.0) / masks[key].size
+ logger.info(f'{key} sparsity is {layer_ratio}')
+ return masks
+
def get_pattern_lock_masks(self, modules):
"""Obtain masks from original weight map by masking the zero-valued weights.
@@ -927,6 +1111,7 @@ def get_pattern_lock_masks(self, modules):
reduced_mask = self.get_reduced_masks_from_data(weight, key)
mask = self.reshape_reduced_to_orig(reduced_mask, key, ori_shape)
pattern_lock_masks[key] = mask
+
return pattern_lock_masks
def register_block_masks(self, modules):
@@ -951,6 +1136,7 @@ def register_block_masks(self, modules):
block_mask = torch.nn.Parameter(self.get_reduced_masks_from_data(weight, key).to(dtype=weight.dtype))
module.register_parameter("block_mask", block_mask)
masks[key] = modules[key].block_mask.data
+
return masks
def remove_block_masks(self):
@@ -1011,9 +1197,9 @@ class PatternNInM(BasePattern):
M: The size of the weight sequence.
"""
- def __init__(self, config, modules):
+ def __init__(self, config, modules, framework='pytorch'):
"""Initialize the basic pruning unit of N:M pattern."""
- super(PatternNInM, self).__init__(config, modules)
+ super(PatternNInM, self).__init__(config, modules, framework)
pattern = self.pattern.split('_')[-1]
self.N = int(pattern.split(':')[0])
self.M = int(pattern.split(':')[1]) ##m is bigger
@@ -1345,7 +1531,8 @@ class PatternMHA(BasePattern):
M: The size of the weight sequence.
"""
- def __init__(self, config, modules = None):
+ def __init__(self, config, modules = None, framework='pytorch'):
+ self.framework = framework
self.is_global = config.pruning_scope == "global"
# only implement three method: get_masks, get_masks_local, get_masks_global
diff --git a/neural_compressor/compression/pruner/pruners.py b/neural_compressor/compression/pruner/pruners.py
index 36e53ace1b0..5ac68944de9 100644
--- a/neural_compressor/compression/pruner/pruners.py
+++ b/neural_compressor/compression/pruner/pruners.py
@@ -16,18 +16,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import copy
import re
-from .utils import torch, F
+import copy
+import numpy as np
from functools import partial
from .patterns import get_pattern
from .schedulers import get_scheduler
from .criteria import get_criterion, CRITERIA
from .regs import get_reg
from .utils import logger
-# auto slim related: head pruning objects
-from .model_slim.pattern_analyzer import SelfMHASearcher
-from .model_slim.weight_slim import MHACompression
+
+from ...utils.utility import LazyImport
+torch = LazyImport('torch')
+tf = LazyImport('tensorflow')
+F = LazyImport('torch.nn.functional')
PRUNERS = {}
@@ -62,7 +64,7 @@ def parse_valid_pruner_types():
valid_pruner_types.append("pattern_lock")
return valid_pruner_types
-def get_pruner(config, modules):
+def get_pruner(config, modules, framework='pytorch'):
"""Get registered pruner class.
Get a Pruner object from PRUNERS.
@@ -102,7 +104,7 @@ def get_pruner(config, modules):
if name not in PRUNERS.keys():
assert False, f"does not support {name}, currently only support {parse_valid_pruner_types()}"
- return PRUNERS[name](config, modules)
+ return PRUNERS[name](config, modules, framework)
class BasePruner:
"""Pruning Pruner.
@@ -131,10 +133,11 @@ class BasePruner:
max_sparsity_ratio_per_op: A float showing the maximum sparsity ratio for every module.
"""
- def __init__(self, config, modules):
+ def __init__(self, config, modules, framework='pytorch'):
"""Initialize."""
self.modules = modules
self.config = config
+ self.framework = framework
self.masks = {}
self.global_step = 0
self.handled_global_step = -1
@@ -150,10 +153,15 @@ def __init__(self, config, modules):
self.total_prune_cnt = 1
self.completed_pruned_cnt = 1
- for key in self.modules.keys():
- module = self.modules[key]
- self.masks[key] = torch.ones(module.weight.shape).to(module.weight.device) ##TODO support bias or others
-
+ if self.framework == 'pytorch':
+ for key in self.modules.keys():
+ module = self.modules[key]
+ ##TODO support bias or others
+ self.masks[key] = torch.ones(module.weight.shape).to(module.weight.device)
+ elif self.framework == 'keras':
+ for key in self.modules.keys():
+ module = self.modules[key]
+ self.masks[key] = np.ones(module.get_weights()[0].shape)
self.target_sparsity_ratio = self.config['target_sparsity']
self.current_sparsity_ratio = 0.0
self.init_sparsity_ratio = 0.0
@@ -172,10 +180,15 @@ def mask_weights(self):
Weights are multipled with masks. This is the formal pruning process.
"""
- with torch.no_grad():
+ if self.framework == 'pytorch':
+ with torch.no_grad():
+ for key in self.modules.keys():
+ module = self.modules[key]
+ module.weight.data = module.weight.data * self.masks[key]
+ elif self.framework == 'keras':
for key in self.modules.keys():
module = self.modules[key]
- module.weight.data = module.weight.data * self.masks[key]
+ module.set_weights([module.get_weights()[0] * self.masks[key]] + module.get_weights()[1:])
def mask_weights_general(self, input_masks):
"""Apply input masks to corresponding modules' weights.
@@ -185,10 +198,15 @@ def mask_weights_general(self, input_masks):
Args:
input_masks: A dict {"module_name": Tensor} that stores the masks for modules' weights.
"""
- with torch.no_grad():
+ if self.framework == 'pytorch':
+ with torch.no_grad():
+ for key in self.modules.keys():
+ module = self.modules[key]
+ module.weight.data = module.weight.data * input_masks[key]
+ elif self.framework == 'keras':
for key in self.modules.keys():
module = self.modules[key]
- module.weight.data = module.weight.data * input_masks[key]
+ module.set_weights([module.get_weights()[0] * input_masks[key]] + module.get_weights()[1:])
def on_step_begin(self, local_step):
"""Implement at the start of each step."""
@@ -254,6 +272,7 @@ def check_is_pruned_step(self, step):
def rewrite_forward(self):
"""Rewrite forward to implement block mask operation"""
+ assert self.framework != 'keras', "This pruning method is not supported by Keras now."
def forward(self, input):
block_size = [self.weight.shape[0]//self.block_mask.shape[0], \
self.weight.shape[1]//self.block_mask.shape[1]]
@@ -269,6 +288,7 @@ def forward(self, input):
def recover_forward(self):
"""Restore the forward format at the end of pruning"""
+ assert self.framework != 'keras', "This pruning method is not supported by Keras now."
with torch.no_grad():
for key in self.modules.keys():
if not hasattr(self.modules[key], 'block_mask'):
@@ -296,18 +316,18 @@ class BasicPruner(BasePruner):
reg: A Reg object that defines regulization terms.
"""
- def __init__(self, config, modules):
+ def __init__(self, config, modules, framework='pytorch'):
"""Initialize."""
# self.modules = modules
# self.config = config
# self.masks = {}
- super(BasicPruner, self).__init__(config, modules)
+ super(BasicPruner, self).__init__(config, modules, framework)
def _init(self):
"""Auxiliary function for initializing."""
- self.pattern = get_pattern(self.config, self.modules)
+ self.pattern = get_pattern(self.config, self.modules, self.framework)
self.scheduler = get_scheduler(self.config)
- self.criterion = get_criterion(self.config, self.modules)
+ self.criterion = get_criterion(self.config, self.modules, self.framework)
self.reg = get_reg(self.config, self.modules, self.pattern)
# if switch off progressive but use per-channel pruning, give a warn
if "channel" in self.pattern.pattern:
@@ -388,9 +408,9 @@ class PatternLockPruner(BasePruner):
Inherit from parent class Pruner.
"""
- def __init__(self, config, modules):
+ def __init__(self, config, modules, framework='pytorch'):
"""Initialize."""
- super(PatternLockPruner, self).__init__(config, modules)
+ super(PatternLockPruner, self).__init__(config, modules, framework)
self.pattern = get_pattern(self.config, modules)
assert self.config.end_step == self.config.start_step, "pattern_lock pruner only supports one shot mode"
@@ -428,17 +448,17 @@ class BlockMaskPruner(BasePruner):
scheduler: A Scheduler object that defines how the model's sparsity changes as training/pruning proceeds.
reg: A Reg object that defines regulization terms.
"""
- def __init__(self, config, modules):
+ def __init__(self, config, modules, framework='pytorch'):
"""Initialize."""
- super(BlockMaskPruner, self).__init__(config, modules)
+ super(BlockMaskPruner, self).__init__(config, modules, framework)
def _init(self):
"""Initialize."""
- self.pattern = get_pattern(self.config, self.modules)
+ self.pattern = get_pattern(self.config, self.modules, self.framework)
self.masks = self.pattern.register_block_masks(self.modules)
self.rewrite_forward()
self.scheduler = get_scheduler(self.config)
- self.criterion = get_criterion(self.config, self.modules)
+ self.criterion = get_criterion(self.config, self.modules, self.framework)
self.reg = get_reg(self.config, self.modules, self.pattern)
if "channel" not in self.pattern.pattern:
@@ -553,17 +573,17 @@ class RetrainFreePruner(BasePruner):
scheduler: A Scheduler object that defines how the model's sparsity changes as training/pruning proceeds.
reg: A Reg object that defines regulization terms.
"""
- def __init__(self, config, modules):
+ def __init__(self, config, modules, framework='pytorch'):
"""Initialize."""
- super(RetrainFreePruner, self).__init__(config, modules)
+ super(RetrainFreePruner, self).__init__(config, modules, framework)
def _init(self):
"""Initialize."""
- self.pattern = get_pattern(self.config, self.modules)
+ self.pattern = get_pattern(self.config, self.modules, self.framework)
self.masks = self.pattern.register_block_masks(self.modules)
self.rewrite_forward()
self.scheduler = get_scheduler(self.config)
- self.criterion = get_criterion(self.config, self.modules)
+ self.criterion = get_criterion(self.config, self.modules, self.framework)
self.reg = get_reg(self.config, self.modules, self.pattern)
logger.warning("Retrain-free pruner fixed the weights, please DO NOT turn on gradient update.")
@@ -703,15 +723,15 @@ class ProgressivePruner(BasicPruner):
Inherit from parent class Pruner.
"""
- def __init__(self, config, modules):
+ def __init__(self, config, modules, framework='pytorch'):
"""Initialize."""
- super(ProgressivePruner, self).__init__(config, modules)
+ super(ProgressivePruner, self).__init__(config, modules, framework)
def _init(self):
"""Auxiliary function for initialization."""
- self.pattern = get_pattern(self.config, self.modules)
+ self.pattern = get_pattern(self.config, self.modules, self.framework)
self.scheduler = get_scheduler(self.config)
- self.criterion = get_criterion(self.config, self.modules)
+ self.criterion = get_criterion(self.config, self.modules, self.framework)
self.reg = get_reg(self.config, self.modules, self.pattern)
# progressive pruning set up, including check up paramters.
self.use_progressive = self.config["progressive"]
@@ -992,6 +1012,8 @@ def _init_mha_attrs(self):
# similar to original mha slim process, but only hook mha modules and their attributes,
# do not call slim main functions.
"""
+ # auto slim related: head pruning objects
+ from .model_slim.weight_slim import MHACompression
for mha_module in self.mha_modules:
# initialize self.mha_compressions
mha_comp = MHACompression(mha_module)
diff --git a/neural_compressor/compression/pruner/utils.py b/neural_compressor/compression/pruner/utils.py
index 34595ab0260..8945998393f 100644
--- a/neural_compressor/compression/pruner/utils.py
+++ b/neural_compressor/compression/pruner/utils.py
@@ -18,6 +18,7 @@
import re
import yaml
+import numpy as np
from ...config import WeightPruningConfig as WeightPruningConf
try:
@@ -29,6 +30,7 @@
from neural_compressor.conf.config import Pruner
LazyImport('torch.nn')
torch = LazyImport('torch')
+ tf = LazyImport('tensorflow')
F = LazyImport('torch.nn.functional')
except:
import torch
@@ -141,6 +143,62 @@ def get_sparsity_ratio(pruners, model):
return elementwise_over_matmul_gemm_conv, elementwise_over_all, blockwise_over_matmul_gemm_conv
+def get_sparsity_ratio_tf(pruners, model):
+ """Calculate sparsity ratio of a module/layer.
+
+ Returns:
+ Three floats.
+ elementwise_over_matmul_gemm_conv refers to zero elements' ratio in pruning layers.
+ elementwise_over_all refers to zero elements' ratio in all layers in the model.
+ blockwise_over_matmul_gemm_conv refers to all-zero blocks' ratio in pruning layers.
+ """
+ pattern_sparsity_cnt = 0
+ element_sparsity_cnt = 0
+ if hasattr(model, 'model'):
+ model = model.model
+ for pruner in pruners:
+ modules = pruner.modules
+ sparsity_ratio = pruner.pattern.get_sparsity_ratio(pruner.masks)
+ cnt = 0
+ for key in modules.keys():
+ cnt += modules[key].get_weights()[0].size
+ pattern_sparsity_cnt += int(cnt * sparsity_ratio)
+ for key in pruner.masks.keys():
+ block_num = 1
+ if pruner.pattern.block:
+ block_size = pruner.pattern.block_size[key]
+ block_num = block_size[0] * block_size[1]
+ element_sparsity_cnt += np.sum(pruner.masks[key] == 0) * block_num
+
+ linear_conv_cnt = 0
+ param_cnt = 0
+ for layer in model.layers:
+ if layer.__class__.__name__ in ["Dense"] or re.search(r'Conv.d', layer.__class__.__name__) != None:
+ linear_conv_cnt += layer.get_weights()[0].size
+
+ for layer in model.layers:
+ if bool(layer.weights):
+ weights = layer.get_weights()[0]
+ param_cnt += weights.size
+ if linear_conv_cnt == 0:
+ blockwise_over_matmul_gemm_conv = 0
+ elementwise_over_matmul_gemm_conv = 0
+ else:
+ blockwise_over_matmul_gemm_conv = float(pattern_sparsity_cnt) / linear_conv_cnt
+ elementwise_over_matmul_gemm_conv = float(element_sparsity_cnt) / linear_conv_cnt
+ if param_cnt == 0:
+ elementwise_over_all = 0
+ else:
+ elementwise_over_all = float(
+ element_sparsity_cnt) / param_cnt
+
+ logger.info(
+ f"elementwise_over_matmul_gemm_conv:{elementwise_over_matmul_gemm_conv},"
+ f" elementwise_over_all:{elementwise_over_all},"
+ f"blockwise_over_matmul_gemm_conv:{blockwise_over_matmul_gemm_conv}")
+
+ return elementwise_over_matmul_gemm_conv, elementwise_over_all, blockwise_over_matmul_gemm_conv
+
def check_config(prune_config):
"""Check if the configuration dict is valid for running Pruning object.
@@ -424,6 +482,19 @@ def parse_last_linear(model):
layer = searcher.search(return_name=True)
return layer
+def parse_last_linear_tf(model):
+ """Locate the last linear layers of the model.
+ While pruning, the final linear often acts like classifier head, which might cause
+ accuracy drop.
+
+ Args:
+ model(tf.keras.Model): The model to be pruned.
+ """
+ from .model_slim.pattern_analyzer import ClassifierHeadSearcherTF
+ searcher = ClassifierHeadSearcherTF(model)
+ layer = searcher.search(return_name=True)
+ return layer
+
def parse_to_prune(config, model):
"""Keep target pruned layers.
@@ -462,6 +533,40 @@ def parse_to_prune(config, model):
new_modules[name] = modules[name]
return new_modules
+def parse_to_prune_tf(config, model):
+ """Keep target pruned layers.
+
+ Args:
+ config(string): A string representing the path to the configuration file.
+ model(tf.keras.Model): The model to be pruned.
+ """
+ modules = {}
+ # additional function: exclude last layer (often a classifier head and not suitable to be pruned)
+ classifier_head_name = parse_last_linear_tf(model)
+ if classifier_head_name != None:
+ config["excluded_op_names"].append(classifier_head_name)
+ # locate target layers
+ if config["op_names"] == None or config["op_names"] == []:
+ config["op_names"] = [".*"]
+
+ for layer in model.layers:
+ for layer_type in config["pruning_op_types"]:
+ if layer_type in layer.__class__.__name__ and bool(layer.weights):
+ modules[layer.name] = layer
+
+ ##remove not to prune layers
+ """Drop non-pruned layers."""
+ exclude_names = config["excluded_op_names"]
+ patterns = [re.compile(s) for s in exclude_names]
+ if len(patterns) <= 0:
+ return modules
+ new_modules = {}
+ for name in modules.keys():
+ if any([p.search(name) for p in patterns]):
+ continue
+ new_modules[name] = modules[name]
+ return new_modules
+
def generate_pruner_config(info):
"""Generate pruner config object from prune information.
diff --git a/neural_compressor/config.py b/neural_compressor/config.py
index 935aff1a526..7e05f6e8ca0 100644
--- a/neural_compressor/config.py
+++ b/neural_compressor/config.py
@@ -1372,12 +1372,13 @@ class WeightPruningConfig:
def __init__(self, pruning_configs=[{}], ##empty dict will use global values
target_sparsity=0.9, pruning_type="snip_momentum", pattern="4x1", op_names=[],
- excluded_op_names=[],
+ excluded_op_names=[], backend=None,
start_step=0, end_step=0, pruning_scope="global", pruning_frequency=1,
min_sparsity_ratio_per_op=0.0, max_sparsity_ratio_per_op=0.98,
sparsity_decay_type="exp", pruning_op_types=['Conv', 'Linear'],
**kwargs):
"""Init a WeightPruningConfig object."""
+ self.backend = backend
self.pruning_configs = pruning_configs
self._weight_compression = DotDict({
'target_sparsity': target_sparsity,
diff --git a/neural_compressor/model/model.py b/neural_compressor/model/model.py
index 0d0aa2b374a..e73efbb4feb 100644
--- a/neural_compressor/model/model.py
+++ b/neural_compressor/model/model.py
@@ -167,10 +167,10 @@ def __new__(cls, root, **kwargs):
if isinstance(root, BaseModel):
if conf != "NA" and conf.framework is None:
conf.framework = list(MODELS.keys())[list(MODELS.values()).index(type(root))]
- if conf.backend == "ipex":
+ if hasattr(conf, 'backend') and conf.backend == "ipex":
assert conf.framework == "pytorch_ipex",\
"Please wrap the model with correct Model class!"
- if conf.backend == "itex":
+ if hasattr(conf, 'backend') and conf.backend == "itex":
if get_model_type(root.model) == 'keras':
assert conf.framework == "keras",\
"Please wrap the model with KerasModel class!"
@@ -213,10 +213,10 @@ def __new__(cls, root, **kwargs):
return MODELS[framework](root, **kwargs)
else:
conf.framework = framework
- if conf.backend == "default":
+ if hasattr(conf, 'backend') and conf.backend == "default":
if framework == "pytorch":
conf.framework = "pytorch_fx"
- elif conf.backend == "ipex":
+ elif hasattr(conf, 'backend') and conf.backend == "ipex":
conf.framework = "pytorch_ipex"
if 'tensorflow' in conf.framework:
@@ -227,7 +227,7 @@ def __new__(cls, root, **kwargs):
model_type = kwargs['modelType']
else:
model_type = get_model_type(root)
- if conf.backend == "itex":
+ if hasattr(conf, 'backend') and conf.backend == "itex":
if model_type == 'keras':
conf.framework = "keras"
model = MODELS[conf.framework](root, **kwargs)
@@ -243,7 +243,7 @@ def __new__(cls, root, **kwargs):
root, **kwargs)
else:
model = MODELS[conf.framework](root, **kwargs)
- if 'tensorflow' in conf.framework:
+ if 'tensorflow' in conf.framework and hasattr(conf, 'model_name'):
model.name = conf.model_name
model.output_tensor_names = conf.outputs
model.input_tensor_names = conf.inputs
diff --git a/neural_compressor/training.py b/neural_compressor/training.py
index 5d6d90e5e63..28905f1bd78 100644
--- a/neural_compressor/training.py
+++ b/neural_compressor/training.py
@@ -80,7 +80,7 @@ def __init__(self, model: Callable, confs: Union[Callable, List], **kwargs):
if isinstance(confs, List) and len(confs) > 1:
for conf in confs:
- if isinstance(conf, QuantizationAwareTrainingConfig):
+ if isinstance(conf, QuantizationAwareTrainingConfig) or isinstance(conf, WeightPruningConfig):
self.model = Model(model, conf=conf)
if self.model is None:
self.model = Model(model)
@@ -150,7 +150,7 @@ def __init__(self, model: Callable, confs: Union[Callable, List], **kwargs):
callbacks_list.append(QuantizationAwareTrainingCallbacks(confs, adaptor=self.adaptor))
self.conf = _Config(quantization=confs, benchmark=None, pruning=None, distillation=None, nas=None)
elif isinstance(confs, WeightPruningConfig):
- self.model = Model(model)
+ self.model = Model(model, conf=confs)
callbacks_list.append(PruningCallbacks(confs, model=self.model))
self.conf = _Config(quantization=None, benchmark=None, pruning=confs, distillation=None, nas=None)
elif isinstance(confs, DistillationConfig):
diff --git a/test/pruning_2.x/test_pruning.py b/test/pruning_2.x/test_pruning.py
index c401a4cd269..2398e57c99d 100644
--- a/test/pruning_2.x/test_pruning.py
+++ b/test/pruning_2.x/test_pruning.py
@@ -9,7 +9,11 @@
from neural_compressor.data.dataloaders.pytorch_dataloader import PyTorchDataLoader
from neural_compressor import WeightPruningConfig
from neural_compressor.training import prepare_compression
-
+from neural_compressor.data import DataLoader
+from neural_compressor.adaptor import FRAMEWORKS
+from neural_compressor.conf.dotdict import DotDict
+from neural_compressor.utils import create_obj_from_config
+from neural_compressor.conf.config import default_workspace
class TestPruning(unittest.TestCase):
model = torchvision.models.resnet18()
@@ -72,6 +76,58 @@ def test_pruning_basic(self):
compression_manager.callbacks.on_before_eval()
compression_manager.callbacks.on_after_eval()
+ def test_pruning_keras(self):
+ import tensorflow as tf
+ model = tf.keras.applications.ResNet50V2(weights='imagenet')
+ def train(model, adaptor, compression_manager, train_dataloader):
+ train_cfg = {
+ 'epoch': 1,
+ 'start_epoch': 0,
+ 'execution_mode': 'eager',
+ 'criterion': {'SparseCategoricalCrossentropy': {'reduction': 'sum_over_batch_size'}},
+ 'optimizer': {'SGD': {'learning_rate': 1e-03, 'momentum': 0.9, 'nesterov': True}},
+ }
+ train_cfg = DotDict(train_cfg)
+ train_func = create_obj_from_config.create_train_func(
+ 'tensorflow', \
+ train_dataloader, \
+ adaptor, \
+ train_cfg, \
+ hooks=compression_manager.callbacks.callbacks_list[0].hooks, \
+ callbacks=compression_manager.callbacks.callbacks_list[0])
+ train_func(model)
+
+ tf_datasets = Datasets('tensorflow')
+ dummy_dataset = tf_datasets['dummy'](shape=(100, 224, 224, 3), low=0., high=1., label=True)
+ train_dataloader = DataLoader(dataset=dummy_dataset, batch_size=32,
+ framework='tensorflow', distributed=False)
+
+ framework_specific_info = {
+ 'device': 'cpu', 'random_seed': 9527,
+ 'workspace_path': default_workspace,
+ 'q_dataloader': None, 'format': 'default',
+ 'backend': 'default', 'inputs': [], 'outputs': []
+ }
+ adaptor = FRAMEWORKS['keras'](framework_specific_info)
+
+ configs = WeightPruningConfig(
+ backend='itex',
+ pruning_type='magnitude',
+ pattern='3x1',
+ target_sparsity=0.5,
+ start_step=1,
+ end_step=10,
+ pruning_op_types=['Conv', 'Dense']
+ )
+ compression_manager = prepare_compression(model, confs=configs)
+ compression_manager.callbacks.on_train_begin()
+ model = compression_manager.model
+
+ train(model, adaptor, compression_manager, train_dataloader)
+
+ compression_manager.callbacks.on_train_end()
+ stats, sparsity = model.report_sparsity()
+
if __name__ == "__main__":
unittest.main()