-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
1,966 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
from torch.utils.data import Dataset | ||
import torch | ||
import tqdm | ||
import json | ||
import logging | ||
logger = logging.getLogger(__name__) | ||
|
||
class VAEDataset(Dataset): | ||
def __init__(self, source_path, tokenizer, device=torch.device('cuda:0')): | ||
self.data = [] | ||
self.tokenizer = tokenizer | ||
self.device = device | ||
with open(source_path) as f: | ||
for line in tqdm.tqdm(f, desc='Loading data...'): | ||
line = line.strip() | ||
if line == '': | ||
continue | ||
line = line.split('\t')[-1] | ||
self.data.append(line) | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def __getitem__(self, idx): | ||
return self.tokenizer.encode(self.data[idx]) | ||
|
||
@staticmethod | ||
def create_mask(num_tokens, max_len): | ||
base_position_matrix = torch.arange( | ||
0, max_len, dtype=num_tokens.dtype, device=num_tokens.device).view(1, -1) | ||
mask = (base_position_matrix < num_tokens.view(-1, 1)).type_as(num_tokens) | ||
return mask | ||
|
||
def collate_fn(self, samples): | ||
samples = [[self.tokenizer.bos_id] + s + [self.tokenizer.eos_id] for s in samples] | ||
length_list = [len(s) for s in samples] | ||
max_t = max(length_list) | ||
new_samples = [s + [self.tokenizer.pad_id] * (max_t - len(s)) for s in samples] | ||
new_samples = torch.LongTensor(new_samples) | ||
attention_mask = self.create_mask(torch.LongTensor(length_list), max_t) | ||
return { | ||
'input_ids': new_samples.to(self.device), | ||
'attention_mask': attention_mask.byte().to(self.device), | ||
} | ||
|
||
class WPDataset(Dataset): | ||
def __init__(self, source_path, tokenizer, device=torch.device('cuda:0'), max_length=700, add_prefix=False, add_special_token=False): | ||
self.source = [] | ||
self.target = [] | ||
self.tokenizer = tokenizer | ||
self.device = device | ||
self.max_length = max_length | ||
self.add_special_token = add_special_token | ||
self.add_prefix = add_prefix | ||
with open(source_path) as f: | ||
for line in tqdm.tqdm(f, desc='Loading data...'): | ||
line = json.loads(line.strip()) | ||
source = line['source'].replace('<newline>', '\n') | ||
target = line['target'].replace('<newline>', '\n') | ||
if len(source.split()) + len(target.split()) < self.max_length: | ||
self.source.append(source) | ||
self.target.append(target) | ||
|
||
def __len__(self): | ||
return len(self.source) | ||
|
||
def __getitem__(self, idx): | ||
source = self.tokenizer.encode(self.source[idx]) | ||
target = self.tokenizer.encode(self.target[idx]) | ||
return source, target | ||
|
||
@staticmethod | ||
def create_mask(num_tokens, max_len): | ||
base_position_matrix = torch.arange( | ||
0, max_len, dtype=num_tokens.dtype, device=num_tokens.device).view(1, -1) | ||
mask = (base_position_matrix < num_tokens.view(-1, 1)).type_as(num_tokens) | ||
return mask | ||
|
||
def collate_fn(self, samples): | ||
source_initial = [item[0] for item in samples] | ||
target_initial = [item[1] for item in samples] | ||
source = [s + [self.tokenizer.eos_id] for s in source_initial] | ||
target = [s + t + [self.tokenizer.eos_id] for s, t in zip(source, target_initial)] | ||
labels = [[self.tokenizer.pad_id] * len(s) + t + [self.tokenizer.eos_id] for s, t in zip(source, target_initial)] | ||
source = [item[:self.max_length] for item in source] | ||
target = [item[:self.max_length] for item in target] | ||
labels = [item[:self.max_length] for item in labels] | ||
|
||
source_length_list = [len(s) for s in source] | ||
source_max_t = max(source_length_list) | ||
new_source = [s + [self.tokenizer.pad_id] * (source_max_t - len(s)) for s in source] | ||
new_source = torch.LongTensor(new_source) | ||
source_attention_mask = self.create_mask(torch.LongTensor(source_length_list), source_max_t) | ||
|
||
target_length_list = [len(s) for s in target] | ||
target_max_t = max(target_length_list) | ||
new_target = [s + [self.tokenizer.pad_id] * (target_max_t - len(s)) for s in target] | ||
new_labels = [s + [self.tokenizer.pad_id] * (target_max_t - len(s)) for s in labels] | ||
new_target = torch.LongTensor(new_target) | ||
new_labels = torch.LongTensor(new_labels) | ||
target_attention_mask = self.create_mask(torch.LongTensor(target_length_list), target_max_t) | ||
|
||
return { | ||
'input_ids': new_target.to(self.device), | ||
'attention_mask': target_attention_mask.byte().to(self.device), | ||
'labels': new_labels.to(self.device), | ||
'condition': new_source.to(self.device), | ||
'condition_mask': source_attention_mask.byte().to(self.device), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import torch | ||
import numpy as np | ||
|
||
@torch.jit.script | ||
def soft_clamp5(x: torch.Tensor): | ||
return x.div_(5.).tanh_().mul(5.) | ||
|
||
class Normal: | ||
def __init__(self, mu, log_sigma): | ||
self.mu = torch.clamp(mu, -5, 5) | ||
log_sigma = torch.clamp(log_sigma, -5, 5) | ||
self.std = log_sigma.mul(0.5).exp() | ||
|
||
def sample(self): | ||
eps = self.mu.mul(0).normal_() | ||
z = eps.mul_(self.std).add_(self.mu) | ||
return z, eps | ||
|
||
@staticmethod | ||
def get_standard(bs, nz, device): | ||
zeros = torch.zeros(bs, nz).to(device) | ||
return Normal(zeros, zeros) | ||
|
||
def sample_given_eps(self, eps): | ||
return eps * self.std + self.mu | ||
|
||
def log_p(self, samples): | ||
normalized_samples = (samples - self.mu) / self.std | ||
log_p = - 0.5 * normalized_samples * normalized_samples - 0.5 * np.log(2 * np.pi) - torch.log(self.std) | ||
log_p = torch.sum(log_p, dim=-1) | ||
return log_p | ||
|
||
def kl(self, normal_dist): | ||
assert normal_dist.mu.shape == self.mu.shape | ||
term1 = (self.mu - normal_dist.mu) / normal_dist.std | ||
term2 = self.std / normal_dist.std | ||
loss = 0.5 * (term1 * term1 + term2 * term2) - 0.5 - torch.log(term2) | ||
loss = torch.sum(loss, dim=-1) | ||
return loss | ||
|
||
def set_device(self, cuda_id): | ||
self.mu = self.mu.to(cuda_id) | ||
self.std = self.std.to(cuda_id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
import argparse | ||
import logging | ||
import os | ||
import json | ||
import torch | ||
import random | ||
import numpy as np | ||
import time | ||
|
||
from torch.nn.parallel import DataParallel | ||
from torch.utils.data import DataLoader | ||
|
||
from dataset import VAEDataset, WPDataset | ||
from train import train, valid, generate | ||
|
||
from model import Della | ||
|
||
from transformers import AutoConfig, AutoModel, AutoTokenizer | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--train_file", default='./data/yelp/yelp.train.txt', type=str, | ||
help="Data path for training.") | ||
parser.add_argument("--valid_file", default='./data/yelp/yelp.train.txt', type=str, | ||
help="Data path for valid") | ||
parser.add_argument("--test_file", default='./data/yelp/yelp.train.txt', type=str, | ||
help="Data path for test") | ||
parser.add_argument("--pretrained_model", type=str, default='gpt2', | ||
help="Pretrained model to be loaded") | ||
parser.add_argument("--dataset_type", type=str, default='vae', choices=['vae', 'wp'], | ||
help="Dataset type") | ||
parser.add_argument("--output_dir", default='./checkpoints', type=str, | ||
help="The output directory where the model checkpoints and predictions will be written.") | ||
parser.add_argument("--model_name", default='della', type=str, | ||
help="The model name") | ||
parser.add_argument("--generation_output_dir", default='./generation_output', type=str, | ||
help="The output directory where the log will be written.") | ||
# Other parameters\ | ||
parser.add_argument("--load_epoch", default=None, type=int, help="the epochs of trained model to load") | ||
parser.add_argument("--epochs", default=40, type=int, help="total epochs") | ||
parser.add_argument("--per_gpu_train_batch_size", default=4, type=int,help="Batch size per GPU for training.") | ||
parser.add_argument("--no_cuda", action='store_true', | ||
help="Whether not to use CUDA when available") | ||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") | ||
parser.add_argument('--gradient_accumulation_steps', type=int, default=8, | ||
help="Number of updates steps to accumulate before performing a backward/update pass.") | ||
parser.add_argument("--weight_decay", default=0.01, type=float, | ||
help="Weight decay if we apply some.") | ||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, | ||
help="Epsilon for Adam optimizer.") | ||
parser.add_argument("--kl_threshold", default=0, type=float, | ||
help="The threshold of the minimum KL value, default as 0") | ||
parser.add_argument("--latent_size", default=32, type=int, | ||
help="The dimension of latent space") | ||
parser.add_argument("--latent_lmf_rank", default=4, type=int, | ||
help="latent size") | ||
parser.add_argument("--max_length", default=200, type=int, | ||
help="Max length for generation") | ||
parser.add_argument('--seed', type=int, default=42, | ||
help="Random seed for initialization") | ||
parser.add_argument('--log_step', type=int, default=100, | ||
help="Steps for logging") | ||
parser.add_argument('--num_beams', type=int, default=10, | ||
help="Beam size for searching") | ||
parser.add_argument('--greedy_decoding', action='store_true', | ||
help="Choose to use greedy decoding") | ||
parser.add_argument('--top_k', type=int, default=-1, help='Set top k') | ||
parser.add_argument('--top_p', type=float, default=0.9, help='Set top p') | ||
parser.add_argument('--repetition_penalty', type=float, default=1.2) | ||
parser.add_argument('--model_parallel', action='store_true', | ||
help="Choose to use model parallel, mapping the layers to different devices") | ||
parser.add_argument('--eval', action='store_true', help='Choose to eval the model') | ||
parser.add_argument('--eval_metrics', action='store_true', | ||
help="Choose to eval the metrics for representation learning") | ||
parser.add_argument('--generation', action='store_true', help='Choose to generate') | ||
parser.add_argument('--use_scheduler', action='store_true', | ||
help="Choose to use lr scheduler") | ||
parser.add_argument('--cycle_annealing', action='store_true', | ||
help="Choose to use cycle annealing") | ||
parser.add_argument('--cycle_iters', type=int, default=2, | ||
help="Set the iters for cycle annealing") | ||
parser.add_argument('--sample_times', type=int, default=30, | ||
help="The total times of sample when computing PPL with importance weighted sampling") | ||
parser.add_argument('--use_bow', action='store_true', | ||
help="Choose to use bow loss") | ||
parser.add_argument('--bow_weight',type=float, default=0.2, | ||
help="Set the weight of bow loss term") | ||
parser.add_argument("--begin_layer", default=None, type=int, | ||
help="The beginning layer to consider the latent vector, default as the first layer of model") | ||
parser.add_argument("--end_layer", default=None, type=int, | ||
help="The end layer to consider the latent vector, default as the last layer of model") | ||
args = parser.parse_args() | ||
return args | ||
|
||
def prepare(args): | ||
torch.set_num_threads(3) | ||
|
||
if not args.eval and not args.generation: | ||
os.makedirs(os.path.join(args.output_dir, args.model_name), exist_ok=True) | ||
json.dump(args.__dict__, open(os.path.join( | ||
args.output_dir, args.model_name, 'train_opt.json'), 'w'), sort_keys=True, indent=2) | ||
|
||
if args.no_cuda: | ||
args.n_gpu = 1 | ||
else: | ||
args.n_gpu = torch.cuda.device_count() | ||
args.batch_size = args.per_gpu_train_batch_size * args.n_gpu | ||
|
||
# Setup logging | ||
logger = logging.getLogger() | ||
logger.setLevel(logging.INFO) | ||
formatter = logging.Formatter("%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s") | ||
sh = logging.StreamHandler() | ||
sh.setFormatter(formatter) | ||
logger.addHandler(sh) | ||
|
||
# Set seed | ||
random.seed(args.seed) | ||
np.random.seed(args.seed) | ||
torch.manual_seed(args.seed) | ||
|
||
if args.n_gpu > 0: | ||
torch.cuda.manual_seed_all(args.seed) | ||
|
||
logger.info("Training/evaluation parameters %s", args) | ||
|
||
if args.no_cuda: | ||
args.device = torch.device('cpu') | ||
else: | ||
args.device = torch.device('cuda:0') | ||
|
||
def init_para_frompretrained(model, gpt2): | ||
logger.info('load gpt2 pretrained model parameters') | ||
model = model.encoder | ||
model.wte.weight = gpt2.wte.weight | ||
model.wpe.weight = gpt2.wpe.weight | ||
|
||
for i in range(len(gpt2.h)): | ||
model.h[i].ln_1.weight = gpt2.h[i].ln_1.weight | ||
model.h[i].ln_1.bias = gpt2.h[i].ln_1.bias | ||
model.h[i].attn.c_attn.weight = gpt2.h[i].attn.c_attn.weight | ||
model.h[i].attn.c_attn.bias = gpt2.h[i].attn.c_attn.bias | ||
model.h[i].attn.c_proj.weight = gpt2.h[i].attn.c_proj.weight | ||
model.h[i].attn.c_proj.bias = gpt2.h[i].attn.c_proj.bias | ||
model.h[i].ln_2.weight = gpt2.h[i].ln_2.weight | ||
model.h[i].ln_2.bias = gpt2.h[i].ln_2.bias | ||
model.h[i].mlp.c_fc.weight = gpt2.h[i].mlp.c_fc.weight | ||
model.h[i].mlp.c_fc.bias = gpt2.h[i].mlp.c_fc.bias | ||
model.h[i].mlp.c_proj.weight = gpt2.h[i].mlp.c_proj.weight | ||
model.h[i].mlp.c_proj.bias = gpt2.h[i].mlp.c_proj.bias | ||
|
||
model.ln_f.weight = gpt2.ln_f.weight | ||
model.ln_f.bias = gpt2.ln_f.bias | ||
|
||
def prepare_model(args): | ||
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model) | ||
if '<s>' not in tokenizer.vocab: | ||
tokenizer._add_tokens(['<s>']) | ||
if '</s>' not in tokenizer.vocab: | ||
tokenizer._add_tokens(['</s>']) | ||
tokenizer.pad_id = 50256 | ||
|
||
tokenizer.bos_id = tokenizer.convert_tokens_to_ids('<s>') | ||
tokenizer.eos_id = tokenizer.convert_tokens_to_ids('</s>') | ||
|
||
model_config = AutoConfig.from_pretrained(args.pretrained_model) | ||
model_config.vocab_size = len(tokenizer) | ||
model_config.pad_token_id = tokenizer.pad_id | ||
model_config.kl_threshold = args.kl_threshold | ||
model_config.is_cvae = (args.dataset_type == 'wp') | ||
model_config.use_bow = args.use_bow | ||
model_config.begin_layer = args.begin_layer | ||
model_config.end_layer = args.end_layer | ||
|
||
for arg in vars(args): | ||
if arg.startswith('latent'): | ||
setattr(model_config, arg, getattr(args, arg)) | ||
|
||
model = Della(model_config) | ||
pretrained_model = AutoModel.from_pretrained(args.pretrained_model) | ||
logging.info('loading pretrained model parameters...') | ||
init_para_frompretrained(model, pretrained_model) | ||
model.encoder.resize_token_embeddings(len(tokenizer)) | ||
model.decoder.wte = model.encoder.wte | ||
if args.load_epoch is not None: | ||
model_path = os.path.join(args.output_dir, args.model_name, 'model_epoch_{}.pt'.format(args.load_epoch)) | ||
model_state_dict = torch.load(model_path, map_location=args.device) | ||
model.load_state_dict(model_state_dict) | ||
logging.info('load model_epoch_{}.pt finish'.format(args.load_epoch)) | ||
else: | ||
args.load_epoch = -1 | ||
|
||
if args.model_parallel and torch.cuda.device_count() > 1: | ||
logging.info('model paralleize...') | ||
model.parallelize() | ||
else: | ||
model = model.to(args.device) | ||
if torch.cuda.device_count() > 1: | ||
model = DataParallel(model) | ||
return model, tokenizer | ||
|
||
def prepare_data(tokenizer, args): | ||
dataset_class = {'vae': VAEDataset, 'wp': WPDataset} | ||
if args.eval or args.generation: | ||
logging.info("eval model: the epoch {} of {}".format(args.load_epoch, args.model_name)) | ||
test_dataset = dataset_class[args.dataset_type](args.test_file, tokenizer, args.device) | ||
test_iter = DataLoader(test_dataset, batch_size=args.batch_size, collate_fn=test_dataset.collate_fn) | ||
return test_iter | ||
else: | ||
train_dataset = dataset_class[args.dataset_type](args.train_file, tokenizer, args.device) | ||
valid_dataset = dataset_class[args.dataset_type](args.valid_file, tokenizer, args.device) | ||
train_iter = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=train_dataset.collate_fn) | ||
valid_iter = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=valid_dataset.collate_fn) | ||
logging.info('training with {} samples...'.format(len(train_dataset))) | ||
return train_iter, valid_iter | ||
|
||
def main(): | ||
args = get_args() | ||
prepare(args) | ||
model, tokenizer = prepare_model(args) | ||
total_params = sum(p.numel() for p in model.parameters()) | ||
logging.info('total parameters: {}'.format(total_params)) | ||
if args.eval or args.generation: | ||
test_iter = prepare_data(tokenizer, args) | ||
if args.eval: | ||
valid(model, test_iter, args.load_epoch, args) | ||
if args.generation: | ||
generate(model, test_iter, tokenizer, args) | ||
else: | ||
train_iter, valid_iter = prepare_data(tokenizer, args) | ||
train(model, train_iter, valid_iter, args) | ||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.