Skip to content

Commit

Permalink
Add arguments for model parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgjs committed Dec 8, 2023
1 parent 1d80dfe commit aa1b494
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 8 deletions.
165 changes: 165 additions & 0 deletions notebooks/generate.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"from autora.doc.runtime.predict_hf import Predictor\n",
"from autora.doc.runtime.prompts import INSTR, SYS, InstructionPrompts, SystemPrompts"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# model = \"../../models\" # if model has been previously downloaded via huggingface-cli\n",
"model = \"meta-llama/Llama-2-7b-chat-hf\"\n",
"pred = Predictor(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# The following prompt uses an example (code, doc) to specify the desired behavior\n",
"EX_CODE=\"\"\"\n",
"from sweetpea import *\n",
"\n",
"color = Factor('color', ['red', 'green', 'blue', 'yellow'])\n",
"word = Factor('word', ['red', 'green', 'blue', 'yellow'])\n",
"\n",
"def is_congruent(word, color):\n",
" return (word == color)\n",
"\n",
"def is_not_congruent(word, color):\n",
" return not is_congruent(word, color)\n",
"\n",
"congruent = DerivedLevel('congruent', WithinTrial(is_congruent, [word, color]))\n",
"incongruent = DerivedLevel('incongruent', WithinTrial(is_not_congruent, [word, color]))\n",
"\n",
"congruency = Factor('congruency', [congruent, incongruent])\n",
"\n",
"constraints = [MinimumTrials(48)]\n",
"design = [word, color, congruency]\n",
"crossing = [word, congruency]\n",
"\n",
"block = CrossBlock(design, crossing, constraints)\n",
"\n",
"experiment = synthesize_trials(block, 1)\n",
"\n",
"save_experiments_csv(block, experiment, 'code_1_sequences/seq')\n",
"\"\"\"\n",
"\n",
"EX_DOC=\"\"\"There are two regular factors: color and word. The color factor consists of four levels: \"red\", \"green\", \"blue\", and \"yellow\". \n",
"The word factor also consists of the four levels: \"red\", \"green\", \"blue\", and \"yellow\". There is another derived factor referred to as congruency. \n",
"The congruency factor depends on the regular factors word and color and has two levels: \"congruent\" and \"incongruent\".\n",
"A trial is considered \"congruent\" if the word matches the color, otherwise, it is considered \"incongruent\". We counterbalanced the word factor with the congruency factor. \n",
"All experiment sequences contained at least 48 trials.\"\"\"\n",
"\n",
"TEST_CODE=\"\"\"\n",
"from sweetpea import *\n",
"from sweetpea.primitives import *\n",
"\n",
"number_list = [125, 132, 139, 146, 160, 167, 174, 181]\n",
"letter_list = ['b', 'd', 'f', 'h', 's', 'u', 'w', 'y']\n",
"\n",
"number = Factor(\"number\", number_list)\n",
"letter = Factor(\"letter\", letter_list)\n",
"task = Factor(\"task\", [\"number task\", \"letter task\", \"free choice task\"])\n",
"\n",
"\n",
"def is_forced_trial_switch(task):\n",
" return (task[-1] == \"number task\" and task[0] == \"letter task\") or \\\n",
" (task[-1] == \"letter task\" and task[0] == \"number task\")\n",
"\n",
"\n",
"def is_forced_trial_repeat(task):\n",
" return (task[-1] == \"number task\" and task[0] == \"number task\") or \\\n",
" (task[-1] == \"letter task\" and task[0] == \"letter task\")\n",
"\n",
"\n",
"def is_free_trial_transition(task):\n",
" return task[-1] != \"free choice task\" and task[0] == \"free choice task\"\n",
"\n",
"\n",
"def is_free_trial_repeat(task):\n",
" return task[-1] == \"free choice task\" and task[0] == \"free choice task\"\n",
"\n",
"\n",
"def is_not_relevant_transition(task):\n",
" return not (is_forced_trial_repeat(task) or is_forced_trial_switch(task) or is_free_trial_repeat(\n",
" task) or is_free_trial_transition(task))\n",
"\n",
"\n",
"transit = Factor(\"task transition\", [\n",
" DerivedLevel(\"forced switch\", transition(is_forced_trial_switch, [task]), 3),\n",
" DerivedLevel(\"forced repeat\", transition(is_forced_trial_repeat, [task])),\n",
" DerivedLevel(\"free transition\", transition(is_free_trial_transition, [task]), 4),\n",
" DerivedLevel(\"free repeat\", transition(is_free_trial_repeat, [task]), 4),\n",
" DerivedLevel(\"forced first\", transition(is_not_relevant_transition, [task]), 4)\n",
"])\n",
"design = [letter, number, task, transit]\n",
"crossing = [[letter], [number], [transit]]\n",
"constraints = [MinimumTrials(256)]\n",
"\n",
"block = MultiCrossBlock(design, crossing, constraints)\n",
"\n",
"experiment = synthesize_trials(block, 1)\n",
"\n",
"save_experiments_csv(block, experiment, 'code_1_sequences/seq')\n",
"\"\"\"\n",
"\n",
"PROMPT=f\"\"\"Consider the following experiment code:\n",
"---\n",
"{EX_CODE}\n",
"---\n",
"Here's a a good English description:\n",
"---\n",
"{EX_DOC}\n",
"---\n",
"Using the same style, please generate a high-level one paragraph description for the following experiment code:\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"output = pred.predict(SYS[SystemPrompts.SYS_1], PROMPT, [TEST_CODE], temperature=0.05, top_k=10, num_ret_seq=3)[0]\n",
"for i,o in enumerate(output):\n",
" print(f\"******** Output {i} ********\\n{o}*************\\n\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "autodoc",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.18"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
15 changes: 7 additions & 8 deletions src/autora/doc/runtime/predict_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,21 @@ def __init__(self, model_path: str):
tokenizer=self.tokenizer,
)

def predict(self, sys: str, instr: str, inputs: List[str]) -> List[str]:
def predict(self, sys: str, instr: str, inputs: List[str], temperature=0.6, top_p=0.95, top_k=40, max_length=2048, num_ret_seq=1) -> List[List[str]]:
logger.info(f"Generating {len(inputs)} predictions")
prompts = [TEMP_LLAMA2.format(sys=sys, instr=instr, input=input) for input in inputs]
# TODO: Make these parameters configurable
sequences = self.pipeline(
prompts,
do_sample=True,
temperature=0.6,
top_p=0.95,
top_k=40,
num_return_sequences=1,
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_return_sequences=num_ret_seq,
eos_token_id=self.tokenizer.eos_token_id,
max_length=2048,
max_length=max_length,
)

results = [Predictor.trim_prompt(sequence[0]["generated_text"]) for sequence in sequences]
results = [[Predictor.trim_prompt(seq["generated_text"]) for seq in sequence] for sequence in sequences]
logger.info(f"Generated {len(results)} results")
return results

Expand Down

0 comments on commit aa1b494

Please sign in to comment.