Skip to content

Commit

Permalink
add sparseGPT pruner feature, refactor pruning class (#1111)
Browse files Browse the repository at this point in the history
Signed-off-by: Zhang, Weiwei1 <[email protected]>
  • Loading branch information
WeiweiZhang1 authored Jul 31, 2023
1 parent be393ee commit 88adfc9
Show file tree
Hide file tree
Showing 29 changed files with 1,873 additions and 555 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ sentencepiece
transformers
torch
tqdm
cupy
optimum
einops

Original file line number Diff line number Diff line change
Expand Up @@ -276,33 +276,12 @@ def parse_args():
default=8,
help="Batch size (per device) for the evaluation dataloader.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--lr_scheduler_type",
type=SchedulerType,
default="linear",
help="The scheduler type to use.",
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
)
parser.add_argument(
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
)
Expand Down Expand Up @@ -342,18 +321,6 @@ def parse_args():
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
)
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--checkpointing_steps",
type=str,
default=None,
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="If the training should continue from a checkpoint folder.",
)
parser.add_argument(
"--with_tracking",
action="store_true",
Expand Down Expand Up @@ -395,7 +362,7 @@ def parse_args():
)
parser.add_argument(
"--pruning_pattern",
type=str, default="4x1",
type=str, default="channelx1",
help="pruning pattern type, we support NxM and N:M."
)
parser.add_argument(
Expand Down Expand Up @@ -673,64 +640,25 @@ def group_texts(examples):
# logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")

# DataLoaders creation:
train_dataset = train_dataset.shuffle(seed=42)
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
max_sample_num = args.max_pruning_steps * total_batch_size
train_dataset = train_dataset.shuffle(seed=42).select(range(max_sample_num))
train_dataloader = DataLoader(
train_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size
)
eval_dataloader = DataLoader(
eval_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size
)

# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "layer_norm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

# Scheduler and math around the number of training steps.
args.max_train_steps = args.max_pruning_steps
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True

lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)
# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
model, train_dataloader, eval_dataloader = accelerator.prepare(
model, train_dataloader, eval_dataloader
)

# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
model.tie_weights()

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

# Figure out how many steps we should save the Accelerator states
checkpointing_steps = args.checkpointing_steps
if checkpointing_steps is not None and checkpointing_steps.isdigit():
checkpointing_steps = int(checkpointing_steps)

# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if args.with_tracking:
Expand All @@ -739,49 +667,16 @@ def group_texts(examples):
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("clm_no_trainer", experiment_config)

# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

logger.info("***** Running training *****")
# Pruning!
logger.info("***** Running Pruning *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0
starting_epoch = 0

# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint)
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
# Extract `epoch_{i}` or `step_{i}`
training_difference = os.path.splitext(path)[0]

if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else:
# need to multiply `gradient_accumulation_steps` to reflect real steps
resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps
starting_epoch = resume_step // len(train_dataloader)
resume_step -= starting_epoch * len(train_dataloader)

# update the progress_bar if load from checkpoint
progress_bar.update(starting_epoch * num_update_steps_per_epoch)
completed_steps = starting_epoch * num_update_steps_per_epoch
logger.info(f" Total pruning steps = {args.max_pruning_steps}")

# Pruning preparation
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_iterations = num_update_steps_per_epoch
num_warm = args.num_warmup_steps
total_iterations = args.max_pruning_steps
Expand All @@ -790,7 +685,7 @@ def group_texts(examples):
pruning_start = max(num_warm, 1)
pruning_end = max(total_iterations - 1, pruning_start)
if not args.do_prune:
pruning_start = num_iterations * args.num_train_epochs + 1
pruning_start = args.max_pruning_steps + 1
pruning_end = pruning_start

if not args.auto_config:
Expand All @@ -804,7 +699,7 @@ def group_texts(examples):
"pattern": "channelx1",
"pruning_op_types": ["Linear"],
"max_sparsity_ratio_per_op": 0.98,
}
},
]
else:
# auto config
Expand All @@ -826,52 +721,9 @@ def group_texts(examples):
start_step=pruning_start,
end_step=pruning_end,
)
compression_manager = prepare_compression(model=model, confs=configs)
compression_manager.callbacks.on_train_begin()
model = compression_manager.model.model

for epoch in range(starting_epoch, args.num_train_epochs):
# model.train()
model.eval()
if args.with_tracking:
total_loss = 0
for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == starting_epoch:
if resume_step is not None and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
completed_steps += 1
continue
compression_manager.callbacks.on_step_begin(step)
with accelerator.accumulate(model):
outputs = model(return_dict=True, **batch)
# outputs = model(**batch)
loss = outputs.loss
# We keep track of the loss at each epoch
if args.with_tracking:
total_loss += loss.detach().float()
accelerator.backward(loss)
compression_manager.callbacks.on_before_optimizer_step()
# optimizer.step()
compression_manager.callbacks.on_after_optimizer_step()
# lr_scheduler.step()
optimizer.zero_grad()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
completed_steps += 1

if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if completed_steps >= args.max_pruning_steps:
break
compression_manager.callbacks.on_train_end()
from neural_compressor.compression.pruner import prepare_pruning
pruning = prepare_pruning(configs, model, dataloader=train_dataloader)

model.eval()
if args.evaluation_dataset_name != None:
Expand All @@ -892,19 +744,22 @@ def eval_func(model):
if args.output_dir is not None:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
output_dir = args.output_dir
if args.auto_slim:
output_dir += "/before_slim"
unwrapped_model.save_pretrained(
args.output_dir+"/noslim", is_main_process=accelerator.is_main_process, save_function=accelerator.save
output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir+"/noslim")
tokenizer.save_pretrained(output_dir)
if args.push_to_hub:
repo.push_to_hub(commit_message="End of pruning", auto_lfs_prune=True)

if not args.auto_slim:
# only eval
logger.info(f"***** Running Evaluation *****")
acc, _ = eval_func(model)
logger.info(f"total_steps:{completed_steps} accuracy:{acc}")
logger.info(f"total_steps:{args.max_pruning_steps} accuracy:{acc}")
else:
logger.info(f"***** Running Evaluation before ffn auto slim*****")
accuracy, avg_latency = eval_func(model)
Expand Down
Loading

0 comments on commit 88adfc9

Please sign in to comment.