Skip to content
/ GNER Public
forked from yyDing1/GNER

Reproducibility Study: Rethinking Negative Instances for Generative Named Entity Recognition

License

Notifications You must be signed in to change notification settings

mamnuya/GNER

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

63 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Reproduction: Rethinking Negative Instances for Generative Named Entity Recognition

View our contributions and access the full reproducibility study report, corresponding to README section "Using GNER and Reproduce Findings":


GitHub license Pretrained Models Paper

We introduce GNER, a Generative Named Entity Recognition framework, which demonstrates enhanced zero-shot capabilities across unseen entity domains. Experiments on two representative generative models, i.e., LLaMA and Flan-T5, show that the integration of negative instances into the training process yields substantial performance enhancements. The resulting models, GNER-LLaMA and GNER-T5, outperform state-of-the-art (SoTA) approaches by a large margin, achieving improvements of 8 and 11 points in $F_1$ score, respectively. Code and models are publicly available.

PreTrained Models

We release five GNER models based on LLaMA (7B) and Flan-T5 (base, large, xl and xxl).

Model # Params Zero-shot Average $F_1$ Supervised Average $F_1$ 🤗 HuggingFace
Download Link
GNER-LLaMA 7B 66.1 86.09 link
GNER-T5-base 248M 59.5 83.21 link
GNER-T5-large 783M 63.5 85.45 link
GNER-T5-xl 3B 66.1 85.94 link
GNER-T5-xxl 11B 69.1 86.15 link

Demo usage

Please check out Example Jupyter Notebooks for guidance on utilizing GNER models.

A simple inference example is as follows:

GNER-LLaMA:

>>> import torch
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("dyyyyyyyy/GNER-LLaMA-7B")
>>> model = AutoModelForCausalLM.from_pretrained("dyyyyyyyy/GNER-LLaMA-7B", torch_dtype=torch.bfloat16).cuda()
>>> model = model.eval()
>>> instruction_template = "Please analyze the sentence provided, identifying the type of entity for each word on a token-by-token basis.\nOutput format is: word_1(label_1), word_2(label_2), ...\nWe'll use the BIO-format to label the entities, where:\n1. B- (Begin) indicates the start of a named entity.\n2. I- (Inside) is used for words within a named entity but are not the first word.\n3. O (Outside) denotes words that are not part of a named entity.\n"
>>> sentence = "did george clooney make a musical in the 1980s"
>>> entity_labels = ["genre", "rating", "review", "plot", "song", "average ratings", "director", "character", "trailer", "year", "actor", "title"]
>>> instruction = f"{instruction_template}\nUse the specific entity tags: {', '.join(entity_labels)} and O.\nSentence: {sentence}"
>>> instruction = f"[INST] {instruction} [/INST]"
>>> inputs = tokenizer(instruction, return_tensors="pt").to("cuda")
>>> outputs = model.generate(**inputs, max_new_tokens=640)
>>> response = tokenizer.decode(outputs[0], skip_special_tokens=True)
>>> response = response[response.find("[/INST]") + len("[/INST]"):].strip()
>>> print(response)
"did(O) george(B-actor) clooney(I-actor) make(O) a(O) musical(B-genre) in(O) the(O) 1980s(B-year)"

GNER-T5:

>>> import torch
>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
>>> tokenizer = AutoTokenizer.from_pretrained("dyyyyyyyy/GNER-T5-xxl")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("dyyyyyyyy/GNER-T5-xxl", torch_dtype=torch.bfloat16).cuda()
>>> model = model.eval()
>>> instruction_template = "Please analyze the sentence provided, identifying the type of entity for each word on a token-by-token basis.\nOutput format is: word_1(label_1), word_2(label_2), ...\nWe'll use the BIO-format to label the entities, where:\n1. B- (Begin) indicates the start of a named entity.\n2. I- (Inside) is used for words within a named entity but are not the first word.\n3. O (Outside) denotes words that are not part of a named entity.\n"
>>> sentence = "did george clooney make a musical in the 1980s"
>>> entity_labels = ["genre", "rating", "review", "plot", "song", "average ratings", "director", "character", "trailer", "year", "actor", "title"]
>>> instruction = f"{instruction_template}\nUse the specific entity tags: {', '.join(entity_labels)} and O.\nSentence: {sentence}"
>>> inputs = tokenizer(instruction, return_tensors="pt").to("cuda")
>>> outputs = model.generate(**inputs, max_new_tokens=640)
>>> response = tokenizer.decode(outputs[0], skip_special_tokens=True)
>>> print(response)
"did(O) george(B-actor) clooney(I-actor) make(O) a(O) musical(B-genre) in(O) the(O) 1980s(B-year)"

Task schema: Incorporating negative instances into training

Hierarchical Matching: A faster algorithm for structuring

We develop a Hierarchical Matching algorithm that provides a straightforward and effective solution to the omission, addition, and substitution problems in the structuring process.

Furthermore, we implement a fast version of the LCS algorithm within $O(N\log N)$, based on the nature of the small number of duplicate words in the query sentence.

First, we transform the Longest Common Subsequence (LCS) problem into a Longest Increasing Subsequence (LIS) problem. Subsequently, we construct a Directed Acyclic Graph (DAG) to facilitate the traceback of the specific sequence.

# A fast version of LCS with a complexity of O(NlogN)
# in the condiction that there are few depulicate words in the sentence
# input: a = [word_1, word_2, ..., word_n], b = [word_1, word_2, ..., word_m]
# return: match_idx = [idx_1, idx_2, ..., idx_n] (correspoding matching index between a and b)
def lcs_solve_fast(a, b):
    n, m = len(a), len(b)
    match_idx = [-1] * n
    match_list_b = defaultdict(list)
  
    # First we can convert the LCS problem into a LIS problem,
    # i.e., LCS(a, b) <=> LIS(index_list)
    for idx, word in enumerate(reversed(b)):
        match_list_b[word].append(m - idx - 1)
    index_list = []
    elem_list = []
    for idx, word in enumerate(a):
        if word in match_list_b:
            index_list.extend(match_list_b[word])
            elem_list.extend([idx] * len(match_list_b[word]))

    # then we compute the longest increasing subsequence of index_list
    # we compute a dag, the edges array store the parent of the node, and path store the results
    father, increasing_seq = [[(-1, -1, -1)]], [-1]
    for i in range(len(index_list)):
        if index_list[i] > increasing_seq[-1]:
            father.append([(len(father[-1]) - 1, i, index_list[i])])
            increasing_seq.append(index_list[i])
        else:
            # binary search
            l, r, query_idx = 0, len(increasing_seq) - 1, -1
            while l <= r:
                mid = (l + r) >> 1
                if increasing_seq[mid] >= index_list[i]:
                    query_idx = mid
                    r = mid - 1
                else:
                    l = mid + 1
            father[query_idx].append((len(father[query_idx - 1]) - 1, i, index_list[i]))
            increasing_seq[query_idx] = index_list[i]

    # finally, we trace back the path to get a solution of the original LCS problem
    i, j = len(father) - 1, len(father[-1]) - 1
    while i > 0:
        match_idx[elem_list[father[i][j][1]]] = father[i][j][2]
        j = father[i][j][0]
        i -= 1
    return match_idx

Citation

@misc{ding2024rethinking,
      title={Rethinking Negative Instances for Generative Named Entity Recognition}, 
      author={Yuyang Ding and Juntao Li and Pinzheng Wang and Zecheng Tang and Bowen Yan and Min Zhang},
      year={2024},
      eprint={2402.16602},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

Using GNER and Reproduce Findings

(Updated as of November 2024)

Requirements

You should install the dependencies:

# CUDA 11.7 and above
# PyTorch 2.0 and above.
# transformers>=4.32.0,<4.38.0
python -m pip install -r requirements.txt

Quick Reproduction

We also provide all the generated results for quick reproduction of our results. The model_predictions folder contains the generated results of GNER-LLaMA-7B and GNER-T5-xxl (including the ground truth). You can execute the following commands to evaluate the generated results:

# 0shot performance of GNER-LLaMA
python evaluate.py --tokenizer-path yahma/llama-7b-hf --prediction-path model_predictions/llama-7b-task-adaptation-beam1.jsonl
# 0shot performance of GNER-T5-xxl
python evaluate.py --tokenizer-path google/flan-t5-xxl --prediction-path model_predictions/flan-t5-xxl-task-adaptation-beam1.jsonl

Other generated results can be found at here, and the execution process is similar to the two examples mentioned above.

Reproduce and Test Paper Results in Table 9

To verify the results in Table 9 using unit tests, you can execute the following command:

# Table 9 Testing
python test_GNER_table9.py

# Run single tests
python -m unittest test_GNER_table9.TestGNERModelPredictions.test_omission_case_1_llama
python -m unittest test_GNER_table9.TestGNERModelPredictions.test_omission_case_2_llama

python -m unittest test_GNER_table9.TestGNERModelPredictions.test_addition_case_3_llama

python -m unittest test_GNER_table9.TestGNERModelPredictions.test_substitution_case_4_llama
python -m unittest test_GNER_table9.TestGNERModelPredictions.test_substitution_case_5_llama

python -m unittest test_GNER_table9.TestGNERModelPredictions.test_omission_case_6_t5
python -m unittest test_GNER_table9.TestGNERModelPredictions.test_omission_case_7_t5

python -m unittest test_GNER_table9.TestGNERModelPredictions.test_addition_case_8_t5
python -m unittest test_GNER_table9.TestGNERModelPredictions.test_addition_case_9_t5

python -m unittest test_GNER_table9.TestGNERModelPredictions.test_substitution_case_10_t5


Reproduce and Test Paper Results in Figure 6 and Table 10

To optionally observe the outputs of the following tests, view table1case1output.txt and figure6output.txt

#Figure 6 Test
python figure6.py

#Table10 Case 1 Test
python table10case1.py

Robustness Testing and Evaluation Script

Perform Robustness Testing

To optionally observe the outputs of the following tests, view robusttest1.txt

python test_robust1.py

Run provided evaluation script on an empty existing .json file

This produces a ZeroDivisionError.

python evaluate.py --tokenizer-path yahma/llama-7b-hf --prediction-path model_predictions/test_gner_evaluation_empty.jsonl

Run provided evaluation script on existing .json file with limited data and word edits

python evaluate.py --tokenizer-path yahma/llama-7b-hf --prediction-path model_predictions/test_gner_evaluation.jsonl

Training & Inference

First, you should download the training data from here, put it in the current directory and rename it as data

The training scripts are outlined in folder scripts, you can train and evaluate the model by the following command:

# Train and evaluate LLaMA Model
bash scripts/train_llama_task_adaptation.sh
# Evaluate only
bash scripts/eval_llama_task_adaptation.sh

# Train T5 xxl Model
bash scripts/train_t5_xxl_task_adaptation.sh
# Evaluate only
bash scripts/eval_t5_task_adaptation.sh

Error Analysis

We performed a detailed error analysis on the GNER models to understand where they may struggle with entity recognition. The following tests and results help highlight the model’s limitations and areas for improvements

Prerequisites Python 3.x Install all dependencies using:

pip install -r requirements.txt

1. Contextual Entity Recognition Test

This test evaluates the model's ability to recognize entities based on the context in which they appear. It checks whether the model can accurately identify and label entities in sentences with varying structures and complexities.

Implementation Details Script: contextual_entity_recognition_test.py Model Used: GNER-LLaMA

Process: Tokenize input sentences. Run the model to predict entity labels. Compare the predicted labels with expected labels. Determine if the predictions match expectations

How to Run Execute the following command in your terminal:

python contextual_entity_recognition_test.py

Tests Conducted Contextual Entity Recognition Test

Objective: To check if the model accurately identifies entities in sentences where the context is essential.

Observation: The model generally recognized entities correctly in straightforward contexts. However, it struggled with ambiguous cases, often mislabeling or failing to recognize entities in complex sentence structures.

Error Found: In cases with multiple interpretations, the model occasionally applied incorrect labels, showing a limitation in contextual understanding.

2. Synonym Entity Test

This test assesses the model's ability to recognize entities expressed through synonyms or abbreviations. It evaluates whether the model can generalize entity recognition beyond exact matches to include equivalent terms.

Implementation Details

Script: synonym_entity_test.py Model Used: GNER-LLaMA

Process: Provide sentences with entities represented by synonyms or abbreviations. Predict entity labels using the model. Compare predictions with expected labels. How to Run Execute the following command:

python synonym_entity_test.py

Objective: To evaluate the model’s ability to recognize entities when synonyms or abbreviations are used instead of exact terms.

Observation: The model performed inconsistently with synonyms and abbreviations. For example, it recognized "NYC" as a location but failed with more unusual or varied synonyms for common terms.

Error Found: The model often failed to generalize across equivalent terms, indicating a need for improved synonym and abbreviation handling during training.

3. Tokenization Test

This test verifies the model's tokenization accuracy, focusing on complex words like hyphenated terms and multi-token phrases. Proper tokenization is essential for accurate entity recognition.

Implementation Details Script: tokenization_test.py Tokenizer Used: GNER-LLaMA tokenizer Process: Define test cases with expected tokens. Tokenize input phrases using the tokenizer. Compare the tokenizer's output with expected tokens. Determine if tokenization matches expectations. How to Run Execute the following command:

python tokenization_test.py

Objective: To assess how well the tokenizer processes complex words, like hyphenated terms or multi-token phrases, which can affect the entity recognition accuracy.

Observation: For most simple terms, tokenization was accurate. However, in cases of hyphenated words and multi-part terms, tokenization was sometimes incorrect, leading to mismatches with expected outputs.

Error Found: Incorrect tokenization on complex words led to misalignment in entity tagging, which could propagate errors in the recognition process.

4. Resource Constraint Test

This test monitors the model's memory usage during inference to evaluate its efficiency and suitability for deployment in resource-constrained environments.

Implementation Details Script: resource_constraint_test.py Process: Load a GNER-LLaMA model for sequence classification. Perform multiple inferences while recording memory usage. Output memory usage data to a log file for analysis. How to Run Execute the Command with:

python resource_constraint_test.py

Objective: To monitor the model's memory usage during inference, ensuring it can run efficiently in memory-limited environments.

Observation: Memory usage remained consistent across test cases, but larger sentences led to minor spikes.

Error Found: No critical errors were found, though resource use was high, suggesting potential optimization needs for deployment on devices with restricted memory.

5. Labeling Test

This test verifies the model’s ability to accurately label entities in a sentence, comparing the actual labels assigned by the model with the expected labels. It assesses the model’s consistency and accuracy in identifying and labeling entities in diverse sentence structures.

Implementation Details

Script: labeling_test.py Process: Run several test sentences through the model, comparing its output labels with predefined, expected labels. Output each test sentence with its expected vs. actual labels and note if they match or not. How to Run Execute the Command with:

python labeling_test.py

Objective: To assess the model's ability to accurately assign entity labels in sentences, ensuring consistency in labeling diverse entities such as names, dates, and locations.

Observation: The model accurately labeled entities in most cases, correctly identifying entities like B-PERSON, B-DATE, and B-LOC. However, minor inconsistencies were observed with more complex sentence structures, which impacted accuracy slightly.

Error Found: No critical errors were found in basic entity labeling, though minor misclassifications occurred in complex sentence structures, indicating a need for fine-tuning on more varied datasets to improve labeling accuracy in edge cases.

Error Analysis Conclusion

Contextual Entity Recognition: The model failed to leverage contextual information, resulting in misclassification or omission of entities in nuanced sentences. Proposed Improvement: Fine-tune the model on diverse datasets emphasizing context-driven annotations.

Synonym Handling: A lack of generalization across synonyms and abbreviations restricted the model’s ability to recognize entities with varied representations. Proposed Improvement: Integrate synonym and abbreviation mapping into the training dataset.

Tokenization Accuracy: Tokenizer errors on hyphenated and multi-part terms highlighted a gap in preprocessing, affecting downstream tasks. Proposed Improvement: Enhance the tokenizer for complex terms and multi-token phrases.

Resource Efficiency: High memory usage indicates the need for optimization, particularly for deployment on memory-constrained devices. Proposed Improvement: Implement pruning, quantization, or model distillation techniques to reduce resource requirements.

Labeling Consistency: The model displayed robust performance in simple scenarios but faced challenges with complex sentence structures. Proposed Improvement: Expand the training set to include varied and complex entity structures.

Future Work

Based on our findings, we suggest the following steps for improving the proposed models:

  • Enhance Contextual Training: Incorporate real-world examples with varied sentence complexities

  • Expand Synonym Mapping: Include diverse synonyms and abbreviations across languages and domains

  • Optimize Tokenization Algorithms: Refine preprocessing steps to better handle edge cases

  • Implement Resource Optimization Techniques: Focus on memory and inference time improvements

  • Conduct Robust Evaluations: Test the model on larger datasets and across different languages to ensure scalability and accuracy

About

Reproducibility Study: Rethinking Negative Instances for Generative Named Entity Recognition

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 85.1%
  • Shell 8.8%
  • Jupyter Notebook 6.1%