Skip to content

Commit

Permalink
ref: Refactoring to allow for one-shot prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgjs committed Jan 19, 2024
1 parent 398c393 commit a884b82
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 124 deletions.
8 changes: 3 additions & 5 deletions azureml/eval.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ command: >
python -m autora.doc.pipelines.main eval
${{inputs.data_dir}}/data.jsonl
--model-path ${{inputs.model_path}}
--sys-id ${{inputs.sys_id}}
--instruc-id ${{inputs.instruc_id}}
--prompt-id ${{inputs.prompt_id}}
--param do_sample=${{inputs.do_sample}}
--param temperature=${{inputs.temperature}}
--param top_k=${{inputs.top_k}}
Expand All @@ -23,8 +22,7 @@ inputs:
do_sample: 0
top_p: 0.95
top_k: 1
sys_id: SYS_1
instruc_id: INSTR_SWEETP_1
prompt_id: SWEETP_1
# using a curated environment doesn't work because we need additional packages
environment: # azureml://registries/azureml/environments/acpt-pytorch-2.0-cuda11.7/versions/21
image: mcr.microsoft.com/azureml/curated/acpt-pytorch-2.0-cuda11.7:21
Expand All @@ -37,6 +35,6 @@ environment: # azureml://registries/azureml/environments/acpt-pytorch-2.0-cuda11
# image: nvcr.io/nvidia/pytorch:23.10-py3
conda_file: conda.yml
display_name: autodoc_prediction
compute: azureml:t4cluster
compute: azureml:v100cluster
experiment_name: evaluation
description: |
8 changes: 3 additions & 5 deletions azureml/generate.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@ command: >
python -m autora.doc.pipelines.main generate
--model-path ${{inputs.model_path}}
--output ./outputs/output.txt
--sys-id ${{inputs.sys_id}}
--instruc-id ${{inputs.instruc_id}}
--param do_sample=${{inputs.do_sample}}
--prompt-id ${{inputs.prompt_id}}
--param temperature=${{inputs.temperature}}
--param top_k=${{inputs.top_k}}
--param top_p=${{inputs.top_p}}
Expand All @@ -21,12 +20,11 @@ inputs:
do_sample: 0
top_p: 0.95
top_k: 40
sys_id: SYS_1
instruc_id: INSTR_SWEETP_1
prompt_id: SWEETP_1
environment:
image: mcr.microsoft.com/azureml/curated/acpt-pytorch-2.0-cuda11.7:21
conda_file: conda.yml
display_name: autodoc_prediction
compute: azureml:t4cluster
compute: azureml:v100cluster
experiment_name: prediction
description: |
34 changes: 34 additions & 0 deletions data/autora/code1_sm.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
iv = Variable(name="x", value_range=(0, 2 * np.pi), allowed_values=np.linspace(0, 2 * np.pi, 30))
dv = Variable(name="y", type=ValueType.REAL)
variables = VariableCollection(independent_variables=[iv], dependent_variables=[dv])

conditions = random_pool(variables, num_samples=10, random_state=0)

experimentalist = on_state(random_pool, output=["conditions"])

sin_experiment = equation_experiment(
sp.simplify("sin(x)"), variables.independent_variables, variables.dependent_variables[0]
)
sin_runner = sin_experiment.experiment_runner

experiment_runner = on_state(sin_runner, output=["experiment_data"])

theorist = estimator_on_state(BMSRegressor(epochs=100))

s = StandardState(
variables=variables, conditions=conditions, experiment_data=pd.DataFrame(columns=["x", "y"])
)

print("Pre-Defined State:")
print(f"Number of datapoints collected: {len(s['experiment_data'])}")
print(f"Derived models: {s['models']}")
print("\n")

for i in range(5):
s = experimentalist(s, num_samples=10, random_state=42)
s = experiment_runner(s, added_noise=1.0, random_state=42)
s = theorist(s)
print(f"\nCycle {i+1} Results:")
print(f"Number of datapoints collected: {len(s['experiment_data'])}")
print(f"Derived models: {s['models']}")
print("\n")
41 changes: 18 additions & 23 deletions notebooks/generate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"%load_ext autoreload\n",
"%autoreload 2\n",
"from autora.doc.runtime.predict_hf import Predictor\n",
"from autora.doc.runtime.prompts import INSTR, SYS, InstructionPrompts, SystemPrompts"
"from autora.doc.runtime.prompts import PROMPTS, PromptIds"
]
},
{
Expand All @@ -29,12 +29,7 @@
"metadata": {},
"outputs": [],
"source": [
"AUTORA_VARS = (\n",
" \"Generate a brief description of the dependent and independent variables used in the experiment based on the following code. \"\n",
" \"Use only one line per variable and do not include code or code-like syntax in your description. Use LaTeX to format mathematical expressions. \"\n",
")\n",
"\n",
"VAR_CODE = \"\"\"\n",
"TEST_VAR_CODE = \"\"\"\n",
"iv = Variable(name=\"x\", value_range=(0, 2 * np.pi), allowed_values=np.linspace(0, 2 * np.pi, 30))\n",
"dv = Variable(name=\"y\", type=ValueType.REAL)\n",
"variables = VariableCollection(independent_variables=[iv], dependent_variables=[dv])\n",
Expand All @@ -47,20 +42,18 @@
"metadata": {},
"outputs": [],
"source": [
"output = pred.predict(\n",
" SYS[SystemPrompts.SYS_1],\n",
" # INSTR[InstructionPrompts.INSTR_SWEETP_EXAMPLE],\n",
" AUTORA_VARS,\n",
" # [TEST_CODE],\n",
" [VAR_CODE],\n",
" do_sample=0,\n",
" max_length=500,\n",
" temperature=0.05,\n",
" top_k=10,\n",
" num_ret_seq=1,\n",
")[0]\n",
"for i, o in enumerate(output):\n",
" print(f\"******** Output {i} ********\\n{o}*************\\n\")"
"def test(promptid, code):\n",
" output = pred.predict(\n",
" PROMPTS[promptid],\n",
" [code],\n",
" do_sample=0,\n",
" max_length=800,\n",
" temperature=0.05,\n",
" top_k=10,\n",
" num_ret_seq=1,\n",
" )[0]\n",
" for i, o in enumerate(output):\n",
" print(f\"{promptid}\\n******* Output {i} ********\\n{o}\\n*************\\n\")"
]
},
{
Expand All @@ -69,15 +62,17 @@
"metadata": {},
"outputs": [],
"source": [
"AUTORA_VARS"
"test(PromptIds.AUTORA_VARS_ZEROSHOT, TEST_VAR_CODE)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"test(PromptIds.AUTORA_VARS_ONESHOT, TEST_VAR_CODE)"
]
}
],
"metadata": {
Expand Down
22 changes: 7 additions & 15 deletions src/autora/doc/pipelines/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import typer

from autora.doc.runtime.predict_hf import Predictor
from autora.doc.runtime.prompts import INSTR, SYS, InstructionPrompts, SystemPrompts
from autora.doc.runtime.prompts import PROMPTS, PromptIds

app = typer.Typer()
logging.basicConfig(
Expand All @@ -21,10 +21,7 @@
def eval(
data_file: str = typer.Argument(..., help="JSONL Data file to evaluate on"),
model_path: str = typer.Option("meta-llama/Llama-2-7b-chat-hf", help="Path to HF model"),
sys_id: SystemPrompts = typer.Option(SystemPrompts.SYS_1, help="System prompt ID"),
instruc_id: InstructionPrompts = typer.Option(
InstructionPrompts.INSTR_SWEETP_1, help="Instruction prompt ID"
),
prompt_id: PromptIds = typer.Option(PromptIds.SWEETP_1, help="Instruction prompt ID"),
param: List[str] = typer.Option(
[], help="Additional float parameters to pass to the model as name=float pairs"
),
Expand All @@ -37,8 +34,7 @@ def eval(
param_dict = {pair[0]: float(pair[1]) for pair in [pair.split("=") for pair in param]}
run = mlflow.active_run()

sys_prompt = SYS[sys_id]
instr_prompt = INSTR[instruc_id]
prompt = PROMPTS[prompt_id]
if run is None:
run = mlflow.start_run()
with run:
Expand All @@ -54,7 +50,7 @@ def eval(

pred = Predictor(model_path)
timer_start = timer()
predictions = pred.predict(sys_prompt, instr_prompt, inputs, **param_dict)
predictions = pred.predict(prompt, inputs, **param_dict)
timer_end = timer()
pred_time = timer_end - timer_start
mlflow.log_metric("prediction_time/doc", pred_time / (len(inputs)))
Expand All @@ -78,10 +74,7 @@ def generate(
python_file: str = typer.Argument(..., help="Python file to generate documentation for"),
model_path: str = typer.Option("meta-llama/Llama-2-7b-chat-hf", help="Path to HF model"),
output: str = typer.Option("output.txt", help="Output file"),
sys_id: SystemPrompts = typer.Option(SystemPrompts.SYS_1, help="System prompt ID"),
instruc_id: InstructionPrompts = typer.Option(
InstructionPrompts.INSTR_SWEETP_1, help="Instruction prompt ID"
),
prompt_id: PromptIds = typer.Option(PromptIds.SWEETP_1, help="Instruction prompt ID"),
param: List[str] = typer.Option(
[], help="Additional float parameters to pass to the model as name=float pairs"
),
Expand All @@ -92,11 +85,10 @@ def generate(
"""
with open(python_file, "r") as f:
input = f.read()
sys_prompt = SYS[sys_id]
instr_prompt = INSTR[instruc_id]
prompt = PROMPTS[prompt_id]
pred = Predictor(model_path)
# grab first result since we only passed one input
predictions = pred.predict(sys_prompt, instr_prompt, [input], **param_dict)[0]
predictions = pred.predict(prompt, [input], **param_dict)[0]
assert len(predictions) == 1, f"Expected only one output, got {len(predictions)}"
logger.info(f"Writing output to {output}")
with open(output, "w") as f:
Expand Down
9 changes: 4 additions & 5 deletions src/autora/doc/runtime/predict_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer

from autora.doc.runtime.prompts import LLAMA2_INST_CLOSE, TEMP_LLAMA2
from autora.doc.runtime.prompts import LLAMA2_INST_CLOSE

logger = logging.getLogger(__name__)

Expand All @@ -29,8 +29,7 @@ def __init__(self, model_path: str):

def predict(
self,
sys: str,
instr: str,
prompt_template: str,
inputs: List[str],
do_sample: float = 0.0,
temperature: float = 0.01,
Expand All @@ -45,7 +44,7 @@ def predict(
f"Generating {len(inputs)} predictions. do_sample: {do_sample}, temperature: {temperature}, top_p: {top_p},"
f" top_k: {top_k}, max_length: {max_length}"
)
prompts = [TEMP_LLAMA2.format(sys=sys, instr=instr, input=input) for input in inputs]
prompts = [prompt_template.format(code=input) for input in inputs]
sequences = self.pipeline(
prompts,
do_sample=do_sample,
Expand All @@ -65,7 +64,7 @@ def predict(

@staticmethod
def trim_prompt(output: str) -> str:
marker = output.find(LLAMA2_INST_CLOSE)
marker = output.rfind(LLAMA2_INST_CLOSE)
if marker == -1:
logger.warning(f"Could not find end of prompt marker '{LLAMA2_INST_CLOSE}' in '{output}'")
return output
Expand Down
Loading

0 comments on commit a884b82

Please sign in to comment.