Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor eval code, create eval cli #55

Merged
merged 5 commits into from
Sep 24, 2023
Merged

refactor eval code, create eval cli #55

merged 5 commits into from
Sep 24, 2023

Conversation

Ben-Epstein
Copy link
Contributor

there was massive overlap between the eval rag and eval retriever scripts.

I broke those out into functions in eval/utils.py so they could be reused.

Cleaned up the functions a bit, and then built the eval CLI

Tested with the following (note I had to comment out this line to work with gpt2

# install the dalm repo
pip install -e .

# train rag e2e
dalm train-rag-e2e \
"./dalm/datasets/toy_data_train.csv" \
"BAAI/bge-small-en" \
"gpt2" \
--output-dir "rag_e2e_checkpoints" \
--per-device-train-batch-size 32

# eval retriever
dalm eval-retriever "./dalm/datasets/toy_data_train.csv" \
 --retriever-name-or-path "BAAI/bge-small-en" \
 --retriever-peft-model-path "rag_e2e_checkpoints/retriever" \
 --embed-dim 384

###
Construct passage index
Evaluation start
Retriever results:
Recall: 0.10000000000000003
Precision: 1.0
Hit Rate: 1.0
*************
###

# eval rag
dalm eval-rag "./dalm/datasets/toy_data_train.csv"  \
 --retriever-name-or-path "BAAI/bge-small-en" \
 --generator-name-or-path "gpt2" \
 --retriever-peft-model-path rag_e2e_checkpoints/retriever \
 --generator-peft-model-path rag_e2e_checkpoints/generator \
 --query-batch-size 5 \
 --embed-dim 384

###
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Retriever results:
Recall: 0.10000000000000003
Precision: 1.0
Hit Rate: 1.0
*************
Generator evaluation:
Exact match: 0.0
###

unique_passage_dataset, passage_embeddings_array = get_passage_embeddings(
processed_datasets,
passage_column_name,
rag_model.retrieval_forward,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this okay to pass a function? I am ok with it. Is there a better way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's pretty normal to pass callables typing.Callable around :)

@shamanez
Copy link
Member

There are major structural changes. I am ok with the logic. But please run the eval with the given dataset and confirm we get the same results.

@Ben-Epstein Ben-Epstein merged commit 2245860 into main Sep 24, 2023
1 check passed
@Ben-Epstein Ben-Epstein deleted the feat/eval-cli branch September 24, 2023 12:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants