Skip to content

Commit

Permalink
Fix get_scores_dict for duplicate tokens (#192)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti authored Jun 19, 2023
1 parent b5d3610 commit ed5a4c2
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions inseq/data/attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,22 +324,22 @@ def get_scores_dicts(
for tgt_idx in range(aggr.attr_pos_start, aggr.attr_pos_end):
tgt_tok = aggr.target[tgt_idx]
if aggr.source_attributions is not None:
return_dict["source_attributions"][tgt_tok.token] = {}
return_dict["source_attributions"][(tgt_idx, tgt_tok.token)] = {}
for src_idx, src_tok in enumerate(aggr.source):
return_dict["source_attributions"][tgt_tok.token][src_tok.token] = aggr.source_attributions[
src_idx, tgt_idx - aggr.attr_pos_start
].item()
return_dict["source_attributions"][(tgt_idx, tgt_tok.token)][
(src_idx, src_tok.token)
] = aggr.source_attributions[src_idx, tgt_idx - aggr.attr_pos_start].item()
if aggr.target_attributions is not None:
return_dict["target_attributions"][tgt_tok.token] = {}
return_dict["target_attributions"][(tgt_idx, tgt_tok.token)] = {}
for tgt_idx_attr in range(aggr.attr_pos_end):
tgt_tok_attr = aggr.target[tgt_idx_attr]
return_dict["target_attributions"][tgt_tok.token][tgt_tok_attr.token] = aggr.target_attributions[
tgt_idx_attr, tgt_idx - aggr.attr_pos_start
].item()
return_dict["target_attributions"][(tgt_idx, tgt_tok.token)][
(tgt_idx_attr, tgt_tok_attr.token)
] = aggr.target_attributions[tgt_idx_attr, tgt_idx - aggr.attr_pos_start].item()
if aggr.step_scores is not None:
return_dict["step_scores"][tgt_tok.token] = {}
return_dict["step_scores"][(tgt_idx, tgt_tok.token)] = {}
for step_score_id, step_score in aggr.step_scores.items():
return_dict["step_scores"][tgt_tok.token][step_score_id] = step_score[
return_dict["step_scores"][(tgt_idx, tgt_tok.token)][step_score_id] = step_score[
tgt_idx - aggr.attr_pos_start
].item()
return return_dict
Expand Down

0 comments on commit ed5a4c2

Please sign in to comment.