Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update FT'ing for function calling notebook to match new python SDK #866

Merged
merged 1 commit into from
Nov 27, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 68 additions & 22 deletions examples/Fine_tuning_for_function_calling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -82,20 +82,20 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import openai\n",
"import numpy as np\n",
"import json\n",
"import os\n",
"import openai\n",
"from openai import OpenAI\n",
"import itertools\n",
"from tenacity import retry, wait_random_exponential, stop_after_attempt\n",
"from typing import Any, Dict, List, Generator\n",
"import ast\n",
"openai.api_key = os.getenv('OPENAI_API_KEY')\n"
"client = OpenAI()\n"
]
},
{
Expand All @@ -114,7 +114,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -137,7 +137,7 @@
" if functions:\n",
" params['functions'] = functions\n",
"\n",
" completion = openai.ChatCompletion.create(**params)\n",
" completion = client.chat.completions.create(**params)\n",
" return completion.choices[0].message\n"
]
},
Expand All @@ -159,7 +159,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -177,7 +177,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -430,7 +430,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -442,9 +442,28 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Land the drone at the home base\n",
"FunctionCall(arguments='{\\n \"location\": \"home_base\"\\n}', name='land_drone') \n",
"\n",
"Take off the drone to 50 meters\n",
"FunctionCall(arguments='{\\n \"altitude\": 50\\n}', name='takeoff_drone') \n",
"\n",
"change speed to 15 kilometers per hour\n",
"FunctionCall(arguments='{\\n \"speed\": 15\\n}', name='set_drone_speed') \n",
"\n",
"turn into an elephant!\n",
"FunctionCall(arguments='{}', name='reject_request') \n",
"\n"
]
}
],
"source": [
"for prompt in straightforward_prompts:\n",
" messages = []\n",
Expand All @@ -464,7 +483,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -477,9 +496,36 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Play pre-recorded audio message\n",
"FunctionCall(arguments='{}', name='reject_request')\n",
"\n",
"\n",
"Initiate live-streaming on social media\n",
"FunctionCall(arguments='{\\n\"mode\": \"video\",\\n\"duration\": 0\\n}', name='control_camera')\n",
"\n",
"\n",
"Scan environment for heat signatures\n",
"None\n",
"\n",
"\n",
"Enable stealth mode\n",
"FunctionCall(arguments='{\\n \"mode\": \"off\"\\n}', name='set_drone_lighting')\n",
"\n",
"\n",
"Change drone's paint job color\n",
"FunctionCall(arguments='{\\n \"pattern\": \"solid\",\\n \"color\": \"blue\"\\n}', name='configure_led_display')\n",
"\n",
"\n"
]
}
],
"source": [
"for prompt in challenging_prompts:\n",
" messages = []\n",
Expand Down Expand Up @@ -537,7 +583,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -557,7 +603,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -664,7 +710,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -734,7 +780,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -784,7 +830,7 @@
" request_prompt = COMMAND_GENERATION_PROMPT.format(invocation=invocation)\n",
"\n",
" messages = [{\"role\": \"user\", \"content\": f\"{request_prompt}\"}]\n",
" completion = get_chat_completion(messages,temperature=0.8)\n",
" completion = get_chat_completion(messages,temperature=0.8).content\n",
" command_dict = {\n",
" \"Input\": invocation,\n",
" \"Prompt\": completion\n",
Expand Down Expand Up @@ -926,13 +972,13 @@
"outputs": [],
"source": [
"if __name__ == \"__main__\":\n",
" file = openai.File.create(\n",
" file = client.files.create(\n",
" file=open(training_file, \"rb\"),\n",
" purpose=\"fine-tune\",\n",
" )\n",
" file_id = file.id\n",
" print(file_id)\n",
" ft = openai.FineTuningJob.create(\n",
" ft = client.fine_tuning.jobs.create(\n",
" # model=\"gpt-4-0613\",\n",
" model=\"gpt-3.5-turbo\",\n",
" training_file=file_id,\n",
Expand Down Expand Up @@ -980,7 +1026,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Conclustion"
"### Conclusion"
]
},
{
Expand Down