From 40b5eb10c311a79b1590530dce5630eddcc792c9 Mon Sep 17 00:00:00 2001 From: Bartosz Kuncer Date: Thu, 7 Apr 2022 19:07:40 +0200 Subject: [PATCH 1/3] Add assert for doc_stride, max_seq_length and max_query_length --- scripts/question_answering/run_squad.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/question_answering/run_squad.py b/scripts/question_answering/run_squad.py index 521ee15a47..09b68e5334 100644 --- a/scripts/question_answering/run_squad.py +++ b/scripts/question_answering/run_squad.py @@ -41,6 +41,7 @@ CACHE_PATH = os.path.realpath(os.path.join(os.path.realpath(__file__), '..', 'cached')) if not os.path.exists(CACHE_PATH): os.makedirs(CACHE_PATH, exist_ok=True) +SEPARATORS = 3 def parse_args(): @@ -151,6 +152,7 @@ def parse_args(): 'use --dtype float16, amp will be turned on in the training phase and ' 'fp16 will be used in evaluation.') args = parser.parse_args() + assert args.doc_stride <= args.max_seq_length - args.max_query_length - SEPARATORS, "possible loss of data when chunking" return args @@ -256,7 +258,7 @@ def process_sample(self, feature: SquadFeature): truncated_query_ids = feature.query_token_ids[:self._max_query_length] chunks = feature.get_chunks( doc_stride=self._doc_stride, - max_chunk_length=self._max_seq_length - len(truncated_query_ids) - 3) + max_chunk_length=self._max_seq_length - len(truncated_query_ids) - SEPARATORS) for chunk in chunks: data = np.array([self.cls_id] + truncated_query_ids + [self.sep_id] + feature.context_token_ids[chunk.start:(chunk.start + chunk.length)] + From 1836ac2dd3e1449dd8088301f3c84e69143eb04c Mon Sep 17 00:00:00 2001 From: Bartosz Kuncer Date: Thu, 7 Apr 2022 21:44:22 +0200 Subject: [PATCH 2/3] Uniform assert message --- scripts/question_answering/run_squad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/question_answering/run_squad.py b/scripts/question_answering/run_squad.py index 09b68e5334..3d207eb9de 100644 --- a/scripts/question_answering/run_squad.py +++ b/scripts/question_answering/run_squad.py @@ -152,7 +152,7 @@ def parse_args(): 'use --dtype float16, amp will be turned on in the training phase and ' 'fp16 will be used in evaluation.') args = parser.parse_args() - assert args.doc_stride <= args.max_seq_length - args.max_query_length - SEPARATORS, "possible loss of data when chunking" + assert args.doc_stride <= args.max_seq_length - args.max_query_length - SEPARATORS, 'Possible loss of data while chunking input features' return args From da8d8d4126b1b7480597154d8745c37905377c98 Mon Sep 17 00:00:00 2001 From: bartekkuncer Date: Tue, 12 Jul 2022 16:57:07 +0200 Subject: [PATCH 3/3] Update scripts/question_answering/run_squad.py Co-authored-by: bgawrych --- scripts/question_answering/run_squad.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/question_answering/run_squad.py b/scripts/question_answering/run_squad.py index 3d207eb9de..1daf4689e4 100644 --- a/scripts/question_answering/run_squad.py +++ b/scripts/question_answering/run_squad.py @@ -152,7 +152,10 @@ def parse_args(): 'use --dtype float16, amp will be turned on in the training phase and ' 'fp16 will be used in evaluation.') args = parser.parse_args() - assert args.doc_stride <= args.max_seq_length - args.max_query_length - SEPARATORS, 'Possible loss of data while chunking input features' + + assert args.doc_stride <= args.max_seq_length - args.max_query_length - SEPARATORS, \ + 'Possible loss of data while chunking input features' + return args