diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index fda6ded3a19f..8e2b8584c442 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -266,6 +266,7 @@ ALL_TESTS = [ "//pkg/sql/doctor:doctor_test", "//pkg/sql/enum:enum_test", "//pkg/sql/execinfra:execinfra_test", + "//pkg/sql/execinfrapb:execinfrapb_disallowed_imports_test", "//pkg/sql/execinfrapb:execinfrapb_test", "//pkg/sql/execstats:execstats_test", "//pkg/sql/flowinfra:flowinfra_test", diff --git a/pkg/cmd/generate-test-suites/main.go b/pkg/cmd/generate-test-suites/main.go index 3a2a62b867dc..272f44d8cd40 100644 --- a/pkg/cmd/generate-test-suites/main.go +++ b/pkg/cmd/generate-test-suites/main.go @@ -39,15 +39,19 @@ func main() { packagesToQuery = append(packagesToQuery, fmt.Sprintf("//pkg/%s/...", info.Name())) } allPackages := strings.Join(packagesToQuery, "+") - queryArgs := []string{"query", fmt.Sprintf("kind(go_test, %s)", allPackages), "--output=label"} - buf, err := exec.Command("bazel", queryArgs...).Output() + cmd := exec.Command( + "bazel", "query", + fmt.Sprintf(`kind("(go|sh)_test", %s)`, allPackages), + "--output=label", + ) + buf, err := cmd.Output() if err != nil { log.Printf("Could not query Bazel tests: got error %v", err) var cmderr *exec.ExitError if errors.As(err, &cmderr) { log.Printf("Got error output: %s", string(cmderr.Stderr)) } else { - log.Printf("Run `bazel %s` to reproduce the failure", shellescape.QuoteCommand(queryArgs)) + log.Printf("Run `%s` to reproduce the failure", shellescape.QuoteCommand(cmd.Args)) } os.Exit(1) } diff --git a/pkg/sql/execinfrapb/BUILD.bazel b/pkg/sql/execinfrapb/BUILD.bazel index 157f28c94fa7..6b9e427febe1 100644 --- a/pkg/sql/execinfrapb/BUILD.bazel +++ b/pkg/sql/execinfrapb/BUILD.bazel @@ -1,6 +1,7 @@ load("@rules_proto//proto:defs.bzl", "proto_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//pkg/testutils/buildutil:buildutil.bzl", "disallowed_imports_test") go_library( name = "execinfrapb", @@ -142,3 +143,8 @@ go_proto_library( "@com_github_gogo_protobuf//gogoproto", ], ) + +disallowed_imports_test( + "execinfrapb", + ["//pkg/sql/sem/builtins"], +) diff --git a/pkg/testutils/buildutil/buildutil.bzl b/pkg/testutils/buildutil/buildutil.bzl new file mode 100644 index 000000000000..a64a30568daf --- /dev/null +++ b/pkg/testutils/buildutil/buildutil.bzl @@ -0,0 +1,64 @@ +load("@io_bazel_rules_go//go:def.bzl", "GoLibrary") + +# This file contains a single macro disallowed_imports_test which internally +# generates a sh_test which ensures that the label provided as the first arg +# does not import any of the labels provided as a list in the second arg. + +# _DepsInfo is used in the _deps_aspect to pick up all the transitive +# dependencies of a go package. +_DepsInfo = provider( + fields = {'deps' : 'depset of targets'} +) + +def _deps_aspect_impl(target, ctx): + return [_DepsInfo( + deps = depset( + [target], + transitive = [dep[_DepsInfo].deps for dep in ctx.rule.attr.deps], + ) + )] + +_deps_aspect = aspect( + implementation = _deps_aspect_impl, + attr_aspects = ['deps'], + provides = [_DepsInfo], +) + +def _deps_rule_impl(ctx): + deps = {k: None for k in ctx.attr.src[_DepsInfo].deps.to_list()} + data = "" + failed = [p for p in ctx.attr.disallowed if p in deps] + if failed: + failures = [ + """echo >&2 "ERROR: {0} imports {1} +\tcheck: bazel query 'somepath({0}, {1})'"\ +""".format( + ctx.attr.src.label, d.label, + ) for d in failed + ] + data = "\n".join(failures + ["exit 1"]) + f = ctx.actions.declare_file(ctx.attr.name + "_deps_test.sh") + ctx.actions.write(f, data) + return [ + DefaultInfo(executable = f), + ] + +_deps_rule = rule( + implementation = _deps_rule_impl, + executable = True, + attrs = { + 'src' : attr.label(aspects = [_deps_aspect], providers = [GoLibrary]), + 'disallowed': attr.label_list(providers = [GoLibrary]), + }, +) + +def disallowed_imports_test(src, disallowed): + script = src.strip(":") + "_disallowed_imports_script" + _deps_rule(name = script, src = src, disallowed = disallowed) + native.sh_test( + name = src.strip(":") + "_disallowed_imports_test", + srcs = [":"+script], + tags = ["local"], + ) + +