Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add arguments for model parameters #10

Merged
merged 5 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions README.md
lsetiawan marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# AutoDoc

[![ssec](https://img.shields.io/badge/SSEC-Project-purple?logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAA0AAAAOCAQAAABedl5ZAAAACXBIWXMAAAHKAAABygHMtnUxAAAAGXRFWHRTb2Z0d2FyZQB3d3cuaW5rc2NhcGUub3Jnm+48GgAAAMNJREFUGBltwcEqwwEcAOAfc1F2sNsOTqSlNUopSv5jW1YzHHYY/6YtLa1Jy4mbl3Bz8QIeyKM4fMaUxr4vZnEpjWnmLMSYCysxTcddhF25+EvJia5hhCudULAePyRalvUteXIfBgYxJufRuaKuprKsbDjVUrUj40FNQ11PTzEmrCmrevPhRcVQai8m1PRVvOPZgX2JttWYsGhD3atbHWcyUqX4oqDtJkJiJHUYv+R1JbaNHJmP/+Q1HLu2GbNoSm3Ft0+Y1YMdPSTSwQAAAABJRU5ErkJggg==&style=plastic)](https://escience.washington.edu/software-engineering/ssec/)

[![Template](https://img.shields.io/badge/Template-LINCC%20Frameworks%20Python%20Project%20Template-brightgreen)](https://lincc-ppt.readthedocs.io/en/latest/)

[![PyPI](https://img.shields.io/pypi/v/autora-doc?color=blue&logo=pypi&logoColor=white)](https://pypi.org/project/autora-doc/)
<!-- [![PyPI](https://img.shields.io/pypi/v/autora-doc?color=blue&logo=pypi&logoColor=white)](https://pypi.org/project/autora-doc/) -->


[![GitHub Workflow Status](https://github.com/autoresearch/autodoc/actions/workflows/smoke-test.yml/badge.svg)](https://github.com/AutoResearch/autodoc/actions/workflows/smoke-test.yml)
[![codecov](https://codecov.io/gh/AutoResearch/autodoc/branch/main/graph/badge.svg)](https://codecov.io/gh/AutoResearch/autodoc)
[![Read the Docs](https://img.shields.io/readthedocs/autora-doc)](https://autora-doc.readthedocs.io/)
<!-- [![Read the Docs](https://img.shields.io/readthedocs/autora-doc)](https://autora-doc.readthedocs.io/) -->

This project was automatically generated using the LINCC-Frameworks
[python-project-template](https://github.com/lincc-frameworks/python-project-template). For more information about the project template see the
Expand Down
18 changes: 13 additions & 5 deletions azureml/eval.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ $schema: https://azuremlschemas.azureedge.net/latest/commandJob.schema.json
command: >
python -m autora.doc.pipelines.main eval
${{inputs.data_dir}}/data.jsonl
${{inputs.model_dir}}/llama-2-7b-chat-hf
SYS_1
INSTR_SWEETP_1
--model-path ${{inputs.model_dir}}/llama-2-7b-chat-hf
--sys-id ${{inputs.sys_id}}
--instruc-id ${{inputs.instruc_id}}
--param temperature=${{inputs.temperature}}
--param top_k=${{inputs.top_k}}
--param top_p=${{inputs.top_p}}
code: ../src
inputs:
data_dir:
Expand All @@ -13,6 +16,11 @@ inputs:
model_dir:
type: uri_folder
path: azureml://datastores/workspaceblobstore/paths/base_models
temperature: 0.7
top_p: 0.95
top_k: 40
sys_id: SYS_1
instruc_id: INSTR_SWEETP_1
# using a curated environment doesn't work because we need additional packages
environment: # azureml://registries/azureml/environments/acpt-pytorch-2.0-cuda11.7/versions/21
image: mcr.microsoft.com/azureml/curated/acpt-pytorch-2.0-cuda11.7:21
Expand All @@ -25,6 +33,6 @@ environment: # azureml://registries/azureml/environments/acpt-pytorch-2.0-cuda11
# image: nvcr.io/nvidia/pytorch:23.10-py3
conda_file: conda.yml
display_name: autodoc_prediction
compute: azureml:v100cluster
experiment_name: autodoc_prediction
compute: azureml:t4cluster
experiment_name: evaluation
description: |
14 changes: 12 additions & 2 deletions azureml/generate.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,26 @@ command: >
python -m autora.doc.pipelines.main generate
--model-path ${{inputs.model_dir}}/llama-2-7b-chat-hf
--output ./outputs/output.txt
--sys-id ${{inputs.sys_id}}
--instruc-id ${{inputs.instruc_id}}
--param temperature=${{inputs.temperature}}
--param top_k=${{inputs.top_k}}
--param top_p=${{inputs.top_p}}
autora/doc/pipelines/main.py
code: ../src
inputs:
model_dir:
type: uri_folder
path: azureml://datastores/workspaceblobstore/paths/base_models
temperature: 0.7
top_p: 0.95
top_k: 40
sys_id: SYS_1
instruc_id: INSTR_SWEETP_1
environment:
image: mcr.microsoft.com/azureml/curated/acpt-pytorch-2.0-cuda11.7:21
conda_file: conda.yml
display_name: autodoc_prediction
compute: azureml:v100cluster
experiment_name: autodoc_prediction
compute: azureml:t4cluster
experiment_name: prediction
description: |
126 changes: 126 additions & 0 deletions notebooks/generate.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"from autora.doc.runtime.predict_hf import Predictor\n",
"from autora.doc.runtime.prompts import INSTR, SYS, InstructionPrompts, SystemPrompts"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# model = \"../../models\" # if model has been previously downloaded via huggingface-cli\n",
"model = \"meta-llama/Llama-2-7b-chat-hf\"\n",
"pred = Predictor(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"TEST_CODE = \"\"\"\n",
"from sweetpea import *\n",
"from sweetpea.primitives import *\n",
"\n",
"number_list = [125, 132, 139, 146, 160, 167, 174, 181]\n",
"letter_list = ['b', 'd', 'f', 'h', 's', 'u', 'w', 'y']\n",
"\n",
"number = Factor(\"number\", number_list)\n",
"letter = Factor(\"letter\", letter_list)\n",
"task = Factor(\"task\", [\"number task\", \"letter task\", \"free choice task\"])\n",
"\n",
"\n",
"def is_forced_trial_switch(task):\n",
" return (task[-1] == \"number task\" and task[0] == \"letter task\") or \\\n",
" (task[-1] == \"letter task\" and task[0] == \"number task\")\n",
"\n",
"\n",
"def is_forced_trial_repeat(task):\n",
" return (task[-1] == \"number task\" and task[0] == \"number task\") or \\\n",
" (task[-1] == \"letter task\" and task[0] == \"letter task\")\n",
"\n",
"\n",
"def is_free_trial_transition(task):\n",
" return task[-1] != \"free choice task\" and task[0] == \"free choice task\"\n",
"\n",
"\n",
"def is_free_trial_repeat(task):\n",
" return task[-1] == \"free choice task\" and task[0] == \"free choice task\"\n",
"\n",
"\n",
"def is_not_relevant_transition(task):\n",
" return not (is_forced_trial_repeat(task) or is_forced_trial_switch(task) or is_free_trial_repeat(\n",
" task) or is_free_trial_transition(task))\n",
"\n",
"\n",
"transit = Factor(\"task transition\", [\n",
" DerivedLevel(\"forced switch\", transition(is_forced_trial_switch, [task]), 3),\n",
" DerivedLevel(\"forced repeat\", transition(is_forced_trial_repeat, [task])),\n",
" DerivedLevel(\"free transition\", transition(is_free_trial_transition, [task]), 4),\n",
" DerivedLevel(\"free repeat\", transition(is_free_trial_repeat, [task]), 4),\n",
" DerivedLevel(\"forced first\", transition(is_not_relevant_transition, [task]), 4)\n",
"])\n",
"design = [letter, number, task, transit]\n",
"crossing = [[letter], [number], [transit]]\n",
"constraints = [MinimumTrials(256)]\n",
"\n",
"block = MultiCrossBlock(design, crossing, constraints)\n",
"\n",
"experiment = synthesize_trials(block, 1)\n",
"\n",
"save_experiments_csv(block, experiment, 'code_1_sequences/seq')\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"output = pred.predict(\n",
" SYS[SystemPrompts.SYS_1],\n",
" INSTR[InstructionPrompts.INSTR_SWEETP_EXAMPLE],\n",
" [TEST_CODE],\n",
" temperature=0.05,\n",
" top_k=10,\n",
" num_ret_seq=3,\n",
")[0]\n",
"for i, o in enumerate(output):\n",
" print(f\"******** Output {i} ********\\n{o}*************\\n\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "autodoc",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.18"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
49 changes: 37 additions & 12 deletions src/autora/doc/pipelines/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,24 @@
logger = logging.getLogger(__name__)


@app.command()
def eval(data_file: str, model_path: str, sys_id: SystemPrompts, instruc_id: InstructionPrompts) -> List[str]:
@app.command(help="Evaluate model on a data file")
def eval(
data_file: str = typer.Argument(..., help="JSONL Data file to evaluate on"),
model_path: str = typer.Option("meta-llama/Llama-2-7b-chat-hf", help="Path to HF model"),
lsetiawan marked this conversation as resolved.
Show resolved Hide resolved
sys_id: SystemPrompts = typer.Option(SystemPrompts.SYS_1, help="System prompt ID"),
instruc_id: InstructionPrompts = typer.Option(
InstructionPrompts.INSTR_SWEETP_1, help="Instruction prompt ID"
),
param: List[str] = typer.Option(
[], help="Additional float parameters to pass to the model as name=float pairs"
),
) -> List[List[str]]:
import jsonlines
import mlflow

mlflow.autolog()

param_dict = {pair[0]: float(pair[1]) for pair in [pair.split("=") for pair in param]}
run = mlflow.active_run()

sys_prompt = SYS[sys_id]
Expand All @@ -33,6 +44,7 @@ def eval(data_file: str, model_path: str, sys_id: SystemPrompts, instruc_id: Ins
logger.info(f"Active run_id: {run.info.run_id}")
logger.info(f"running predict with {data_file}")
logger.info(f"model path: {model_path}")
mlflow.log_params(param_dict)

with jsonlines.open(data_file) as reader:
items = [item for item in reader]
Expand All @@ -41,16 +53,19 @@ def eval(data_file: str, model_path: str, sys_id: SystemPrompts, instruc_id: Ins

pred = Predictor(model_path)
timer_start = timer()
predictions = pred.predict(sys_prompt, instr_prompt, inputs)
predictions = pred.predict(sys_prompt, instr_prompt, inputs, **param_dict)
timer_end = timer()
pred_time = timer_end - timer_start
mlflow.log_metric("prediction_time/doc", pred_time / (len(inputs)))
for i in range(len(inputs)):
mlflow.log_text(labels[i], f"label_{i}.txt")
mlflow.log_text(inputs[i], f"input_{i}.py")
mlflow.log_text(predictions[i], f"prediction_{i}.txt")
for j in range(len(predictions[i])):
mlflow.log_text(predictions[i][j], f"prediction_{i}_{j}.txt")

tokens = pred.tokenize(predictions)["input_ids"]
# flatten predictions for counting tokens
predictions_flat = [pred for pred_list in predictions for pred in pred_list]
lsetiawan marked this conversation as resolved.
Show resolved Hide resolved
tokens = pred.tokenize(predictions_flat)["input_ids"]
total_tokens = sum([len(token) for token in tokens])
mlflow.log_metric("total_tokens", total_tokens)
mlflow.log_metric("tokens/sec", total_tokens / pred_time)
Expand All @@ -59,18 +74,28 @@ def eval(data_file: str, model_path: str, sys_id: SystemPrompts, instruc_id: Ins

@app.command()
def generate(
python_file: str,
model_path: str = "meta-llama/llama-2-7b-chat-hf",
output: str = "output.txt",
sys_id: SystemPrompts = SystemPrompts.SYS_1,
instruc_id: InstructionPrompts = InstructionPrompts.INSTR_SWEETP_1,
python_file: str = typer.Argument(..., help="Python file to generate documentation for"),
model_path: str = typer.Option("meta-llama/Llama-2-7b-chat-hf", help="Path to HF model"),
output: str = typer.Option("output.txt", help="Output file"),
sys_id: SystemPrompts = typer.Option(SystemPrompts.SYS_1, help="System prompt ID"),
instruc_id: InstructionPrompts = typer.Option(
InstructionPrompts.INSTR_SWEETP_1, help="Instruction prompt ID"
),
param: List[str] = typer.Option(
[], help="Additional float parameters to pass to the model as name=float pairs"
),
) -> None:
param_dict = {pair[0]: float(pair[1]) for pair in [pair.split("=") for pair in param]}
"""
Generate documentation from python file
"""
with open(python_file, "r") as f:
inputs = [f.read()]
input = f.read()
sys_prompt = SYS[sys_id]
instr_prompt = INSTR[instruc_id]
pred = Predictor(model_path)
predictions = pred.predict(sys_prompt, instr_prompt, inputs)
# grab first result since we only passed one input
predictions = pred.predict(sys_prompt, instr_prompt, [input], **param_dict)[0]
assert len(predictions) == 1, f"Expected only one output, got {len(predictions)}"
logger.info(f"Writing output to {output}")
with open(output, "w") as f:
Expand Down
32 changes: 23 additions & 9 deletions src/autora/doc/runtime/predict_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,36 @@ def __init__(self, model_path: str):
tokenizer=self.tokenizer,
)

def predict(self, sys: str, instr: str, inputs: List[str]) -> List[str]:
logger.info(f"Generating {len(inputs)} predictions")
def predict(
self,
sys: str,
instr: str,
inputs: List[str],
temperature: float = 0.6,
top_p: float = 0.95,
top_k: float = 40,
max_length: float = 2048,
num_ret_seq: float = 1,
) -> List[List[str]]:
logger.info(
f"Generating {len(inputs)} predictions. Temperature: {temperature}, top_p: {top_p}, top_k: {top_k}, "
f"max_length: {max_length}"
)
prompts = [TEMP_LLAMA2.format(sys=sys, instr=instr, input=input) for input in inputs]
# TODO: Make these parameters configurable
sequences = self.pipeline(
prompts,
do_sample=True,
temperature=0.6,
top_p=0.95,
top_k=40,
num_return_sequences=1,
temperature=temperature,
top_p=top_p,
top_k=int(top_k),
num_return_sequences=int(num_ret_seq),
eos_token_id=self.tokenizer.eos_token_id,
max_length=2048,
max_length=int(max_length),
)

results = [Predictor.trim_prompt(sequence[0]["generated_text"]) for sequence in sequences]
results = [
[Predictor.trim_prompt(seq["generated_text"]) for seq in sequence] for sequence in sequences
]
logger.info(f"Generated {len(results)} results")
return results

Expand Down
Loading