From 698b5846f94cffdc8460c775988f51db947b7063 Mon Sep 17 00:00:00 2001 From: Marc-Antoine Parent Date: Sat, 7 Oct 2023 17:14:55 -0400 Subject: [PATCH] Refactoring was too aggressive; send the model change out --- easycompletion/model.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/easycompletion/model.py b/easycompletion/model.py index 59d1056..41be67d 100644 --- a/easycompletion/model.py +++ b/easycompletion/model.py @@ -157,7 +157,7 @@ def validate_functions(response, functions, function_call, debug=DEBUG): def sanity_check(prompt, model=None, chunk_length=DEFAULT_CHUNK_LENGTH, api_key=EASYCOMPLETION_API_KEY, debug=DEBUG): # Validate the API key if not api_key.strip(): - return {"error": "Invalid OpenAI API key"} + return model, {"error": "Invalid OpenAI API key"} openai.api_key = api_key @@ -175,7 +175,7 @@ def sanity_check(prompt, model=None, chunk_length=DEFAULT_CHUNK_LENGTH, api_key= # If text is too long even for long text model, return None if total_tokens > (16384 - chunk_length): print("Error: Message too long") - return { + return model, { "text": None, "usage": None, "finish_reason": None, @@ -189,6 +189,7 @@ def sanity_check(prompt, model=None, chunk_length=DEFAULT_CHUNK_LENGTH, api_key= else: log(f"Prompt ({total_tokens} tokens):\n{str(prompt)}", type="prompt", log=debug) + return model, None def do_chat_completion( messages, model=TEXT_MODEL, temperature=0.8, functions=None, function_call=None, model_failure_retries=5, debug=DEBUG): @@ -254,7 +255,7 @@ def chat_completion( # Use the default model if no model is specified model = model or TEXT_MODEL - error = sanity_check(messages, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug) + model, error = sanity_check(messages, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug) if error: return error @@ -306,7 +307,7 @@ async def chat_completion_async( # Use the default model if no model is specified model = model or TEXT_MODEL - error = sanity_check(messages, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug) + model, error = sanity_check(messages, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug) if error: return error @@ -358,7 +359,7 @@ def text_completion( # Use the default model if no model is specified model = model or TEXT_MODEL - error = sanity_check(text, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug) + model, error = sanity_check(text, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug) if error: return error @@ -411,7 +412,7 @@ async def text_completion_async( # Use the default model if no model is specified model = model or TEXT_MODEL - error = sanity_check(text, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug) + model, error = sanity_check(text, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug) if error: return error @@ -533,7 +534,7 @@ def function_completion( "error": "function_call had an invalid name. Should be a string of the function name or an object with a name property" } - error = sanity_check(dict( + model, error = sanity_check(dict( text=text, functions=functions, messages=messages, system_message=system_message ), model=model, chunk_length=chunk_length, api_key=api_key) if error: @@ -702,7 +703,7 @@ async def function_completion_async( "error": "function_call had an invalid name. Should be a string of the function name or an object with a name property" } - error = sanity_check(dict( + model, error = sanity_check(dict( text=text, functions=functions, messages=messages, system_message=system_message ), model=model, chunk_length=chunk_length, api_key=api_key) if error: