diff --git a/.vscode/settings.json b/.vscode/settings.json index d5acfcab..2e41ac9d 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -25,5 +25,11 @@ ], "isort.args": [ "--settings-path=pyproject.toml" - ] + ], + "python.linting.banditEnabled": true, + "python.linting.banditArgs": [ + "-c", + "pyproject.toml" + ], + "python.analysis.inlayHints.functionReturnTypes": true } \ No newline at end of file diff --git a/README.md b/README.md index 8fb8d5c8..8d97af56 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,12 @@ # gpt-review +

+Actions Status +Coverage Status +License: MIT +PyPI +Downloads +Code style: black +

A Python based CLI and GitHub Action to use Open AI or Azure Open AI models to review contents of pull requests. diff --git a/src/gpt_review/_ask.py b/src/gpt_review/_ask.py index e9fa9d13..5c3e2ab8 100644 --- a/src/gpt_review/_ask.py +++ b/src/gpt_review/_ask.py @@ -2,7 +2,7 @@ import logging import os import time -from typing import Dict +from typing import Dict, List from typing_extensions import override from knack import CLICommandsLoader from knack.arguments import ArgumentsContext @@ -10,12 +10,19 @@ from knack.util import CLIError import openai - from azure.identity import DefaultAzureCredential from azure.keyvault.secrets import SecretClient - - from openai.error import RateLimitError +from langchain.llms import AzureOpenAI +from langchain.embeddings import OpenAIEmbeddings +from llama_index import ( + GPTSimpleVectorIndex, + LangchainEmbedding, + ServiceContext, + LLMPredictor, + SimpleDirectoryReader, +) +from llama_index.indices.base import BaseGPTIndex from gpt_review._command import GPTCommandGroup import gpt_review.constants as C @@ -24,6 +31,89 @@ DEFAULT_KEY_VAULT = "https://dciborow-openai.vault.azure.net/" +def _ask_doc(question: List[str], files: List[str]) -> str: + """ + Ask GPT a question. + + Args: + question (List[str]): The question to ask. + files (List[str]): The files to search. + + Returns: + Dict[str, str]: The response. + """ + documents = SimpleDirectoryReader(input_files=files).load_data() + index = _document_indexer(documents) + + return index.query(" ".join(question)).response # type: ignore + + +def _document_indexer(documents) -> BaseGPTIndex: + """ + Create a document indexer. + + Args: + documents (List[Document]): The documents to index. + azure (bool): Whether to use Azure OpenAI. + + Returns: + GPTSimpleVectorIndex: The document indexer. + """ + service_context = None + if os.getenv("AZURE_OPENAI_API_KEY"): + _load_azure_openai_context() + + os.environ["OPENAI_API_KEY"] = openai.api_key # type: ignore + llm = AzureGPT35Turbo( # type: ignore + deployment_name="gpt-35-turbo", # "gpt-35-turbo", # "text-davinci-003", + model_kwargs={ + "api_key": openai.api_key, + "api_base": openai.api_base, + "api_type": "azure", + "api_version": "2023-03-15-preview", + }, + max_retries=10, + ) + llm_predictor = LLMPredictor(llm=llm) + + embedding_llm = LangchainEmbedding( + OpenAIEmbeddings( + document_model_name="text-embedding-ada-002", + query_model_name="text-embedding-ada-002", + ), # type: ignore + embed_batch_size=1, + ) + + service_context = ServiceContext.from_defaults( + llm_predictor=llm_predictor, + embed_model=embedding_llm, + ) + return GPTSimpleVectorIndex.from_documents(documents, service_context=service_context) + + +class AzureGPT35Turbo(AzureOpenAI): + """Azure OpenAI Chat API.""" + + @property + @override + def _default_params(self): + """ + Get the default parameters for calling OpenAI API. + gpt-35-turbo does not support best_of, logprobs, or echo. + """ + normal_params = { + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "top_p": self.top_p, + "frequency_penalty": self.frequency_penalty, + "presence_penalty": self.presence_penalty, + "n": self.n, + "request_timeout": self.request_timeout, + "logit_bias": self.logit_bias, + } + return {**normal_params, **self.model_kwargs} + + def validate_parameter_range(namespace) -> None: """Validate that max_tokens is in [1,4000], temperature and top_p are in [0,1], and frequency_penalty and presence_penalty are in [0,2]""" _range_validation(namespace.max_tokens, "max-tokens", C.MAX_TOKENS_MIN, C.MAX_TOKENS_MAX) @@ -58,6 +148,7 @@ def _ask( top_p=C.TOP_P_DEFAULT, frequency_penalty=C.FREQUENCY_PENALTY_DEFAULT, presence_penalty=C.PRESENCE_PENALTY_DEFAULT, + files=None, ) -> Dict[str, str]: """Ask GPT a question. @@ -68,18 +159,22 @@ def _ask( top_p (float): This value also determines the level or randomness. frequency_penalty (float): The chance of repeating a token based on current frequency in the text. presence_penalty (float): The chance of repeating any token that has appeared in the text so far. + files (List[str]): The files to search. Yields: dict[str, str]: The response from GPT. """ - response = _call_gpt( - prompt=question[0], - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - ) + if files: + response = _ask_doc(question, files) + else: + response = _call_gpt( + prompt=question[0], + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + ) return {"response": response} @@ -261,3 +356,10 @@ def load_arguments(loader: CLICommandsLoader) -> None: help="Reduce the chance of repeating any token that has appeared in the text so far.", validator=validate_parameter_range, ) + args.argument( + "files", + type=str, + help="Ask question about a file. Can be used multiple times.", + default=None, + action="append", + ) diff --git a/tests/conftest.py b/tests/conftest.py index 6eee14de..77fc4d7a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,6 +16,15 @@ class MockResponse: def __init__(self) -> None: self.choices = [namedtuple("mockMessage", "message")(*[namedtuple("mockContent", "content")(*[["test"]])])] + class MockQueryResponse: + def __init__(self) -> None: + self.response = "test" + + class MockIndex: + def query(self, question: str) -> MockQueryResponse: + assert isinstance(question, str) + return MockQueryResponse() + def mock_create( engine, messages, @@ -27,4 +36,8 @@ def mock_create( ) -> MockResponse: return MockResponse() + def from_documents(documents, service_context=None) -> MockIndex: + return MockIndex() + monkeypatch.setattr("openai.ChatCompletion.create", mock_create) + monkeypatch.setattr("llama_index.GPTSimpleVectorIndex.from_documents", from_documents) diff --git a/tests/test_gpt_cli.py b/tests/test_gpt_cli.py index 771180d0..0d227f06 100644 --- a/tests/test_gpt_cli.py +++ b/tests/test_gpt_cli.py @@ -101,6 +101,7 @@ class CLICase: CLICase( f"""ask how are you --max-tokens {C.MAX_TOKENS_DEFAULT} --top-p {C.TOP_P_DEFAULT} --frequency-penalty {C.FREQUENCY_PENALTY_DEFAULT} --presence-penalty {C.FREQUENCY_PENALTY_MAX}""" ), + CLICase("ask --files review.py --files review.py what programming language is this code written in?"), ] ARGS = ROOT_COMMANDS + ASK_COMMANDS