Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Add assert for doc_stride, max_seq_length and max_query_length #1587

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion scripts/question_answering/run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
CACHE_PATH = os.path.realpath(os.path.join(os.path.realpath(__file__), '..', 'cached'))
if not os.path.exists(CACHE_PATH):
os.makedirs(CACHE_PATH, exist_ok=True)
SEPARATORS = 3


def parse_args():
Expand Down Expand Up @@ -151,6 +152,7 @@ def parse_args():
'use --dtype float16, amp will be turned on in the training phase and '
'fp16 will be used in evaluation.')
args = parser.parse_args()
assert args.doc_stride <= args.max_seq_length - args.max_query_length - SEPARATORS, 'Possible loss of data while chunking input features'
bartekkuncer marked this conversation as resolved.
Show resolved Hide resolved
return args


Expand Down Expand Up @@ -256,7 +258,7 @@ def process_sample(self, feature: SquadFeature):
truncated_query_ids = feature.query_token_ids[:self._max_query_length]
chunks = feature.get_chunks(
doc_stride=self._doc_stride,
max_chunk_length=self._max_seq_length - len(truncated_query_ids) - 3)
max_chunk_length=self._max_seq_length - len(truncated_query_ids) - SEPARATORS)
for chunk in chunks:
data = np.array([self.cls_id] + truncated_query_ids + [self.sep_id] +
feature.context_token_ids[chunk.start:(chunk.start + chunk.length)] +
Expand Down