(Unofficial) Pytorch implementation of R-BERT
: Enriching Pre-trained Language Model with Entity Information for Relation Classification
- Get three vectors from BERT.
- [CLS] token vector
- averaged entity_1 vector
- averaged entity_2 vector
- Pass each vector to the fully-connected layers.
- dropout -> tanh -> fc-layer
- Concatenate three vectors.
- Pass the concatenated vector to fully-connect layer.
- dropout -> fc-layer
- Exactly the SAME conditions as written in paper.
- Averaging on
entity_1
andentity_2
hidden state vectors, respectively. (including $, # tokens) - Dropout and Tanh before Fully-connected layer.
- No [SEP] token at the end of sequence. (If you want add [SEP] token, give
--add_sep_token
option)
- Averaging on
- perl (For evaluating official f1 score)
- python>=3.6
- torch==1.6.0
- transformers==3.3.1
$ python3 main.py --do_train --do_eval
- Prediction will be written on
proposed_answers.txt
ineval
directory.
$ python3 official_eval.py
# macro-averaged F1 = 88.29%
- Evaluate based on the official evaluation perl script.
- MACRO-averaged f1 score (except
Other
relation)
- MACRO-averaged f1 score (except
- You can see the detailed result on
result.txt
ineval
directory.
$ python3 predict.py --input_file {INPUT_FILE_PATH} --output_file {OUTPUT_FILE_PATH} --model_dir {SAVED_CKPT_PATH}