-
Notifications
You must be signed in to change notification settings - Fork 0
/
helpers.py
219 lines (172 loc) · 8.57 KB
/
helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import json
import logging
import pathlib
import re
from typing import Union, Sequence
from ordered_set import OrderedSet
from aenum import extend_enum
from flags import RDFExampleFormat, DatasetChoice, RDF_SEPARATOR
from flags import Templates, ModelChoices, ConsistencyTemplateNames
# Define a function that will load and validate the config
def load_and_validate_config(config_file):
with open(config_file, 'r') as f:
config = json.load(f)
config['file_name'] = pathlib.Path(config_file).stem
# Validate the values against the Enums
config['dataset_choice'] = DatasetChoice[config['dataset_choice'].upper()]
config['initial_template'] = Templates[config['initial_template'].upper()]
config['llm_stack'] = _parse_model_choices(config['llm_stack'], config['file_name'])
config['consist_val_model'] = _parse_model_choices([config['consist_val_model']], "CV")[0]
config['example_format'] = RDFExampleFormat[config['example_format'].upper()]
config['cv_template'] = ConsistencyTemplateNames[config['cv_template'].upper()]
return config
def _validate_url(input_string):
# Remove 'http://' or 'https://' prefix if present
input_string = re.sub(r'^(http://|https://)', '', input_string)
# Regex for a valid IPv4 address and port number
ipv4_port_regex = r'^(\d{1,3}\.){3}\d{1,3}:\d{1,5}$'
# Regex for 'localhost' and a port number
localhost_port_regex = r'^(localhost):\d{1,5}$'
# Check for 'localhost:port' or 'IPv4:port'
if re.match(localhost_port_regex, input_string) or re.match(ipv4_port_regex, input_string):
return True
# If neither, return False
return False
def _parse_model_choices(model_list: list[str], model_name: str):
llm_stack = []
for model in model_list:
if _validate_url(model):
name_upper = model_name.upper()
url = model
if not hasattr(ModelChoices, name_upper):
extend_enum(ModelChoices, name_upper, url)
llm_stack.append(ModelChoices[name_upper])
elif pathlib.Path(model).exists():
path_to_data = pathlib.Path(model)
if not hasattr(ModelChoices, name_upper):
extend_enum(ModelChoices, model_name.upper(), path_to_data)
llm_stack.append(ModelChoices[model_name.upper()])
elif model.upper() in [m.name for m in ModelChoices]:
llm_stack.append(ModelChoices[model.upper()])
elif model.lower() in ModelChoices:
llm_stack.append(ModelChoices(model.lower()))
else:
raise NotImplementedError("Given model is not a valid choice. Refer to class `ModelChoices` in flags.py")
return llm_stack
def setup_logger(name=__name__, loglevel=logging.DEBUG, handlers=None, output_log_file: pathlib.Path or str = None):
if handlers is None:
handlers = [logging.StreamHandler()]
if output_log_file:
file_handler = logging.FileHandler(output_log_file, mode="w", encoding="utf-8")
handlers.append(file_handler)
logger = logging.getLogger(name)
logger.setLevel(loglevel)
formatter = logging.Formatter('%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
datefmt='%d/%m/%Y %I:%M:%S %p')
for handler in handlers:
handler.setLevel(loglevel)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def _uppercase_sequence(sequence: Union[Sequence[str], OrderedSet[str]], tp):
if not isinstance(sequence, tp):
return sequence
new_sequence = []
for ent in sequence:
try:
new_sequence.append(ent.upper())
except AttributeError:
raise AttributeError("Entity in sequence is not a string!")
return tp(new_sequence)
def uppercase(f):
def wrap(entry, *args, **kwargs):
if isinstance(entry, str):
entry = entry.upper()
else:
entry = _uppercase_sequence(entry, OrderedSet)
entry = _uppercase_sequence(entry, list)
entry = _uppercase_sequence(entry, tuple)
return f(entry, *args, **kwargs)
return wrap
def format_rdf_entries(entries: list[tuple[str, str, str]], output_type: RDFExampleFormat, sep: str):
if len(entries) == 0:
return None
if output_type == RDFExampleFormat.TABLE:
return "\\n".join([f"Table: {e[0]} {sep} {e[1]} {sep} {e[2]}" for e in entries])
# entries.append(f"Table: {s_label} {sep} {r_label} {sep} {o_label}")
elif output_type == RDFExampleFormat.JSON:
return json.dumps(entries)
else: # RDFExampleFormat.DEFAULT
return entries
def parse_rdf_list_to_examples(rdf_list: list[dict[str: str]], max_examples: int, output_type: RDFExampleFormat,
dataset_choice: DatasetChoice, sep="|") -> tuple[str, list[str], list[str], list[str]]:
"""
:param rdf_list: (list[dict]) list[{'sid': '', 'rid': '', 'oid': ''}, {'sid': '', 'rid': '', 'oid': ''}, ...]
:param max_examples: (int) maximum number of examples to parse into the final string
:param output_type: (RDFExampleFormat) how examples are formatted at the output
:param dataset_choice: (DatasetChoice[Enum]) which dataset are we loading the examples from
:param sep: (str) separator between entities for parsed Table string entries
:return: (tuple[str, list[str], list[str], list[str]]) ("Table: s_label | r_label | o_label\nTable: s_label | r_label | o_label\n ...", s_label, o_label)
"""
entries = []
s_labels = []
r_labels = []
o_labels = []
if output_type not in RDFExampleFormat:
raise NotImplementedError(f"output_type expected to be instance of `RDFExampleFormat` enum (got {output_type})")
for rdf in rdf_list[:max_examples]:
if dataset_choice in [DatasetChoice.DART, DatasetChoice.DART_TEST]:
s_label = rdf["s"]
r_label = rdf["r"]
o_label = rdf["o"]
elif dataset_choice in [DatasetChoice.REL2TEXT, DatasetChoice.REL2TEXT_TEST]:
s_label = rdf["s"]
r_label = rdf["r"]
o_label = rdf["o"]
elif dataset_choice in [DatasetChoice.WEBNLG, DatasetChoice.WEBNLG_TEST]:
s_label = rdf["s"]
r_label = rdf["r"]
o_label = rdf["o"]
elif dataset_choice in [DatasetChoice.WIKIDATA, DatasetChoice.WIKIDATA_TEST]:
s_label = rdf["s"]
r_label = rdf["r"]
o_label = rdf["o"]
else:
raise NotImplementedError("Chosen `DatasetChoice` is not supported yet.")
entries.append((s_label, r_label, o_label))
s_labels.append(s_label)
r_labels.append(r_label)
o_labels.append(o_label)
output = format_rdf_entries(entries, output_type, sep)
return output, s_labels, r_labels, o_labels
def load_examples(path_to_example_json: pathlib.Path, max_examples_per_pid: int, output_type: RDFExampleFormat,
dataset_choice: DatasetChoice):
if path_to_example_json.exists():
fetched_example_dict = json.load(path_to_example_json.open())
else:
raise NotImplementedError(f"{path_to_example_json.name} is missing! Please run `scripts/dataset_builders/build` script for the respective dataset first")
pid_examples_dict = {pid: parse_rdf_list_to_examples(rdf_list, max_examples_per_pid, output_type, dataset_choice, sep=RDF_SEPARATOR)
for pid, rdf_list in fetched_example_dict.items()}
return pid_examples_dict
def make_json_compliant(json_str: str):
# # Ensure that keys and string values are wrapped with double quotes
# json_str = re.sub(r'(?:(?<=\{)|(?<=,))\s*(\'[^:\']+\')\s*:', r' \1 :',
# json_str) # if the key is in single quotes
# json_str = re.sub(r'(?:(?<=\{)|(?<=,))\s*([^:\'"]+)\s*:', r' "\1" :',
# json_str) # if the key is without any quotes
# json_str = re.sub(r':\s*\'([^,\']*)\'', r': "\1"', json_str) # if the value is in single quotes
# json_str = re.sub(r':\s*([^\',"{}]*)\s*(?=[,\}])', r': "\1"', json_str) # if the value is without any quotes
#
# # Remove trailing commas
# json_str = re.sub(r',\s*}', '}', json_str)
# json_str = re.sub(r',\s*]', ']', json_str)
# print(json_str)
json_str = json_str.lstrip('{`\\n\n"').rstrip('.`\\n\n"}')
# Add '{' and '}' at start and end if not exist
# print(json_str)
json_str = '{"' + json_str
json_str = json_str + '."}'
# print(json_str)
# Remove '{' and '}' within the string (not at start or end)
# json_str = json_str[0] + json_str[1:-1].replace('{', '').replace('}', '') + json_str[-1]
return json_str