From 16e5a1c2f7db2635f11f5e2ada3b551061bc9bb0 Mon Sep 17 00:00:00 2001 From: jhills20 <70035505+jhills20@users.noreply.github.com> Date: Mon, 27 Nov 2023 16:46:21 -0500 Subject: [PATCH] Update FT'ing for function calling notebook to match new python SDK (#866) --- .../Fine_tuning_for_function_calling.ipynb | 90 ++++++++++++++----- 1 file changed, 68 insertions(+), 22 deletions(-) diff --git a/examples/Fine_tuning_for_function_calling.ipynb b/examples/Fine_tuning_for_function_calling.ipynb index 85a7d32a5e..8a9236aec0 100644 --- a/examples/Fine_tuning_for_function_calling.ipynb +++ b/examples/Fine_tuning_for_function_calling.ipynb @@ -71,7 +71,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -82,7 +82,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -90,12 +90,12 @@ "import numpy as np\n", "import json\n", "import os\n", - "import openai\n", + "from openai import OpenAI\n", "import itertools\n", "from tenacity import retry, wait_random_exponential, stop_after_attempt\n", "from typing import Any, Dict, List, Generator\n", "import ast\n", - "openai.api_key = os.getenv('OPENAI_API_KEY')\n" + "client = OpenAI()\n" ] }, { @@ -114,7 +114,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -137,7 +137,7 @@ " if functions:\n", " params['functions'] = functions\n", "\n", - " completion = openai.ChatCompletion.create(**params)\n", + " completion = client.chat.completions.create(**params)\n", " return completion.choices[0].message\n" ] }, @@ -159,7 +159,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -177,7 +177,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -430,7 +430,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -442,9 +442,28 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Land the drone at the home base\n", + "FunctionCall(arguments='{\\n \"location\": \"home_base\"\\n}', name='land_drone') \n", + "\n", + "Take off the drone to 50 meters\n", + "FunctionCall(arguments='{\\n \"altitude\": 50\\n}', name='takeoff_drone') \n", + "\n", + "change speed to 15 kilometers per hour\n", + "FunctionCall(arguments='{\\n \"speed\": 15\\n}', name='set_drone_speed') \n", + "\n", + "turn into an elephant!\n", + "FunctionCall(arguments='{}', name='reject_request') \n", + "\n" + ] + } + ], "source": [ "for prompt in straightforward_prompts:\n", " messages = []\n", @@ -464,7 +483,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -477,9 +496,36 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Play pre-recorded audio message\n", + "FunctionCall(arguments='{}', name='reject_request')\n", + "\n", + "\n", + "Initiate live-streaming on social media\n", + "FunctionCall(arguments='{\\n\"mode\": \"video\",\\n\"duration\": 0\\n}', name='control_camera')\n", + "\n", + "\n", + "Scan environment for heat signatures\n", + "None\n", + "\n", + "\n", + "Enable stealth mode\n", + "FunctionCall(arguments='{\\n \"mode\": \"off\"\\n}', name='set_drone_lighting')\n", + "\n", + "\n", + "Change drone's paint job color\n", + "FunctionCall(arguments='{\\n \"pattern\": \"solid\",\\n \"color\": \"blue\"\\n}', name='configure_led_display')\n", + "\n", + "\n" + ] + } + ], "source": [ "for prompt in challenging_prompts:\n", " messages = []\n", @@ -537,7 +583,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -557,7 +603,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -664,7 +710,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -734,7 +780,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -784,7 +830,7 @@ " request_prompt = COMMAND_GENERATION_PROMPT.format(invocation=invocation)\n", "\n", " messages = [{\"role\": \"user\", \"content\": f\"{request_prompt}\"}]\n", - " completion = get_chat_completion(messages,temperature=0.8)\n", + " completion = get_chat_completion(messages,temperature=0.8).content\n", " command_dict = {\n", " \"Input\": invocation,\n", " \"Prompt\": completion\n", @@ -926,13 +972,13 @@ "outputs": [], "source": [ "if __name__ == \"__main__\":\n", - " file = openai.File.create(\n", + " file = client.files.create(\n", " file=open(training_file, \"rb\"),\n", " purpose=\"fine-tune\",\n", " )\n", " file_id = file.id\n", " print(file_id)\n", - " ft = openai.FineTuningJob.create(\n", + " ft = client.fine_tuning.jobs.create(\n", " # model=\"gpt-4-0613\",\n", " model=\"gpt-3.5-turbo\",\n", " training_file=file_id,\n", @@ -980,7 +1026,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Conclustion" + "### Conclusion" ] }, {