diff --git a/src/test/py/bazel/BUILD b/src/test/py/bazel/BUILD index 4fa55fe5dd2a5d..38f6e588d27108 100644 --- a/src/test/py/bazel/BUILD +++ b/src/test/py/bazel/BUILD @@ -352,3 +352,17 @@ py_test( ":test_base", ], ) + +py_test( + name = "bzlmod_credentials_test", + size = "large", + srcs = ["bzlmod/bzlmod_credentials_test.py"], + tags = [ + "no_windows", # test uses a Python script as a credential helper + "requires-network", + ], + deps = [ + ":bzlmod_test_utils", + ":test_base", + ], +) diff --git a/src/test/py/bazel/bzlmod/bzlmod_credentials_test.py b/src/test/py/bazel/bzlmod/bzlmod_credentials_test.py new file mode 100644 index 00000000000000..bace01983132db --- /dev/null +++ b/src/test/py/bazel/bzlmod/bzlmod_credentials_test.py @@ -0,0 +1,170 @@ +# pylint: disable=g-backslash-continuation +# Copyright 2023 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests using credentials to connect to the bzlmod registry.""" + +import base64 +import os +import tempfile +import unittest + +from src.test.py.bazel import test_base +from src.test.py.bazel.bzlmod.test_utils import BazelRegistry +from src.test.py.bazel.bzlmod.test_utils import StaticHTTPServer + + +class BzlmodCredentialsTest(test_base.TestBase): + """Test class for using credentials to connect to the bzlmod registry.""" + + def setUp(self): + test_base.TestBase.setUp(self) + self.registries_work_dir = tempfile.mkdtemp(dir=self._test_cwd) + self.registry_root = os.path.join(self.registries_work_dir, 'main') + self.main_registry = BazelRegistry(self.registry_root) + self.main_registry.createCcModule('aaa', '1.0') + + self.ScratchFile( + '.bazelrc', + [ + # In ipv6 only network, this has to be enabled. + # 'startup --host_jvm_args=-Djava.net.preferIPv6Addresses=true', + 'common --enable_bzlmod', + # Disable yanked version check so we are not affected BCR changes. + 'common --allow_yanked_versions=all', + ], + ) + self.ScratchFile('WORKSPACE') + # The existence of WORKSPACE.bzlmod prevents WORKSPACE prefixes or suffixes + # from being used; this allows us to test built-in modules actually work + self.ScratchFile('WORKSPACE.bzlmod') + self.ScratchFile( + 'MODULE.bazel', + [ + 'bazel_dep(name = "aaa", version = "1.0")', + ], + ) + self.ScratchFile( + 'BUILD', + [ + 'cc_binary(', + ' name = "main",', + ' srcs = ["main.cc"],', + ' deps = ["@aaa//:lib_aaa"],', + ')', + ], + ) + self.ScratchFile( + 'main.cc', + [ + '#include "aaa.h"', + 'int main() {', + ' hello_aaa("main function");', + '}', + ], + ) + self.ScratchFile( + 'credhelper', + [ + '#!/usr/bin/env python3', + 'import sys', + 'if "127.0.0.1" in sys.stdin.read():', + ' print("""{"headers":{"Authorization":["Bearer TOKEN"]}}""")', + 'else:', + ' print("""{}""")', + ], + executable=True, + ) + self.ScratchFile( + '.netrc', + [ + 'machine 127.0.0.1', + 'login foo', + 'password bar', + ], + ) + + def testUnauthenticated(self): + with StaticHTTPServer(self.registry_root) as static_server: + _, stdout, _ = self.RunBazel([ + 'run', + '--registry=' + static_server.getURL(), + '--registry=https://bcr.bazel.build', + '//:main', + ]) + self.assertIn('main function => aaa@1.0', stdout) + + def testMissingCredentials(self): + with StaticHTTPServer( + self.registry_root, expected_auth='Bearer TOKEN' + ) as static_server: + _, _, stderr = self.RunBazel( + [ + 'run', + '--registry=' + static_server.getURL(), + '--registry=https://bcr.bazel.build', + '//:main', + ], + allow_failure=True, + ) + self.assertIn('GET returned 401 Unauthorized', '\n'.join(stderr)) + + def testCredentialsFromHelper(self): + with StaticHTTPServer( + self.registry_root, expected_auth='Bearer TOKEN' + ) as static_server: + _, stdout, _ = self.RunBazel([ + 'run', + '--experimental_credential_helper=%workspace%/credhelper', + '--registry=' + static_server.getURL(), + '--registry=https://bcr.bazel.build', + '//:main', + ]) + self.assertIn('main function => aaa@1.0', stdout) + + def testCredentialsFromNetrc(self): + expected_auth = 'Basic ' + base64.b64encode(b'foo:bar').decode('ascii') + + with StaticHTTPServer( + self.registry_root, expected_auth=expected_auth + ) as static_server: + _, stdout, _ = self.RunBazel( + [ + 'run', + '--registry=' + static_server.getURL(), + '--registry=https://bcr.bazel.build', + '//:main', + ], + env_add={'NETRC': self.Path('.netrc')}, + ) + self.assertIn('main function => aaa@1.0', stdout) + + def testCredentialsFromHelperOverrideNetrc(self): + with StaticHTTPServer( + self.registry_root, expected_auth='Bearer TOKEN' + ) as static_server: + _, stdout, _ = self.RunBazel( + [ + 'run', + '--experimental_credential_helper=%workspace%/credhelper', + '--registry=' + static_server.getURL(), + '--registry=https://bcr.bazel.build', + '//:main', + ], + env_add={'NETRC': self.Path('.netrc')}, + ) + self.assertIn('main function => aaa@1.0', stdout) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/test/py/bazel/bzlmod/test_utils.py b/src/test/py/bazel/bzlmod/test_utils.py index dd15b9082e2f72..aa90f0eeb87c13 100644 --- a/src/test/py/bazel/bzlmod/test_utils.py +++ b/src/test/py/bazel/bzlmod/test_utils.py @@ -16,11 +16,14 @@ """Test utils for Bzlmod.""" import base64 +import functools import hashlib +import http.server import json import os import pathlib import shutil +import threading import urllib.request import zipfile @@ -299,3 +302,62 @@ def createLocalPathModule(self, name, version, path, deps=None): with module_dir.joinpath('source.json').open('w') as f: json.dump(source, f, indent=4, sort_keys=True) + + +class StaticHTTPServer: + """An HTTP server serving static files, optionally with authentication.""" + + def __init__(self, root_directory, expected_auth=None): + self.root_directory = root_directory + self.expected_auth = expected_auth + + def __enter__(self): + address = ('localhost', 0) # assign random port + handler = functools.partial( + _Handler, self.root_directory, self.expected_auth + ) + self.httpd = http.server.HTTPServer(address, handler) + self.thread = threading.Thread(target=self.httpd.serve_forever, daemon=True) + self.thread.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.httpd.shutdown() + self.thread.join() + + def getURL(self): + return 'http://{}:{}'.format(*self.httpd.server_address) + + +class _Handler(http.server.SimpleHTTPRequestHandler): + """A SimpleHTTPRequestHandler with authentication.""" + + # Note: until Python 3.6, SimpleHTTPRequestHandler was only able to serve + # files from the working directory. A 'directory' parameter was added in + # Python 3.7, but sadly our CI builds are stuck with Python 3.6. Instead, + # we monkey-patch translate_path() to rewrite the path. + + def __init__(self, root_directory, expected_auth, *args, **kwargs): + self.root_directory = root_directory + self.expected_auth = expected_auth + super().__init__(*args, **kwargs) + + def translate_path(self, path): + abs_path = super().translate_path(path) + rel_path = os.path.relpath(abs_path, os.getcwd()) + return os.path.join(self.root_directory, rel_path) + + def check_auth(self): + auth_header = self.headers.get('Authorization', None) + if auth_header != self.expected_auth: + self.send_error(http.HTTPStatus.UNAUTHORIZED) + return False + return True + + def do_HEAD(self): + if self.check_auth(): + return super().do_HEAD() + + def do_GET(self): + if self.check_auth(): + return super().do_GET()