diff --git a/c4200m_get_target_sentences.py b/c4200m_get_target_sentences.py index 7856e75..f893832 100644 --- a/c4200m_get_target_sentences.py +++ b/c4200m_get_target_sentences.py @@ -8,27 +8,36 @@ LOGGING_STEPS = 100000 + def main(argv): if len(argv) != 3 and len(argv) != 4: raise app.UsageError( - "python3 c4200m_get_target_sentences.py []") + "python3 c4200m_get_target_sentences.py " + " []" + ) edits_tsv_path = argv[1] output_tsv_path = argv[2] - tfds_name = "c4/en:2.2.1" if len(argv) == 4 and argv[3] != "en": - tfds_name = "c4/multilingual/" + argv[3] + tfds_name = "c4/multilingual:3.1.0" + split = argv[3] + else: + tfds_name = "c4/en:2.2.1" + split = "train" print("Loading C4_200M target sentence hashes from %r..." % edits_tsv_path) remaining_hashes = set() with open(edits_tsv_path) as edits_tsv_reader: for tsv_line in edits_tsv_reader: remaining_hashes.add(tsv_line.split("\t", 1)[0]) - print("Searching for %d target sentences in the dataset %r..." % - (len(remaining_hashes), tfds_name)) + print( + "Searching for %d target sentences in the dataset %r split %r..." + % (len(remaining_hashes), tfds_name, split) + ) target_sentences = [] for num_done_examples, example in enumerate( - tfds.load("c4/en:2.2.1", split="train")): + tfds.load(tfds_name, split=split) + ): for line in example["text"].numpy().decode("utf-8").split("\n"): line_md5 = hashlib.md5(line.encode("utf-8")).hexdigest() if line_md5 in remaining_hashes: @@ -37,10 +46,14 @@ def main(argv): if not remaining_hashes: break if num_done_examples % LOGGING_STEPS == 0: - print("-- %d C4 examples done, %d sentences still to be found" % - (num_done_examples, len(remaining_hashes))) - print("Found %d target sentences (%d not found)." % - (len(target_sentences), len(remaining_hashes))) + print( + "-- %d C4 examples done, %d sentences still to be found" + % (num_done_examples, len(remaining_hashes)) + ) + print( + "Found %d target sentences (%d not found)." + % (len(target_sentences), len(remaining_hashes)) + ) print("Writing C4_200M sentence pairs to %r..." % output_tsv_path) with open(output_tsv_path, "w") as output_tsv_writer: while target_sentences: @@ -49,3 +62,4 @@ def main(argv): if __name__ == "__main__": app.run(main) +