-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
275 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ requires = ["flit"] | |
build-backend = "flit.buildapi" | ||
|
||
[project] | ||
name = "ask-gpt" | ||
name = "gpt-review" | ||
authors = [ | ||
{name = "Daniel Ciborowski", email = "[email protected]"}, | ||
] | ||
|
@@ -55,15 +55,15 @@ test = [ | |
] | ||
|
||
[project.scripts] | ||
gpt = "ask_gpt.main:__main__" | ||
gpt = "gpt_review.main:__main__" | ||
|
||
[project.urls] | ||
Documentation = "https://github.com/dciborow/action-gpt/tree/main#readme" | ||
Source = "https://github.com/dciborow/action-gpt" | ||
Tracker = "https://github.com/dciborow/action-gpt/issues" | ||
|
||
[tool.flit.module] | ||
name = "ask_gpt" | ||
name = "gpt_review" | ||
|
||
[tool.bandit] | ||
exclude_dirs = ["build","dist","tests","scripts"] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# ------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See LICENSE in project root for information. | ||
# ------------------------------------------------------------- | ||
"""Easy GPT CLI""" | ||
from __future__ import annotations | ||
|
||
__version__ = "0.0.1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
"""Ask GPT a question.""" | ||
import logging | ||
import os | ||
import time | ||
from knack import CLICommandsLoader | ||
from knack.arguments import ArgumentsContext | ||
from knack.commands import CommandGroup | ||
|
||
import openai | ||
|
||
from azure.identity import DefaultAzureCredential | ||
from azure.keyvault.secrets import SecretClient | ||
|
||
|
||
from openai.error import RateLimitError | ||
|
||
|
||
from gpt_review._command import GPTCommandGroup | ||
|
||
DEFAULT_KEY_VAULT = "https://dciborow-openai.vault.azure.net/" | ||
|
||
|
||
def _ask(question, max_tokens=100): | ||
"""Ask GPT a question.""" | ||
response = _call_gpt(prompt=question[0], max_tokens=max_tokens) | ||
return {"response": response} | ||
|
||
|
||
def _load_azure_openai_context(): | ||
""" | ||
Load the Azure OpenAI context. | ||
If the environment variables are not set, retrieve the values from Azure Key Vault. | ||
""" | ||
openai.api_type = "azure" | ||
openai.api_version = "2023-03-15-preview" | ||
if os.getenv("AZURE_OPENAI_API"): | ||
openai.api_base = os.getenv("AZURE_OPENAI_API") | ||
openai.api_key = os.getenv("AZURE_OPENAI_API_KEY") | ||
else: | ||
secret_client = SecretClient( | ||
vault_url=os.getenv("AZURE_KEY_VAULT_URL", DEFAULT_KEY_VAULT), credential=DefaultAzureCredential() | ||
) | ||
|
||
openai.api_base = secret_client.get_secret("azure-open-ai").value | ||
openai.api_key = secret_client.get_secret("azure-openai-key").value | ||
|
||
|
||
def _call_gpt( | ||
prompt: str, | ||
temperature=0.10, | ||
max_tokens=500, | ||
top_p=1, | ||
frequency_penalty=0.5, | ||
presence_penalty=0.0, | ||
retry=0, | ||
messages=None, | ||
) -> str: | ||
""" | ||
Call GPT-4 with the given prompt. | ||
Args: | ||
prompt (str): The prompt to send to GPT-4. | ||
temperature (float, optional): The temperature to use. Defaults to 0.10. | ||
max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 500. | ||
top_p (float, optional): The top_p to use. Defaults to 1. | ||
frequency_penalty (float, optional): The frequency penalty to use. Defaults to 0.5. | ||
presence_penalty (float, optional): The presence penalty to use. Defaults to 0.0. | ||
retry (int, optional): The number of times to retry the request. Defaults to 0. | ||
Returns: | ||
str: The response from GPT-4. | ||
""" | ||
_load_azure_openai_context() | ||
|
||
if len(prompt) > 32767: | ||
return _batch_large_changes( | ||
prompt, temperature, max_tokens, top_p, frequency_penalty, presence_penalty, retry, messages | ||
) | ||
|
||
messages = messages or [{"role": "user", "content": prompt}] | ||
try: | ||
engine = _get_engine(prompt) | ||
logging.info("Model Selected based on prompt size: %s", engine) | ||
|
||
logging.info("Prompt sent to GPT: %s\n", prompt) | ||
completion = openai.ChatCompletion.create( | ||
engine=engine, | ||
messages=messages, | ||
max_tokens=max_tokens, | ||
temperature=temperature, | ||
top_p=top_p, | ||
frequency_penalty=frequency_penalty, | ||
presence_penalty=presence_penalty, | ||
) | ||
return completion.choices[0].message.content # type: ignore | ||
except RateLimitError as error: | ||
if retry < 5: | ||
logging.warning("Call to GPT failed due to rate limit, retry attempt: %s", retry) | ||
time.sleep(retry * 5) | ||
return _call_gpt(prompt, temperature, max_tokens, top_p, frequency_penalty, presence_penalty, retry + 1) | ||
raise RateLimitError("Retry limit exceeded") from error | ||
|
||
|
||
def _batch_large_changes( | ||
prompt: str, | ||
temperature=0.10, | ||
max_tokens=500, | ||
top_p=1, | ||
frequency_penalty=0.5, | ||
presence_penalty=0.0, | ||
retry=0, | ||
messages=None, | ||
) -> str: | ||
"""Placeholder for batching large changes to GPT-4.""" | ||
try: | ||
logging.warning("Prompt too long, batching") | ||
output = "" | ||
for i in range(0, len(prompt), 32767): | ||
logging.debug("Batching %s to %s", i, i + 32767) | ||
batch = prompt[i : i + 32767] | ||
output += _call_gpt( | ||
batch, | ||
temperature=temperature, | ||
max_tokens=max_tokens, | ||
top_p=top_p, | ||
frequency_penalty=frequency_penalty, | ||
presence_penalty=presence_penalty, | ||
retry=retry, | ||
messages=messages, | ||
) | ||
prompt = f""" | ||
"Summarize the large file batches" | ||
{output} | ||
""" | ||
return _call_gpt(prompt, temperature, max_tokens, top_p, frequency_penalty, presence_penalty, retry, messages) | ||
except RateLimitError: | ||
logging.warning("Prompt too long, truncating") | ||
prompt = prompt[:32767] | ||
return _call_gpt( | ||
prompt, | ||
temperature=temperature, | ||
max_tokens=max_tokens, | ||
top_p=top_p, | ||
frequency_penalty=frequency_penalty, | ||
presence_penalty=presence_penalty, | ||
retry=retry, | ||
messages=messages, | ||
) | ||
|
||
|
||
def _get_engine(prompt: str) -> str: | ||
""" | ||
Get the Engine based on the prompt length. | ||
- when greater then 8k use gpt-4-32k | ||
- when greater then 4k use gpt-4 | ||
- use gpt-35-turbo for all small prompts | ||
""" | ||
if len(prompt) > 8000: | ||
return "gpt-4-32k" | ||
return "gpt-4" if len(prompt) > 4000 else "gpt-35-turbo" | ||
|
||
|
||
class AskCommandGroup(GPTCommandGroup): | ||
"""Ask Command Group.""" | ||
|
||
@staticmethod | ||
def load_command_table(loader: CLICommandsLoader): | ||
with CommandGroup(loader, "", "gpt_review._ask#{}") as group: | ||
group.command("ask", "_ask", is_preview=True) | ||
|
||
@staticmethod | ||
def load_arguments(loader: CLICommandsLoader): | ||
with ArgumentsContext(loader, "ask") as args: | ||
args.positional("question", type=str, nargs="+", help="Provide a question to ask GPT.") | ||
args.argument("max_tokens", type=int, help="The maximum number of tokens to generate.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
"""Interface for GPT CLI command groups.""" | ||
from knack import CLICommandsLoader | ||
|
||
|
||
class GPTCommandGroup: | ||
"""Command Group Interface.""" | ||
|
||
@staticmethod | ||
def load_command_table(loader: CLICommandsLoader) -> None: | ||
"""Load the command table.""" | ||
|
||
@staticmethod | ||
def load_arguments(loader: CLICommandsLoader) -> None: | ||
"""Load the arguments for the command group.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
"""The GPT CLI configuration and utilities.""" | ||
from collections import OrderedDict | ||
import os | ||
import sys | ||
|
||
from knack import CLI, CLICommandsLoader | ||
|
||
from gpt_review import __version__ | ||
from gpt_review._ask import AskCommandGroup | ||
|
||
CLI_NAME = "gpt" | ||
|
||
|
||
class GPTCLI(CLI): | ||
"""Custom CLI implemntation to set version for the GPT CLI.""" | ||
|
||
def get_cli_version(self) -> str: | ||
return __version__ | ||
|
||
|
||
class GPTCommandsLoader(CLICommandsLoader): | ||
"""The GPT CLI Commands Loader.""" | ||
|
||
_CommandGroups = [ | ||
AskCommandGroup, | ||
] | ||
|
||
_ArgumentGroups = [ | ||
AskCommandGroup, | ||
] | ||
|
||
def load_command_table(self, args) -> OrderedDict: | ||
for command_group in self._CommandGroups: | ||
command_group.load_command_table(self) | ||
return OrderedDict(self.command_table) | ||
|
||
def load_arguments(self, command) -> None: | ||
for argument_group in self._ArgumentGroups: | ||
argument_group.load_arguments(self) | ||
super(GPTCommandsLoader, self).load_arguments(command) | ||
|
||
|
||
def cli() -> int: | ||
"""The GPT CLI entry point.""" | ||
gpt = GPTCLI( | ||
cli_name=CLI_NAME, | ||
config_dir=os.path.expanduser(os.path.join("~", f".{CLI_NAME}")), | ||
config_env_var_prefix=CLI_NAME, | ||
commands_loader_cls=GPTCommandsLoader, | ||
) | ||
return gpt.invoke(sys.argv[1:]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
"""The GPT CLI entry point.""" | ||
import sys | ||
|
||
from knack.help_files import helps | ||
|
||
from gpt_review._gpt_cli import cli | ||
|
||
|
||
def _help_text(help_type, short_summary) -> str: | ||
return f""" | ||
type: {help_type} | ||
short-summary: {short_summary} | ||
""" | ||
|
||
|
||
helps[""] = _help_text("group", "Easily interact with GPT APIs.") | ||
helps["git"] = _help_text("group", "Use GPT enchanced git commands.") | ||
|
||
|
||
exit_code = cli() | ||
sys.exit(exit_code) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters