Skip to content

Commit

Permalink
refactor: Do one prediction per input sequence, easier experimentation (
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgjs authored Jan 25, 2024
1 parent 10294bc commit 3c7e0a0
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 91 deletions.
83 changes: 71 additions & 12 deletions notebooks/generate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,19 @@
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"from autora.doc.runtime.predict_hf import Predictor\n",
"from autora.doc.runtime.prompts import PROMPTS, PromptIds"
"from autora.doc.runtime.predict_hf import Predictor, preprocess_code\n",
"from autora.doc.runtime.prompts import PROMPTS, PromptIds, PromptBuilder, SYS_GUIDES\n",
"from autora.doc.pipelines.main import evaluate_documentation\n",
"from autora.doc.pipelines.main import eval_prompt, load_data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = \"meta-llama/Llama-2-7b-chat-hf\""
]
},
{
Expand All @@ -18,11 +29,16 @@
"metadata": {},
"outputs": [],
"source": [
"# model = \"../../models\" # if model has been previously downloaded via huggingface-cli\n",
"model = \"meta-llama/Llama-2-7b-chat-hf\"\n",
"pred = Predictor(model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test generation for the variable declararion only"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -33,7 +49,8 @@
"iv = Variable(name=\"x\", value_range=(0, 2 * np.pi), allowed_values=np.linspace(0, 2 * np.pi, 30))\n",
"dv = Variable(name=\"y\", type=ValueType.REAL)\n",
"variables = VariableCollection(independent_variables=[iv], dependent_variables=[dv])\n",
"\"\"\""
"\"\"\"\n",
"LABEL = \"The discovery problem is defined by a single independent variable $x \\in [0, 2 \\pi]$ and dependent variable $y$.\""
]
},
{
Expand All @@ -42,18 +59,46 @@
"metadata": {},
"outputs": [],
"source": [
"def test(promptid, code):\n",
"def test(promptid, code, label):\n",
" output = pred.predict(\n",
" PROMPTS[promptid],\n",
" [code],\n",
" do_sample=0,\n",
" max_length=800,\n",
" max_new_tokens=100,\n",
" temperature=0.05,\n",
" top_k=10,\n",
" num_ret_seq=1,\n",
" )[0]\n",
" for i, o in enumerate(output):\n",
" print(f\"{promptid}\\n******* Output {i} ********\\n{o}\\n*************\\n\")"
" )\n",
" bleu, meteor = evaluate_documentation(output, [label])\n",
" for i, o in enumerate(output[0]):\n",
" print(f\"{promptid}\\n******* Output {i} ********. bleu={bleu}, meteor={meteor}\\n{o}\\n*************\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Zero shot test\n",
"test(PromptIds.AUTORA_VARS_ZEROSHOT, TEST_VAR_CODE, LABEL)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# One shot test\n",
"test(PromptIds.AUTORA_VARS_ONESHOT, TEST_VAR_CODE, LABEL)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## One-shot generation for the complete code sample"
]
},
{
Expand All @@ -62,7 +107,13 @@
"metadata": {},
"outputs": [],
"source": [
"test(PromptIds.AUTORA_VARS_ZEROSHOT, TEST_VAR_CODE)"
"data_file = \"../data/autora/data.jsonl\"\n",
"inputs, labels = load_data(data_file)\n",
"# preprocessing removes comments, import statements and empty lines\n",
"inputs = [preprocess_code(i) for i in inputs]\n",
"INSTR = \"Generate high-level, one or two paragraph documentation for the following experiment.\"\n",
"prompt = PromptBuilder(SYS_GUIDES, INSTR).add_example(f\"{inputs[0]}\", labels[0]).build()\n",
"print(prompt)"
]
},
{
Expand All @@ -71,8 +122,16 @@
"metadata": {},
"outputs": [],
"source": [
"test(PromptIds.AUTORA_VARS_ONESHOT, TEST_VAR_CODE)"
"out, bleu, meteor = eval_prompt(data_file, pred, prompt, {\"max_new_tokens\": 800.0})\n",
"print(f\"bleu={bleu}, meteor={meteor}\\n{out[0][0]}\\n*************\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
88 changes: 50 additions & 38 deletions src/autora/doc/pipelines/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import itertools
import logging
from timeit import default_timer as timer
from typing import List, Tuple
from typing import Dict, List, Tuple

import nltk
import torch
Expand All @@ -20,13 +20,13 @@
logger = logging.getLogger(__name__)


def evaluate_documentation(predictions: List[List[str]], references: List[str]) -> Tuple[float, float]:
def evaluate_documentation(predictions: List[str], references: List[str]) -> Tuple[float, float]:
nltk.download("wordnet")

# Tokenize references
tokenized_references = [ref.split() for ref in references]
# Currently there is only 1 prediction for 1 reference, need to avg in future
tokenized_predictions = [pred[0].split() if pred else [] for pred in predictions]
tokenized_predictions = [pred.split() if pred else [] for pred in predictions]

# Calculate BLEU score with smoothing function
# SmoothingFunction().method1 is used to avoid zero scores for n-grams not found in the reference.
Expand Down Expand Up @@ -55,16 +55,13 @@ def eval(
param: List[str] = typer.Option(
[], help="Additional float parameters to pass to the model as name=float pairs"
),
) -> List[List[str]]:
import jsonlines
) -> Tuple[List[str], float, float]:
import mlflow

mlflow.autolog()

param_dict = {pair[0]: float(pair[1]) for pair in [pair.split("=") for pair in param]}
run = mlflow.active_run()
param_dict = {pair[0]: float(pair[1]) for pair in [pair.split("=") for pair in param]}

prompt = PROMPTS[prompt_id]
if run is None:
run = mlflow.start_run()
with run:
Expand All @@ -75,36 +72,51 @@ def eval(
mlflow.log_param("prompt_id", prompt_id)
mlflow.log_param("model_path", model_path)
mlflow.log_param("data_file", data_file)
prompt = PROMPTS[prompt_id]
pred = Predictor(model_path)
return eval_prompt(data_file, pred, prompt, param_dict)


def load_data(data_file: str) -> Tuple[List[str], List[str]]:
import jsonlines

with jsonlines.open(data_file) as reader:
items = [item for item in reader]
inputs = [f"{item['instruction']}" for item in items]
labels = [item["output"] for item in items]
return inputs, labels


def eval_prompt(
data_file: str, pred: Predictor, prompt: str, param_dict: Dict[str, float]
) -> Tuple[List[str], float, float]:
import mlflow

with jsonlines.open(data_file) as reader:
items = [item for item in reader]
inputs = [item["instruction"] for item in items]
labels = [item["output"] for item in items]

pred = Predictor(model_path)
timer_start = timer()
predictions = pred.predict(prompt, inputs, **param_dict)
timer_end = timer()
bleu, meteor = evaluate_documentation(predictions, labels)
pred_time = timer_end - timer_start
mlflow.log_metric("prediction_time/doc", pred_time / (len(inputs)))
for i in range(len(inputs)):
mlflow.log_text(labels[i], f"label_{i}.txt")
mlflow.log_text(inputs[i], f"input_{i}.py")
for j in range(len(predictions[i])):
mlflow.log_text(predictions[i][j], f"prediction_{i}_{j}.txt")
mlflow.log_text("bleu_score is ", str(bleu))
mlflow.log_text("meteor_score is ", str(meteor))

# flatten predictions for counting tokens
predictions_flat = list(itertools.chain.from_iterable(predictions))
tokens = pred.tokenize(predictions_flat)["input_ids"]
total_tokens = sum([len(token) for token in tokens])
mlflow.log_metric("total_tokens", total_tokens)
mlflow.log_metric("tokens/sec", total_tokens / pred_time)
mlflow.log_metric("bleu_score", round(bleu, 5))
mlflow.log_metric("meteor_score", round(meteor, 5))
return predictions
inputs, labels = load_data(data_file)

timer_start = timer()
predictions = pred.predict(prompt, inputs, **param_dict)
timer_end = timer()
bleu, meteor = evaluate_documentation(predictions, labels)
pred_time = timer_end - timer_start
mlflow.log_metric("prediction_time/doc", pred_time / (len(inputs)))
for i in range(len(inputs)):
mlflow.log_text(labels[i], f"label_{i}.txt")
mlflow.log_text(inputs[i], f"input_{i}.py")
for j in range(len(predictions[i])):
mlflow.log_text(predictions[i][j], f"prediction_{i}_{j}.txt")
mlflow.log_text("bleu_score is ", str(bleu))
mlflow.log_text("meteor_score is ", str(meteor))

# flatten predictions for counting tokens
predictions_flat = list(itertools.chain.from_iterable(predictions))
tokens = pred.tokenize(predictions_flat)["input_ids"]
total_tokens = sum([len(token) for token in tokens])
mlflow.log_metric("total_tokens", total_tokens)
mlflow.log_metric("tokens/sec", total_tokens / pred_time)
mlflow.log_metric("bleu_score", round(bleu, 5))
mlflow.log_metric("meteor_score", round(meteor, 5))
return predictions, bleu, meteor


@app.command()
Expand All @@ -126,7 +138,7 @@ def generate(
prompt = PROMPTS[prompt_id]
pred = Predictor(model_path)
# grab first result since we only passed one input
predictions = pred.predict(prompt, [input], **param_dict)[0]
predictions = pred.predict(prompt, [input], **param_dict)
assert len(predictions) == 1, f"Expected only one output, got {len(predictions)}"
logger.info(f"Writing output to {output}")
with open(output, "w") as f:
Expand Down
30 changes: 20 additions & 10 deletions src/autora/doc/runtime/predict_hf.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
import logging
from typing import Dict, List
from typing import Dict, Iterable, List

import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer

from autora.doc.runtime.prompts import LLAMA2_INST_CLOSE
from autora.doc.runtime.prompts import CODE_PLACEHOLDER, LLAMA2_INST_CLOSE

logger = logging.getLogger(__name__)


def preprocess_code(code: str) -> str:
lines: Iterable[str] = code.splitlines()
skip_starts = {"import", "from", "#"}
lines = filter(
lambda line: not (any([line.strip().startswith(skip) for skip in skip_starts]) or line.strip() == ""),
lines,
)
return "\n".join(lines)


class Predictor:
def __init__(self, model_path: str):
config = self.get_config()
Expand All @@ -35,16 +45,18 @@ def predict(
temperature: float = 0.01,
top_p: float = 0.95,
top_k: float = 1,
max_length: float = 2048,
max_new_tokens: float = 2048,
num_ret_seq: float = 1,
) -> List[List[str]]:
) -> List[str]:
# convert to bool in case it came in as a generate float param from the CLI
do_sample = bool(do_sample)
logger.info(
f"Generating {len(inputs)} predictions. do_sample: {do_sample}, temperature: {temperature}, top_p: {top_p},"
f" top_k: {top_k}, max_length: {max_length}"
f" top_k: {top_k}, max_new_tokens: {max_new_tokens}"
)
prompts = [prompt_template.format(code=input) for input in inputs]
prompts = [
prompt_template.replace(CODE_PLACEHOLDER, preprocess_code(input).strip("\n")) for input in inputs
]
sequences = self.pipeline(
prompts,
do_sample=do_sample,
Expand All @@ -53,12 +65,10 @@ def predict(
top_k=int(top_k),
num_return_sequences=int(num_ret_seq),
eos_token_id=self.tokenizer.eos_token_id,
max_length=int(max_length),
max_new_tokens=int(max_new_tokens),
)

results = [
[Predictor.trim_prompt(seq["generated_text"]) for seq in sequence] for sequence in sequences
]
results = [Predictor.trim_prompt(seq["generated_text"]) for sequence in sequences for seq in sequence]
logger.info(f"Generated {len(results)} results")
return results

Expand Down
Loading

0 comments on commit 3c7e0a0

Please sign in to comment.