Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixtral scaling: Reduce perplexity from 4.294 to 4.269 #301

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
28 changes: 24 additions & 4 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from awq.models._config import AwqConfig
from awq.modules.act import ScaledActivation
from awq.modules.moe import ScaledMixtralSparseMoeBlock
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.modules.exllama import WQLinear_Exllama
from awq.modules.exllamav2 import WQLinear_ExllamaV2
Expand Down Expand Up @@ -96,11 +97,13 @@ def quantize(
split="train",
text_column="text",
duo_scaling=True,
modules_to_not_convert=None,
export_compatible=False,
):
self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config)

if hasattr(self, "modules_to_not_convert"):
self.quant_config.modules_to_not_convert = self.modules_to_not_convert

self.quantizer = AwqQuantizer(
self,
self.model,
Expand All @@ -112,7 +115,7 @@ def quantize(
split,
text_column,
duo_scaling,
modules_to_not_convert=modules_to_not_convert,
modules_to_not_convert=self.quant_config.modules_to_not_convert,
export_compatible=export_compatible,
)
self.quantizer.quantize()
Expand Down Expand Up @@ -402,6 +405,9 @@ def _load_quantized_modules(
# Replace activation functions
self._scale_activations(self, layer)

# Replace mixture of experts
self._scale_moe(self, layer)

# Replace nn.Linear with WQLinear
for name, module in named_linears.items():
if use_exllama:
Expand Down Expand Up @@ -436,5 +442,19 @@ def _scale_activations(self, layer):
)

# scale activation
scaled_act = ScaledActivation(scale_dict["scale_layer"], scale_like)
set_op_by_name(layer, scale_dict["scale_name"], scaled_act)
scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
set_op_by_name(layer, scale_dict['scale_name'], scaled_act)

def _scale_moe(self, layer):
if hasattr(self, "get_moe_for_scaling") and hasattr(layer.block_sparse_moe, "scales"):
scale_dict: dict = self.get_moe_for_scaling(layer)

if not isinstance(scale_dict['scale_layer'], ScaledMixtralSparseMoeBlock):
param = next(layer.parameters())

# get activation scale
scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)

# scale moe
scaled_act = ScaledMixtralSparseMoeBlock(scale_dict['scale_layer'], scale_like)
set_op_by_name(layer, scale_dict['scale_name'], scaled_act)
39 changes: 30 additions & 9 deletions awq/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,26 @@
from awq.modules.fused.model import MixtralModel
from transformers.models.mixtral.modeling_mixtral import (
MixtralDecoderLayer as OldMixtralDecoderLayer,
MixtralForCausalLM as OldMixtralForCausalLM
MixtralForCausalLM as OldMixtralForCausalLM,
MixtralBLockSparseTop2MLP as OldMixtralBLockSparseTop2MLP,
)
from awq.modules.fused.norm import FasterTransformerRMSNorm

def _transformers_version_check():
import transformers
tv = transformers.__version__.split('.')
if len(tv) == 4:
major, minor, patch, dev = tv
else:
major, minor, patch = tv

if int(major) == 4 and int(minor) < 37:
raise Exception("Mixtral requires a minimum of 4.37.0.dev0: pip install git+https://github.com/huggingface/transformers.git")

class MixtralAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MixtralDecoderLayer"
max_new_tokens_key = "max_position_embeddings"
modules_to_not_convert = ["gate"]

@staticmethod
def fuse_layers(model: OldMixtralForCausalLM):
Expand All @@ -21,6 +34,7 @@ def fuse_layers(model: OldMixtralForCausalLM):

@staticmethod
def get_model_layers(model: OldMixtralForCausalLM):
_transformers_version_check()
return model.model.layers

@staticmethod
Expand All @@ -29,6 +43,14 @@ def get_act_for_scaling(module):
is_scalable=False
)

@staticmethod
def get_moe_for_scaling(module: OldMixtralDecoderLayer):
return dict(
scale_name="block_sparse_moe",
scale_layer=module.block_sparse_moe,
scale_shape=(module.block_sparse_moe.num_experts, module.block_sparse_moe.hidden_dim),
)

@staticmethod
def move_embed(model: OldMixtralForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
Expand All @@ -53,19 +75,18 @@ def get_layers_for_scaling(module: OldMixtralDecoderLayer, input_feat, module_kw
layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'],
))

# linear in

# NOTE: Scaled in awq.quantize.scale.scale_moe_experts, awq.modules.moe.ScaledMixtralSparseMoeBlock
# Experts: Not a linear layer, special handling is introduced in awq.quantize.quantizer
layers.append(dict(
prev_op=module.post_attention_layernorm,
layers=[
w for expert in module.block_sparse_moe.experts
for w in [expert.w1, expert.w3]
],
prev_op=module.block_sparse_moe,
layers=module.block_sparse_moe.experts,
inp=input_feat['block_sparse_moe'],
module2inspect=module.block_sparse_moe,
))

# linear out
# scaling w2
expert: OldMixtralBLockSparseTop2MLP
for i, expert in enumerate(module.block_sparse_moe.experts):
layers.append(dict(
prev_op=expert.w3,
Expand Down
75 changes: 75 additions & 0 deletions awq/modules/moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import torch
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

class ScaledMixtralSparseMoeBlock(torch.nn.Module):
"""
This is a modified sparse MoE that scales experts individually.

Modified version of:
transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock
"""

def __init__(self, prev_op: MixtralSparseMoeBlock, scales: torch.Tensor):
super().__init__()
self.hidden_dim = prev_op.hidden_dim
self.ffn_dim = prev_op.ffn_dim
self.num_experts = prev_op.num_experts
self.top_k = prev_op.top_k

# gating
self.gate = prev_op.gate

# experts
self.experts = prev_op.experts

# [expert_num, hidden_dim]
self.scales = torch.nn.Parameter(scales.data)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)

routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)

final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)

# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])

if top_x.shape[0] == 0:
continue

# in torch it is faster to index using lists than torch tensors
top_x_list = top_x.tolist()
idx_list = idx.tolist()

# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)

### NOTE: We scale weights here, modified from original MoE.
current_state = hidden_states[None, top_x_list].reshape(
-1, hidden_dim) / self.scales[expert_idx] # <-- scales

current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]

# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
54 changes: 39 additions & 15 deletions awq/quantize/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import functools
import torch.nn as nn
from tqdm import tqdm
from typing import Dict, List
from collections import defaultdict
from typing import Dict, List, Union
from awq.utils.utils import clear_memory
from transformers import PreTrainedModel
from awq.utils.calib_data import get_calib_dataset
from awq.quantize.scale import apply_scale, apply_clip
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
Expand All @@ -17,14 +18,15 @@
set_op_by_name,
exclude_layers_to_not_quantize
)
from transformers.models.mixtral.modeling_mixtral import MixtralBLockSparseTop2MLP


class AwqQuantizer:
def __init__(self, awq_model, model, tokenizer, w_bit, group_size, version,
calib_data, split, text_column, duo_scaling, modules_to_not_convert=None,
export_compatible=False) -> None:
self.awq_model = awq_model
self.model = model
self.model: PreTrainedModel = model
self.tokenizer = tokenizer
self.w_bit = w_bit
self.group_size = group_size
Expand Down Expand Up @@ -162,7 +164,8 @@ def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]):
clear_memory()

@torch.no_grad()
def _search_best_scale(self, module, prev_op, layers: List[nn.Linear], inp: torch.Tensor, module2inspect=None, kwargs={}):
def _search_best_scale(self, module, prev_op, layers: Union[List[nn.Linear], List[MixtralBLockSparseTop2MLP]],
inp: torch.Tensor, module2inspect=None, kwargs={}):
if module2inspect is None:
assert len(layers) == 1
module2inspect = layers[0]
Expand All @@ -174,13 +177,23 @@ def _search_best_scale(self, module, prev_op, layers: List[nn.Linear], inp: torc
inp = inp.to(next(module2inspect.parameters()).device)

# [STEP 1]: Compute maximum of weight
weight = torch.cat([_m.weight for _m in layers], dim=0)
org_shape = weight.shape
weight = weight.view(-1, self.group_size)
w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
w_scale = w_scale.view(org_shape)
w_max = w_scale.mean(0)
clear_memory(weight)
def _get_w_max(layer_weights):
weight = torch.cat([_m.weight for _m in layer_weights], dim=0)
org_shape = weight.shape
weight = weight.view(-1, self.group_size)
w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
w_scale = w_scale.view(org_shape)
w_max = w_scale.mean(0)
clear_memory(weight)

return w_max

if type(layers[0]) == nn.Linear:
w_max = _get_w_max(layers)
else:
# FIXME: Specific to Mixtral
weights = [[expert.w1, expert.w3] for expert in layers]
w_max = [_get_w_max(weight) for weight in weights]

# [STEP 2]: Compute maximum of x
x_max = inp.abs().view(-1, inp.shape[-1]).mean(0)
Expand All @@ -194,12 +207,23 @@ def _search_best_scale(self, module, prev_op, layers: List[nn.Linear], inp: torc
fp16_output = fp16_output[0]

# [STEP 4]: Compute loss
best_scales = self._compute_best_scale(
inp, w_max, x_max, module2inspect,
layers, fp16_output, module_kwargs
)
if type(layers[0]) == nn.Linear:
best_scales = self._compute_best_scale(
inp, w_max, x_max, module2inspect,
layers, fp16_output, module_kwargs
)
else:
best_scales = [
self._compute_best_scale(
inp, w_max[i], x_max, module2inspect,
experts, fp16_output, module_kwargs
) for i, experts in enumerate(weights)
]

prev_op_name = get_op_name(module, prev_op)
layer_names = tuple([get_op_name(module, m) for m in layers])

return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), best_scales)
return (prev_op_name, layer_names, best_scales)

def _compute_best_scale(self, x, w_max, x_max, module2inspect, linears2scale: List[nn.Linear],
fp16_output, kwargs={}):
Expand Down
44 changes: 42 additions & 2 deletions awq/quantize/scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,19 @@
import torch.nn as nn
from typing import Tuple, List
from awq.modules.act import ScaledActivation
from awq.modules.moe import ScaledMixtralSparseMoeBlock
from awq.utils.module import get_op_by_name, set_op_by_name
from transformers.models.bloom.modeling_bloom import BloomGelu
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers.activations import NewGELUActivation, PytorchGELUTanh, GELUActivation
from transformers.models.mixtral.modeling_mixtral import (
MixtralSparseMoeBlock,
MixtralBLockSparseTop2MLP,
)

allowed_norms = [nn.LayerNorm, LlamaRMSNorm]
allowed_act_fns = [nn.GELU, BloomGelu, NewGELUActivation, PytorchGELUTanh, GELUActivation]
allowed_moe = [MixtralSparseMoeBlock]

@torch.no_grad()
def apply_clip(module, clip_list: Tuple[str, torch.Tensor]):
Expand All @@ -31,7 +37,12 @@ def apply_scale(module, scales_list, input_feat_dict=None):
prev_op.cuda()
for layer in layers:
layer.cuda()
scales.cuda()

if type(scales) == list:
for scale in scales:
scale.cuda()
else:
scales.cuda()

if isinstance(prev_op, nn.Linear) and type(layers) == list and isinstance(layers[0], nn.Linear):
scale_fc_fcs(prev_op, layers, scales)
Expand All @@ -48,6 +59,15 @@ def apply_scale(module, scales_list, input_feat_dict=None):
new_module = ScaledActivation(prev_op, scales)
set_op_by_name(module, prev_op_name, new_module)
scale_gelu_fc(prev_op, layers[0], scales)

elif any(isinstance(prev_op,t) for t in allowed_moe):
# scales: [best_scale_expert_0, best_scale_expert_1, ...] -> [expert_index, scales]
scales = torch.stack(scales).cuda()

# apply scales
new_module = ScaledMixtralSparseMoeBlock(prev_op, scales)
set_op_by_name(module, prev_op_name, new_module)
scale_moe_experts(prev_op, layers, scales)

else:
raise NotImplementedError(
Expand All @@ -64,7 +84,12 @@ def apply_scale(module, scales_list, input_feat_dict=None):
prev_op.cpu()
for layer in layers:
layer.cpu()
scales.cpu()

if type(scales) == list:
for scale in scales:
scale.cpu()
else:
scales.cpu()

@torch.no_grad()
def scale_ln_fcs(ln: nn.Linear, fcs: List[nn.Linear], scales: torch.Tensor):
Expand Down Expand Up @@ -133,3 +158,18 @@ def scale_gelu_fc(gelu: allowed_act_fns, fc: nn.Linear, scales: torch.Tensor):

for p in fc.parameters():
assert torch.isnan(p).sum() == 0

@torch.no_grad()
def scale_moe_experts(moe: MixtralSparseMoeBlock, experts: List[MixtralBLockSparseTop2MLP], scales: torch.Tensor):
assert any(isinstance(moe, allowed_module) for allowed_module in allowed_moe)
assert all(isinstance(m, MixtralBLockSparseTop2MLP) for m in experts)

# One scale for each expert, applied to w1 and w3 only
# Not applied to w2 because it does not take hidden_states as input
for i, expert in enumerate(experts):
expert.w1.weight.mul_(scales[i].view(1, -1))
expert.w3.weight.mul_(scales[i].view(1, -1))

for expert in experts:
for p in expert.parameters():
assert torch.isnan(p).sum() == 0
Loading