diff --git a/examples/text_classification/config_classifier.py b/examples/text_classification/config_classifier.py
index 85e64440b..3000603ec 100644
--- a/examples/text_classification/config_classifier.py
+++ b/examples/text_classification/config_classifier.py
@@ -1,11 +1,11 @@
-name = "bert_classifier"
-hidden_size = 768
-clas_strategy = "cls_time"
-dropout = 0.1
-num_classes = 2
-
-# This hyperparams is used in bert_with_hypertuning_main.py example
-hyperparams = {
- "optimizer.warmup_steps": {"start": 10000, "end": 20000, "dtype": int},
- "optimizer.static_lr": {"start": 1e-3, "end": 1e-2, "dtype": float}
-}
+name = "bert_classifier"
+hidden_size = 768
+clas_strategy = "cls_time"
+dropout = 0.1
+num_classes = 2
+
+# This hyperparams is used in bert_with_hypertuning_main.py example
+hyperparams = {
+ "optimizer.warmup_steps": {"start": 10000, "end": 20000, "dtype": int},
+ "optimizer.static_lr": {"start": 1e-3, "end": 1e-2, "dtype": float}
+}
diff --git a/examples/text_classification/config_data.py b/examples/text_classification/config_data.py
index d15379abc..493aea92b 100644
--- a/examples/text_classification/config_data.py
+++ b/examples/text_classification/config_data.py
@@ -1,68 +1,68 @@
-pickle_data_dir = "data/IMDB"
-max_seq_length = 64
-num_classes = 2
-num_train_data = 25000
-
-# used for bert executor example
-max_batch_tokens = 128
-
-train_batch_size = 32
-max_train_epoch = 5
-display_steps = 50 # Print training loss every display_steps; -1 to disable
-
-# tbx config
-tbx_logging_steps = 5 # log the metrics for tbX visualization
-tbx_log_dir = "runs/"
-exp_number = 1 # experiment number
-
-eval_steps = 100 # Eval on the dev set every eval_steps; -1 to disable
-# Proportion of training to perform linear learning rate warmup for.
-# E.g., 0.1 = 10% of training.
-warmup_proportion = 0.1
-eval_batch_size = 8
-test_batch_size = 8
-
-feature_types = {
- # Reading features from pickled data file.
- # E.g., Reading feature "input_ids" as dtype `int64`;
- # "FixedLenFeature" indicates its length is fixed for all data instances;
- # and the sequence length is limited by `max_seq_length`.
- "input_ids": ["int64", "stacked_tensor", max_seq_length],
- "input_mask": ["int64", "stacked_tensor", max_seq_length],
- "segment_ids": ["int64", "stacked_tensor", max_seq_length],
- "label_ids": ["int64", "stacked_tensor"]
-}
-
-train_hparam = {
- "allow_smaller_final_batch": False,
- "batch_size": train_batch_size,
- "dataset": {
- "data_name": "data",
- "feature_types": feature_types,
- "files": "{}/train.pkl".format(pickle_data_dir)
- },
- "shuffle": True,
- "shuffle_buffer_size": None
-}
-
-eval_hparam = {
- "allow_smaller_final_batch": True,
- "batch_size": eval_batch_size,
- "dataset": {
- "data_name": "data",
- "feature_types": feature_types,
- "files": "{}/eval.pkl".format(pickle_data_dir)
- },
- "shuffle": False
-}
-
-test_hparam = {
- "allow_smaller_final_batch": True,
- "batch_size": test_batch_size,
- "dataset": {
- "data_name": "data",
- "feature_types": feature_types,
- "files": "{}/predict.pkl".format(pickle_data_dir)
- },
- "shuffle": False
-}
+pickle_data_dir = "data/IMDB"
+max_seq_length = 64
+num_classes = 2
+num_train_data = 25000
+
+# used for bert executor example
+max_batch_tokens = 128
+
+train_batch_size = 32
+max_train_epoch = 5
+display_steps = 50 # Print training loss every display_steps; -1 to disable
+
+# tbx config
+tbx_logging_steps = 5 # log the metrics for tbX visualization
+tbx_log_dir = "runs/"
+exp_number = 1 # experiment number
+
+eval_steps = 100 # Eval on the dev set every eval_steps; -1 to disable
+# Proportion of training to perform linear learning rate warmup for.
+# E.g., 0.1 = 10% of training.
+warmup_proportion = 0.1
+eval_batch_size = 8
+test_batch_size = 8
+
+feature_types = {
+ # Reading features from pickled data file.
+ # E.g., Reading feature "input_ids" as dtype `int64`;
+ # "FixedLenFeature" indicates its length is fixed for all data instances;
+ # and the sequence length is limited by `max_seq_length`.
+ "input_ids": ["int64", "stacked_tensor", max_seq_length],
+ "input_mask": ["int64", "stacked_tensor", max_seq_length],
+ "segment_ids": ["int64", "stacked_tensor", max_seq_length],
+ "label_ids": ["int64", "stacked_tensor"]
+}
+
+train_hparam = {
+ "allow_smaller_final_batch": False,
+ "batch_size": train_batch_size,
+ "dataset": {
+ "data_name": "data",
+ "feature_types": feature_types,
+ "files": "{}/train.pkl".format(pickle_data_dir)
+ },
+ "shuffle": True,
+ "shuffle_buffer_size": None
+}
+
+eval_hparam = {
+ "allow_smaller_final_batch": True,
+ "batch_size": eval_batch_size,
+ "dataset": {
+ "data_name": "data",
+ "feature_types": feature_types,
+ "files": "{}/eval.pkl".format(pickle_data_dir)
+ },
+ "shuffle": False
+}
+
+test_hparam = {
+ "allow_smaller_final_batch": True,
+ "batch_size": test_batch_size,
+ "dataset": {
+ "data_name": "data",
+ "feature_types": feature_types,
+ "files": "{}/predict.pkl".format(pickle_data_dir)
+ },
+ "shuffle": False
+}
diff --git a/examples/text_classification/download_imdb.py b/examples/text_classification/download_imdb.py
index 99f04601c..faefbac4a 100644
--- a/examples/text_classification/download_imdb.py
+++ b/examples/text_classification/download_imdb.py
@@ -1,18 +1,34 @@
-import os
-import sys
-
-def main(arguments):
- import subprocess
- if not os.path.exists("data/IMDB_raw"):
- subprocess.run("mkdir data/IMDB_raw", shell=True)
- # pylint: disable=line-too-long
- subprocess.run(
- 'wget -P data/IMDB_raw/ https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz',
- shell=True)
- subprocess.run(
- 'tar xzvf data/IMDB_raw/aclImdb_v1.tar.gz -C data/IMDB_raw/ && rm data/IMDB_raw/aclImdb_v1.tar.gz',
- shell=True)
-
-
-if __name__ == '__main__':
- sys.exit(main(sys.argv[1:]))
+# Copyright 2020 The Forte Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import sys
+import subprocess
+
+
+def main():
+ if not os.path.exists("data/IMDB_raw"):
+ subprocess.run("mkdir data/IMDB_raw", shell=True, check=True)
+ # pylint: disable=line-too-long
+ subprocess.run(
+ 'wget -P data/IMDB_raw/ https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz',
+ shell=True, check=True)
+ subprocess.run(
+ 'tar xzvf data/IMDB_raw/aclImdb_v1.tar.gz -C data/IMDB_raw/ && rm data/IMDB_raw/aclImdb_v1.tar.gz',
+ shell=True, check=True)
+
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/examples/text_classification/main.py b/examples/text_classification/main.py
index 5806103e8..690156103 100644
--- a/examples/text_classification/main.py
+++ b/examples/text_classification/main.py
@@ -1,28 +1,31 @@
-# Copyright 2020 The Forte Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-
-from forte.models.imdb_text_classifier.model import IMDBClassifier
-import config_data
-import config_classifier
-
-def main(argv=None):
- model = IMDBClassifier(config_data, config_classifier)
- if not os.path.isfile("data/IMDB/train.pkl"):
- model.prepare_data("data/IMDB")
- model.run(do_train=True, do_eval=True, do_test=False)
-
-if __name__ == "__main__":
- main()
+# Copyright 2020 The Forte Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+
+from forte.models.imdb_text_classifier.model import IMDBClassifier
+import config_data
+import config_classifier
+
+
+def main():
+ model = IMDBClassifier(config_data, config_classifier)
+ if not os.path.isfile("data/IMDB/train.pkl"):
+ model.prepare_data("data/IMDB")
+ model.run(do_train=True, do_eval=True, do_test=False)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/text_classification/utils/imdb_format.py b/examples/text_classification/utils/imdb_format.py
index 084866600..b2d818225 100644
--- a/examples/text_classification/utils/imdb_format.py
+++ b/examples/text_classification/utils/imdb_format.py
@@ -1,115 +1,116 @@
-# coding=utf-8
-# Copyright 2019 The Google UDA Team Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Read all data in IMDB and merge them to a csv file."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import csv
-import os
-from absl import app
-from absl import flags
-
-FLAGS = flags.FLAGS
-flags.DEFINE_string("raw_data_dir", "", "raw data dir")
-flags.DEFINE_string("output_dir", "", "output_dir")
-flags.DEFINE_string("train_id_path", "", "path of id list")
-
-
-def dump_raw_data(contents, file_path):
- with open(file_path, "w", encoding="utf-8") as ouf:
- writer = csv.writer(ouf, delimiter="\t", quotechar="\"")
- for line in contents:
- writer.writerow(line)
-
-def clean_web_text(st):
- """clean text."""
- st = st.replace("
", " ")
- st = st.replace(""", "\"")
- st = st.replace("
", " ")
- if "", start_pos)
- if end_pos != -1:
- st = st[:start_pos] + st[end_pos + 1:]
- else:
- print("incomplete href")
- print("before", st)
- st = st[:start_pos] + st[start_pos + len("", "")
- st = st.replace("\\n", " ")
- # st = st.replace("\\", " ")
- # while " " in st:
- # st = st.replace(" ", " ")
- return st
-
-
-def load_data_by_id(sub_set, id_path):
- with open(id_path, encoding="utf-8") as inf:
- id_list = inf.readlines()
- contents = []
- for example_id in id_list:
- example_id = example_id.strip()
- label = example_id.split("_")[0]
- file_path = os.path.join(FLAGS.raw_data_dir, sub_set, label, example_id[len(label) + 1:])
- with open(file_path, encoding="utf-8") as inf:
- st_list = inf.readlines()
- assert len(st_list) == 1
- st = clean_web_text(st_list[0].strip())
- contents += [(st, label, example_id)]
- return contents
-
-
-def load_all_data(sub_set):
- contents = []
- for label in ["pos", "neg", "unsup"]:
- data_path = os.path.join(FLAGS.raw_data_dir, sub_set, label)
- if not os.path.exists(data_path):
- continue
- for filename in os.listdir(data_path):
- file_path = os.path.join(data_path, filename)
- with open(file_path, encoding="utf-8") as inf:
- st_list = inf.readlines()
- assert len(st_list) == 1
- st = clean_web_text(st_list[0].strip())
- example_id = "{}_{}".format(label, filename)
- contents += [(st, label, example_id)]
- return contents
-
-
-def main(_):
- # load train
- header = ["content", "label", "id"]
- contents = load_data_by_id("train", FLAGS.train_id_path)
- if not os.path.exists(FLAGS.output_dir):
- os.mkdir(FLAGS.output_dir)
- dump_raw_data(
- [header] + contents,
- os.path.join(FLAGS.output_dir, "train.csv"),
- )
- # load test
- contents = load_all_data("test")
- dump_raw_data(
- [header] + contents,
- os.path.join(FLAGS.output_dir, "test.csv"),
- )
-
-
-if __name__ == "__main__":
- app.run(main)
+# coding=utf-8
+# Copyright 2019 The Google UDA Team Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Read all data in IMDB and merge them to a csv file."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import csv
+import os
+from absl import app
+from absl import flags
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string("raw_data_dir", "", "raw data dir")
+flags.DEFINE_string("output_dir", "", "output_dir")
+flags.DEFINE_string("train_id_path", "", "path of id list")
+
+
+def dump_raw_data(contents, file_path):
+ with open(file_path, "w", encoding="utf-8") as ouf:
+ writer = csv.writer(ouf, delimiter="\t", quotechar="\"")
+ for line in contents:
+ writer.writerow(line)
+
+
+def clean_web_text(st):
+ """clean text."""
+ st = st.replace(" ", " ")
+ if "", start_pos)
+ if end_pos != -1:
+ st = st[:start_pos] + st[end_pos + 1:]
+ else:
+ print("incomplete href")
+ print("before", st)
+ st = st[:start_pos] + st[start_pos + len("", "")
+ st = st.replace("\\n", " ")
+ # st = st.replace("\\", " ")
+ # while " " in st:
+ # st = st.replace(" ", " ")
+ return st
+
+
+def load_data_by_id(sub_set, id_path):
+ with open(id_path, encoding="utf-8") as inf:
+ id_list = inf.readlines()
+ contents = []
+ for example_id in id_list:
+ example_id = example_id.strip()
+ label = example_id.split("_")[0]
+ file_path = os.path.join(FLAGS.raw_data_dir, sub_set, label, example_id[len(label) + 1:])
+ with open(file_path, encoding="utf-8") as inf:
+ st_list = inf.readlines()
+ assert len(st_list) == 1
+ st = clean_web_text(st_list[0].strip())
+ contents += [(st, label, example_id)]
+ return contents
+
+
+def load_all_data(sub_set):
+ contents = []
+ for label in ["pos", "neg", "unsup"]:
+ data_path = os.path.join(FLAGS.raw_data_dir, sub_set, label)
+ if not os.path.exists(data_path):
+ continue
+ for filename in os.listdir(data_path):
+ file_path = os.path.join(data_path, filename)
+ with open(file_path, encoding="utf-8") as inf:
+ st_list = inf.readlines()
+ assert len(st_list) == 1
+ st = clean_web_text(st_list[0].strip())
+ example_id = "{}_{}".format(label, filename)
+ contents += [(st, label, example_id)]
+ return contents
+
+
+def main(_):
+ # load train
+ header = ["content", "label", "id"]
+ contents = load_data_by_id("train", FLAGS.train_id_path)
+ if not os.path.exists(FLAGS.output_dir):
+ os.mkdir(FLAGS.output_dir)
+ dump_raw_data(
+ [header] + contents,
+ os.path.join(FLAGS.output_dir, "train.csv"),
+ )
+ # load test
+ contents = load_all_data("test")
+ dump_raw_data(
+ [header] + contents,
+ os.path.join(FLAGS.output_dir, "test.csv"),
+ )
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/forte/models/imdb_text_classifier/__init__.py b/forte/models/imdb_text_classifier/__init__.py
index 24cec1562..a5dd21c1f 100644
--- a/forte/models/imdb_text_classifier/__init__.py
+++ b/forte/models/imdb_text_classifier/__init__.py
@@ -1,13 +1,13 @@
-# Copyright 2020 The Forte Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
+# Copyright 2020 The Forte Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/forte/models/imdb_text_classifier/config_classifier.py b/forte/models/imdb_text_classifier/config_classifier.py
index 85e64440b..3000603ec 100644
--- a/forte/models/imdb_text_classifier/config_classifier.py
+++ b/forte/models/imdb_text_classifier/config_classifier.py
@@ -1,11 +1,11 @@
-name = "bert_classifier"
-hidden_size = 768
-clas_strategy = "cls_time"
-dropout = 0.1
-num_classes = 2
-
-# This hyperparams is used in bert_with_hypertuning_main.py example
-hyperparams = {
- "optimizer.warmup_steps": {"start": 10000, "end": 20000, "dtype": int},
- "optimizer.static_lr": {"start": 1e-3, "end": 1e-2, "dtype": float}
-}
+name = "bert_classifier"
+hidden_size = 768
+clas_strategy = "cls_time"
+dropout = 0.1
+num_classes = 2
+
+# This hyperparams is used in bert_with_hypertuning_main.py example
+hyperparams = {
+ "optimizer.warmup_steps": {"start": 10000, "end": 20000, "dtype": int},
+ "optimizer.static_lr": {"start": 1e-3, "end": 1e-2, "dtype": float}
+}
diff --git a/forte/models/imdb_text_classifier/config_data.py b/forte/models/imdb_text_classifier/config_data.py
index d15379abc..493aea92b 100644
--- a/forte/models/imdb_text_classifier/config_data.py
+++ b/forte/models/imdb_text_classifier/config_data.py
@@ -1,68 +1,68 @@
-pickle_data_dir = "data/IMDB"
-max_seq_length = 64
-num_classes = 2
-num_train_data = 25000
-
-# used for bert executor example
-max_batch_tokens = 128
-
-train_batch_size = 32
-max_train_epoch = 5
-display_steps = 50 # Print training loss every display_steps; -1 to disable
-
-# tbx config
-tbx_logging_steps = 5 # log the metrics for tbX visualization
-tbx_log_dir = "runs/"
-exp_number = 1 # experiment number
-
-eval_steps = 100 # Eval on the dev set every eval_steps; -1 to disable
-# Proportion of training to perform linear learning rate warmup for.
-# E.g., 0.1 = 10% of training.
-warmup_proportion = 0.1
-eval_batch_size = 8
-test_batch_size = 8
-
-feature_types = {
- # Reading features from pickled data file.
- # E.g., Reading feature "input_ids" as dtype `int64`;
- # "FixedLenFeature" indicates its length is fixed for all data instances;
- # and the sequence length is limited by `max_seq_length`.
- "input_ids": ["int64", "stacked_tensor", max_seq_length],
- "input_mask": ["int64", "stacked_tensor", max_seq_length],
- "segment_ids": ["int64", "stacked_tensor", max_seq_length],
- "label_ids": ["int64", "stacked_tensor"]
-}
-
-train_hparam = {
- "allow_smaller_final_batch": False,
- "batch_size": train_batch_size,
- "dataset": {
- "data_name": "data",
- "feature_types": feature_types,
- "files": "{}/train.pkl".format(pickle_data_dir)
- },
- "shuffle": True,
- "shuffle_buffer_size": None
-}
-
-eval_hparam = {
- "allow_smaller_final_batch": True,
- "batch_size": eval_batch_size,
- "dataset": {
- "data_name": "data",
- "feature_types": feature_types,
- "files": "{}/eval.pkl".format(pickle_data_dir)
- },
- "shuffle": False
-}
-
-test_hparam = {
- "allow_smaller_final_batch": True,
- "batch_size": test_batch_size,
- "dataset": {
- "data_name": "data",
- "feature_types": feature_types,
- "files": "{}/predict.pkl".format(pickle_data_dir)
- },
- "shuffle": False
-}
+pickle_data_dir = "data/IMDB"
+max_seq_length = 64
+num_classes = 2
+num_train_data = 25000
+
+# used for bert executor example
+max_batch_tokens = 128
+
+train_batch_size = 32
+max_train_epoch = 5
+display_steps = 50 # Print training loss every display_steps; -1 to disable
+
+# tbx config
+tbx_logging_steps = 5 # log the metrics for tbX visualization
+tbx_log_dir = "runs/"
+exp_number = 1 # experiment number
+
+eval_steps = 100 # Eval on the dev set every eval_steps; -1 to disable
+# Proportion of training to perform linear learning rate warmup for.
+# E.g., 0.1 = 10% of training.
+warmup_proportion = 0.1
+eval_batch_size = 8
+test_batch_size = 8
+
+feature_types = {
+ # Reading features from pickled data file.
+ # E.g., Reading feature "input_ids" as dtype `int64`;
+ # "FixedLenFeature" indicates its length is fixed for all data instances;
+ # and the sequence length is limited by `max_seq_length`.
+ "input_ids": ["int64", "stacked_tensor", max_seq_length],
+ "input_mask": ["int64", "stacked_tensor", max_seq_length],
+ "segment_ids": ["int64", "stacked_tensor", max_seq_length],
+ "label_ids": ["int64", "stacked_tensor"]
+}
+
+train_hparam = {
+ "allow_smaller_final_batch": False,
+ "batch_size": train_batch_size,
+ "dataset": {
+ "data_name": "data",
+ "feature_types": feature_types,
+ "files": "{}/train.pkl".format(pickle_data_dir)
+ },
+ "shuffle": True,
+ "shuffle_buffer_size": None
+}
+
+eval_hparam = {
+ "allow_smaller_final_batch": True,
+ "batch_size": eval_batch_size,
+ "dataset": {
+ "data_name": "data",
+ "feature_types": feature_types,
+ "files": "{}/eval.pkl".format(pickle_data_dir)
+ },
+ "shuffle": False
+}
+
+test_hparam = {
+ "allow_smaller_final_batch": True,
+ "batch_size": test_batch_size,
+ "dataset": {
+ "data_name": "data",
+ "feature_types": feature_types,
+ "files": "{}/predict.pkl".format(pickle_data_dir)
+ },
+ "shuffle": False
+}
diff --git a/forte/models/imdb_text_classifier/data/download_imdb.py b/forte/models/imdb_text_classifier/data/download_imdb.py
index 2e8d24521..faefbac4a 100644
--- a/forte/models/imdb_text_classifier/data/download_imdb.py
+++ b/forte/models/imdb_text_classifier/data/download_imdb.py
@@ -1,18 +1,34 @@
-import os
-import sys
-
-def main(arguments):
- import subprocess
- if not os.path.exists("data/IMDB_raw"):
- subprocess.run("mkdir data/IMDB_raw", shell=True)
- # pylint: disable=line-too-long
- subprocess.run(
- 'wget -P data/IMDB_raw/ https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz',
- shell=True)
- subprocess.run(
- 'tar xzvf data/IMDB_raw/aclImdb_v1.tar.gz -C data/IMDB_raw/ && data/IMDB_raw/rm aclImdb_v1.tar.gz',
- shell=True)
-
-
-if __name__ == '__main__':
- sys.exit(main(sys.argv[1:]))
+# Copyright 2020 The Forte Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import sys
+import subprocess
+
+
+def main():
+ if not os.path.exists("data/IMDB_raw"):
+ subprocess.run("mkdir data/IMDB_raw", shell=True, check=True)
+ # pylint: disable=line-too-long
+ subprocess.run(
+ 'wget -P data/IMDB_raw/ https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz',
+ shell=True, check=True)
+ subprocess.run(
+ 'tar xzvf data/IMDB_raw/aclImdb_v1.tar.gz -C data/IMDB_raw/ && rm data/IMDB_raw/aclImdb_v1.tar.gz',
+ shell=True, check=True)
+
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/forte/models/imdb_text_classifier/model.py b/forte/models/imdb_text_classifier/model.py
index 4d386ba77..815570a04 100644
--- a/forte/models/imdb_text_classifier/model.py
+++ b/forte/models/imdb_text_classifier/model.py
@@ -1,249 +1,250 @@
-# Copyright 2020 The Forte Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import functools
-import logging
-import os
-
-import torch
-import torch.nn.functional as F
-import texar.torch as tx
-
-# pylint: disable=no-name-in-module
-from forte.models.imdb_text_classifier.utils import data_utils, model_utils
-
-
-class IMDBClassifier:
- """
- A baseline text classifier for the IMDB dataset.
- The input data should be CSV format with columns (content label id).
- An example usage can be found at examples/text_classification.
- """
-
- def __init__(self, config_data, config_classifier, checkpoint=None, pretrained_model_name="bert-base-uncased"):
- """Constructs the text classifier.
- Args:
- config_data: string, data config file.
- """
- self.config_data = config_data
- self.config_classifier = config_classifier
- self.checkpoint = checkpoint
- self.pretrained_model_name = pretrained_model_name
-
- def prepare_data(self, csv_data_dir):
- """Prepares data.
- """
- logging.info("Loading data")
-
- if self.config_data.pickle_data_dir is None:
- output_dir = csv_data_dir
- else:
- output_dir = self.config_data.pickle_data_dir
- tx.utils.maybe_create_dir(output_dir)
-
- processor = data_utils.IMDbProcessor()
-
- num_classes = len(processor.get_labels())
- num_train_data = len(processor.get_train_examples(csv_data_dir))
- logging.info(
- 'num_classes:%d; num_train_data:%d' % (num_classes, num_train_data))
-
- tokenizer = tx.data.BERTTokenizer(
- pretrained_model_name=self.pretrained_model_name)
-
- data_utils.prepare_record_data(
- processor=processor,
- tokenizer=tokenizer,
- data_dir=csv_data_dir,
- max_seq_length=self.config_data.max_seq_length,
- output_dir=output_dir,
- feature_types=self.config_data.feature_types)
-
- def run(self, do_train, do_eval, do_test, output_dir="output/"):
- """
- Builds the model and runs.
- """
- tx.utils.maybe_create_dir(output_dir)
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- logging.root.setLevel(logging.INFO)
-
- # Loads data
- num_train_data = self.config_data.num_train_data
-
- # config_downstream = importlib.import_module(args.config_downstream)
- hparams = {
- k: v for k, v in self.config_classifier.__dict__.items()
- if not k.startswith('__') and k != "hyperparams"}
-
- # Builds BERT
- model = tx.modules.BERTClassifier(
- pretrained_model_name=self.pretrained_model_name,
- hparams=hparams)
- model.to(device)
-
- num_train_steps = int(num_train_data / self.config_data.train_batch_size *
- self.config_data.max_train_epoch)
- num_warmup_steps = int(num_train_steps * self.config_data.warmup_proportion)
-
- # Builds learning rate decay scheduler
- static_lr = 2e-5
-
- vars_with_decay = []
- vars_without_decay = []
- for name, param in model.named_parameters():
- if 'layer_norm' in name or name.endswith('bias'):
- vars_without_decay.append(param)
- else:
- vars_with_decay.append(param)
-
- opt_params = [{
- 'params': vars_with_decay,
- 'weight_decay': 0.01,
- }, {
- 'params': vars_without_decay,
- 'weight_decay': 0.0,
- }]
- optim = tx.core.BertAdam(
- opt_params, betas=(0.9, 0.999), eps=1e-6, lr=static_lr)
-
- scheduler = torch.optim.lr_scheduler.LambdaLR(
- optim, functools.partial(model_utils.get_lr_multiplier,
- total_steps=num_train_steps,
- warmup_steps=num_warmup_steps))
-
- train_dataset = tx.data.RecordData(hparams=self.config_data.train_hparam,
- device=device)
- eval_dataset = tx.data.RecordData(hparams=self.config_data.eval_hparam,
- device=device)
- test_dataset = tx.data.RecordData(hparams=self.config_data.test_hparam,
- device=device)
-
- iterator = tx.data.DataIterator(
- {"train": train_dataset, "eval": eval_dataset, "test": test_dataset}
- )
-
- def _compute_loss(logits, labels):
- r"""Compute loss.
- """
- if model.is_binary:
- loss = F.binary_cross_entropy(
- logits.view(-1), labels.view(-1), reduction='mean')
- else:
- loss = F.cross_entropy(
- logits.view(-1, model.num_classes),
- labels.view(-1), reduction='mean')
- return loss
-
- def _train_epoch():
- r"""Trains on the training set, and evaluates on the dev set
- periodically.
- """
- iterator.switch_to_dataset("train")
- model.train()
-
- for batch in iterator:
- optim.zero_grad()
- input_ids = batch["input_ids"]
- segment_ids = batch["segment_ids"]
- labels = batch["label_ids"]
-
- input_length = (1 - (input_ids == 0).int()).sum(dim=1)
-
- logits, _ = model(input_ids, input_length, segment_ids)
-
- loss = _compute_loss(logits, labels)
- loss.backward()
- optim.step()
- scheduler.step()
- step = scheduler.last_epoch
-
- dis_steps = self.config_data.display_steps
- if dis_steps > 0 and step % dis_steps == 0:
- logging.info("step: %d; loss: %f", step, loss)
-
- eval_steps = self.config_data.eval_steps
- if eval_steps > 0 and step % eval_steps == 0:
- _eval_epoch()
- model.train()
-
- @torch.no_grad()
- def _eval_epoch():
- """Evaluates on the dev set.
- """
- iterator.switch_to_dataset("eval")
- model.eval()
-
- nsamples = 0
- avg_rec = tx.utils.AverageRecorder()
- for batch in iterator:
- input_ids = batch["input_ids"]
- segment_ids = batch["segment_ids"]
- labels = batch["label_ids"]
-
- input_length = (1 - (input_ids == 0).int()).sum(dim=1)
-
- logits, preds = model(input_ids, input_length, segment_ids)
-
- loss = _compute_loss(logits, labels)
- accu = tx.evals.accuracy(labels, preds)
- batch_size = input_ids.size()[0]
- avg_rec.add([accu, loss], batch_size)
- nsamples += batch_size
- logging.info("eval accu: %.4f; loss: %.4f; nsamples: %d",
- avg_rec.avg(0), avg_rec.avg(1), nsamples)
-
- @torch.no_grad()
- def _test_epoch():
- """Does predictions on the test set.
- """
- iterator.switch_to_dataset("test")
- model.eval()
-
- _all_preds = []
- for batch in iterator:
- input_ids = batch["input_ids"]
- segment_ids = batch["segment_ids"]
-
- input_length = (1 - (input_ids == 0).int()).sum(dim=1)
-
- _, preds = model(input_ids, input_length, segment_ids)
-
- _all_preds.extend(preds.tolist())
-
- output_file = os.path.join(args.output_dir, "test_results.tsv")
- with open(output_file, "w+") as writer:
- writer.write("\n".join(str(p) for p in _all_preds))
- logging.info("test output written to %s", output_file)
-
- if self.checkpoint:
- ckpt = torch.load(self.checkpoint)
- model.load_state_dict(ckpt['model'])
- optim.load_state_dict(ckpt['optimizer'])
- scheduler.load_state_dict(ckpt['scheduler'])
- if do_train:
- for _ in range(self.config_data.max_train_epoch):
- _train_epoch()
- states = {
- 'model': model.state_dict(),
- 'optimizer': optim.state_dict(),
- 'scheduler': scheduler.state_dict(),
- }
- torch.save(states, os.path.join(output_dir, 'model.ckpt'))
-
- if do_eval:
- _eval_epoch()
-
- if do_test:
- _test_epoch()
+# Copyright 2020 The Forte Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import functools
+import logging
+import os
+
+import torch
+import torch.nn.functional as F
+import texar.torch as tx
+
+# pylint: disable=no-name-in-module
+from forte.models.imdb_text_classifier.utils import data_utils, model_utils
+
+
+class IMDBClassifier:
+ """
+ A baseline text classifier for the IMDB dataset.
+ The input data should be CSV format with columns (content label id).
+ An example usage can be found at examples/text_classification.
+ """
+
+ def __init__(self, config_data, config_classifier, checkpoint=None,
+ pretrained_model_name="bert-base-uncased"):
+ """Constructs the text classifier.
+ Args:
+ config_data: string, data config file.
+ """
+ self.config_data = config_data
+ self.config_classifier = config_classifier
+ self.checkpoint = checkpoint
+ self.pretrained_model_name = pretrained_model_name
+
+ def prepare_data(self, csv_data_dir):
+ """Prepares data.
+ """
+ logging.info("Loading data")
+
+ if self.config_data.pickle_data_dir is None:
+ output_dir = csv_data_dir
+ else:
+ output_dir = self.config_data.pickle_data_dir
+ tx.utils.maybe_create_dir(output_dir)
+
+ processor = data_utils.IMDbProcessor()
+
+ num_classes = len(processor.get_labels())
+ num_train_data = len(processor.get_train_examples(csv_data_dir))
+ logging.info(
+ 'num_classes:%d; num_train_data:%d', num_classes, num_train_data)
+
+ tokenizer = tx.data.BERTTokenizer(
+ pretrained_model_name=self.pretrained_model_name)
+
+ data_utils.prepare_record_data(
+ processor=processor,
+ tokenizer=tokenizer,
+ data_dir=csv_data_dir,
+ max_seq_length=self.config_data.max_seq_length,
+ output_dir=output_dir,
+ feature_types=self.config_data.feature_types)
+
+ def run(self, do_train, do_eval, do_test, output_dir="output/"):
+ """
+ Builds the model and runs.
+ """
+ tx.utils.maybe_create_dir(output_dir)
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ logging.root.setLevel(logging.INFO)
+
+ # Loads data
+ num_train_data = self.config_data.num_train_data
+
+ hparams = {
+ k: v for k, v in self.config_classifier.__dict__.items()
+ if not k.startswith('__') and k != "hyperparams"}
+
+ # Builds BERT
+ model = tx.modules.BERTClassifier(
+ pretrained_model_name=self.pretrained_model_name,
+ hparams=hparams)
+ model.to(device)
+
+ num_train_steps = int(num_train_data / self.config_data.train_batch_size
+ * self.config_data.max_train_epoch)
+ num_warmup_steps = int(num_train_steps
+ * self.config_data.warmup_proportion)
+
+ # Builds learning rate decay scheduler
+ static_lr = 2e-5
+
+ vars_with_decay = []
+ vars_without_decay = []
+ for name, param in model.named_parameters():
+ if 'layer_norm' in name or name.endswith('bias'):
+ vars_without_decay.append(param)
+ else:
+ vars_with_decay.append(param)
+
+ opt_params = [{
+ 'params': vars_with_decay,
+ 'weight_decay': 0.01,
+ }, {
+ 'params': vars_without_decay,
+ 'weight_decay': 0.0,
+ }]
+ optim = tx.core.BertAdam(
+ opt_params, betas=(0.9, 0.999), eps=1e-6, lr=static_lr)
+
+ scheduler = torch.optim.lr_scheduler.LambdaLR(
+ optim, functools.partial(model_utils.get_lr_multiplier,
+ total_steps=num_train_steps,
+ warmup_steps=num_warmup_steps))
+
+ train_dataset = tx.data.RecordData(
+ hparams=self.config_data.train_hparam, device=device)
+ eval_dataset = tx.data.RecordData(
+ hparams=self.config_data.eval_hparam, device=device)
+ test_dataset = tx.data.RecordData(
+ hparams=self.config_data.test_hparam, device=device)
+
+ iterator = tx.data.DataIterator(
+ {"train": train_dataset, "eval": eval_dataset, "test": test_dataset}
+ )
+
+ def _compute_loss(logits, labels):
+ r"""Compute loss.
+ """
+ if model.is_binary:
+ loss = F.binary_cross_entropy(
+ logits.view(-1), labels.view(-1), reduction='mean')
+ else:
+ loss = F.cross_entropy(
+ logits.view(-1, model.num_classes),
+ labels.view(-1), reduction='mean')
+ return loss
+
+ def _train_epoch():
+ r"""Trains on the training set, and evaluates on the dev set
+ periodically.
+ """
+ iterator.switch_to_dataset("train")
+ model.train()
+
+ for batch in iterator:
+ optim.zero_grad()
+ input_ids = batch["input_ids"]
+ segment_ids = batch["segment_ids"]
+ labels = batch["label_ids"]
+
+ input_length = (1 - (input_ids == 0).int()).sum(dim=1)
+
+ logits, _ = model(input_ids, input_length, segment_ids)
+
+ loss = _compute_loss(logits, labels)
+ loss.backward()
+ optim.step()
+ scheduler.step()
+ step = scheduler.last_epoch
+
+ dis_steps = self.config_data.display_steps
+ if dis_steps > 0 and step % dis_steps == 0:
+ logging.info("step: %d; loss: %f", step, loss)
+
+ eval_steps = self.config_data.eval_steps
+ if eval_steps > 0 and step % eval_steps == 0:
+ _eval_epoch()
+ model.train()
+
+ @torch.no_grad()
+ def _eval_epoch():
+ """Evaluates on the dev set.
+ """
+ iterator.switch_to_dataset("eval")
+ model.eval()
+
+ nsamples = 0
+ avg_rec = tx.utils.AverageRecorder()
+ for batch in iterator:
+ input_ids = batch["input_ids"]
+ segment_ids = batch["segment_ids"]
+ labels = batch["label_ids"]
+
+ input_length = (1 - (input_ids == 0).int()).sum(dim=1)
+
+ logits, preds = model(input_ids, input_length, segment_ids)
+
+ loss = _compute_loss(logits, labels)
+ accu = tx.evals.accuracy(labels, preds)
+ batch_size = input_ids.size()[0]
+ avg_rec.add([accu, loss], batch_size)
+ nsamples += batch_size
+ logging.info("eval accu: %.4f; loss: %.4f; nsamples: %d",
+ avg_rec.avg(0), avg_rec.avg(1), nsamples)
+
+ @torch.no_grad()
+ def _test_epoch():
+ """Does predictions on the test set.
+ """
+ iterator.switch_to_dataset("test")
+ model.eval()
+
+ _all_preds = []
+ for batch in iterator:
+ input_ids = batch["input_ids"]
+ segment_ids = batch["segment_ids"]
+
+ input_length = (1 - (input_ids == 0).int()).sum(dim=1)
+
+ _, preds = model(input_ids, input_length, segment_ids)
+
+ _all_preds.extend(preds.tolist())
+
+ output_file = os.path.join(output_dir, "test_results.tsv")
+ with open(output_file, "w+") as writer:
+ writer.write("\n".join(str(p) for p in _all_preds))
+ logging.info("test output written to %s", output_file)
+
+ if self.checkpoint:
+ ckpt = torch.load(self.checkpoint)
+ model.load_state_dict(ckpt['model'])
+ optim.load_state_dict(ckpt['optimizer'])
+ scheduler.load_state_dict(ckpt['scheduler'])
+ if do_train:
+ for _ in range(self.config_data.max_train_epoch):
+ _train_epoch()
+ states = {
+ 'model': model.state_dict(),
+ 'optimizer': optim.state_dict(),
+ 'scheduler': scheduler.state_dict(),
+ }
+ torch.save(states, os.path.join(output_dir, 'model.ckpt'))
+
+ if do_eval:
+ _eval_epoch()
+
+ if do_test:
+ _test_epoch()
diff --git a/forte/models/imdb_text_classifier/utils/data_utils.py b/forte/models/imdb_text_classifier/utils/data_utils.py
index ed0de4908..8816f7d5d 100644
--- a/forte/models/imdb_text_classifier/utils/data_utils.py
+++ b/forte/models/imdb_text_classifier/utils/data_utils.py
@@ -1,495 +1,487 @@
-# coding=utf-8
-# Copyright 2018 The Google AI Language Team Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-This is the Data Loading Pipeline for Sentence Classifier Task from:
- `https://github.com/google-research/bert/blob/master/run_classifier.py`
-"""
-
-import os
-import csv
-import collections
-import logging
-
-import tensorflow as tf
-
-# import texar.tf as tx
-import texar.torch as tx
-
-
-class InputExample():
- """A single training/test example for simple sequence classification."""
-
- def __init__(self, guid, text_a, text_b=None, label=None):
- """Constructs a InputExample.
- Args:
- guid: Unique id for the example.
- text_a: string. The untokenized text of the first sequence.
- For single sequence tasks, only this sequence must be specified.
- text_b: (Optional) string. The untokenized text of the second
- sequence. Only must be specified for sequence pair tasks.
- label: (Optional) string. The label of the example. This should be
- specified for train and dev examples, but not for test examples.
- """
- self.guid = guid
- self.text_a = text_a
- self.text_b = text_b
- self.label = label
-
-
-class InputFeatures:
- """A single set of features of data."""
-
- def __init__(self, input_ids, input_mask, segment_ids, label_id):
- self.input_ids = input_ids
- self.input_mask = input_mask
- self.segment_ids = segment_ids
- self.label_id = label_id
-
-
-class DataProcessor(object):
- """Base class for data converters for sequence classification data sets."""
-
- def get_train_examples(self, data_dir):
- """Gets a collection of `InputExample`s for the train set."""
- raise NotImplementedError()
-
- def get_dev_examples(self, data_dir):
- """Gets a collection of `InputExample`s for the dev set."""
- raise NotImplementedError()
-
- def get_test_examples(self, data_dir):
- """Gets a collection of `InputExample`s for prediction."""
- raise NotImplementedError()
-
- def get_labels(self):
- """Gets the list of labels for this data set."""
- raise NotImplementedError()
-
- @classmethod
- def _read_tsv(cls, input_file, quotechar=None):
- """Reads a tab separated value file."""
- with tf.gfile.Open(input_file, "r") as f:
- reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
- lines = []
- for line in reader:
- lines.append(line)
- return lines
-
-
-def clean_web_text(st):
- """clean text."""
- st = st.replace(" ", " ")
- if "", start_pos)
- if end_pos != -1:
- st = st[:start_pos] + st[end_pos + 1:]
- else:
- print("incomplete href")
- print("before", st)
- st = st[:start_pos] + st[start_pos + len("", "")
- # print("after\n", st)
- # print("")
- st = st.replace("\\n", " ")
- st = st.replace("\\", " ")
- # while " " in st:
- # st = st.replace(" ", " ")
- return st
-
-
-class IMDbProcessor(DataProcessor):
- """Processor for the CoLA data set (GLUE version)."""
-
- def get_train_examples(self, raw_data_dir):
- """See base class."""
- return self._create_examples(
- self._read_tsv(os.path.join(raw_data_dir, "train.csv"),
- quotechar='"'), "train")
-
- def get_dev_examples(self, raw_data_dir):
- """See base class."""
- return self._create_examples(
- self._read_tsv(os.path.join(raw_data_dir, "test.csv"), # temporary workaround
- quotechar='"'), "test")
-
- def get_test_examples(self, raw_data_dir):
- """See base class."""
- return self._create_examples(
- self._read_tsv(os.path.join(raw_data_dir, "test.csv"),
- quotechar='"'), "test")
-
- def get_unsup_examples(self, raw_data_dir, unsup_set):
- """See base class."""
- if unsup_set == "unsup_ext":
- return self._create_examples(
- self._read_tsv(os.path.join(raw_data_dir, "unsup_ext.csv"),
- quotechar='"'), "unsup_ext", skip_unsup=False)
- elif unsup_set == "unsup_in":
- return self._create_examples(
- self._read_tsv(os.path.join(raw_data_dir, "train.csv"),
- quotechar='"'), "unsup_in", skip_unsup=False)
-
- def get_labels(self):
- """See base class."""
- return ["pos", "neg"]
-
- def _create_examples(self, lines, set_type, skip_unsup=True):
- """Creates examples for the training and dev sets."""
- examples = []
- for (i, line) in enumerate(lines):
- if i == 0:
- continue
- if skip_unsup and line[1] == "unsup":
- continue
- if line[1] == "unsup" and len(line[0]) < 500:
- # tf.logging.info("skipping short samples:{:s}".format(line[0]))
- continue
- guid = "%s-%s" % (set_type, line[2])
- text_a = line[0]
- label = tx.utils.compat_as_text(line[1])
- text_a = tx.utils.compat_as_text(clean_web_text(text_a))
- examples.append(
- InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
- return examples
-
- def get_train_size(self):
- return 25000
-
- def get_dev_size(self):
- return 25000
-
-
-class SSTProcessor(DataProcessor):
- """Processor for the MRPC data set (GLUE version)."""
-
- def get_train_examples(self, data_dir):
- """See base class."""
- return self._create_examples(
- self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
-
- def get_dev_examples(self, data_dir):
- """See base class."""
- return self._create_examples(
- self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
-
- def get_test_examples(self, data_dir):
- """See base class."""
- return self._create_examples(
- self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
-
- def get_labels(self):
- """See base class."""
- return ["0", "1"]
-
- @staticmethod
- def _create_examples(lines, set_type):
- """Creates examples for the training and dev sets."""
- examples = []
- if set_type == 'train' or set_type == 'dev':
- for (i, line) in enumerate(lines):
- if i == 0:
- continue
- guid = "%s-%s" % (set_type, i)
- text_a = tx.utils.compat_as_text(line[0])
- # Single sentence classification, text_b doesn't exist
- text_b = None
- label = tx.utils.compat_as_text(line[1])
- examples.append(InputExample(guid=guid, text_a=text_a,
- text_b=text_b, label=label))
- if set_type == 'test':
- for (i, line) in enumerate(lines):
- if i == 0:
- continue
- guid = "%s-%s" % (set_type, i)
- text_a = tx.utils.compat_as_text(line[1])
- # Single sentence classification, text_b doesn't exist
- text_b = None
- label = '0' # arbitrary set as 0
- examples.append(InputExample(guid=guid, text_a=text_a,
- text_b=text_b, label=label))
- return examples
-
-
-class XnliProcessor(DataProcessor):
- """Processor for the XNLI data set."""
-
- def __init__(self):
- self.language = "zh"
-
- def get_train_examples(self, data_dir):
- """See base class."""
- lines = self._read_tsv(
- os.path.join(data_dir, "multinli",
- "multinli.train.%s.tsv" % self.language))
- examples = []
- for (i, line) in enumerate(lines):
- if i == 0:
- continue
- guid = "train-%d" % (i)
- text_a = tx.utils.compat_as_text(line[0])
- text_b = tx.utils.compat_as_text(line[1])
- label = tx.utils.compat_as_text(line[2])
- if label == tx.utils.compat_as_text("contradictory"):
- label = tx.utils.compat_as_text("contradiction")
- examples.append(InputExample(guid=guid, text_a=text_a,
- text_b=text_b, label=label))
- return examples
-
- def get_dev_examples(self, data_dir):
- """See base class."""
- lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv"))
- examples = []
- for (i, line) in enumerate(lines):
- if i == 0:
- continue
- guid = "dev-%d" % (i)
- language = tx.utils.compat_as_text(line[0])
- if language != tx.utils.compat_as_text(self.language):
- continue
- text_a = tx.utils.compat_as_text(line[6])
- text_b = tx.utils.compat_as_text(line[7])
- label = tx.utils.compat_as_text(line[1])
- examples.append(InputExample(guid=guid, text_a=text_a,
- text_b=text_b, label=label))
- return examples
-
- def get_labels(self):
- """See base class."""
- return ["contradiction", "entailment", "neutral"]
-
-
-class MnliProcessor(DataProcessor):
- """Processor for the MultiNLI data set (GLUE version)."""
-
- def get_train_examples(self, data_dir):
- """See base class."""
- return self._create_examples(
- self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
-
- def get_dev_examples(self, data_dir):
- """See base class."""
- return self._create_examples(
- self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
- "dev_matched")
-
- def get_test_examples(self, data_dir):
- """See base class."""
- return self._create_examples(
- self._read_tsv(os.path.join(data_dir, "test_matched.tsv")),
- "test")
-
- def get_labels(self):
- """See base class."""
- return ["contradiction", "entailment", "neutral"]
-
- @staticmethod
- def _create_examples(lines, set_type):
- """Creates examples for the training and dev sets."""
- examples = []
- for (i, line) in enumerate(lines):
- if i == 0:
- continue
- guid = "%s-%s" % (set_type,
- tx.utils.compat_as_text(line[0]))
- text_a = tx.utils.compat_as_text(line[8])
- text_b = tx.utils.compat_as_text(line[9])
- if set_type == "test":
- label = "contradiction"
- else:
- label = tx.utils.compat_as_text(line[-1])
- examples.append(InputExample(guid=guid, text_a=text_a,
- text_b=text_b, label=label))
- return examples
-
-
-class MrpcProcessor(DataProcessor):
- """Processor for the MRPC data set (GLUE version)."""
-
- def get_train_examples(self, data_dir):
- """See base class."""
- return self._create_examples(
- self._read_tsv(os.path.join(data_dir, "train.tsv")),
- "train")
-
- def get_dev_examples(self, data_dir):
- """See base class."""
- return self._create_examples(
- self._read_tsv(os.path.join(data_dir, "dev.tsv")),
- "dev")
-
- def get_test_examples(self, data_dir):
- """See base class."""
- return self._create_examples(
- self._read_tsv(os.path.join(data_dir, "test.tsv")),
- "test")
-
- def get_labels(self):
- """See base class."""
- return ["0", "1"]
-
- @staticmethod
- def _create_examples(lines, set_type):
- """Creates examples for the training and dev sets."""
- examples = []
- for (i, line) in enumerate(lines):
- if i == 0:
- continue
- guid = "%s-%s" % (set_type, i)
- text_a = tx.utils.compat_as_text(line[3])
- text_b = tx.utils.compat_as_text(line[4])
- if set_type == "test":
- label = "0"
- else:
- label = tx.utils.compat_as_text(line[0])
- examples.append(InputExample(guid=guid, text_a=text_a,
- text_b=text_b, label=label))
- return examples
-
-
-class ColaProcessor(DataProcessor):
- """Processor for the CoLA data set (GLUE version)."""
-
- def get_train_examples(self, data_dir):
- """See base class."""
- return self._create_examples(
- self._read_tsv(os.path.join(data_dir, "train.tsv")),
- "train")
-
- def get_dev_examples(self, data_dir):
- """See base class."""
- return self._create_examples(
- self._read_tsv(os.path.join(data_dir, "dev.tsv")),
- "dev")
-
- def get_test_examples(self, data_dir):
- """See base class."""
- return self._create_examples(
- self._read_tsv(os.path.join(data_dir, "test.tsv")),
- "test")
-
- def get_labels(self):
- """See base class."""
- return ["0", "1"]
-
- @staticmethod
- def _create_examples(lines, set_type):
- """Creates examples for the training and dev sets."""
- examples = []
- for (i, line) in enumerate(lines):
- # Only the test set has a header
- if set_type == "test" and i == 0:
- continue
- guid = "%s-%s" % (set_type, i)
- if set_type == "test":
- text_a = tx.utils.compat_as_text(line[1])
- label = "0"
- else:
- text_a = tx.utils.compat_as_text(line[3])
- label = tx.utils.compat_as_text(line[1])
- examples.append(InputExample(guid=guid, text_a=text_a,
- text_b=None, label=label))
- return examples
-
-
-def convert_single_example(ex_index, example, label_list, max_seq_length,
- tokenizer):
- r"""Converts a single `InputExample` into a single `InputFeatures`."""
- label_map = {}
- for (i, label) in enumerate(label_list):
- label_map[label] = i
-
- input_ids, segment_ids, input_mask = \
- tokenizer.encode_text(text_a=example.text_a,
- text_b=example.text_b,
- max_seq_length=max_seq_length)
-
- label_id = label_map[example.label]
-
- # here we disable the verbose printing of the data
- if ex_index < 0:
- logging.info("*** Example ***")
- logging.info("guid: %s", example.guid)
- logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
- logging.info("input_ids length: %d", len(input_ids))
- logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
- logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
- logging.info("label: %s (id = %d)", example.label, label_id)
-
- feature = InputFeatures(input_ids=input_ids,
- input_mask=input_mask,
- segment_ids=segment_ids,
- label_id=label_id)
- return feature
-
-
-def convert_examples_to_features_and_output_to_files(
- examples, label_list, max_seq_length, tokenizer, output_file,
- feature_types):
- r"""Convert a set of `InputExample`s to a pickled file."""
-
- with tx.data.RecordData.writer(output_file, feature_types) as writer:
- for (ex_index, example) in enumerate(examples):
- feature = convert_single_example(ex_index, example, label_list,
- max_seq_length, tokenizer)
-
- features = {
- "input_ids": feature.input_ids,
- "input_mask": feature.input_mask,
- "segment_ids": feature.segment_ids,
- "label_ids": feature.label_id
- }
- writer.write(features)
-
-
-def prepare_record_data(processor, tokenizer,
- data_dir, max_seq_length, output_dir,
- feature_types):
- r"""Prepare record data.
- Args:
- processor: Data Preprocessor, which must have get_labels,
- get_train/dev/test/examples methods defined.
- tokenizer: The Sentence Tokenizer. Generally should be
- SentencePiece Model.
- data_dir: The input data directory.
- max_seq_length: Max sequence length.
- output_dir: The directory to save the pickled file in.
- feature_types: The original type of the feature.
- """
- label_list = processor.get_labels()
-
- train_examples = processor.get_train_examples(data_dir)
- train_file = os.path.join(output_dir, "train.pkl")
- convert_examples_to_features_and_output_to_files(
- train_examples, label_list, max_seq_length,
- tokenizer, train_file, feature_types)
-
- eval_examples = processor.get_dev_examples(data_dir)
- eval_file = os.path.join(output_dir, "eval.pkl")
- convert_examples_to_features_and_output_to_files(
- eval_examples, label_list,
- max_seq_length, tokenizer, eval_file, feature_types)
-
- test_examples = processor.get_test_examples(data_dir)
- test_file = os.path.join(output_dir, "predict.pkl")
- convert_examples_to_features_and_output_to_files(
- test_examples, label_list,
- max_seq_length, tokenizer, test_file, feature_types)
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This is the Data Loading Pipeline for Sentence Classifier Task from:
+ `https://github.com/google-research/bert/blob/master/run_classifier.py`
+"""
+
+import os
+import csv
+import logging
+
+import tensorflow as tf
+
+import texar.torch as tx
+
+
+class InputExample():
+ """A single training/test example for simple sequence classification."""
+
+ def __init__(self, guid, text_a, text_b=None, label=None):
+ """Constructs a InputExample.
+ Args:
+ guid: Unique id for the example.
+ text_a: string. The untokenized text of the first sequence.
+ For single sequence tasks, only this sequence must be specified.
+ text_b: (Optional) string. The untokenized text of the second
+ sequence. Only must be specified for sequence pair tasks.
+ label: (Optional) string. The label of the example. This should be
+ specified for train and dev examples, but not for test examples.
+ """
+ self.guid = guid
+ self.text_a = text_a
+ self.text_b = text_b
+ self.label = label
+
+
+class InputFeatures:
+ """A single set of features of data."""
+
+ def __init__(self, input_ids, input_mask, segment_ids, label_id):
+ self.input_ids = input_ids
+ self.input_mask = input_mask
+ self.segment_ids = segment_ids
+ self.label_id = label_id
+
+
+class DataProcessor():
+ """Base class for data converters for sequence classification data sets."""
+
+ def get_train_examples(self, data_dir):
+ """Gets a collection of `InputExample`s for the train set."""
+ raise NotImplementedError()
+
+ def get_dev_examples(self, data_dir):
+ """Gets a collection of `InputExample`s for the dev set."""
+ raise NotImplementedError()
+
+ def get_test_examples(self, data_dir):
+ """Gets a collection of `InputExample`s for prediction."""
+ raise NotImplementedError()
+
+ def get_labels(self):
+ """Gets the list of labels for this data set."""
+ raise NotImplementedError()
+
+ @classmethod
+ def _read_tsv(cls, input_file, quotechar=None):
+ """Reads a tab separated value file."""
+ with tf.gfile.Open(input_file, "r") as f:
+ reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
+ lines = []
+ for line in reader:
+ lines.append(line)
+ return lines
+
+
+def clean_web_text(st):
+ """clean text."""
+ st = st.replace(" ", " ")
+ if "", start_pos)
+ if end_pos != -1:
+ st = st[:start_pos] + st[end_pos + 1:]
+ else:
+ print("incomplete href")
+ print("before", st)
+ st = st[:start_pos] + st[start_pos + len("", "")
+ # print("after\n", st)
+ # print("")
+ st = st.replace("\\n", " ")
+ st = st.replace("\\", " ")
+ # while " " in st:
+ # st = st.replace(" ", " ")
+ return st
+
+
+class IMDbProcessor(DataProcessor):
+ """Processor for the CoLA data set (GLUE version)."""
+
+ def get_train_examples(self, raw_data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(raw_data_dir, "train.csv"),
+ quotechar='"'), "train")
+
+ def get_dev_examples(self, raw_data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(raw_data_dir, "test.csv"),
+ quotechar='"'), "test")
+
+ def get_unsup_examples(self, raw_data_dir, unsup_set):
+ """See base class."""
+ if unsup_set == "unsup_ext":
+ return self._create_examples(
+ self._read_tsv(os.path.join(raw_data_dir, "unsup_ext.csv"),
+ quotechar='"'), "unsup_ext", skip_unsup=False)
+ elif unsup_set == "unsup_in":
+ return self._create_examples(
+ self._read_tsv(os.path.join(raw_data_dir, "train.csv"),
+ quotechar='"'), "unsup_in", skip_unsup=False)
+
+ def get_labels(self):
+ """See base class."""
+ return ["pos", "neg"]
+
+ def _create_examples(self, lines, set_type, skip_unsup=True):
+ """Creates examples for the training and dev sets."""
+ examples = []
+ for (i, line) in enumerate(lines):
+ if i == 0:
+ continue
+ if skip_unsup and line[1] == "unsup":
+ continue
+ if line[1] == "unsup" and len(line[0]) < 500:
+ # tf.logging.info("skipping short samples:{:s}".format(line[0]))
+ continue
+ guid = "%s-%s" % (set_type, line[2])
+ text_a = line[0]
+ label = line[1]
+ text_a = clean_web_text(text_a)
+ examples.append(InputExample(guid=guid, text_a=text_a,
+ text_b=None, label=label))
+ return examples
+
+ def get_train_size(self):
+ return 25000
+
+ def get_dev_size(self):
+ return 25000
+
+
+class SSTProcessor(DataProcessor):
+ """Processor for the MRPC data set (GLUE version)."""
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
+
+ def get_labels(self):
+ """See base class."""
+ return ["0", "1"]
+
+ @staticmethod
+ def _create_examples(lines, set_type):
+ """Creates examples for the training and dev sets."""
+ examples = []
+ if set_type in ('train', 'dev'):
+ for (i, line) in enumerate(lines):
+ if i == 0:
+ continue
+ guid = "%s-%s" % (set_type, i)
+ text_a = tx.utils.compat_as_text(line[0])
+ # Single sentence classification, text_b doesn't exist
+ text_b = None
+ label = tx.utils.compat_as_text(line[1])
+ examples.append(InputExample(guid=guid, text_a=text_a,
+ text_b=text_b, label=label))
+ if set_type == 'test':
+ for (i, line) in enumerate(lines):
+ if i == 0:
+ continue
+ guid = "%s-%s" % (set_type, i)
+ text_a = tx.utils.compat_as_text(line[1])
+ # Single sentence classification, text_b doesn't exist
+ text_b = None
+ label = '0' # arbitrary set as 0
+ examples.append(InputExample(guid=guid, text_a=text_a,
+ text_b=text_b, label=label))
+ return examples
+
+
+class XnliProcessor(DataProcessor):
+ """Processor for the XNLI data set."""
+
+ def __init__(self):
+ self.language = "zh"
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ lines = self._read_tsv(
+ os.path.join(data_dir, "multinli",
+ "multinli.train.%s.tsv" % self.language))
+ examples = []
+ for (i, line) in enumerate(lines):
+ if i == 0:
+ continue
+ guid = "train-%d" % (i)
+ text_a = tx.utils.compat_as_text(line[0])
+ text_b = tx.utils.compat_as_text(line[1])
+ label = tx.utils.compat_as_text(line[2])
+ if label == tx.utils.compat_as_text("contradictory"):
+ label = tx.utils.compat_as_text("contradiction")
+ examples.append(InputExample(guid=guid, text_a=text_a,
+ text_b=text_b, label=label))
+ return examples
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv"))
+ examples = []
+ for (i, line) in enumerate(lines):
+ if i == 0:
+ continue
+ guid = "dev-%d" % (i)
+ language = tx.utils.compat_as_text(line[0])
+ if language != tx.utils.compat_as_text(self.language):
+ continue
+ text_a = tx.utils.compat_as_text(line[6])
+ text_b = tx.utils.compat_as_text(line[7])
+ label = tx.utils.compat_as_text(line[1])
+ examples.append(InputExample(guid=guid, text_a=text_a,
+ text_b=text_b, label=label))
+ return examples
+
+ def get_labels(self):
+ """See base class."""
+ return ["contradiction", "entailment", "neutral"]
+
+
+class MnliProcessor(DataProcessor):
+ """Processor for the MultiNLI data set (GLUE version)."""
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
+ "dev_matched")
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "test_matched.tsv")),
+ "test")
+
+ def get_labels(self):
+ """See base class."""
+ return ["contradiction", "entailment", "neutral"]
+
+ @staticmethod
+ def _create_examples(lines, set_type):
+ """Creates examples for the training and dev sets."""
+ examples = []
+ for (i, line) in enumerate(lines):
+ if i == 0:
+ continue
+ guid = "%s-%s" % (set_type,
+ tx.utils.compat_as_text(line[0]))
+ text_a = tx.utils.compat_as_text(line[8])
+ text_b = tx.utils.compat_as_text(line[9])
+ if set_type == "test":
+ label = "contradiction"
+ else:
+ label = tx.utils.compat_as_text(line[-1])
+ examples.append(InputExample(guid=guid, text_a=text_a,
+ text_b=text_b, label=label))
+ return examples
+
+
+class MrpcProcessor(DataProcessor):
+ """Processor for the MRPC data set (GLUE version)."""
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "train.tsv")),
+ "train")
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "dev.tsv")),
+ "dev")
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "test.tsv")),
+ "test")
+
+ def get_labels(self):
+ """See base class."""
+ return ["0", "1"]
+
+ @staticmethod
+ def _create_examples(lines, set_type):
+ """Creates examples for the training and dev sets."""
+ examples = []
+ for (i, line) in enumerate(lines):
+ if i == 0:
+ continue
+ guid = "%s-%s" % (set_type, i)
+ text_a = tx.utils.compat_as_text(line[3])
+ text_b = tx.utils.compat_as_text(line[4])
+ if set_type == "test":
+ label = "0"
+ else:
+ label = tx.utils.compat_as_text(line[0])
+ examples.append(InputExample(guid=guid, text_a=text_a,
+ text_b=text_b, label=label))
+ return examples
+
+
+class ColaProcessor(DataProcessor):
+ """Processor for the CoLA data set (GLUE version)."""
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "train.tsv")),
+ "train")
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "dev.tsv")),
+ "dev")
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "test.tsv")),
+ "test")
+
+ def get_labels(self):
+ """See base class."""
+ return ["0", "1"]
+
+ @staticmethod
+ def _create_examples(lines, set_type):
+ """Creates examples for the training and dev sets."""
+ examples = []
+ for (i, line) in enumerate(lines):
+ # Only the test set has a header
+ if set_type == "test" and i == 0:
+ continue
+ guid = "%s-%s" % (set_type, i)
+ if set_type == "test":
+ text_a = tx.utils.compat_as_text(line[1])
+ label = "0"
+ else:
+ text_a = tx.utils.compat_as_text(line[3])
+ label = tx.utils.compat_as_text(line[1])
+ examples.append(InputExample(guid=guid, text_a=text_a,
+ text_b=None, label=label))
+ return examples
+
+
+def convert_single_example(ex_index, example, label_list, max_seq_length,
+ tokenizer):
+ r"""Converts a single `InputExample` into a single `InputFeatures`."""
+ label_map = {}
+ for (i, label) in enumerate(label_list):
+ label_map[label] = i
+
+ input_ids, segment_ids, input_mask = \
+ tokenizer.encode_text(text_a=example.text_a,
+ text_b=example.text_b,
+ max_seq_length=max_seq_length)
+
+ label_id = label_map[example.label]
+
+ # here we disable the verbose printing of the data
+ if ex_index < 0:
+ logging.info("*** Example ***")
+ logging.info("guid: %s", example.guid)
+ logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
+ logging.info("input_ids length: %d", len(input_ids))
+ logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
+ logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
+ logging.info("label: %s (id = %d)", example.label, label_id)
+
+ feature = InputFeatures(input_ids=input_ids,
+ input_mask=input_mask,
+ segment_ids=segment_ids,
+ label_id=label_id)
+ return feature
+
+
+def convert_examples_to_features_and_output_to_files(
+ examples, label_list, max_seq_length, tokenizer, output_file,
+ feature_types):
+ r"""Convert a set of `InputExample`s to a pickled file."""
+
+ with tx.data.RecordData.writer(output_file, feature_types) as writer:
+ for (ex_index, example) in enumerate(examples):
+ feature = convert_single_example(ex_index, example, label_list,
+ max_seq_length, tokenizer)
+
+ features = {
+ "input_ids": feature.input_ids,
+ "input_mask": feature.input_mask,
+ "segment_ids": feature.segment_ids,
+ "label_ids": feature.label_id
+ }
+ writer.write(features)
+
+
+def prepare_record_data(processor, tokenizer,
+ data_dir, max_seq_length, output_dir,
+ feature_types):
+ r"""Prepare record data.
+ Args:
+ processor: Data Preprocessor, which must have get_labels,
+ get_train/dev/test/examples methods defined.
+ tokenizer: The Sentence Tokenizer. Generally should be
+ SentencePiece Model.
+ data_dir: The input data directory.
+ max_seq_length: Max sequence length.
+ output_dir: The directory to save the pickled file in.
+ feature_types: The original type of the feature.
+ """
+ label_list = processor.get_labels()
+
+ train_examples = processor.get_train_examples(data_dir)
+ train_file = os.path.join(output_dir, "train.pkl")
+ convert_examples_to_features_and_output_to_files(
+ train_examples, label_list, max_seq_length,
+ tokenizer, train_file, feature_types)
+
+ eval_examples = processor.get_dev_examples(data_dir)
+ eval_file = os.path.join(output_dir, "eval.pkl")
+ convert_examples_to_features_and_output_to_files(
+ eval_examples, label_list,
+ max_seq_length, tokenizer, eval_file, feature_types)
+
+ test_examples = processor.get_test_examples(data_dir)
+ test_file = os.path.join(output_dir, "predict.pkl")
+ convert_examples_to_features_and_output_to_files(
+ test_examples, label_list,
+ max_seq_length, tokenizer, test_file, feature_types)
diff --git a/forte/models/imdb_text_classifier/utils/model_utils.py b/forte/models/imdb_text_classifier/utils/model_utils.py
index 747c0dfa0..2e53492d8 100644
--- a/forte/models/imdb_text_classifier/utils/model_utils.py
+++ b/forte/models/imdb_text_classifier/utils/model_utils.py
@@ -1,19 +1,19 @@
-"""
-Model utility functions
-"""
-
-
-def get_lr_multiplier(step: int, total_steps: int, warmup_steps: int) -> float:
- r"""Calculate the learning rate multiplier given current step and the number
- of warm-up steps. The learning rate schedule follows a linear warm-up and
- linear decay.
- """
- step = min(step, total_steps)
-
- multiplier = (1 - (step - warmup_steps) / (total_steps - warmup_steps))
-
- if warmup_steps > 0 and step < warmup_steps:
- warmup_percent_done = step / warmup_steps
- multiplier = warmup_percent_done
-
- return multiplier
+"""
+Model utility functions
+"""
+
+
+def get_lr_multiplier(step: int, total_steps: int, warmup_steps: int) -> float:
+ r"""Calculate the learning rate multiplier given current step and the number
+ of warm-up steps. The learning rate schedule follows a linear warm-up and
+ linear decay.
+ """
+ step = min(step, total_steps)
+
+ multiplier = (1 - (step - warmup_steps) / (total_steps - warmup_steps))
+
+ if warmup_steps > 0 and step < warmup_steps:
+ warmup_percent_done = step / warmup_steps
+ multiplier = warmup_percent_done
+
+ return multiplier
", " ")
+ st = st.replace(""", "\"")
+ st = st.replace("
", " ")
- st = st.replace(""", "\"")
- st = st.replace("
", " ")
+ st = st.replace(""", "\"")
+ st = st.replace("