Skip to content

Commit

Permalink
Surface inference parameters to the CLI and jobs
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgjs committed Dec 8, 2023
1 parent aa1b494 commit e3c004a
Show file tree
Hide file tree
Showing 10 changed files with 152 additions and 78 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# AutoDoc

[![ssec](https://img.shields.io/badge/SSEC-Project-purple?logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAA0AAAAOCAQAAABedl5ZAAAACXBIWXMAAAHKAAABygHMtnUxAAAAGXRFWHRTb2Z0d2FyZQB3d3cuaW5rc2NhcGUub3Jnm+48GgAAAMNJREFUGBltwcEqwwEcAOAfc1F2sNsOTqSlNUopSv5jW1YzHHYY/6YtLa1Jy4mbl3Bz8QIeyKM4fMaUxr4vZnEpjWnmLMSYCysxTcddhF25+EvJia5hhCudULAePyRalvUteXIfBgYxJufRuaKuprKsbDjVUrUj40FNQ11PTzEmrCmrevPhRcVQai8m1PRVvOPZgX2JttWYsGhD3atbHWcyUqX4oqDtJkJiJHUYv+R1JbaNHJmP/+Q1HLu2GbNoSm3Ft0+Y1YMdPSTSwQAAAABJRU5ErkJggg==&style=plastic)](https://escience.washington.edu/software-engineering/ssec/)

[![Template](https://img.shields.io/badge/Template-LINCC%20Frameworks%20Python%20Project%20Template-brightgreen)](https://lincc-ppt.readthedocs.io/en/latest/)

[![PyPI](https://img.shields.io/pypi/v/autora-doc?color=blue&logo=pypi&logoColor=white)](https://pypi.org/project/autora-doc/)
Expand Down
2 changes: 1 addition & 1 deletion azureml/conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ dependencies:
- xformers
- scipy
# This works, while installing from pytorch and cuda from conda does not
- torch==2.0.1
- torch==2.1.0
16 changes: 12 additions & 4 deletions azureml/eval.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ $schema: https://azuremlschemas.azureedge.net/latest/commandJob.schema.json
command: >
python -m autora.doc.pipelines.main eval
${{inputs.data_dir}}/data.jsonl
${{inputs.model_dir}}/llama-2-7b-chat-hf
SYS_1
INSTR_SWEETP_1
--model-path ${{inputs.model_dir}}/llama-2-7b-chat-hf
--sys-id ${{inputs.sys_id}}
--instruc-id ${{inputs.instruc_id}}
--param temperature=${{inputs.temperature}}
--param top_k=${{inputs.top_k}}
--param top_p=${{inputs.top_p}}
code: ../src
inputs:
data_dir:
Expand All @@ -13,6 +16,11 @@ inputs:
model_dir:
type: uri_folder
path: azureml://datastores/workspaceblobstore/paths/base_models
temperature: 0.7
top_p: 0.95
top_k: 40
sys_id: SYS_1
instruc_id: INSTR_SWEETP_1
# using a curated environment doesn't work because we need additional packages
environment: # azureml://registries/azureml/environments/acpt-pytorch-2.0-cuda11.7/versions/21
image: mcr.microsoft.com/azureml/curated/acpt-pytorch-2.0-cuda11.7:21
Expand All @@ -26,5 +34,5 @@ environment: # azureml://registries/azureml/environments/acpt-pytorch-2.0-cuda11
conda_file: conda.yml
display_name: autodoc_prediction
compute: azureml:v100cluster
experiment_name: autodoc_prediction
experiment_name: evaluation
description: |
12 changes: 11 additions & 1 deletion azureml/generate.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,26 @@ command: >
python -m autora.doc.pipelines.main generate
--model-path ${{inputs.model_dir}}/llama-2-7b-chat-hf
--output ./outputs/output.txt
--sys-id ${{inputs.sys_id}}
--instruc-id ${{inputs.instruc_id}}
--param temperature=${{inputs.temperature}}
--param top_k=${{inputs.top_k}}
--param top_p=${{inputs.top_p}}
autora/doc/pipelines/main.py
code: ../src
inputs:
model_dir:
type: uri_folder
path: azureml://datastores/workspaceblobstore/paths/base_models
temperature: 0.7
top_p: 0.95
top_k: 40
sys_id: SYS_1
instruc_id: INSTR_SWEETP_1
environment:
image: mcr.microsoft.com/azureml/curated/acpt-pytorch-2.0-cuda11.7:21
conda_file: conda.yml
display_name: autodoc_prediction
compute: azureml:v100cluster
experiment_name: autodoc_prediction
experiment_name: prediction
description: |
59 changes: 10 additions & 49 deletions notebooks/generate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,42 +29,7 @@
"metadata": {},
"outputs": [],
"source": [
"# The following prompt uses an example (code, doc) to specify the desired behavior\n",
"EX_CODE=\"\"\"\n",
"from sweetpea import *\n",
"\n",
"color = Factor('color', ['red', 'green', 'blue', 'yellow'])\n",
"word = Factor('word', ['red', 'green', 'blue', 'yellow'])\n",
"\n",
"def is_congruent(word, color):\n",
" return (word == color)\n",
"\n",
"def is_not_congruent(word, color):\n",
" return not is_congruent(word, color)\n",
"\n",
"congruent = DerivedLevel('congruent', WithinTrial(is_congruent, [word, color]))\n",
"incongruent = DerivedLevel('incongruent', WithinTrial(is_not_congruent, [word, color]))\n",
"\n",
"congruency = Factor('congruency', [congruent, incongruent])\n",
"\n",
"constraints = [MinimumTrials(48)]\n",
"design = [word, color, congruency]\n",
"crossing = [word, congruency]\n",
"\n",
"block = CrossBlock(design, crossing, constraints)\n",
"\n",
"experiment = synthesize_trials(block, 1)\n",
"\n",
"save_experiments_csv(block, experiment, 'code_1_sequences/seq')\n",
"\"\"\"\n",
"\n",
"EX_DOC=\"\"\"There are two regular factors: color and word. The color factor consists of four levels: \"red\", \"green\", \"blue\", and \"yellow\". \n",
"The word factor also consists of the four levels: \"red\", \"green\", \"blue\", and \"yellow\". There is another derived factor referred to as congruency. \n",
"The congruency factor depends on the regular factors word and color and has two levels: \"congruent\" and \"incongruent\".\n",
"A trial is considered \"congruent\" if the word matches the color, otherwise, it is considered \"incongruent\". We counterbalanced the word factor with the congruency factor. \n",
"All experiment sequences contained at least 48 trials.\"\"\"\n",
"\n",
"TEST_CODE=\"\"\"\n",
"TEST_CODE = \"\"\"\n",
"from sweetpea import *\n",
"from sweetpea.primitives import *\n",
"\n",
Expand Down Expand Up @@ -115,17 +80,6 @@
"experiment = synthesize_trials(block, 1)\n",
"\n",
"save_experiments_csv(block, experiment, 'code_1_sequences/seq')\n",
"\"\"\"\n",
"\n",
"PROMPT=f\"\"\"Consider the following experiment code:\n",
"---\n",
"{EX_CODE}\n",
"---\n",
"Here's a a good English description:\n",
"---\n",
"{EX_DOC}\n",
"---\n",
"Using the same style, please generate a high-level one paragraph description for the following experiment code:\n",
"\"\"\""
]
},
Expand All @@ -135,8 +89,15 @@
"metadata": {},
"outputs": [],
"source": [
"output = pred.predict(SYS[SystemPrompts.SYS_1], PROMPT, [TEST_CODE], temperature=0.05, top_k=10, num_ret_seq=3)[0]\n",
"for i,o in enumerate(output):\n",
"output = pred.predict(\n",
" SYS[SystemPrompts.SYS_1],\n",
" INSTR[InstructionPrompts.INSTR_SWEETP_EXAMPLE],\n",
" [TEST_CODE],\n",
" temperature=0.05,\n",
" top_k=10,\n",
" num_ret_seq=3,\n",
")[0]\n",
"for i, o in enumerate(output):\n",
" print(f\"******** Output {i} ********\\n{o}*************\\n\")"
]
}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dependencies = [
"typer",
"scipy",
# This works, while installing from pytorch and cuda from conda does not",
"torch==2.0.1",
"torch==2.1.0",
"transformers>=4.35.2",
]

Expand Down
49 changes: 37 additions & 12 deletions src/autora/doc/pipelines/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,24 @@
logger = logging.getLogger(__name__)


@app.command()
def eval(data_file: str, model_path: str, sys_id: SystemPrompts, instruc_id: InstructionPrompts) -> List[str]:
@app.command(help="Evaluate model on a data file")
def eval(
data_file: str = typer.Argument(..., help="JSONL Data file to evaluate on"),
model_path: str = typer.Option("meta-llama/Llama-2-7b-chat-hf", help="Path to HF model"),
sys_id: SystemPrompts = typer.Option(SystemPrompts.SYS_1, help="System prompt ID"),
instruc_id: InstructionPrompts = typer.Option(
InstructionPrompts.INSTR_SWEETP_1, help="Instruction prompt ID"
),
param: List[str] = typer.Option(
[], help="Additional float parameters to pass to the model as name=float pairs"
),
) -> List[List[str]]:
import jsonlines
import mlflow

mlflow.autolog()

param_dict = {pair[0]: float(pair[1]) for pair in [pair.split("=") for pair in param]}
run = mlflow.active_run()

sys_prompt = SYS[sys_id]
Expand All @@ -33,6 +44,7 @@ def eval(data_file: str, model_path: str, sys_id: SystemPrompts, instruc_id: Ins
logger.info(f"Active run_id: {run.info.run_id}")
logger.info(f"running predict with {data_file}")
logger.info(f"model path: {model_path}")
mlflow.log_params(param_dict)

with jsonlines.open(data_file) as reader:
items = [item for item in reader]
Expand All @@ -41,16 +53,19 @@ def eval(data_file: str, model_path: str, sys_id: SystemPrompts, instruc_id: Ins

pred = Predictor(model_path)
timer_start = timer()
predictions = pred.predict(sys_prompt, instr_prompt, inputs)
predictions = pred.predict(sys_prompt, instr_prompt, inputs, **param_dict)
timer_end = timer()
pred_time = timer_end - timer_start
mlflow.log_metric("prediction_time/doc", pred_time / (len(inputs)))
for i in range(len(inputs)):
mlflow.log_text(labels[i], f"label_{i}.txt")
mlflow.log_text(inputs[i], f"input_{i}.py")
mlflow.log_text(predictions[i], f"prediction_{i}.txt")
for j in range(len(predictions[i])):
mlflow.log_text(predictions[i][j], f"prediction_{i}_{j}.txt")

tokens = pred.tokenize(predictions)["input_ids"]
# flatten predictions for counting tokens
predictions_flat = [pred for pred_list in predictions for pred in pred_list]
tokens = pred.tokenize(predictions_flat)["input_ids"]
total_tokens = sum([len(token) for token in tokens])
mlflow.log_metric("total_tokens", total_tokens)
mlflow.log_metric("tokens/sec", total_tokens / pred_time)
Expand All @@ -59,18 +74,28 @@ def eval(data_file: str, model_path: str, sys_id: SystemPrompts, instruc_id: Ins

@app.command()
def generate(
python_file: str,
model_path: str = "meta-llama/llama-2-7b-chat-hf",
output: str = "output.txt",
sys_id: SystemPrompts = SystemPrompts.SYS_1,
instruc_id: InstructionPrompts = InstructionPrompts.INSTR_SWEETP_1,
python_file: str = typer.Argument(..., help="Python file to generate documentation for"),
model_path: str = typer.Option("meta-llama/Llama-2-7b-chat-hf", help="Path to HF model"),
output: str = typer.Option("output.txt", help="Output file"),
sys_id: SystemPrompts = typer.Option(SystemPrompts.SYS_1, help="System prompt ID"),
instruc_id: InstructionPrompts = typer.Option(
InstructionPrompts.INSTR_SWEETP_1, help="Instruction prompt ID"
),
param: List[str] = typer.Option(
[], help="Additional float parameters to pass to the model as name=float pairs"
),
) -> None:
param_dict = {pair[0]: float(pair[1]) for pair in [pair.split("=") for pair in param]}
"""
Generate documentation from python file
"""
with open(python_file, "r") as f:
inputs = [f.read()]
input = f.read()
sys_prompt = SYS[sys_id]
instr_prompt = INSTR[instruc_id]
pred = Predictor(model_path)
predictions = pred.predict(sys_prompt, instr_prompt, inputs)
# grab first result since we only passed one input
predictions = pred.predict(sys_prompt, instr_prompt, [input], **param_dict)[0]
assert len(predictions) == 1, f"Expected only one output, got {len(predictions)}"
logger.info(f"Writing output to {output}")
with open(output, "w") as f:
Expand Down
27 changes: 21 additions & 6 deletions src/autora/doc/runtime/predict_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,36 @@ def __init__(self, model_path: str):
tokenizer=self.tokenizer,
)

def predict(self, sys: str, instr: str, inputs: List[str], temperature=0.6, top_p=0.95, top_k=40, max_length=2048, num_ret_seq=1) -> List[List[str]]:
logger.info(f"Generating {len(inputs)} predictions")
def predict(
self,
sys: str,
instr: str,
inputs: List[str],
temperature: float = 0.6,
top_p: float = 0.95,
top_k: float = 40,
max_length: float = 2048,
num_ret_seq: float = 1,
) -> List[List[str]]:
logger.info(
f"Generating {len(inputs)} predictions. Temperature: {temperature}, top_p: {top_p}, top_k: {top_k}, "
f"max_length: {max_length}"
)
prompts = [TEMP_LLAMA2.format(sys=sys, instr=instr, input=input) for input in inputs]
sequences = self.pipeline(
prompts,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_return_sequences=num_ret_seq,
top_k=int(top_k),
num_return_sequences=int(num_ret_seq),
eos_token_id=self.tokenizer.eos_token_id,
max_length=max_length,
max_length=int(max_length),
)

results = [[Predictor.trim_prompt(seq["generated_text"]) for seq in sequence] for sequence in sequences]
results = [
[Predictor.trim_prompt(seq["generated_text"]) for seq in sequence] for sequence in sequences
]
logger.info(f"Generated {len(results)} results")
return results

Expand Down
53 changes: 52 additions & 1 deletion src/autora/doc/runtime/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,65 @@
INSTR_SWEETP_1 = """Please generate high-level two paragraph documentation for the following experiment. The first
paragraph should explain the purpose and the second one the procedure, but don't use the word 'Paragraph'"""

# The following prompt uses an example (code, doc) to specify the desired behavior
EX_CODE = """
from sweetpea import *
color = Factor('color', ['red', 'green', 'blue', 'yellow'])
word = Factor('word', ['red', 'green', 'blue', 'yellow'])
def is_congruent(word, color):
return (word == color)
def is_not_congruent(word, color):
return not is_congruent(word, color)
congruent = DerivedLevel('congruent', WithinTrial(is_congruent, [word, color]))
incongruent = DerivedLevel('incongruent', WithinTrial(is_not_congruent, [word, color]))
congruency = Factor('congruency', [congruent, incongruent])
constraints = [MinimumTrials(48)]
design = [word, color, congruency]
crossing = [word, congruency]
block = CrossBlock(design, crossing, constraints)
experiment = synthesize_trials(block, 1)
save_experiments_csv(block, experiment, 'code_1_sequences/seq')
"""

EX_DOC = """There are two regular factors: color and word. The color factor consists of four levels: "red", "green",
"blue", and "yellow". The word factor also consists of the four levels: "red", "green", "blue", and "yellow".
There is another derived factor referred to as congruency. The congruency factor depends on the regular factors word
and color and has two levels: "congruent" and "incongruent". A trial is considered "congruent" if the word matches
the color, otherwise, it is considered "incongruent". We counterbalanced the word factor with the congruency factor.
All experiment sequences contained at least 48 trials."""

INSTR_SWEETP_EXAMPLE = f"""Consider the following experiment code:
---
{EX_CODE}
---
Here's a a good English description:
---
{EX_DOC}
---
Using the same style, please generate a high-level one paragraph description for the following experiment code:
"""


class SystemPrompts(str, Enum):
SYS_1 = "SYS_1"


class InstructionPrompts(str, Enum):
INSTR_SWEETP_1 = "INSTR_SWEETP_1"
INSTR_SWEETP_EXAMPLE = "INSTR_SWEETP_EXAMPLE"


SYS = {SystemPrompts.SYS_1: SYS_1}
INSTR = {InstructionPrompts.INSTR_SWEETP_1: INSTR_SWEETP_1}
INSTR = {
InstructionPrompts.INSTR_SWEETP_1: INSTR_SWEETP_1,
InstructionPrompts.INSTR_SWEETP_EXAMPLE: INSTR_SWEETP_EXAMPLE,
}
8 changes: 5 additions & 3 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@

def test_predict() -> None:
data = Path(__file__).parent.joinpath("../data/data.jsonl").resolve()
outputs = eval(str(data), TEST_HF_MODEL, SystemPrompts.SYS_1, InstructionPrompts.INSTR_SWEETP_1)
outputs = eval(str(data), TEST_HF_MODEL, SystemPrompts.SYS_1, InstructionPrompts.INSTR_SWEETP_1, [])
assert len(outputs) == 3, "Expected 3 outputs"
for output in outputs:
assert len(output) > 0, "Expected non-empty output"
assert len(output[0]) > 0, "Expected non-empty output"


def test_generate() -> None:
python_file = __file__
output = Path("output.txt")
output.unlink(missing_ok=True)
generate(python_file, TEST_HF_MODEL, str(output), SystemPrompts.SYS_1, InstructionPrompts.INSTR_SWEETP_1)
generate(
python_file, TEST_HF_MODEL, str(output), SystemPrompts.SYS_1, InstructionPrompts.INSTR_SWEETP_1, []
)
assert output.exists(), f"Expected output file {output} to exist"
with open(str(output), "r") as f:
assert len(f.read()) > 0, f"Expected non-empty output file {output}"

0 comments on commit e3c004a

Please sign in to comment.