Skip to content

Commit

Permalink
Fix return all score issue (#70)
Browse files Browse the repository at this point in the history
* Use self.tokenizer_name in text_classification

* Fix distilbert classification input for return_all_scores

* Fix post_process output for return_all_scores in text_classification.py

* Fix roberta input for return_all_scores in roberta_text_classification.py

* Fix bert input for return_all_scores in bert_text_classification.py
  • Loading branch information
pooya-mohammadi authored Aug 6, 2023
1 parent 8acc686 commit a54460b
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,32 @@ def _build_inner_config(self):
def forward(self, inputs, **kwargs) -> Dict:
input_ids = inputs.get("token_ids")
attention_mask = inputs.get("attention_mask", None)
token_types_ids = inputs.get("token_type_ids", None) or kwargs.get("token_type_ids", None)
position_ids = inputs.get("position_ids", None) or kwargs.get("position_ids", None)
head_mask = inputs.get("head_mask", None)
inputs_embeds = inputs.get("inputs_embeds", None)
encoder_hidden_states = inputs.get("encoder_hidden_states", None) or kwargs.get("encoder_hidden_states", None)
encoder_attention_mask = inputs.get("encoder_attention_mask", None) or kwargs.get("encoder_attention_mask", None)
past_key_values = inputs.get("past_key_values", None) or kwargs.get("past_key_values", None)
use_cache = inputs.get("use_cache", None) or kwargs.get("use_cache", None)
output_attentions = inputs.get("output_attentions", None)
output_hidden_states = inputs.get("output_hidden_states", None)
return_dict = inputs.get("return_dict", None) or kwargs.get("return_dict", None)

lm_outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_types_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
**kwargs,
return_dict=return_dict,
)
pooled_output = lm_outputs[1]
pooled_output = self.dropout(pooled_output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def forward(self, inputs, **kwargs) -> Dict:
inputs_embeds = inputs.get("inputs_embeds", None)
output_attentions = inputs.get("output_attentions", None)
output_hidden_states = inputs.get("output_hidden_states", None)
return_dict = inputs.get("return_dict", None) or kwargs.get("return_dict", None)

lm_outputs = self.distilbert(
input_ids,
Expand All @@ -51,7 +52,7 @@ def forward(self, inputs, **kwargs) -> Dict:
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
**kwargs,
return_dict=return_dict,
)
hidden_state = lm_outputs[0]
pooled_output = hidden_state[:, 0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,32 @@ def _build_inner_config(self):
def forward(self, inputs, **kwargs):
input_ids = inputs.get("token_ids")
attention_mask = inputs.get("attention_mask", None)
token_type_ids = inputs.get("token_type_ids", None) or kwargs.get("token_type_ids", None)
position_ids = inputs.get("position_ids", None) or kwargs.get("position_ids", None)
head_mask = inputs.get("head_mask", None)
inputs_embeds = inputs.get("inputs_embeds", None)
encoder_hidden_states = inputs.get("encoder_hidden_states", None) or kwargs.get("encoder_hidden_states", None)
encoder_attention_mask = inputs.get("encoder_attention_mask", None) or kwargs.get("encoder_attention_mask", None)
past_key_values = inputs.get("past_key_values", None) or kwargs.get("past_key_values", None)
use_cache = inputs.get("use_cache", None) or kwargs.get("use_cache", None)
output_attentions = inputs.get("output_attentions", None)
output_hidden_states = inputs.get("output_hidden_states", None)
return_dict = inputs.get("return_dict", None) or kwargs.get("return_dict", None)

lm_outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
**kwargs,
return_dict=return_dict,
)
sequence_output = lm_outputs[0]
logits = self.classifier(sequence_output)
Expand Down
8 changes: 4 additions & 4 deletions hezar/models/text_classification/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def preprocess(self, inputs: Union[str, List[str]], **kwargs):
if "text_normalizer" in self.preprocessor:
normalizer = self.preprocessor["text_normalizer"]
inputs = normalizer(inputs)
tokenizer = self.preprocessor[self._tokenizer_name]
tokenizer = self.preprocessor[self.tokenizer_name]
inputs = tokenizer(inputs, return_tensors="pt", device=self.device)
return inputs

Expand All @@ -27,9 +27,9 @@ def post_process(self, inputs, **kwargs) -> Dict:
outputs = []
for sample_index in range(predictions.shape[0]):
sample_outputs = []
for prediction, prob in zip(predictions[sample_index], predictions_probs[sample_index]):
label = self.config.id2label[prediction.item()]
sample_outputs.append({"label": label, "score": prob.item()})
for label_index, score in enumerate(predictions_probs[sample_index]):
label = self.config.id2label[label_index]
sample_outputs.append({"label": label, "score": score.item()})
outputs.append(sample_outputs)
else:
predictions = logits.argmax(1)
Expand Down

0 comments on commit a54460b

Please sign in to comment.