diff --git a/slither/__main__.py b/slither/__main__.py index 8e32a1c433..9426a6de9d 100644 --- a/slither/__main__.py +++ b/slither/__main__.py @@ -305,7 +305,35 @@ def parse_args( "--codex", help="Enable codex (require an OpenAI API Key)", action="store_true", - default=False, + default=defaults_flag_in_config["codex"], + ) + + parser.add_argument( + "--codex-contracts", + help="Comma separated list of contracts to submit to OpenAI Codex", + action="store", + default=defaults_flag_in_config["codex_contracts"], + ) + + parser.add_argument( + "--codex-model", + help="Name of the Codex model to use (affects pricing). Defaults to 'text-davinci-003'", + action="store", + default=defaults_flag_in_config["codex_model"], + ) + + parser.add_argument( + "--codex-temperature", + help="Temperature to use with Codex. Lower number indicates a more precise answer while higher numbers return more creative answers. Defaults to 0", + action="store", + default=defaults_flag_in_config["codex_temperature"], + ) + + parser.add_argument( + "--codex-max-tokens", + help="Maximum amount of tokens to use on the response. This number plus the size of the prompt can be no larger than the limit (4097 for text-davinci-003)", + action="store", + default=defaults_flag_in_config["codex_max_tokens"], ) cryticparser.init(parser) diff --git a/slither/detectors/functions/codex.py b/slither/detectors/functions/codex.py index daed064252..4dca447756 100644 --- a/slither/detectors/functions/codex.py +++ b/slither/detectors/functions/codex.py @@ -7,6 +7,7 @@ logger = logging.getLogger("Slither") +VULN_FOUND = "VULN_FOUND" class Codex(AbstractDetector): """ @@ -52,29 +53,63 @@ def _detect(self) -> List[Output]: openai.api_key = api_key for contract in self.compilation_unit.contracts: - prompt = "Is there a vulnerability in this solidity contracts?\n" + if self.slither.codex_contracts != "all" and contract.name not in self.slither.codex_contracts.split(","): + continue + prompt = "Analyze this Solidity contract and find the vulnerabilities. If you find any vulnerabilities, begin the response with {}".format(VULN_FOUND) src_mapping = contract.source_mapping content = contract.compilation_unit.core.source_code[src_mapping.filename.absolute] start = src_mapping.start end = src_mapping.start + src_mapping.length prompt += content[start:end] - answer = openai.Completion.create( # type: ignore - model="text-davinci-003", prompt=prompt, temperature=0, max_tokens=200 - ) - - if "choices" in answer: - if answer["choices"]: - if "text" in answer["choices"][0]: - if "Yes," in answer["choices"][0]["text"]: - info = [ - "Codex detected a potential bug in ", - contract, - "\n", - answer["choices"][0]["text"], - "\n", - ] - - res = self.generate_result(info) - results.append(res) - + logging.info("Querying OpenAI") + print("Querying OpenAI") + answer = "" + res = {} + try: + res = openai.Completion.create( # type: ignore + prompt=prompt, + model=self.slither.codex_model, + temperature=self.slither.codex_temperature, + max_tokens=self.slither.codex_max_tokens, + ) + except Exception as e: + print("OpenAI request failed: " + str(e)) + logging.info("OpenAI request failed: " + str(e)) + + """ OpenAI completion response shape example: + { + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": null, + "text": "VULNERABILITIES:. The withdraw() function does not check..." + } + ], + "created": 1670357537, + "id": "cmpl-6KYaXdA6QIisHlTMM7RCJ1nR5wTKx", + "model": "text-davinci-003", + "object": "text_completion", + "usage": { + "completion_tokens": 80, + "prompt_tokens": 249, + "total_tokens": 329 + } + } """ + + if len(res.get("choices", [])) and VULN_FOUND in res["choices"][0].get("text", ""): + # remove VULN_FOUND keyword and cleanup + answer = res["choices"][0]["text"].replace(VULN_FOUND, "").replace("\n", "").replace(": ", "") + + if len(answer): + info = [ + "Codex detected a potential bug in ", + contract, + "\n", + answer, + "\n", + ] + + res = self.generate_result(info) + results.append(res) return results diff --git a/slither/slither.py b/slither/slither.py index a61e8255ff..0b1f57a37e 100644 --- a/slither/slither.py +++ b/slither/slither.py @@ -83,8 +83,12 @@ def __init__(self, target: Union[str, CryticCompile], **kwargs): self.line_prefix = kwargs.get("change_line_prefix", "#") - # Indicate if codex-related features should be used + # Indicate if Codex related features should be used self.codex_enabled = kwargs.get("codex", False) + self.codex_contracts = kwargs.get("codex_contracts") + self.codex_model = kwargs.get("codex_model") + self.codex_temperature = kwargs.get("codex_temperature") + self.codex_max_tokens = kwargs.get("codex_max_tokens") self._parsers: List[SlitherCompilationUnitSolc] = [] try: diff --git a/slither/utils/command_line.py b/slither/utils/command_line.py index c2fef5eca0..f774437a1a 100644 --- a/slither/utils/command_line.py +++ b/slither/utils/command_line.py @@ -29,6 +29,11 @@ # Those are the flags shared by the command line and the config file defaults_flag_in_config = { + "codex": False, + "codex_contracts": "all", + "codex_model": "text-davinci-003", + "codex_temperature": 0, + "codex_max_tokens": 300, "detectors_to_run": "all", "printers_to_run": None, "detectors_to_exclude": None,