Skip to content

Commit

Permalink
Refactoring was too aggressive; send the model change out
Browse files Browse the repository at this point in the history
  • Loading branch information
maparent committed Oct 7, 2023
1 parent 18e911b commit 698b584
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions easycompletion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def validate_functions(response, functions, function_call, debug=DEBUG):
def sanity_check(prompt, model=None, chunk_length=DEFAULT_CHUNK_LENGTH, api_key=EASYCOMPLETION_API_KEY, debug=DEBUG):
# Validate the API key
if not api_key.strip():
return {"error": "Invalid OpenAI API key"}
return model, {"error": "Invalid OpenAI API key"}

openai.api_key = api_key

Expand All @@ -175,7 +175,7 @@ def sanity_check(prompt, model=None, chunk_length=DEFAULT_CHUNK_LENGTH, api_key=
# If text is too long even for long text model, return None
if total_tokens > (16384 - chunk_length):
print("Error: Message too long")
return {
return model, {
"text": None,
"usage": None,
"finish_reason": None,
Expand All @@ -189,6 +189,7 @@ def sanity_check(prompt, model=None, chunk_length=DEFAULT_CHUNK_LENGTH, api_key=
else:
log(f"Prompt ({total_tokens} tokens):\n{str(prompt)}", type="prompt", log=debug)

return model, None

def do_chat_completion(
messages, model=TEXT_MODEL, temperature=0.8, functions=None, function_call=None, model_failure_retries=5, debug=DEBUG):
Expand Down Expand Up @@ -254,7 +255,7 @@ def chat_completion(

# Use the default model if no model is specified
model = model or TEXT_MODEL
error = sanity_check(messages, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug)
model, error = sanity_check(messages, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug)
if error:
return error

Expand Down Expand Up @@ -306,7 +307,7 @@ async def chat_completion_async(

# Use the default model if no model is specified
model = model or TEXT_MODEL
error = sanity_check(messages, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug)
model, error = sanity_check(messages, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug)
if error:
return error

Expand Down Expand Up @@ -358,7 +359,7 @@ def text_completion(

# Use the default model if no model is specified
model = model or TEXT_MODEL
error = sanity_check(text, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug)
model, error = sanity_check(text, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug)
if error:
return error

Expand Down Expand Up @@ -411,7 +412,7 @@ async def text_completion_async(

# Use the default model if no model is specified
model = model or TEXT_MODEL
error = sanity_check(text, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug)
model, error = sanity_check(text, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug)
if error:
return error

Expand Down Expand Up @@ -533,7 +534,7 @@ def function_completion(
"error": "function_call had an invalid name. Should be a string of the function name or an object with a name property"
}

error = sanity_check(dict(
model, error = sanity_check(dict(
text=text, functions=functions, messages=messages, system_message=system_message
), model=model, chunk_length=chunk_length, api_key=api_key)
if error:
Expand Down Expand Up @@ -702,7 +703,7 @@ async def function_completion_async(
"error": "function_call had an invalid name. Should be a string of the function name or an object with a name property"
}

error = sanity_check(dict(
model, error = sanity_check(dict(
text=text, functions=functions, messages=messages, system_message=system_message
), model=model, chunk_length=chunk_length, api_key=api_key)
if error:
Expand Down

0 comments on commit 698b584

Please sign in to comment.