From 94d841773f8b721aef274cbaddf66e4d01e8a9a2 Mon Sep 17 00:00:00 2001 From: Oliver Stanley Date: Sat, 20 May 2023 15:15:32 +0100 Subject: [PATCH] Resolve duplication issue with filtered prosocial --- .../custom_datasets/toxic_conversation.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/model/model_training/custom_datasets/toxic_conversation.py b/model/model_training/custom_datasets/toxic_conversation.py index a34fd94d60..61ddad9233 100644 --- a/model/model_training/custom_datasets/toxic_conversation.py +++ b/model/model_training/custom_datasets/toxic_conversation.py @@ -20,7 +20,12 @@ class ProsocialDialogueExplaination(Dataset): def __init__(self, split="train", cache_dir=".cache") -> None: super().__init__() - dataset = load_dataset("Englishman2022/prosocial-dialog-filtered", cache_dir=cache_dir)[split] + dataset = load_dataset( + "Englishman2022/prosocial-dialog-filtered", + data_files="train.json", + cache_dir=cache_dir, + revision="e121e4fd886fadc030d633274c053b71839f9c20", + )[split] self.pairs = [] for row in dataset: for safety_annotation, safe_answer in zip(row["safety_annotations"], row["safety_annotation_reasons"]): @@ -54,7 +59,12 @@ class ProsocialDialogue(Dataset): def __init__(self, split="train", cache_dir=".cache") -> None: super().__init__() - dataset = load_dataset("Englishman2022/prosocial-dialog-filtered", cache_dir=cache_dir)[split] + dataset = load_dataset( + "Englishman2022/prosocial-dialog-filtered", + data_files="train.json", + cache_dir=cache_dir, + revision="e121e4fd886fadc030d633274c053b71839f9c20", + )[split] self.pairs = [] for row in dataset: prompt = row["context"]