-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_natural_parity.py
136 lines (119 loc) · 3.57 KB
/
run_natural_parity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
__author__ = Chadi Helwe
__version__ = 1.0
__maintainer__ = Chadi Helwe
__email__ = [email protected]
__description__ = CLI to run the natural parity experiment
"""
from typing import Optional
import click
from failBERT.eval import eval_model as eval_model_natural_parity
from failBERT.train import train_model as train_model_natural_parity
from failBERT.utils import download_pretrained_models
@click.group()
def cli():
pass
@click.command()
@click.option(
"--path_train",
default="data/natural_parity/natural_parity_train.csv",
)
@click.option("--path_val", default=None)
@click.option("--passages_column", default="modified_sentence")
@click.option("--labels_column", default="label")
@click.option("--path_save_model", default="models/best_model_natural_parity.pkl")
@click.option("--epochs", default=10)
@click.option("--device", default="cpu")
def train_model(
path_train: str,
path_val: Optional[str],
passages_column: str,
labels_column: str,
path_save_model: str,
epochs: int,
device: str,
):
"""
Command to train a RoBERTa model on the natural parity task
:param path_train: Path of the training dataset
:type path_train: str
:param path_val: Path of te validation dataset
:type path_val: Optional[str]
:param passage_column: Passage column name
:type passage_column: str
:param label_column: Label column name
:type label_column: str
:param path_save_model: Path to save the best model
:type path_save_model: str
:param epochs: Number of epochs
:type epochs: int
:param device: Device to run a model [cpu/cuda]
:type device: str
"""
train_model_natural_parity(
path_train,
path_val,
passages_column,
labels_column,
path_save_model,
epochs,
device,
)
@click.command()
@click.option(
"--url",
default="https://www.dropbox.com/s/c8ushxx3fow4yag/pizza_switch_best_model_1_15.pkl?dl=1",
)
@click.option("--file_name", default="best_model_natural_parity.pkl")
def download_pretrained_model(url: str, file_name: str):
"""
Command to download pretrained model for the natural parity task
:param url: DropBox url of the pretrained model
:type url: str
:param file_name: Name of the pretrained model
:type file_name: str
"""
download_pretrained_models(url, file_name)
@click.command()
@click.option(
"--path_test",
default="data/natural_parity/test_1.csv",
)
@click.option("--passages_column", default="modified_sentence")
@click.option("--labels_column", default="label")
@click.option("--path_model", default="models/best_model_natural_parity.pkl")
@click.option("--device", default="cpu")
def eval_model(
path_test: str,
passages_column: str,
labels_column: str,
path_model: str,
device: str,
):
"""
Command to evaluate a RoBERTa model on the natural parity task
:param path_test: Path of the testing dataset
:type path_test: str
:param passage_column: Passage column name
:type passage_column: str
:param label_column: Label column name
:type label_column: str
:param path_model: Path of the saved model
:type path_model: str
:param device: Device to run a model [GPU/CPU]
:type device: str
"""
_, _, _, _ = eval_model_natural_parity(
path_test,
passages_column,
labels_column,
path_model,
device,
)
cli.add_command(train_model)
cli.add_command(download_pretrained_model)
cli.add_command(eval_model)
if __name__ == "__main__":
cli()