diff --git a/test_all.py b/test_all.py index 2193cfd26564..bf47a15fa9d5 100644 --- a/test_all.py +++ b/test_all.py @@ -62,7 +62,8 @@ def test_return_correct_commit_hash(self): @patch('truffleHog.truffleHog.clone_git_repo') @patch('truffleHog.truffleHog.Repo') - def test_branch(self, repo_const_mock, clone_git_repo): + @patch('shutil.rmtree') + def test_branch(self, rmtree_mock, repo_const_mock, clone_git_repo): repo = MagicMock() repo_const_mock.return_value = repo truffleHog.find_strings("test_repo", branch="testbranch") @@ -130,5 +131,13 @@ def test_path_included(self): + @patch('truffleHog.truffleHog.clone_git_repo') + @patch('truffleHog.truffleHog.Repo') + @patch('shutil.rmtree') + def test_repo_path(self, rmtree_mock, repo_const_mock, clone_git_repo): + truffleHog.find_strings("test_repo", repo_path="test/path/") + rmtree_mock.assert_not_called() + clone_git_repo.assert_not_called() + if __name__ == '__main__': unittest.main() diff --git a/truffleHog/truffleHog.py b/truffleHog/truffleHog.py index c1921b78a28d..a0a6060032f4 100644 --- a/truffleHog/truffleHog.py +++ b/truffleHog/truffleHog.py @@ -39,6 +39,8 @@ def main(): 'in order for it to be scanned; lines starting with "#" are treated as comments and are ' 'ignored. If empty or not provided (default), no Git object paths are excluded unless ' 'effectively excluded via the --include_paths option.') + parser.add_argument("--repo_path", type=str, dest="repo_path", help="Path to the cloned repo. If provided, git_url will not be used") + parser.add_argument("--cleanup", dest="cleanup", action="store_true", help="Clean up all temporary result files") parser.add_argument('git_url', type=str, help='URL for secret searching') parser.set_defaults(regex=False) parser.set_defaults(rules={}) @@ -46,6 +48,8 @@ def main(): parser.set_defaults(since_commit=None) parser.set_defaults(entropy=True) parser.set_defaults(branch=None) + parser.set_defaults(repo_path=None) + parser.set_defaults(cleanup=False) args = parser.parse_args() rules = {} if args.rules: @@ -75,9 +79,10 @@ def main(): path_exclusions.append(re.compile(pattern)) output = find_strings(args.git_url, args.since_commit, args.max_depth, args.output_json, args.do_regex, do_entropy, - surpress_output=False, branch=args.branch, path_inclusions=path_inclusions, path_exclusions=path_exclusions) + surpress_output=False, branch=args.branch, repo_path=args.repo_path, path_inclusions=path_inclusions, path_exclusions=path_exclusions) project_path = output["project_path"] - shutil.rmtree(project_path, onerror=del_rw) + if args.cleanup: + clean_up(output) if output["foundIssues"]: sys.exit(1) else: @@ -296,9 +301,12 @@ def path_included(blob, include_patterns=None, exclude_patterns=None): def find_strings(git_url, since_commit=None, max_depth=1000000, printJson=False, do_regex=False, do_entropy=True, surpress_output=True, - custom_regexes={}, branch=None, path_inclusions=None, path_exclusions=None): + custom_regexes={}, branch=None, repo_path=None, path_inclusions=None, path_exclusions=None): output = {"foundIssues": []} - project_path = clone_git_repo(git_url) + if repo_path: + project_path = repo_path + else: + project_path = clone_git_repo(git_url) repo = Repo(project_path) already_searched = set() output_dir = tempfile.mkdtemp() @@ -343,12 +351,12 @@ def find_strings(git_url, since_commit=None, max_depth=1000000, printJson=False, output["project_path"] = project_path output["clone_uri"] = git_url output["issues_path"] = output_dir + if not repo_path: + shutil.rmtree(project_path, onerror=del_rw) return output def clean_up(output): - project_path = output.get("project_path", None) - if project_path and os.path.isdir(project_path): - shutil.rmtree(output["project_path"]) + print("Whhaat") issues_path = output.get("issues_path", None) if issues_path and os.path.isdir(issues_path): shutil.rmtree(output["issues_path"])