diff --git a/textattack/loggers/csv_logger.py b/textattack/loggers/csv_logger.py index d248b063..68b8d3a3 100644 --- a/textattack/loggers/csv_logger.py +++ b/textattack/loggers/csv_logger.py @@ -46,7 +46,8 @@ def flush(self): self._flushed = True def close(self): - self.fout.close() + # self.fout.close() + super().close() def __del__(self): if not self._flushed: diff --git a/textattack/models/helpers/word_cnn_for_classification.py b/textattack/models/helpers/word_cnn_for_classification.py index e2d1e2b3..a3c8bd9a 100644 --- a/textattack/models/helpers/word_cnn_for_classification.py +++ b/textattack/models/helpers/word_cnn_for_classification.py @@ -84,7 +84,7 @@ def save_pretrained(self, output_path): @classmethod def from_pretrained(cls, name_or_path): - """Load trained LSTM model by name or from path. + """Load trained Word CNN model by name or from path. Args: name_or_path (:obj:`str`): Name of the model (e.g. "cnn-imdb") or model saved via :meth:`save_pretrained`. diff --git a/textattack/shared/attacked_text.py b/textattack/shared/attacked_text.py index b8c0687e..200384aa 100644 --- a/textattack/shared/attacked_text.py +++ b/textattack/shared/attacked_text.py @@ -80,6 +80,8 @@ def __eq__(self, other): """ if not (self.text == other.text): return False + if len(self.attack_attrs) != len(other.attack_attrs): + return False for key in self.attack_attrs: if key not in other.attack_attrs: return False @@ -193,7 +195,10 @@ def _text_index_of_word_index(self, i): # Find all words until `i` in string. look_after_index = 0 for word in pre_words: - look_after_index = lower_text.find(word.lower(), look_after_index) + look_after_index = lower_text.find(word.lower(), look_after_index) + len( + word + ) + look_after_index -= len(self.words[i]) return look_after_index def text_until_word_index(self, i): @@ -217,7 +222,7 @@ def first_word_diff(self, other_attacked_text): w2 = other_attacked_text.words for i in range(min(len(w1), len(w2))): if w1[i] != w2[i]: - return w1 + return w1[i] return None def first_word_diff_index(self, other_attacked_text):