diff --git a/tests/python/unittest/test_ci.py b/tests/python/unittest/test_ci.py index 1e6192349d5c..7552c9cb6af3 100644 --- a/tests/python/unittest/test_ci.py +++ b/tests/python/unittest/test_ci.py @@ -15,14 +15,10 @@ # specific language governing permissions and limitations # under the License. -import os import pathlib -import pytest -import shutil import subprocess -import tempfile -import json import sys +import tempfile import pytest @@ -30,7 +26,7 @@ def test_skip_ci(): - skip_ci_script = REPO_ROOT / "tests" / "scripts" / "git_skip_ci.sh" + skip_ci_script = REPO_ROOT / "tests" / "scripts" / "git_skip_ci.py" class TempGit: def __init__(self, cwd): @@ -47,14 +43,17 @@ def test(commands, should_skip, why): # Jenkins git is too old and doesn't have 'git init --initial-branch' git.run("init") git.run("checkout", "-b", "main") + git.run("remote", "add", "origin", "https://github.com/apache/tvm.git") git.run("config", "user.name", "ci") git.run("config", "user.email", "email@example.com") git.run("commit", "--allow-empty", "--message", "base commit") for command in commands: git.run(*command) pr_number = "1234" - proc = subprocess.run([str(skip_ci_script), pr_number], cwd=dir) - expected = 1 if should_skip else 0 + proc = subprocess.run( + [str(skip_ci_script), "--pr", pr_number, "--pr-title", "[skip ci] test"], cwd=dir + ) + expected = 0 if should_skip else 1 assert proc.returncode == expected, why test( diff --git a/tests/scripts/git_skip_ci.py b/tests/scripts/git_skip_ci.py index 7946ea33872e..73fcc6490ab8 100755 --- a/tests/scripts/git_skip_ci.py +++ b/tests/scripts/git_skip_ci.py @@ -74,6 +74,9 @@ def git(command): parser = argparse.ArgumentParser(description=help) parser.add_argument("--pr", required=True) parser.add_argument("--remote", default="origin", help="ssh remote to parse") + parser.add_argument( + "--pr-title", help="(testing) PR title to use instead of fetching from GitHub" + ) args = parser.parse_args() branch = git(["rev-parse", "--abbrev-ref", "HEAD"]) @@ -83,10 +86,15 @@ def git(command): def check_pr_title(): remote = git(["config", "--get", f"remote.{args.remote}.url"]) user, repo = parse_remote(remote) - github = GitHubRepo(token=os.environ["TOKEN"], user=user, repo=repo) - pr = github.get(f"pulls/{args.pr}") - print("pr title:", pr["title"]) - return pr["title"].startswith("[skip ci]") + + if args.pr_title: + title = args.pr_title + else: + github = GitHubRepo(token=os.environ["TOKEN"], user=user, repo=repo) + pr = github.get(f"pulls/{args.pr}") + title = pr["title"] + print("pr title:", title) + return title.startswith("[skip ci]") if ( args.pr != "null"