Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rui/humaneval #51

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,10 @@ debug-tmp/
wandb/
results/
.vscode/

# defined by rui
DS_1000/
.DS_Store
NLP4Code_humaneval_outputs
NLP4Code_ds1000_outputs
raw_output_evaluation.py
60 changes: 5 additions & 55 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,61 +1,11 @@
# NLP4Code
Repository for the NLP4Code project at the LILY lab.
Installing Human Eval

## Installation
*[Recommended]* Create a virtualenv or conda enviroment
```bash
conda create -n nlp4code python=3.8
conda activate nlp4code
```
Then, install the dependencies:
```bash
pip install -r requirements.txt
```
*(Optional)* At any point, if you met with the Python import problem (e.g., `ModuleNotFoundError`), try doing this in the main (`NLP4Code`) directory:
```bash
export PYTHONPATH=`pwd`
```

## Wandb
We use Wandb for experiment tracking. Please register ask Ansong for an invitation to the Wandb Yale-LILY team before
running experiments. When you are ready to run the exps and log it to the cloud, do the following:
```
wandb login
```
Paste your API key and the login is complete. When start running experiments, you should see something like
```
wandb: Tracking run with wandb version 0.12.11
wandb: Run data is saved locally in /home/ansongni/Code/NLP4Code/wandb/run-20220309_150158-1ebacxm4
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run mathqa-gpt-finetuning
wandb: ⭐️ View project at https://wandb.ai/yale-lily/unified-codegen
wandb: 🚀 View run at https://wandb.ai/yale-lily/unified-codegen/runs/1ebacxm4
git clone https://github.com/openai/human-eval
pip install -e human-eval
```
Creating JSONL files

If you want to do some test runs without logging to the cloud, run `wandb offline` first as suggested above.

## Naming of the experiments
In the $*.yaml$ configuration file, you should see a line like
```
default_root_dir: &exp_name results/mathqa-gpt_neo_1.3B-finetuning
```
We automatically get the experiment name by the string after `/`, the tags for the experiments are automatically
generated by spliting that string by `-`. In this case, the experiment will be named `mathqa-gpt_neo_1.3B-finetuning`
and the tags will be `["mathqa", "gpt_neo_1.3B", "finetuning"]`. Please follow this convention so that we can write all
of this in one place.

## Fine-tuning
(Read the previous sections first if you are ready to run experiments)
For fine-tuning, in the main directory, do:
```
python finetuning/trainer.py fit --config finetuning/training_configs/*.yaml
```

## Testing
There are some basic tests in the `tests` folder, to run all the tests (follow [this link](https://docs.python.org/3/library/unittest.html#command-line-interface) for more):
To run tests, do
```bash
python -m unittest discover <test_directory>
# or
python -m unittest discover -s <directory> -p '*_test.py'
python preprocess/preprocess_humaneval.py
```
74 changes: 71 additions & 3 deletions execution/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from execution.safe_execution_util import execute
from execution.program_tracing import get_function_final_state

from human_eval.execution import check_correctness


"""
From the models' perspective, the model would only want two things:
1) if the execution result is right;
Expand Down Expand Up @@ -168,7 +171,11 @@ def gold_program_len(self, example: Dict[str, Any]) -> int:

@overrides
def process_output(self, output: str, tokenizer_eos_token: str) -> str:
return output.lstrip().split(tokenizer_eos_token)[0].split("\n\n")[0].split(";")[0].strip()
if not tokenizer_eos_token:
# for llama-based model
return output.lstrip().split("\n\n")[0].split(";")[0].strip()
else:
return output.lstrip().split(tokenizer_eos_token)[0].split("\n\n")[0].split(";")[0].strip()

@overrides
def exec_result_eq(self, program_dict_1: Dict[str, Any], program_dict_2: Dict[str, Any]) -> bool:
Expand Down Expand Up @@ -314,7 +321,11 @@ def gold_program_len(self, example: Dict[str, Any]) -> int:

@overrides
def process_output(self, output: str, tokenizer_eos_token: str) -> str:
return output.lstrip().split(tokenizer_eos_token)[0].split("\n\n")[0].strip()
if not tokenizer_eos_token:
# for llama-based model
return output.lstrip().split("\n\n")[0].strip()
else:
return output.lstrip().split(tokenizer_eos_token)[0].split("\n\n")[0].strip()

@overrides
def exec_result_eq(self, program_dict_1: Dict[str, Any], program_dict_2: Dict[str, Any]) -> bool:
Expand Down Expand Up @@ -355,4 +366,61 @@ def real_exec_program(cls, program: str, example: Dict[str, Any]) -> Tuple[int,
executed_answer = "ERROR: program failed to execute"
exec_match = -1

return exec_match, executed_answer
return exec_match, executed_answer


class HumanEvalExecutor(BaseExecutor):
def __init__(self, **kwargs):
super().__init__(**kwargs)

@overrides
def cache_key_func(self, program: str, example: Dict[str, Any]) -> str:
return example["prompt"] + " | " + program

@overrides
def program_len(self, program: str) -> int:
return python_program_len(program)

@overrides
def gold_program_len(self, example: Dict[str, Any]) -> int:
return self.program_len(example["canonical_solution"])

# TODO: modify this later based on generated programs
@overrides
def process_output(self, output: str, tokenizer_eos_token: str) -> str:
stop_sequence = [ '\nclass', '\ndef', '\n#', '\nif', '\nprint']
min_index = len(output) # Initialize with a large value
for substring in stop_sequence:
index = output.find(substring)
if index != -1 and index < min_index:
min_index = index

if min_index < len(output):
processed_output = output[:min_index]
else:
processed_output = output

# for llama, gpt4_alpaca_lora, alpaca_lora_7b, the model output may be missing a space
if processed_output.startswith(" "):
processed_output = " " + processed_output

return processed_output

@overrides
def exec_result_eq(self, program_dict_1: Dict[str, Any], program_dict_2: Dict[str, Any]) -> bool:
return (program_dict_1['exec_result'] and (program_dict_1['exec_result'] == program_dict_2['exec_result']))

@classmethod
def real_exec_program(cls, program: str, example: Dict[str, Any]) -> Tuple[int, Union[str, List, Dict]]:
eval_dict = example
metadata = eval_dict.pop('metadata')
eval_dict.update(metadata)

result_dict = check_correctness(eval_dict, program, timeout=5)
exec_match = result_dict['passed']
exec_result = result_dict['result']

if exec_match < 1 and exec_result.strip() != "failed:":
exec_match = -1

return exec_match, exec_result
1 change: 1 addition & 0 deletions finetuning/lightning_modules/datasets/base_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from finetuning.lightning_modules.datasets.spider_reader import FewShotSpiderDataset, SpiderDataset
from finetuning.lightning_modules.datasets.mathqa_reader import FewShotMathQADataset, MathQADataset
from finetuning.lightning_modules.datasets.mbpp_reader import FewShotMBPPDataset
from finetuning.lightning_modules.datasets.humaneval_reader import FewShotHumanEvalDataset

from finetuning.lightning_modules.models.seq2seq_model_util import is_model_gpt_style
from finetuning.lightning_modules.models.seq2seq_model_util import left_pad_sequences, right_pad_sequences
Expand Down
35 changes: 35 additions & 0 deletions finetuning/lightning_modules/datasets/humaneval_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import re
import os
import pandas as pd

from overrides import overrides

from typing import Dict, Iterable, List, Any, Optional, Union, Tuple

from finetuning.lightning_modules.datasets.base_reader import NL2CodeDataset, FewShotNL2CodeDataset
from execution.program_tracing import assertion_to_test

from human_eval.data import write_jsonl, read_problems


class FewShotHumanEvalDataset(FewShotNL2CodeDataset):

instruction: str = ""
example_io_sep: str = "\n"

@overrides
def get_test_instance(self, example: Dict[str, Any]) -> List[Dict[str, Any]]:
context = self.get_prompt_for_example(example)

return [self.get_example_dict(example, context, train_mode=False)]

# @overrides
def promptify_example(self, example: Dict[str, Any], add_code: bool = True,
add_assertion_n: int = 0, test_input_only: bool = False) -> Tuple[str, str]:

header = example["prompt"]

if add_code:
return header, f'{example["canonical_solution"]}\n\n'
else:
return header, ''
8 changes: 4 additions & 4 deletions finetuning/lightning_modules/models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ def prune_none_args(**kwargs):
temperature=temperature, top_p=top_p, n=n, best_of=best_of,
stop=stop, **kwargs)

if engine.startswith("gpt-3.5-turbo"):
if engine.startswith("gpt-3.5-turbo") or engine.startswith("gpt-4"):
non_none_args.pop("prompt")
non_none_args.pop("engine")
assert len(prompts) == 1, "gpt-3.5-turbo only supports one prompt at a time"
assert len(prompts) == 1, "gpt-3.5-turbo or gpt-4 only supports one prompt at a time"
if use_chat_format:
non_none_args["messages"] = prompt_to_chatgpt_format(prompts[0])
else:
Expand Down Expand Up @@ -115,7 +115,7 @@ def prune_none_args(**kwargs):
time.sleep(60 * 5)

# get the text from the returned results and slice the completions to input_n * completion_n
if engine.startswith("gpt-3.5-turbo"):
if engine.startswith("gpt-3.5-turbo") or engine.startswith("gpt-4"):
completion_texts = [x['message']['content'] for x in completion.choices]
else:
completion_texts = [x.text for x in completion.choices]
Expand All @@ -141,7 +141,7 @@ def __init__(self,
) -> None:
SUPPORTED_OPENAI_MODELS = ["code-davinci-002", "code-cushman-002",
"code-cushman-001", "code-davinci-001",
"gpt-3.5-turbo"]
"gpt-3.5-turbo", "text-davinci-003", "text-davinci-002","gpt-4"]
assert engine in SUPPORTED_OPENAI_MODELS, f"OpenAIModel only supports {SUPPORTED_OPENAI_MODELS}"

self.engine = engine
Expand Down
9 changes: 7 additions & 2 deletions finetuning/lightning_modules/models/seq2seq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def __init__(self,
# We only instantiate this when we need it.
self.transformer_model_name = transformer_model_name
if "openai" in self.transformer_model_name:
if self.transformer_model_name.startswith("openai/gpt-3.5-turbo"):
if self.transformer_model_name.startswith("openai/gpt-3.5-turbo") or self.transformer_model_name.startswith("openai/gpt-4"):
if self.save_raw_generation_results:
print("get_raw_generation_results is not supported for gpt-3.5-turbo, set to False instead")
print("get_raw_generation_results is not supported for gpt-3.5-turbo and gpt-4, set to False instead")
self.save_raw_generation_results = False
transformer_model_init_args["save_raw_generation_results"] = self.save_raw_generation_results
transformer_model_init_args["use_chat_format"] = self.use_chat_format
Expand Down Expand Up @@ -155,6 +155,10 @@ def generate_and_post_process(self,
num_beam = 1
temp = temperature

# https://github.com/THUDM/ChatGLM-6B/issues/31
if "santacoder" in self.transformer_model_name or "gpt-neox-20b" in self.transformer_model_name or "replit" in self.transformer_model_name:
use_sample = False

generation_results = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, do_sample=use_sample,
max_new_tokens=self.max_gen_len, num_beams=num_beam,
temperature=temp, num_return_sequences=num_return_sequences,
Expand Down Expand Up @@ -387,6 +391,7 @@ def validation_epoch_end(self, outputs: List[Dict[str, Any]]) -> None:
# save the predictions
save_pred_file_path = os.path.join(self.trainer.log_dir,
f'predictions_step_{self.trainer.global_step}_rank_{self.trainer.global_rank}.jsonl')
os.makedirs(os.path.dirname(save_pred_file_path), exist_ok=True)
with open(save_pred_file_path, 'w+') as f:
for prediction in self.predictions:
f.write(json.dumps(prediction)+'\n')
Expand Down
63 changes: 60 additions & 3 deletions finetuning/lightning_modules/models/seq2seq_model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from transformers import CodeGenTokenizer, CodeGenForCausalLM, T5Tokenizer
from transformers import BartTokenizer, BartModel, BartForConditionalGeneration
from transformers import DebertaV2Tokenizer, DebertaV2ForSequenceClassification
from transformers import LlamaTokenizer, LlamaForCausalLM

from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerFast

Expand Down Expand Up @@ -55,7 +56,7 @@ def get_model(model_name: str,
model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id)
if len(additional_special_tokens) > 0:
model.resize_token_embeddings(len(tokenizer))
elif model_name == "EleutherAI/gpt-j-6B":
elif model_name == "EleutherAI/gpt-j-6b":
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

Expand All @@ -64,6 +65,14 @@ def get_model(model_name: str,
gradient_checkpointing=gradient_ckpt, use_cache=not gradient_ckpt)
if len(additional_special_tokens) > 0:
model.resize_token_embeddings(len(tokenizer))
elif model_name in ["EleutherAI/gpt-neox-20b", "EleutherAI/pythia-1.4b-deduped", "EleutherAI/pythia-6.9b-deduped", "EleutherAI/pythia-12b-deduped", "databricks/dolly-v2-7b", "databricks/dolly-v2-12b"]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

if not tokenizer_only:
model = AutoModelForCausalLM.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id)
if len(additional_special_tokens) > 0:
model.resize_token_embeddings(len(tokenizer))
elif model_name in ["EleutherAI/gpt-neo-1.3B", "EleutherAI/gpt-neo-125M", "EleutherAI/gpt-neo-2.7B"]:
tokenizer = GPT2Tokenizer.from_pretrained(model_name, additional_special_tokens=additional_special_tokens)
tokenizer.pad_token = tokenizer.eos_token
Expand Down Expand Up @@ -136,15 +145,63 @@ def get_model(model_name: str,

if not tokenizer_only:
model = BartForSequenceClassification.from_pretrained(model_name, num_labels=2)

elif "llama" in model_name.lower() or "alpaca" in model_name.lower():
tokenizer = LlamaTokenizer.from_pretrained(model_name,
additional_special_tokens=additional_special_tokens)
tokenizer.pad_token = tokenizer.eos_token

if not tokenizer_only:
model = LlamaForCausalLM.from_pretrained(model_name,
pad_token_id=tokenizer.eos_token_id,
torch_dtype=torch.float16)
if len(additional_special_tokens) > 0:
model.resize_token_embeddings(len(tokenizer))
elif model_name == "bigcode/santacoder":
tokenizer = AutoTokenizer.from_pretrained(model_name,
additional_special_tokens=additional_special_tokens)
tokenizer.pad_token = tokenizer.eos_token

if not tokenizer_only:
model = AutoModelForCausalLM.from_pretrained(model_name,
pad_token_id=tokenizer.eos_token_id,
torch_dtype=torch.float32,
trust_remote_code=True,
)
if len(additional_special_tokens) > 0:
model.resize_token_embeddings(len(tokenizer))
elif model_name in ["bigcode/starcoder", "HuggingFaceH4/starchat-alpha"]:
tokenizer = AutoTokenizer.from_pretrained(model_name,
additional_special_tokens=additional_special_tokens)
tokenizer.pad_token = tokenizer.eos_token

if not tokenizer_only:
model = AutoModelForCausalLM.from_pretrained(model_name,
pad_token_id=tokenizer.eos_token_id,
torch_dtype=torch.float16,
trust_remote_code=True)
if len(additional_special_tokens) > 0:
model.resize_token_embeddings(len(tokenizer))
elif model_name == "replit/replit-code-v1-3b":
tokenizer = AutoTokenizer.from_pretrained(model_name,
additional_special_tokens=additional_special_tokens,
trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

if not tokenizer_only:
model = AutoModelForCausalLM.from_pretrained(model_name,
pad_token_id=tokenizer.eos_token_id,
torch_dtype=torch.float16,
trust_remote_code=True)
if len(additional_special_tokens) > 0:
model.resize_token_embeddings(len(tokenizer))
elif model_name.startswith("openai/"):
engine = model_name.split("/")[-1]

tokenizer: GPT2TokenizerFast = GPT2TokenizerFast.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# to accomandate the length of openai models and the prompt
if engine in ["code-davinci-002"]:
if engine in ["code-davinci-002", "gpt-4"]:
model_length = 8001
elif engine in ["code-cushman-001", "code-cushman-002"]:
model_length = 1024
Expand Down
Loading