-
Notifications
You must be signed in to change notification settings - Fork 0
/
runner.py
152 lines (123 loc) · 7.81 KB
/
runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from transformers import TrainingArguments, HfArgumentParser
from src.data_loader.query_dataloader import DataLoader
from src.cbr_trainer.cbrTrainer import cbrTrainer
from tqdm import tqdm, trange
from dataclasses import dataclass, field,asdict
import torch
from src.models.rgcn_model import RGCN
import datetime
import wandb
import logging
import os
logger = logging.getLogger()
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO
)
def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info("=====Parsing Arguments=====")
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, CBRArguments))
model_args, train_args, cbr_args = parser.parse_args_into_dataclasses()
os.makedirs(train_args.output_dir, exist_ok=True)
fileHandler = logging.FileHandler("{0}/{1}".format(train_args.output_dir, "log.txt"))
model_args.device = device
# Format the current date and time
current_datetime = datetime.datetime.now()
cbr_args.formatted_datetime = current_datetime.strftime("%Y-%m-%d %H:%M:%S")
#WandB arguments
if train_args.use_wandb:
config = dict()
args = [ model_args, train_args, cbr_args ]
for object_attributes in args:
args_dict = asdict(object_attributes)
for attribute_name, attribute_value in args_dict.items():
config[attribute_name] = attribute_value
wandb.init(project=cbr_args.data_name, config = config)
wandb.run.name = f"{cbr_args.formatted_datetime}_{cbr_args.data_name}" #Run #Name
#Load data
logger.info("=====Loading Data=====")
dataset_obj = DataLoader(cbr_args.data_dir,
cbr_args.data_name,
cbr_args.paths_file_dir,
cbr_args.train_batch_size,
cbr_args.eval_batch_size)
#Set RGCN model
logger.info("=====Loading Model=====")
rgcn_model = RGCN(n_entities =dataset_obj.n_entities,
n_relations= dataset_obj.n_relations,
params = model_args).to(device)
#train
logger.info("=====Setting Training=====")
trainer = cbrTrainer(rgcn_model,
dataset_obj,
model_args,
train_args,
cbr_args,
device)
for epoch in trange(train_args.num_train_epochs, desc=f"[Full Loop]"):
train_loss, results_train = trainer.train() #train
results_dev = trainer.run_evaluate("dev", dataset_obj.dev_dataloader) #evaluaationß
if train_args.use_wandb:
# tracks avg_rr on current batch
wandb.log({'Loss Epoch':train_loss,
"MRR Train": results_train['avg_rr'],
"MRR Dev":results_dev['avg_rr'],
"Hits@1 Dev":results_dev.get('avg_hits@1', 0),
"Hits@3 Dev":results_dev.get('avg_hits@3', 0),
"Hits@5 Dev":results_dev.get('avg_hits@5', 0),
"Hits@10 Dev":results_dev.get('avg_hits@10', 0)})
logger.info('[Epoch:{}]: Training Loss:{:.4} Training MRR:{:.4} Dev MRR:{:.4}'.format(epoch,
train_loss,
results_train['avg_rr'],
results_dev['avg_rr']))
results_test = trainer.run_evaluate("test", dataset_obj.test_dataloader)
if train_args.use_wandb:
wandb.log({
"MRR Test":results_test['avg_rr'],
"Hits@1 Test":results_test.get('avg_hits@1', 0),
"Hits@3 Test":results_test.get('avg_hits@3', 0),
"Hits@5 Test":results_test.get('avg_hits@5', 0),
"Hits@10 Test":results_test.get('avg_hits@10', 0)})
logger.info("Test MRR:{:.4} Hits@1:{:.4} Hits@3:{:.4} Hits@5:{:.4} Hits@10:{:.4}".format(results_test['avg_rr'],
results_test['avg_hits@1'],
results_test['avg_hits@3'],
results_test['avg_hits@5'],
results_test['avg_hits@10']))
@dataclass
class ModelArguments:
gcn_dim_init: int = field(default=32, metadata={"help": "Intial GCN layer dimensionality"})
hidden_channels_gcn: int = field(default=32, metadata={"help": "Hidden GCN layer dimensionality"})
drop_gcn: float = field(default=0.0, metadata={"help": "Dropout probability for RGCN model"})
conv_layers: int = field(default=1, metadata={"help": "Number of GCN layers"})
transform_input: int = field(default = 0, metadata = {"help":"Linear transformation over to input model"})
@dataclass
class DataTrainingArguments(TrainingArguments):
use_wandb: int = field(default=0, metadata={"help": "use wandb if 1"})
dist_metric: str = field(default='l2', metadata={"help": "Options [l2, cosine]"})
dist_aggr1: str = field(default='mean', metadata={"help": "Distance aggregation function at each neighbor query. "
"Options: [none (no aggr), mean, sum]"})
dist_aggr2: str = field(default='mean', metadata={"help": "Distance aggregation function across all neighbor "
"queries. Options: [mean, sum]"})
sampling_loss: float = field(default=1.0, metadata={"help": "Fraction of negative samples used"})
temperature: float = field(default=1.0, metadata={"help": "Temperature for temperature scaled cross-entropy loss"})
learning_rate: float = field(default=0.001, metadata={"help": "Starting learning rate"})
warmup_step: int = (field(default=0, metadata={"help": "scheduler warm up steps"}))
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW"})
num_train_epochs: int = field(default=5, metadata={"help": "Total number of training epochs to perform."})
gradient_accumulation_steps: int = field( default=1, metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."})
check_steps: float = field(default=5.0, metadata={"help": "Steps to check training"})
# output_dir: str = field(default = "01_results/", metadata ={"help": "Path to directory to save results"}) # inherited from TrainingArguments
res_name: str = field(default = "mind_test_predictions", metadata ={"help": "Output file"})
@dataclass
class CBRArguments:
data_name: str = field(default = "MIND", metadata = {"help": "KG dataset"})
data_dir: str = field(default= "src/00_data/", metadata={"help": "Path to data directory (contains train, test, dev)"})
paths_file_dir: str = field(default = 'MIND_cbr_subgraph_knn-5_branch-200.pkl', metadata = {"help": "Paths file name"})
train_batch_size: int = field(default = 1, metadata = {"help": "Training batch size"})
num_neighbors_train: int = field(default = 5, metadata = {"help": "Number of near-neighbor entities for training"})
eval_batch_size: int = field(default = 1, metadata = {"help": "Test/Dev batch size"})
num_neighbors_eval: int = field(default = 5, metadata = {"help": "Number of near-neighbor entities for test"})
if __name__ == '__main__':
main()