-
Notifications
You must be signed in to change notification settings - Fork 2
/
embedding_of_ins.py
79 lines (66 loc) · 2.79 KB
/
embedding_of_ins.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
from transformers import AutoTokenizer, PreTrainedTokenizerFast, AutoModel
import json
from tqdm import tqdm
import numpy as np
import os
import argparse
def embed_texts_batched(texts, batch_size=30):
all_embeddings = []
for i in tqdm(range(0, len(texts), batch_size)):
batch = texts[i:i+batch_size]
tokens = tokenizer(batch, return_tensors="pt", truncation=True, padding='max_length', max_length=512)
tokens = {k: v.cuda() for k, v in tokens.items()}
with torch.no_grad():
outputs = model(**tokens)
embeddings = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
all_embeddings.extend(embeddings)
return all_embeddings
def dataset_preprocess(raw_data):
data = []
for e in raw_data:
data.append({"instruction": e["instruction"].strip(), "input": e["input"].strip(), "output": e["output"].strip()})
return data
def multiple_gen_promptify(instruction, input, output):
if input != "":
with_query = f"Instruction:\n{instruction}\nInput:\n{input}\nResponse:\n"
else:
with_query = f"Instruction:\n{instruction}\nResponse:\n"
with_query_and_choice = f"{with_query}{output}"
return with_query, with_query_and_choice
def load_sample(file_name):
with open(file_name, "r") as f:
data = json.load(f)
data = dataset_preprocess(data)
print(f"Data loaded: {file_name}.")
ex_list = [[e["instruction"], e["input"], e["output"]] for e in data]
ex_prompted = []
for instruction, input, output in ex_list:
_, line = multiple_gen_promptify(instruction, input, output) # query, <query_with_answer>
ex_prompted.append(line)
return ex_prompted
# 初始化模型和分词器
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, help="checkpoint/LLaMA/convert_llama_7b")
parser.add_argument('--instruction_path', type=str, help="datasets/alpaca_gpt4/alpaca_gpt4_data.json")
parser.add_argument('--save_embedding_path', type=str, help="save/alpaca_gpt4/embeddings")
args = parser.parse_args()
MODEL_PATH = args.model_path
INSTRUCTION_PATH = args.instruction_path
SAVE_EMBEDDING_PATH = args.save_embedding_path
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModel.from_pretrained(
MODEL_PATH,
device_map="auto",
# load_in_8bit=in_8bit,
)
model.eval()
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
model.resize_token_embeddings(len(tokenizer))
# load sample
sample = load_sample(INSTRUCTION_PATH)
print("START EMBEDDING ..."*3)
embeddings = embed_texts_batched(sample)
print(len(embeddings))
np.save(f'{SAVE_EMBEDDING_PATH}/{len(sample)}.npy', embeddings)