diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d5c0fc03..42f9043d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -52,10 +52,6 @@ jobs: with: name: test-results path: test-results/ - - name: yapf - run: | - pip install yapf - yapf --diff --recursive nss_cache nsscache - name: pylint run: | pip install pylint diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..075b12e3 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,25 @@ +name: lint + +on: + push: + tags: + - v* + branches: + - main + pull_request: + +permissions: + # none-all, which doesn't exist, but + # https://docs.github.com/en/actions/reference/authentication-in-a-workflow#using-the-github_token-in-a-workflow + # implies that the token still gets created. Elsewhere we learn that any + # permission not mentioned here gets turned to `none`. + actions: none + +jobs: + black: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: psf/black@stable + with: + version: "~= 22.0" diff --git a/examples/authorized-keys-command.py b/examples/authorized-keys-command.py index 0f1c04c4..b456d72d 100755 --- a/examples/authorized-keys-command.py +++ b/examples/authorized-keys-command.py @@ -44,20 +44,34 @@ import copy import textwrap -DEFAULT_SSHKEY_CACHE = '/etc/sshkey.cache' +DEFAULT_SSHKEY_CACHE = "/etc/sshkey.cache" -REGEX_BASE64 = r'(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?' +REGEX_BASE64 = r"(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?" # All of the SSH blobs starts with 3 null bytes , which encode to 'AAAA' in base64 -REGEX_BASE64_START3NULL = r'AAAA' + REGEX_BASE64 +REGEX_BASE64_START3NULL = r"AAAA" + REGEX_BASE64 # This regex needs a lot of work -KEYTYPE_REGEX_STRICT = r'\b(?:ssh-(?:rsa|dss|ed25519)|ecdsa-sha2-nistp(?:256|384|521))\b' +KEYTYPE_REGEX_STRICT = ( + r"\b(?:ssh-(?:rsa|dss|ed25519)|ecdsa-sha2-nistp(?:256|384|521))\b" +) # Docs: # http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xhtml#ssh-parameters-19 # RFC6187, etc -KEYTYPE_REGEX_LAZY_NOX509 = r'\b(?:(?:spki|pgp|x509|x509v3)-)?(?:(?:ssh|sign)-(?:rsa|dss|ed25519)|ecdsa-[0-9a-z-]+|rsa2048-sha256)(?:-cert-v01@openssh\.com|\@ssh\.com)?\b' -KEYTYPE_REGEX_LAZY_X509 = r'\bx509(?:v3)?-(?:(?:ssh|sign)-(?:rsa|dss|ed25519)|ecdsa-[0-9a-z-]+|rsa2048-sha256)(?:-cert-v01@openssh\.com|\@ssh\.com)?\b' -X509_WORDDN = r'(?:(?i)(?:Distinguished[ _-]?Name|DN|Subject)[=:]?)' # case insensitive! -KEY_REGEX = r'(.*)\s*(?:(' + KEYTYPE_REGEX_LAZY_NOX509 + r')\s+(' + REGEX_BASE64_START3NULL + r')\s*(.*)|(' + KEYTYPE_REGEX_LAZY_X509 + r')\s+(' + X509_WORDDN + '.*))' +KEYTYPE_REGEX_LAZY_NOX509 = r"\b(?:(?:spki|pgp|x509|x509v3)-)?(?:(?:ssh|sign)-(?:rsa|dss|ed25519)|ecdsa-[0-9a-z-]+|rsa2048-sha256)(?:-cert-v01@openssh\.com|\@ssh\.com)?\b" +KEYTYPE_REGEX_LAZY_X509 = r"\bx509(?:v3)?-(?:(?:ssh|sign)-(?:rsa|dss|ed25519)|ecdsa-[0-9a-z-]+|rsa2048-sha256)(?:-cert-v01@openssh\.com|\@ssh\.com)?\b" +X509_WORDDN = ( + r"(?:(?i)(?:Distinguished[ _-]?Name|DN|Subject)[=:]?)" # case insensitive! +) +KEY_REGEX = ( + r"(.*)\s*(?:(" + + KEYTYPE_REGEX_LAZY_NOX509 + + r")\s+(" + + REGEX_BASE64_START3NULL + + r")\s*(.*)|(" + + KEYTYPE_REGEX_LAZY_X509 + + r")\s+(" + + X509_WORDDN + + ".*))" +) # Group 1: options # Branch 1: @@ -76,16 +90,16 @@ def warning(*objs): """Helper function for output to stderr.""" - print('WARNING: ', *objs, file=sys.stderr) + print("WARNING: ", *objs, file=sys.stderr) def parse_key(full_key_line): """Explode an authorized_keys line including options into the various parts.""" - #print(KEY_REGEX) + # print(KEY_REGEX) m = re.match(KEY_REGEX, full_key_line) if m is None: - warning('Failed to match', full_key_line) + warning("Failed to match", full_key_line) return (None, None, None, None) options = m.group(1) key_type = m.group(2) @@ -98,24 +112,24 @@ def parse_key(full_key_line): return (options, key_type, blob, comment) -def fingerprint_key(keyblob, fingerprint_format='SHA256'): +def fingerprint_key(keyblob, fingerprint_format="SHA256"): """Generate SSH key fingerprints, using the requested format.""" # Don't try to fingerprint x509 blobs - if keyblob is None or not keyblob.startswith('AAAA'): + if keyblob is None or not keyblob.startswith("AAAA"): return None try: binary_blob = base64.b64decode(keyblob) except TypeError as e: warning(e, keyblob) return None - if fingerprint_format == 'MD5': + if fingerprint_format == "MD5": raw = hashlib.md5(binary_blob).digest() - return 'MD5:' + ':'.join('{:02x}'.format(ord(c)) for c in raw) - elif fingerprint_format in ['SHA256', 'SHA512', 'SHA1']: + return "MD5:" + ":".join("{:02x}".format(ord(c)) for c in raw) + elif fingerprint_format in ["SHA256", "SHA512", "SHA1"]: h = hashlib.new(fingerprint_format) h.update(binary_blob) raw = h.digest() - return fingerprint_format + ':' + base64.b64encode(raw).rstrip('=') + return fingerprint_format + ":" + base64.b64encode(raw).rstrip("=") return None @@ -123,11 +137,11 @@ def detect_fingerprint_format(fpr): """Given a fingerprint, try to detect what fingerprint format is used.""" if fpr is None: return None - for prefix in ['SHA256', 'SHA512', 'SHA1', 'MD5']: - if fpr.startswith(prefix + ':'): + for prefix in ["SHA256", "SHA512", "SHA1", "MD5"]: + if fpr.startswith(prefix + ":"): return prefix - if re.match(r'^(MD5:)?([0-9a-f]{2}:)+[0-9a-f]{2}$', fpr) is not None: - return 'MD5' + if re.match(r"^(MD5:)?([0-9a-f]{2}:)+[0-9a-f]{2}$", fpr) is not None: + return "MD5" # Cannot detect the format return None @@ -136,128 +150,119 @@ def validate_key(candidate_key, conditions, strict=False): # pylint: disable=invalid-name,line-too-long,too-many-locals """Validate a potential authorized_key line against multiple conditions.""" # Explode the key - (candidate_key_options, \ - candidate_key_type, \ - candidate_key_blob, \ - candidate_key_comment) = parse_key(candidate_key) + ( + candidate_key_options, + candidate_key_type, + candidate_key_blob, + candidate_key_comment, + ) = parse_key(candidate_key) # Set up our conditions with their defaults - key_type = conditions.get('key_type', None) - key_blob = conditions.get('key_blob', None) - key_fingerprint = conditions.get('key_fingerprint', None) - key_options_re = conditions.get('key_options_re', None) - key_comment_re = conditions.get('key_comment_re', None) + key_type = conditions.get("key_type", None) + key_blob = conditions.get("key_blob", None) + key_fingerprint = conditions.get("key_fingerprint", None) + key_options_re = conditions.get("key_options_re", None) + key_comment_re = conditions.get("key_comment_re", None) # Try to detect the fingerprint format fingerprint_format = detect_fingerprint_format(key_fingerprint) # Force MD5 prefix on old fingerprints - if fingerprint_format is 'MD5': - if not key_fingerprint.startswith('MD5:'): - key_fingerprint = 'MD5:' + key_fingerprint + if fingerprint_format is "MD5": + if not key_fingerprint.startswith("MD5:"): + key_fingerprint = "MD5:" + key_fingerprint # The OpenSSH base64 fingerprints drops the trailing padding, ensure we do # the same on provided input - if fingerprint_format is not 'MD5' \ - and key_fingerprint is not None: - key_fingerprint = key_fingerprint.rstrip('=') + if fingerprint_format is not "MD5" and key_fingerprint is not None: + key_fingerprint = key_fingerprint.rstrip("=") # Build the fingerprint for the candidate key # (the func does the padding strip as well) - candidate_key_fingerprint = \ - fingerprint_key(candidate_key_blob, - fingerprint_format) + candidate_key_fingerprint = fingerprint_key(candidate_key_blob, fingerprint_format) match = True strict_pass = False - if key_type is not None and \ - candidate_key_type is not None: + if key_type is not None and candidate_key_type is not None: strict_pass = True - match = match and \ - (candidate_key_type == key_type) - if key_fingerprint is not None and \ - candidate_key_fingerprint is not None: + match = match and (candidate_key_type == key_type) + if key_fingerprint is not None and candidate_key_fingerprint is not None: strict_pass = True - match = match and \ - (candidate_key_fingerprint == key_fingerprint) - if key_blob is not None and \ - candidate_key_blob is not None: + match = match and (candidate_key_fingerprint == key_fingerprint) + if key_blob is not None and candidate_key_blob is not None: strict_pass = True - match = match and \ - (candidate_key_blob == key_blob) - if key_comment_re is not None and \ - candidate_key_comment is not None: + match = match and (candidate_key_blob == key_blob) + if key_comment_re is not None and candidate_key_comment is not None: strict_pass = True - match = match and \ - key_comment_re.search(candidate_key_comment) is not None + match = match and key_comment_re.search(candidate_key_comment) is not None if key_options_re is not None: strict_pass = True - match = match and \ - key_options_re.search(candidate_key_options) is not None + match = match and key_options_re.search(candidate_key_options) is not None if strict: return match and strict_pass return match -PROG_EPILOG = textwrap.dedent("""\ +PROG_EPILOG = textwrap.dedent( + """\ Strict match will require that at least one condition matched. Conditions marked with X may not work correctly with X509 authorized_keys lines. -""") -PROG_DESC = 'OpenSSH AuthorizedKeysCommand to read from cached keys file' +""" +) +PROG_DESC = "OpenSSH AuthorizedKeysCommand to read from cached keys file" -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser( - prog='AUTHKEYCMD', + prog="AUTHKEYCMD", description=PROG_DESC, epilog=PROG_EPILOG, formatter_class=argparse.RawDescriptionHelpFormatter, - add_help=False) + add_help=False, + ) # Arguments - group = parser.add_argument_group('Mandatory arguments') - group.add_argument('username', - metavar='USERNAME', - nargs='?', - type=str, - help='Username') - group.add_argument('--username', - metavar='USERNAME', - dest='username_opt', - type=str, - help='Username (alternative form)') + group = parser.add_argument_group("Mandatory arguments") + group.add_argument( + "username", metavar="USERNAME", nargs="?", type=str, help="Username" + ) + group.add_argument( + "--username", + metavar="USERNAME", + dest="username_opt", + type=str, + help="Username (alternative form)", + ) # Conditions - group = parser.add_argument_group('Match Conditions (optional)') - group.add_argument('--key-type', - metavar='KEY-TYPE', - type=str, - help='Key type') - group.add_argument('--key-fingerprint', - '--key-fp', - metavar='KEY-FP', - type=str, - help='Key fingerprint X') - group.add_argument('--key-blob', - metavar='KEY-BLOB', - type=str, - help='Key blob (Base64 section) X') - group.add_argument('--key-comment-re', - metavar='REGEX', - type=str, - help='Regex to match on comments X') - group.add_argument('--key-options-re', - metavar='REGEX', - type=str, - help='Regex to match on options') + group = parser.add_argument_group("Match Conditions (optional)") + group.add_argument("--key-type", metavar="KEY-TYPE", type=str, help="Key type") + group.add_argument( + "--key-fingerprint", + "--key-fp", + metavar="KEY-FP", + type=str, + help="Key fingerprint X", + ) + group.add_argument( + "--key-blob", metavar="KEY-BLOB", type=str, help="Key blob (Base64 section) X" + ) + group.add_argument( + "--key-comment-re", + metavar="REGEX", + type=str, + help="Regex to match on comments X", + ) + group.add_argument( + "--key-options-re", metavar="REGEX", type=str, help="Regex to match on options" + ) # Setup parameters: - group = parser.add_argument_group('Misc settings') + group = parser.add_argument_group("Misc settings") group.add_argument( - '--cache-file', - metavar='FILENAME', + "--cache-file", + metavar="FILENAME", default=DEFAULT_SSHKEY_CACHE, - type=argparse.FileType('r'), - help='Cache file [%s]' % (DEFAULT_SSHKEY_CACHE,), + type=argparse.FileType("r"), + help="Cache file [%s]" % (DEFAULT_SSHKEY_CACHE,), + ) + group.add_argument( + "--strict", action="store_true", default=False, help="Strict match required" ) - group.add_argument('--strict', - action='store_true', - default=False, - help='Strict match required') - group.add_argument('--help', action='help', default=False, help='This help') + group.add_argument("--help", action="help", default=False, help="This help") # Fire it all args = parser.parse_args() @@ -265,21 +270,19 @@ def validate_key(candidate_key, conditions, strict=False): lst = [args.username, args.username_opt] cnt = lst.count(None) if cnt == 2: - parser.error('Username was not specified') + parser.error("Username was not specified") elif cnt == 0: - parser.error( - 'Username must be specified either as an option XOR argument.') + parser.error("Username must be specified either as an option XOR argument.") else: args.username = [x for x in lst if x is not None][0] # Strict makes no sense without at least one condition being specified if args.strict: d = copy.copy(vars(args)) - for k in ['cache_file', 'strict', 'username']: + for k in ["cache_file", "strict", "username"]: d.pop(k, None) if not any(v is not None for v in list(d.values())): - parser.error( - 'At least one condition must be specified with --strict') + parser.error("At least one condition must be specified with --strict") if args.key_comment_re is not None: args.key_comment_re = re.compile(args.key_comment_re) @@ -288,28 +291,28 @@ def validate_key(candidate_key, conditions, strict=False): try: key_conditions = { - 'key_options_re': args.key_options_re, - 'key_type': args.key_type, - 'key_blob': args.key_blob, - 'key_fingerprint': args.key_fingerprint, - 'key_comment_re': args.key_comment_re, + "key_options_re": args.key_options_re, + "key_type": args.key_type, + "key_blob": args.key_blob, + "key_fingerprint": args.key_fingerprint, + "key_comment_re": args.key_comment_re, } with args.cache_file as f: for line in f: - (username, key) = line.split(':', 1) + (username, key) = line.split(":", 1) if username != args.username: continue key = key.strip() - if key.startswith('[') and key.endswith(']'): + if key.startswith("[") and key.endswith("]"): # Python array, but handle it safely! keys = [i.strip() for i in literal_eval(key)] else: # Raw key keys = [key.strip()] for k in keys: - if validate_key(candidate_key=k, - conditions=key_conditions, - strict=args.strict): + if validate_key( + candidate_key=k, conditions=key_conditions, strict=args.strict + ): print(k) except IOError as err: if err.errno in [errno.EPERM, errno.ENOENT]: diff --git a/nss_cache/__init__.py b/nss_cache/__init__.py index 798a256f..b71c516e 100644 --- a/nss_cache/__init__.py +++ b/nss_cache/__init__.py @@ -22,7 +22,9 @@ nss_cache package. """ -__author__ = ('jaq@google.com (Jamie Wilkinson)', - 'vasilios@google.com (Vasilios Hoffman)') +__author__ = ( + "jaq@google.com (Jamie Wilkinson)", + "vasilios@google.com (Vasilios Hoffman)", +) -__version__ = '0.48' +__version__ = "0.48" diff --git a/nss_cache/app.py b/nss_cache/app.py index d96ddcc0..77328cf4 100644 --- a/nss_cache/app.py +++ b/nss_cache/app.py @@ -19,8 +19,10 @@ responsible for updating or building local persistent cache. """ -__author__ = ('jaq@google.com (Jamie Wilkinson)', - 'vasilios@google.com (Vasilios Hoffman)') +__author__ = ( + "jaq@google.com (Jamie Wilkinson)", + "vasilios@google.com (Vasilios Hoffman)", +) import logging import logging.handlers @@ -52,9 +54,9 @@ class NssCacheLogger(BaseLoggingClass): def __init__(self, name): logging.Logger.__init__(self, name) logging.VERBOSE = logging.INFO - 1 - logging.addLevelName(logging.VERBOSE, 'VERBOSE') + logging.addLevelName(logging.VERBOSE, "VERBOSE") logging.DEBUG2 = logging.DEBUG - 1 - logging.addLevelName(logging.DEBUG2, 'DEBUG2') + logging.addLevelName(logging.DEBUG2, "DEBUG2") def verbose(self, msg, *args, **kwargs): self.log(logging.VERBOSE, msg, args, kwargs) @@ -80,10 +82,12 @@ def __init__(self): except ValueError: is_tty = False if is_tty: - format_str = ('%(levelname)-8s %(asctime)-15s ' - '%(filename)s:%(lineno)d: ' - '%(funcName)s: ' - '%(message)s') + format_str = ( + "%(levelname)-8s %(asctime)-15s " + "%(filename)s:%(lineno)d: " + "%(funcName)s: " + "%(message)s" + ) logging.basicConfig(format=format_str) # python2.3's basicConfig doesn't let you set the default level logger = logging.getLogger() @@ -91,15 +95,18 @@ def __init__(self): else: facility = logging.handlers.SysLogHandler.LOG_DAEMON try: - handler = logging.handlers.SysLogHandler(address='/dev/log', - facility=facility) + handler = logging.handlers.SysLogHandler( + address="/dev/log", facility=facility + ) except socket.error: - print('/dev/log could not be opened; falling back on stderr.') + print("/dev/log could not be opened; falling back on stderr.") # Omitting an argument to StreamHandler results in sys.stderr being # used. handler = logging.StreamHandler() - format_str = (os.path.basename(sys.argv[0]) + - '[%(process)d]: %(levelname)s %(message)s') + format_str = ( + os.path.basename(sys.argv[0]) + + "[%(process)d]: %(levelname)s %(message)s" + ) fmt = logging.Formatter(format_str) handler.setFormatter(fmt) handler.setLevel(level=logging.INFO) @@ -116,32 +123,37 @@ def _GetParser(self): # OptionParser is from standard python module optparse OptionParser """ - usage = ('nsscache synchronises a local NSS cache against a ' - 'remote data source.\n' - '\n' - 'Usage: nsscache [global options] command [command options]\n' - '\n' - 'commands:\n') + usage = ( + "nsscache synchronises a local NSS cache against a " + "remote data source.\n" + "\n" + "Usage: nsscache [global options] command [command options]\n" + "\n" + "commands:\n" + ) command_descriptions = [] for (name, cls) in list(command.__dict__.items()): # skip the command base object - if name == 'Command': + if name == "Command": continue - if hasattr(cls, 'Help'): + if hasattr(cls, "Help"): short_help = cls().Help(short=True) - command_descriptions.append(' %-21s %.40s' % - (name.lower(), short_help.lower())) - - usage += '\n'.join(command_descriptions) - version_string = ('nsscache ' + nss_cache.__version__ + '\n' - '\n' - 'Copyright (c) 2007 Google, Inc.\n' - 'This is free software; see the source for copying ' - 'conditions. There is NO\n' - 'warranty; not even for MERCHANTABILITY or FITNESS ' - 'FOR A PARTICULAR PURPOSE.\n' - '\n' - 'Written by Jamie Wilkinson and Vasilios Hoffman.') + command_descriptions.append( + " %-21s %.40s" % (name.lower(), short_help.lower()) + ) + + usage += "\n".join(command_descriptions) + version_string = ( + "nsscache " + nss_cache.__version__ + "\n" + "\n" + "Copyright (c) 2007 Google, Inc.\n" + "This is free software; see the source for copying " + "conditions. There is NO\n" + "warranty; not even for MERCHANTABILITY or FITNESS " + "FOR A PARTICULAR PURPOSE.\n" + "\n" + "Written by Jamie Wilkinson and Vasilios Hoffman." + ) parser = optparse.OptionParser(usage, version=version_string) @@ -150,19 +162,19 @@ def _GetParser(self): # Add options. parser.set_defaults(verbose=False, debug=False) - parser.add_option('-v', - '--verbose', - action='store_true', - help='enable verbose output') - parser.add_option('-d', - '--debug', - action='store_true', - help='enable debugging output') - parser.add_option('-c', - '--config-file', - type='string', - help='read configuration from FILE', - metavar='FILE') + parser.add_option( + "-v", "--verbose", action="store_true", help="enable verbose output" + ) + parser.add_option( + "-d", "--debug", action="store_true", help="enable debugging output" + ) + parser.add_option( + "-c", + "--config-file", + type="string", + help="read configuration from FILE", + metavar="FILE", + ) # filthy monkeypatch hack to remove the prepended 'usage: ' # TODO(jaq): we really ought to subclass OptionParser instead... @@ -214,22 +226,21 @@ def Run(self, args, env): if options.config_file: conf.config_file = options.config_file - self.log.info('using nss_cache library, version %s', - nss_cache.__version__) - self.log.debug('library path is %r', nss_cache.__file__) + self.log.info("using nss_cache library, version %s", nss_cache.__version__) + self.log.debug("library path is %r", nss_cache.__file__) # Identify the command to dispatch. if not args: - print('No command given') + print("No command given") self.parser.print_help() return os.EX_USAGE # print global help if command is 'help' with no argument - if len(args) == 1 and args[0] == 'help': + if len(args) == 1 and args[0] == "help": self.parser.print_help() return os.EX_OK - self.log.debug('args: %r' % args) + self.log.debug("args: %r" % args) command_name = args.pop(0) - self.log.debug('command: %r' % command_name) + self.log.debug("command: %r" % command_name) # Load the configuration from file. config.LoadConfig(conf) @@ -238,15 +249,15 @@ def Run(self, args, env): try: command_callable = getattr(command, command_name.capitalize()) except AttributeError: - self.log.warning('%s is not implemented', command_name) - print(('command %r is not implemented' % command_name)) + self.log.warning("%s is not implemented", command_name) + print(("command %r is not implemented" % command_name)) self.parser.print_help() return os.EX_SOFTWARE try: retval = command_callable().Run(conf=conf, args=args) except error.SourceUnavailable as e: - self.log.error('Problem with configured data source: %s', e) + self.log.error("Problem with configured data source: %s", e) return os.EX_TEMPFAIL return retval diff --git a/nss_cache/app_test.py b/nss_cache/app_test.py index bd2a5f74..f5f793c6 100644 --- a/nss_cache/app_test.py +++ b/nss_cache/app_test.py @@ -15,7 +15,7 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Unit tests for nss_cache/app.py.""" -__author__ = 'vasilios@google.com (Vasilios Hoffman)' +__author__ = "vasilios@google.com (Vasilios Hoffman)" import logging import io @@ -33,9 +33,8 @@ def setUp(self): dev_null = io.StringIO() self.stdout = sys.stdout sys.stdout = dev_null - self.srcdir = os.path.normpath( - os.path.join(os.path.dirname(__file__), '..')) - self.conf_filename = os.path.join(self.srcdir, 'nsscache.conf') + self.srcdir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..")) + self.conf_filename = os.path.join(self.srcdir, "nsscache.conf") def tearDown(self): sys.stdout = self.stdout @@ -46,42 +45,42 @@ def testRun(self): def testParseGlobalOptions(self): a = app.NssCacheApp() - (options, args) = a.parser.parse_args(['-d', '-v', 'command']) + (options, args) = a.parser.parse_args(["-d", "-v", "command"]) self.assertNotEqual(None, options.debug) self.assertNotEqual(None, options.verbose) - self.assertEqual(['command'], args) + self.assertEqual(["command"], args) def testParseCommandLineDebug(self): a = app.NssCacheApp() - (options, args) = a.parser.parse_args(['-d']) + (options, args) = a.parser.parse_args(["-d"]) self.assertNotEqual(None, options.debug) - (options, args) = a.parser.parse_args(['--debug']) + (options, args) = a.parser.parse_args(["--debug"]) self.assertNotEqual(None, options.debug) - a.Run(['-d'], {}) + a.Run(["-d"], {}) self.assertEqual(logging.DEBUG, a.log.getEffectiveLevel()) def testParseCommandLineVerbose(self): a = app.NssCacheApp() - (options, args) = a.parser.parse_args(['-v']) + (options, args) = a.parser.parse_args(["-v"]) self.assertNotEqual(None, options.verbose) self.assertEqual([], args) - (options, args) = a.parser.parse_args(['--verbose']) + (options, args) = a.parser.parse_args(["--verbose"]) self.assertNotEqual(None, options.verbose) self.assertEqual([], args) - a.Run(['-v'], {}) + a.Run(["-v"], {}) self.assertEqual(logging.INFO, a.log.getEffectiveLevel()) def testParseCommandLineVerboseDebug(self): a = app.NssCacheApp() - a.Run(['-v', '-d'], {}) + a.Run(["-v", "-d"], {}) self.assertEqual(logging.DEBUG, a.log.getEffectiveLevel()) def testParseCommandLineConfigFile(self): a = app.NssCacheApp() - (options, args) = a.parser.parse_args(['-c', 'file']) + (options, args) = a.parser.parse_args(["-c", "file"]) self.assertNotEqual(None, options.config_file) self.assertEqual([], args) - (options, args) = a.parser.parse_args(['--config-file', 'file']) + (options, args) = a.parser.parse_args(["--config-file", "file"]) self.assertNotEqual(None, options.config_file) self.assertEqual([], args) @@ -90,7 +89,7 @@ def testBadOptionsCauseNoExit(self): stderr_buffer = io.StringIO() old_stderr = sys.stderr sys.stderr = stderr_buffer - self.assertEqual(2, a.Run(['--invalid'], {})) + self.assertEqual(2, a.Run(["--invalid"], {})) sys.stderr = old_stderr def testHelpOptionPrintsGlobalHelp(self): @@ -98,17 +97,16 @@ def testHelpOptionPrintsGlobalHelp(self): a = app.NssCacheApp() old_stdout = sys.stdout sys.stdout = stdout_buffer - self.assertEqual(0, a.Run(['--help'], {})) + self.assertEqual(0, a.Run(["--help"], {})) sys.stdout = old_stdout self.assertNotEqual(0, stdout_buffer.tell()) - (prelude, usage, commands, - options) = stdout_buffer.getvalue().split('\n\n') - self.assertTrue(prelude.startswith('nsscache synchronises')) - expected_str = 'Usage: nsscache [global options] command [command options]' + (prelude, usage, commands, options) = stdout_buffer.getvalue().split("\n\n") + self.assertTrue(prelude.startswith("nsscache synchronises")) + expected_str = "Usage: nsscache [global options] command [command options]" self.assertEqual(expected_str, usage) - self.assertTrue(commands.startswith('commands:')) - self.assertTrue(options.startswith('Options:')) - self.assertTrue(options.find('show this help message and exit') >= 0) + self.assertTrue(commands.startswith("commands:")) + self.assertTrue(options.startswith("Options:")) + self.assertTrue(options.find("show this help message and exit") >= 0) def testHelpCommandOutput(self): # trap stdout into a StringIO @@ -116,11 +114,10 @@ def testHelpCommandOutput(self): a = app.NssCacheApp() old_stdout = sys.stdout sys.stdout = stdout_buffer - self.assertEqual(0, a.Run(['help'], {})) + self.assertEqual(0, a.Run(["help"], {})) sys.stdout = old_stdout self.assertNotEqual(0, stdout_buffer.tell()) - self.assertTrue( - stdout_buffer.getvalue().find('nsscache synchronises') >= 0) + self.assertTrue(stdout_buffer.getvalue().find("nsscache synchronises") >= 0) def testRunBadArgsPrintsGlobalHelp(self): # trap stdout into a StringIO @@ -129,11 +126,12 @@ def testRunBadArgsPrintsGlobalHelp(self): sys.stdout = stdout_buffer # verify bad arguments calls help return_code = app.NssCacheApp().Run( - ['blarg'], {'NSSCACHE_CONFIG': self.conf_filename}) + ["blarg"], {"NSSCACHE_CONFIG": self.conf_filename} + ) sys.stdout = old_stdout assert return_code == 70 # EX_SOFTWARE - assert stdout_buffer.getvalue().find('enable debugging') >= 0 + assert stdout_buffer.getvalue().find("enable debugging") >= 0 -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/nss_cache/caches/cache_factory.py b/nss_cache/caches/cache_factory.py index 2a7e99e0..8f49bdfc 100644 --- a/nss_cache/caches/cache_factory.py +++ b/nss_cache/caches/cache_factory.py @@ -20,7 +20,7 @@ reliable. """ -__author__ = 'springer@google.com (Matthew Springer)' +__author__ = "springer@google.com (Matthew Springer)" import logging @@ -44,7 +44,7 @@ def RegisterImplementation(cache_name, map_name, cache): """ global _cache_implementations if cache_name not in _cache_implementations: - logging.info('Registering [%s] cache for [%s].', cache_name, map_name) + logging.info("Registering [%s] cache for [%s].", cache_name, map_name) _cache_implementations[cache_name] = {} _cache_implementations[cache_name][map_name] = cache @@ -67,17 +67,17 @@ def Create(conf, map_name, automount_mountpoint=None): """ global _cache_implementations if not _cache_implementations: - raise RuntimeError('no cache implementations exist') - cache_name = conf['name'] + raise RuntimeError("no cache implementations exist") + cache_name = conf["name"] if cache_name not in _cache_implementations: - raise RuntimeError('cache not implemented: %r' % (cache_name,)) + raise RuntimeError("cache not implemented: %r" % (cache_name,)) if map_name not in _cache_implementations[cache_name]: - raise RuntimeError('map %r not supported by cache %r' % - (map_name, cache_name)) + raise RuntimeError("map %r not supported by cache %r" % (map_name, cache_name)) return _cache_implementations[cache_name][map_name]( - conf, map_name, automount_mountpoint=automount_mountpoint) + conf, map_name, automount_mountpoint=automount_mountpoint + ) files.RegisterAllImplementations(RegisterImplementation) diff --git a/nss_cache/caches/cache_factory_test.py b/nss_cache/caches/cache_factory_test.py index 261dba34..c26920d9 100644 --- a/nss_cache/caches/cache_factory_test.py +++ b/nss_cache/caches/cache_factory_test.py @@ -15,7 +15,7 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Unit tests for out cache factory.""" -__author__ = 'springer@google.com (Matthew Springer)' +__author__ = "springer@google.com (Matthew Springer)" import unittest @@ -24,30 +24,29 @@ class TestCacheFactory(unittest.TestCase): - def testRegister(self): - class DummyCache(caches.Cache): pass old_cache_implementations = cache_factory._cache_implementations cache_factory._cache_implementations = {} - cache_factory.RegisterImplementation('dummy', 'dummy', DummyCache) + cache_factory.RegisterImplementation("dummy", "dummy", DummyCache) self.assertEqual(1, len(cache_factory._cache_implementations)) - self.assertEqual(1, len(cache_factory._cache_implementations['dummy'])) - self.assertEqual(DummyCache, - cache_factory._cache_implementations['dummy']['dummy']) + self.assertEqual(1, len(cache_factory._cache_implementations["dummy"])) + self.assertEqual( + DummyCache, cache_factory._cache_implementations["dummy"]["dummy"] + ) cache_factory._cache_implementations = old_cache_implementations def testCreateWithNoImplementations(self): old_cache_implementations = cache_factory._cache_implementations cache_factory._cache_implementations = {} - self.assertRaises(RuntimeError, cache_factory.Create, {}, 'map_name') + self.assertRaises(RuntimeError, cache_factory.Create, {}, "map_name") cache_factory._cache_implementations = old_cache_implementations def testThatRegularImplementationsArePresent(self): self.assertEqual(len(cache_factory._cache_implementations), 1) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/nss_cache/caches/caches.py b/nss_cache/caches/caches.py index 16604fef..34bfb7bc 100644 --- a/nss_cache/caches/caches.py +++ b/nss_cache/caches/caches.py @@ -15,7 +15,7 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Base class of cache for nsscache.""" -__author__ = 'jaq@google.com (Jamie Wilkinson)' +__author__ = "jaq@google.com (Jamie Wilkinson)" import errno import logging @@ -70,7 +70,7 @@ def __init__(self, conf, map_name, automount_mountpoint=None): self.log = logging.getLogger(__name__) # Store config info self.conf = conf - self.output_dir = conf.get('dir', '.') + self.output_dir = conf.get("dir", ".") self.automount_mountpoint = automount_mountpoint self.map_name = map_name @@ -88,31 +88,35 @@ def __init__(self, conf, map_name, automount_mountpoint=None): elif map_name == config.MAP_AUTOMOUNT: self.data = automount.AutomountMap() else: - raise error.UnsupportedMap('Cache does not support %s' % map_name) + raise error.UnsupportedMap("Cache does not support %s" % map_name) def _Begin(self): """Start a write transaction.""" - self.log.debug('Output dir: %s', self.output_dir) - self.log.debug('CWD: %s', os.getcwd()) + self.log.debug("Output dir: %s", self.output_dir) + self.log.debug("CWD: %s", os.getcwd()) try: self.temp_cache_file = tempfile.NamedTemporaryFile( delete=False, - prefix='nsscache-cache-file-', - dir=os.path.join(os.getcwd(), self.output_dir)) + prefix="nsscache-cache-file-", + dir=os.path.join(os.getcwd(), self.output_dir), + ) self.temp_cache_filename = self.temp_cache_file.name - self.log.debug('opened temporary cache filename %r', - self.temp_cache_filename) + self.log.debug( + "opened temporary cache filename %r", self.temp_cache_filename + ) except OSError as e: if e.errno == errno.EACCES: self.log.info( - 'Got OSError (%s) when trying to create temporary file', e) - raise error.PermissionDenied('OSError: ' + str(e)) + "Got OSError (%s) when trying to create temporary file", e + ) + raise error.PermissionDenied("OSError: " + str(e)) raise def _Rollback(self): """Rollback a write transaction.""" - self.log.debug('rolling back, deleting temp cache file %r', - self.temp_cache_filename) + self.log.debug( + "rolling back, deleting temp cache file %r", self.temp_cache_filename + ) self.temp_cache_file.close() # Safe file remove (ignore "no such file or directory" errors): try: @@ -140,7 +144,7 @@ def _Commit(self): os.fsync(self.temp_cache_file.fileno()) self.temp_cache_file.close() else: - self.log.debug('temp cache file was already closed before Commit') + self.log.debug("temp cache file was already closed before Commit") # We emulate the permissions of our source map to avoid bugs where # permissions may differ (usually w/shadow map) # Catch the case where the source file may not exist for some reason and @@ -153,15 +157,21 @@ def _Commit(self): os.chown(self.temp_cache_filename, uid, gid) except OSError as e: if e.errno == errno.ENOENT: - if self.map_name == 'sshkey': - os.chmod(self.temp_cache_filename, - stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH) + if self.map_name == "sshkey": + os.chmod( + self.temp_cache_filename, + stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH, + ) else: os.chmod( - self.temp_cache_filename, stat.S_IRUSR | stat.S_IWUSR | - stat.S_IRGRP | stat.S_IROTH) - self.log.debug('committing temporary cache file %r to %r', - self.temp_cache_filename, self.GetCacheFilename()) + self.temp_cache_filename, + stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IROTH, + ) + self.log.debug( + "committing temporary cache file %r to %r", + self.temp_cache_filename, + self.GetCacheFilename(), + ) os.rename(self.temp_cache_filename, self.GetCacheFilename()) return True @@ -172,7 +182,7 @@ def GetCacheFilename(self): def GetCompatFilename(self): """Return the filename where the normal (not-cache) map would be.""" # TODO(jaq): Probably shouldn't hard code '/etc' here. - return os.path.join('/etc', self.map_name) + return os.path.join("/etc", self.map_name) def GetMap(self, cache_filename=None): """Returns the map from the cache. @@ -185,8 +195,9 @@ def GetMap(self, cache_filename=None): Raises: NotImplementedError: We should have been implemented by child. """ - raise NotImplementedError('%s must implement this method!' % - self.__class__.__name__) + raise NotImplementedError( + "%s must implement this method!" % self.__class__.__name__ + ) def GetMapLocation(self): """Return the location of the Map in this cache. @@ -197,8 +208,9 @@ def GetMapLocation(self): Raises: NotImplementedError: We should have been implemented by child. """ - raise NotImplementedError('%s must implement this method!' % - self.__class__.__name__) + raise NotImplementedError( + "%s must implement this method!" % self.__class__.__name__ + ) def WriteMap(self, map_data=None, force_write=False): """Write a map to disk. @@ -220,10 +232,10 @@ def WriteMap(self, map_data=None, force_write=False): # N.B. Write is destructive, len(writable_map) == 0 now. # Asserting this isn't good for the unit tests, though. - #assert 0 == len(writable_map), "self.Write should be destructive." + # assert 0 == len(writable_map), "self.Write should be destructive." if entries_written is None: - self.log.warning('cache write failed, exiting') + self.log.warning("cache write failed, exiting") return 1 if force_write or self.Verify(entries_written): @@ -234,7 +246,7 @@ def WriteMap(self, map_data=None, force_write=False): self.WriteIndex() return 0 - self.log.warning('verification failed, exiting') + self.log.warning("verification failed, exiting") return 1 def WriteIndex(self): diff --git a/nss_cache/caches/caches_test.py b/nss_cache/caches/caches_test.py index 85c72870..952d10cf 100644 --- a/nss_cache/caches/caches_test.py +++ b/nss_cache/caches/caches_test.py @@ -15,7 +15,7 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Unit tests for caches/caches.py.""" -__author__ = 'jaq@google.com (Jamie Wilkinson)' +__author__ = "jaq@google.com (Jamie Wilkinson)" import os import platform @@ -30,7 +30,7 @@ class FakeCacheCls(caches.Cache): - CACHE_FILENAME = 'shadow' + CACHE_FILENAME = "shadow" def __init__(self, config, map_name): super(FakeCacheCls, self).__init__(config, map_name) @@ -39,15 +39,14 @@ def Write(self, map_data): return 0 def GetCacheFilename(self): - return os.path.join(self.output_dir, self.CACHE_FILENAME + '.test') + return os.path.join(self.output_dir, self.CACHE_FILENAME + ".test") class TestCls(unittest.TestCase): - def setUp(self): self.workdir = tempfile.mkdtemp() - self.config = {'dir': self.workdir} - if platform.system() == 'FreeBSD': + self.config = {"dir": self.workdir} + if platform.system() == "FreeBSD": # FreeBSD doesn't have a shadow file self.shadow = config.MAP_PASSWORD else: @@ -57,7 +56,7 @@ def tearDown(self): os.rmdir(self.workdir) def testCopyOwnerMissing(self): - expected = os.stat(os.path.join('/etc', self.shadow)) + expected = os.stat(os.path.join("/etc", self.shadow)) expected = stat.S_IMODE(expected.st_mode) cache = FakeCacheCls(config=self.config, map_name=self.shadow) cache._Begin() @@ -67,7 +66,7 @@ def testCopyOwnerMissing(self): os.unlink(cache.GetCacheFilename()) def testCopyOwnerPresent(self): - expected = os.stat(os.path.join('/etc/', self.shadow)) + expected = os.stat(os.path.join("/etc/", self.shadow)) expected = stat.S_IMODE(expected.st_mode) cache = FakeCacheCls(config=self.config, map_name=self.shadow) cache._Begin() @@ -78,17 +77,15 @@ def testCopyOwnerPresent(self): class TestCache(unittest.TestCase): - def testWriteMap(self): cache_map = caches.Cache({}, config.MAP_PASSWORD, None) - with mock.patch.object(cache_map, 'Write') as write, mock.patch.object( - cache_map, - 'Verify') as verify, mock.patch.object(cache_map, - '_Commit') as commit: - write.return_value = 'entries_written' + with mock.patch.object(cache_map, "Write") as write, mock.patch.object( + cache_map, "Verify" + ) as verify, mock.patch.object(cache_map, "_Commit") as commit: + write.return_value = "entries_written" verify.return_value = True - self.assertEqual(0, cache_map.WriteMap('writable_map')) + self.assertEqual(0, cache_map.WriteMap("writable_map")) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/nss_cache/caches/files.py b/nss_cache/caches/files.py index ee520158..31ae86c6 100644 --- a/nss_cache/caches/files.py +++ b/nss_cache/caches/files.py @@ -21,8 +21,10 @@ format created here. """ -__author__ = ('jaq@google.com (Jamie Wilkinson)', - 'vasilios@google.com (Vasilios Hoffman)') +__author__ = ( + "jaq@google.com (Jamie Wilkinson)", + "vasilios@google.com (Vasilios Hoffman)", +) import configparser import errno @@ -47,28 +49,28 @@ def LongestLength(l): # Load suffix config variables parser = configparser.ConfigParser() for i in sys.argv: - if ('nsscache.conf') in i: + if ("nsscache.conf") in i: # Remove '--config-file=' from the string - if ('--config-file') in i: + if ("--config-file") in i: i = i[14:] parser.read(i) - elif os.path.isfile('/etc/nsscache.conf'): - parser.read('/etc/nsscache.conf') + elif os.path.isfile("/etc/nsscache.conf"): + parser.read("/etc/nsscache.conf") else: # Config in nsscache folder - parser.read('nsscache.conf') -prefix = parser.get('suffix', 'prefix', fallback='') -suffix = parser.get('suffix', 'suffix', fallback='') + parser.read("nsscache.conf") +prefix = parser.get("suffix", "prefix", fallback="") +suffix = parser.get("suffix", "suffix", fallback="") def RegisterAllImplementations(register_callback): """Register our cache classes independently from the import scheme.""" - register_callback('files', 'passwd', FilesPasswdMapHandler) - register_callback('files', 'sshkey', FilesSshkeyMapHandler) - register_callback('files', 'group', FilesGroupMapHandler) - register_callback('files', 'shadow', FilesShadowMapHandler) - register_callback('files', 'netgroup', FilesNetgroupMapHandler) - register_callback('files', 'automount', FilesAutomountMapHandler) + register_callback("files", "passwd", FilesPasswdMapHandler) + register_callback("files", "sshkey", FilesSshkeyMapHandler) + register_callback("files", "group", FilesGroupMapHandler) + register_callback("files", "shadow", FilesShadowMapHandler) + register_callback("files", "netgroup", FilesNetgroupMapHandler) + register_callback("files", "automount", FilesAutomountMapHandler) class FilesCache(caches.Cache): @@ -91,17 +93,16 @@ def __init__(self, conf, map_name, automount_mountpoint=None): automount_mountpoint: A string containing the automount mountpoint, used only by automount maps. """ - super(FilesCache, - self).__init__(conf, - map_name, - automount_mountpoint=automount_mountpoint) + super(FilesCache, self).__init__( + conf, map_name, automount_mountpoint=automount_mountpoint + ) # Documented in nsscache.conf example. - self.cache_filename_suffix = conf.get('cache_filename_suffix', 'cache') + self.cache_filename_suffix = conf.get("cache_filename_suffix", "cache") # Store a dict of indexes, each containing a dict of keys to line, position # tuples. self._indices = {} - if hasattr(self, '_INDEX_ATTRIBUTES'): + if hasattr(self, "_INDEX_ATTRIBUTES"): for index in self._INDEX_ATTRIBUTES: self._indices[index] = {} @@ -121,10 +122,9 @@ def GetMap(self, cache_filename=None): if cache_filename is None: cache_filename = self.GetCacheFilename() - self.log.debug('Opening %r for reading existing cache', cache_filename) + self.log.debug("Opening %r for reading existing cache", cache_filename) if not os.path.exists(cache_filename): - self.log.warning( - 'Cache file does not exist, using an empty map instead') + self.log.warning("Cache file does not exist, using an empty map instead") else: cache_file = open(cache_filename) data = self.map_parser.GetMap(cache_file, data) @@ -146,18 +146,20 @@ def Verify(self, written_keys): Raises: EmptyMap: The cache being verified is empty. """ - self.log.debug('verification starting on %r', self.temp_cache_filename) + self.log.debug("verification starting on %r", self.temp_cache_filename) cache_data = self.GetMap(self.temp_cache_filename) map_entry_count = len(cache_data) - self.log.debug('entry count: %d', map_entry_count) + self.log.debug("entry count: %d", map_entry_count) if map_entry_count <= 0: # We have read in an empty map, yet we expect that earlier we # should have written more. Uncaught disk full or other error? - self.log.error('The files cache being verified "%r" is empty.', - self.temp_cache_filename) - raise error.EmptyMap(self.temp_cache_filename + ' is empty') + self.log.error( + 'The files cache being verified "%r" is empty.', + self.temp_cache_filename, + ) + raise error.EmptyMap(self.temp_cache_filename + " is empty") cache_keys = set() # Use PopItem() so we free our memory if multiple maps are Verify()ed. @@ -171,27 +173,31 @@ def Verify(self, written_keys): missing_from_cache = written_keys - cache_keys if missing_from_cache: - self.log.warning('verify failed: %d missing from the on-disk cache', - len(missing_from_cache)) + self.log.warning( + "verify failed: %d missing from the on-disk cache", + len(missing_from_cache), + ) if len(missing_from_cache) < 1000: - self.log.debug('keys missing from the on-disk cache: %r', - missing_from_cache) + self.log.debug( + "keys missing from the on-disk cache: %r", missing_from_cache + ) else: - self.log.debug('More than 1000 keys missing from cache. ' - 'Not printing.') + self.log.debug( + "More than 1000 keys missing from cache. " "Not printing." + ) self._Rollback() return False missing_from_map = cache_keys - written_keys if missing_from_map: self.log.warning( - 'verify failed: %d keys found, unexpected in the on-disk ' - 'cache', len(missing_from_map)) + "verify failed: %d keys found, unexpected in the on-disk " "cache", + len(missing_from_map), + ) if len(missing_from_map) < 1000: - self.log.debug('keys missing from map: %r', missing_from_map) + self.log.debug("keys missing from map: %r", missing_from_map) else: - self.log.debug( - 'More than 1000 keys missing from map. Not printing.') + self.log.debug("More than 1000 keys missing from map. Not printing.") self._Rollback() return False @@ -218,8 +224,7 @@ def Write(self, map_data): while 1: entry = map_data.PopItem() for index in self._indices: - self._indices[index][str(getattr( - entry, index))] = str(write_offset) + self._indices[index][str(getattr(entry, index))] = str(write_offset) write_offset += self._WriteData(self.temp_cache_file, entry) written_keys.update(self._ExpectedKeysForEntry(entry)) except KeyError: @@ -235,23 +240,22 @@ def GetCacheFilename(self): """Return the final destination pathname of the cache file.""" cache_filename_target = self.CACHE_FILENAME if self.cache_filename_suffix: - cache_filename_target += '.' + self.cache_filename_suffix + cache_filename_target += "." + self.cache_filename_suffix return os.path.join(self.output_dir, cache_filename_target) def WriteIndex(self): """Generate an index for libnss-cache from this map.""" for index_name in self._indices: # index file write to tmp file first, magic string ".ix" - tmp_index_filename = '%s.ix%s.tmp' % (self.GetCacheFilename(), - index_name) - self.log.debug('Writing index %s', tmp_index_filename) + tmp_index_filename = "%s.ix%s.tmp" % (self.GetCacheFilename(), index_name) + self.log.debug("Writing index %s", tmp_index_filename) index = self._indices[index_name] key_length = LongestLength(list(index.keys())) pos_length = LongestLength(list(index.values())) max_length = key_length + pos_length # Open for write/truncate - index_file = open(tmp_index_filename, 'w') + index_file = open(tmp_index_filename, "w") # setup permissions try: shutil.copymode(self.GetCompatFilename(), tmp_index_filename) @@ -262,36 +266,38 @@ def WriteIndex(self): except OSError as e: if e.errno == errno.ENOENT: os.chmod( - tmp_index_filename, stat.S_IRUSR | stat.S_IWUSR | - stat.S_IRGRP | stat.S_IROTH) + tmp_index_filename, + stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IROTH, + ) for key in sorted(index): pos = index[key] - index_line = ('%s\0%s\0%s\n' % - (key, pos, '\0' * - (max_length - len(key) - len(pos)))) + index_line = "%s\0%s\0%s\n" % ( + key, + pos, + "\0" * (max_length - len(key) - len(pos)), + ) index_file.write(index_line) index_file.close() for index_name in self._indices: # rename tmp index file to target index file in order to # prevent getting user info fail during update index. - tmp_index_filename = '%s.ix%s.tmp' % (self.GetCacheFilename(), - index_name) - index_filename = '%s.ix%s' % (self.GetCacheFilename(), index_name) + tmp_index_filename = "%s.ix%s.tmp" % (self.GetCacheFilename(), index_name) + index_filename = "%s.ix%s" % (self.GetCacheFilename(), index_name) os.rename(tmp_index_filename, index_filename) class FilesSshkeyMapHandler(FilesCache): """Concrete class for updating a nss_files module sshkey cache.""" - CACHE_FILENAME = 'sshkey' - _INDEX_ATTRIBUTES = ('name',) + + CACHE_FILENAME = "sshkey" + _INDEX_ATTRIBUTES = ("name",) def __init__(self, conf, map_name=None, automount_mountpoint=None): if map_name is None: map_name = config.MAP_SSHKEY - super(FilesSshkeyMapHandler, - self).__init__(conf, - map_name, - automount_mountpoint=automount_mountpoint) + super(FilesSshkeyMapHandler, self).__init__( + conf, map_name, automount_mountpoint=automount_mountpoint + ) self.map_parser = file_formats.FilesSshkeyMapParser() def _ExpectedKeysForEntry(self, entry): @@ -315,23 +321,23 @@ def _WriteData(self, target, entry): Returns: Number of bytes written to the target. """ - sshkey_entry = '%s:%s' % (entry.name, entry.sshkey) - target.write(sshkey_entry.encode() + b'\n') + sshkey_entry = "%s:%s" % (entry.name, entry.sshkey) + target.write(sshkey_entry.encode() + b"\n") return len(sshkey_entry) + 1 class FilesPasswdMapHandler(FilesCache): """Concrete class for updating a nss_files module passwd cache.""" - CACHE_FILENAME = 'passwd' - _INDEX_ATTRIBUTES = ('name', 'uid') + + CACHE_FILENAME = "passwd" + _INDEX_ATTRIBUTES = ("name", "uid") def __init__(self, conf, map_name=None, automount_mountpoint=None): if map_name is None: map_name = config.MAP_PASSWORD - super(FilesPasswdMapHandler, - self).__init__(conf, - map_name, - automount_mountpoint=automount_mountpoint) + super(FilesPasswdMapHandler, self).__init__( + conf, map_name, automount_mountpoint=automount_mountpoint + ) self.map_parser = file_formats.FilesPasswdMapParser() def _ExpectedKeysForEntry(self, entry): @@ -355,25 +361,31 @@ def _WriteData(self, target, entry): Returns: Number of bytes written to the target. """ - password_entry = '%s:%s:%d:%d:%s:%s:%s' % ( - entry.name, entry.passwd, entry.uid, entry.gid, entry.gecos, - entry.dir, entry.shell) - target.write(password_entry.encode() + b'\n') + password_entry = "%s:%s:%d:%d:%s:%s:%s" % ( + entry.name, + entry.passwd, + entry.uid, + entry.gid, + entry.gecos, + entry.dir, + entry.shell, + ) + target.write(password_entry.encode() + b"\n") return len(password_entry) + 1 class FilesGroupMapHandler(FilesCache): """Concrete class for updating a nss_files module group cache.""" - CACHE_FILENAME = 'group' - _INDEX_ATTRIBUTES = ('name', 'gid') + + CACHE_FILENAME = "group" + _INDEX_ATTRIBUTES = ("name", "gid") def __init__(self, conf, map_name=None, automount_mountpoint=None): if map_name is None: map_name = config.MAP_GROUP - super(FilesGroupMapHandler, - self).__init__(conf, - map_name, - automount_mountpoint=automount_mountpoint) + super(FilesGroupMapHandler, self).__init__( + conf, map_name, automount_mountpoint=automount_mountpoint + ) self.map_parser = file_formats.FilesGroupMapParser() def _ExpectedKeysForEntry(self, entry): @@ -389,24 +401,28 @@ def _ExpectedKeysForEntry(self, entry): def _WriteData(self, target, entry): """Write a GroupMapEntry to the target cache.""" - group_entry = '%s:%s:%d:%s' % (entry.name, entry.passwd, entry.gid, - ','.join(entry.members)) - target.write(group_entry.encode() + b'\n') + group_entry = "%s:%s:%d:%s" % ( + entry.name, + entry.passwd, + entry.gid, + ",".join(entry.members), + ) + target.write(group_entry.encode() + b"\n") return len(group_entry) + 1 class FilesShadowMapHandler(FilesCache): """Concrete class for updating a nss_files module shadow cache.""" - CACHE_FILENAME = 'shadow' - _INDEX_ATTRIBUTES = ('name',) + + CACHE_FILENAME = "shadow" + _INDEX_ATTRIBUTES = ("name",) def __init__(self, conf, map_name=None, automount_mountpoint=None): if map_name is None: map_name = config.MAP_SHADOW - super(FilesShadowMapHandler, - self).__init__(conf, - map_name, - automount_mountpoint=automount_mountpoint) + super(FilesShadowMapHandler, self).__init__( + conf, map_name, automount_mountpoint=automount_mountpoint + ) self.map_parser = file_formats.FilesShadowMapParser() def _ExpectedKeysForEntry(self, entry): @@ -422,26 +438,33 @@ def _ExpectedKeysForEntry(self, entry): def _WriteData(self, target, entry): """Write a ShadowMapEntry to the target cache.""" - shadow_entry = '%s:%s:%s:%s:%s:%s:%s:%s:%s' % ( - entry.name, entry.passwd, entry.lstchg or '', entry.min or - '', entry.max or '', entry.warn or '', entry.inact or - '', entry.expire or '', entry.flag or '') - target.write(shadow_entry.encode() + b'\n') + shadow_entry = "%s:%s:%s:%s:%s:%s:%s:%s:%s" % ( + entry.name, + entry.passwd, + entry.lstchg or "", + entry.min or "", + entry.max or "", + entry.warn or "", + entry.inact or "", + entry.expire or "", + entry.flag or "", + ) + target.write(shadow_entry.encode() + b"\n") return len(shadow_entry) + 1 class FilesNetgroupMapHandler(FilesCache): """Concrete class for updating a nss_files module netgroup cache.""" - CACHE_FILENAME = 'netgroup' - _TUPLE_RE = re.compile(r'^\((.*?),(.*?),(.*?)\)$') # Do this only once. + + CACHE_FILENAME = "netgroup" + _TUPLE_RE = re.compile(r"^\((.*?),(.*?),(.*?)\)$") # Do this only once. def __init__(self, conf, map_name=None, automount_mountpoint=None): if map_name is None: map_name = config.MAP_NETGROUP - super(FilesNetgroupMapHandler, - self).__init__(conf, - map_name, - automount_mountpoint=automount_mountpoint) + super(FilesNetgroupMapHandler, self).__init__( + conf, map_name, automount_mountpoint=automount_mountpoint + ) self.map_parser = file_formats.FilesNetgroupMapParser() def _ExpectedKeysForEntry(self, entry): @@ -458,34 +481,33 @@ def _ExpectedKeysForEntry(self, entry): def _WriteData(self, target, entry): """Write a NetgroupMapEntry to the target cache.""" if entry.entries: - netgroup_entry = '%s %s' % (entry.name, entry.entries) + netgroup_entry = "%s %s" % (entry.name, entry.entries) else: netgroup_entry = entry.name - target.write(netgroup_entry.encode() + b'\n') + target.write(netgroup_entry.encode() + b"\n") return len(netgroup_entry) + 1 class FilesAutomountMapHandler(FilesCache): """Concrete class for updating a nss_files module automount cache.""" + CACHE_FILENAME = None # we have multiple files, set as we update. def __init__(self, conf, map_name=None, automount_mountpoint=None): if map_name is None: map_name = config.MAP_AUTOMOUNT - super(FilesAutomountMapHandler, - self).__init__(conf, - map_name, - automount_mountpoint=automount_mountpoint) + super(FilesAutomountMapHandler, self).__init__( + conf, map_name, automount_mountpoint=automount_mountpoint + ) self.map_parser = file_formats.FilesAutomountMapParser() if automount_mountpoint is None: # we are dealing with the master map - self.CACHE_FILENAME = 'auto.master' + self.CACHE_FILENAME = "auto.master" else: # turn /auto into auto.auto, and /usr/local into /auto.usr_local - automount_mountpoint = automount_mountpoint.lstrip('/') - self.CACHE_FILENAME = 'auto.%s' % automount_mountpoint.replace( - '/', '_') + automount_mountpoint = automount_mountpoint.lstrip("/") + self.CACHE_FILENAME = "auto.%s" % automount_mountpoint.replace("/", "_") def _ExpectedKeysForEntry(self, entry): """Generate a list of expected cache keys for this type of map. @@ -503,16 +525,15 @@ def _WriteData(self, target, entry): # Modify suffix after mountpoint for autofs pattern = re.compile(prefix) if entry.options is not None: - if prefix != '': - if (pattern.match(entry.location)): # Found string with regex - entry.location = re.sub(r'({0})'.format(prefix), - r'{0}'.format(suffix), - entry.location) - automount_entry = '%s %s %s' % (entry.key, entry.options, - entry.location) + if prefix != "": + if pattern.match(entry.location): # Found string with regex + entry.location = re.sub( + r"({0})".format(prefix), r"{0}".format(suffix), entry.location + ) + automount_entry = "%s %s %s" % (entry.key, entry.options, entry.location) else: - automount_entry = '%s %s' % (entry.key, entry.location) - target.write(automount_entry.encode() + b'\n') + automount_entry = "%s %s" % (entry.key, entry.location) + target.write(automount_entry.encode() + b"\n") return len(automount_entry) + 1 def GetMapLocation(self): diff --git a/nss_cache/caches/files_test.py b/nss_cache/caches/files_test.py index 0e4c9856..c5302254 100644 --- a/nss_cache/caches/files_test.py +++ b/nss_cache/caches/files_test.py @@ -15,8 +15,10 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Unit tests for nss_cache/caches/files.py.""" -__author__ = ('jaq@google.com (Jamie Wilkinson)', - 'vasilios@google.com (Vasilios Hoffman)') +__author__ = ( + "jaq@google.com (Jamie Wilkinson)", + "vasilios@google.com (Vasilios Hoffman)", +) import os import shutil @@ -35,11 +37,10 @@ class TestFilesCache(unittest.TestCase): - def setUp(self): super(TestFilesCache, self).setUp() self.workdir = tempfile.mkdtemp() - self.config = {'dir': self.workdir} + self.config = {"dir": self.workdir} def tearDown(self): super(TestFilesCache, self).tearDown() @@ -51,28 +52,28 @@ def testInstantiation(self): def testWrite(self): cache = files.FilesPasswdMapHandler(self.config) - entry = passwd.PasswdMapEntry({'name': 'foo', 'uid': 10, 'gid': 10}) + entry = passwd.PasswdMapEntry({"name": "foo", "uid": 10, "gid": 10}) pmap = passwd.PasswdMap([entry]) written = cache.Write(pmap) - self.assertTrue('foo' in written) + self.assertTrue("foo" in written) self.assertFalse(entry in pmap) # we emptied pmap to avoid mem leaks self.assertFalse(cache.temp_cache_file.closed) def testCacheFilenameSuffixOption(self): - new_config = {'cache_filename_suffix': 'blarg'} + new_config = {"cache_filename_suffix": "blarg"} new_config.update(self.config) cache = files.FilesCache(new_config, config.MAP_PASSWORD) - cache.CACHE_FILENAME = 'test' - self.assertEqual(os.path.join(self.workdir, 'test.blarg'), - cache.GetCacheFilename()) + cache.CACHE_FILENAME = "test" + self.assertEqual( + os.path.join(self.workdir, "test.blarg"), cache.GetCacheFilename() + ) - cache.temp_cache_file = open(os.path.join(self.workdir, 'pre-commit'), - 'w') - cache.temp_cache_file.write('\n') - cache.temp_cache_filename = os.path.join(self.workdir, 'pre-commit') + cache.temp_cache_file = open(os.path.join(self.workdir, "pre-commit"), "w") + cache.temp_cache_file.write("\n") + cache.temp_cache_filename = os.path.join(self.workdir, "pre-commit") cache._Commit() - expected_cache_filename = os.path.join(self.workdir, 'test.blarg') + expected_cache_filename = os.path.join(self.workdir, "test.blarg") self.assertTrue(os.path.exists(expected_cache_filename)) def testWritePasswdEntry(self): @@ -81,17 +82,16 @@ def testWritePasswdEntry(self): file_mock = mock.create_autospec(sys.stdout) map_entry = passwd.PasswdMapEntry() - map_entry.name = 'root' - map_entry.passwd = 'x' + map_entry.name = "root" + map_entry.passwd = "x" map_entry.uid = 0 map_entry.gid = 0 - map_entry.gecos = 'Rootsy' - map_entry.dir = '/root' - map_entry.shell = '/bin/bash' + map_entry.gecos = "Rootsy" + map_entry.dir = "/root" + map_entry.shell = "/bin/bash" cache._WriteData(file_mock, map_entry) - file_mock.write.assert_called_with( - b'root:x:0:0:Rootsy:/root:/bin/bash\n') + file_mock.write.assert_called_with(b"root:x:0:0:Rootsy:/root:/bin/bash\n") def testWriteGroupEntry(self): """We correctly write a typical entry in /etc/group format.""" @@ -99,13 +99,13 @@ def testWriteGroupEntry(self): file_mock = mock.create_autospec(sys.stdout) map_entry = group.GroupMapEntry() - map_entry.name = 'root' - map_entry.passwd = 'x' + map_entry.name = "root" + map_entry.passwd = "x" map_entry.gid = 0 - map_entry.members = ['zero_cool', 'acid_burn'] + map_entry.members = ["zero_cool", "acid_burn"] cache._WriteData(file_mock, map_entry) - file_mock.write.assert_called_with(b'root:x:0:zero_cool,acid_burn\n') + file_mock.write.assert_called_with(b"root:x:0:zero_cool,acid_burn\n") def testWriteShadowEntry(self): """We correctly write a typical entry in /etc/shadow format.""" @@ -113,11 +113,11 @@ def testWriteShadowEntry(self): file_mock = mock.create_autospec(sys.stdout) map_entry = shadow.ShadowMapEntry() - map_entry.name = 'root' - map_entry.passwd = '$1$zomgmd5support' + map_entry.name = "root" + map_entry.passwd = "$1$zomgmd5support" cache._WriteData(file_mock, map_entry) - file_mock.write.assert_called_with(b'root:$1$zomgmd5support:::::::\n') + file_mock.write.assert_called_with(b"root:$1$zomgmd5support:::::::\n") def testWriteNetgroupEntry(self): """We correctly write a typical entry in /etc/netgroup format.""" @@ -125,12 +125,13 @@ def testWriteNetgroupEntry(self): file_mock = mock.create_autospec(sys.stdout) map_entry = netgroup.NetgroupMapEntry() - map_entry.name = 'administrators' - map_entry.entries = 'unix_admins noc_monkeys (-,zero_cool,)' + map_entry.name = "administrators" + map_entry.entries = "unix_admins noc_monkeys (-,zero_cool,)" cache._WriteData(file_mock, map_entry) file_mock.write.assert_called_with( - b'administrators unix_admins noc_monkeys (-,zero_cool,)\n') + b"administrators unix_admins noc_monkeys (-,zero_cool,)\n" + ) def testWriteAutomountEntry(self): """We correctly write a typical entry in /etc/auto.* format.""" @@ -138,129 +139,129 @@ def testWriteAutomountEntry(self): file_mock = mock.create_autospec(sys.stdout) map_entry = automount.AutomountMapEntry() - map_entry.key = 'scratch' - map_entry.options = '-tcp,rw,intr,bg' - map_entry.location = 'fileserver:/scratch' + map_entry.key = "scratch" + map_entry.options = "-tcp,rw,intr,bg" + map_entry.location = "fileserver:/scratch" cache._WriteData(file_mock, map_entry) file_mock.write.assert_called_with( - b'scratch -tcp,rw,intr,bg fileserver:/scratch\n') + b"scratch -tcp,rw,intr,bg fileserver:/scratch\n" + ) file_mock = mock.create_autospec(sys.stdout) map_entry = automount.AutomountMapEntry() - map_entry.key = 'scratch' + map_entry.key = "scratch" map_entry.options = None - map_entry.location = 'fileserver:/scratch' + map_entry.location = "fileserver:/scratch" cache._WriteData(file_mock, map_entry) - file_mock.write.assert_called_with(b'scratch fileserver:/scratch\n') + file_mock.write.assert_called_with(b"scratch fileserver:/scratch\n") def testAutomountSetsFilename(self): """We set the correct filename based on mountpoint information.""" # also tests GetMapLocation() because it uses it :) - conf = {'dir': self.workdir, 'cache_filename_suffix': ''} + conf = {"dir": self.workdir, "cache_filename_suffix": ""} cache = files.FilesAutomountMapHandler(conf) - self.assertEqual(cache.GetMapLocation(), - '%s/auto.master' % self.workdir) + self.assertEqual(cache.GetMapLocation(), "%s/auto.master" % self.workdir) - cache = files.FilesAutomountMapHandler(conf, - automount_mountpoint='/home') - self.assertEqual(cache.GetMapLocation(), '%s/auto.home' % self.workdir) + cache = files.FilesAutomountMapHandler(conf, automount_mountpoint="/home") + self.assertEqual(cache.GetMapLocation(), "%s/auto.home" % self.workdir) - cache = files.FilesAutomountMapHandler(conf, - automount_mountpoint='/usr/meh') - self.assertEqual(cache.GetMapLocation(), - '%s/auto.usr_meh' % self.workdir) + cache = files.FilesAutomountMapHandler(conf, automount_mountpoint="/usr/meh") + self.assertEqual(cache.GetMapLocation(), "%s/auto.usr_meh" % self.workdir) def testCacheFileDoesNotExist(self): """Make sure we just get an empty map rather than exception.""" - conf = {'dir': self.workdir, 'cache_filename_suffix': ''} + conf = {"dir": self.workdir, "cache_filename_suffix": ""} cache = files.FilesAutomountMapHandler(conf) - self.assertFalse( - os.path.exists(os.path.join(self.workdir, 'auto.master'))) + self.assertFalse(os.path.exists(os.path.join(self.workdir, "auto.master"))) data = cache.GetMap() self.assertFalse(data) def testIndexCreation(self): cache = files.FilesPasswdMapHandler(self.config) entries = [ - passwd.PasswdMapEntry(dict(name='foo', uid=10, gid=10)), - passwd.PasswdMapEntry(dict(name='bar', uid=11, gid=11)), - passwd.PasswdMapEntry(dict(name='quux', uid=12, gid=11)), + passwd.PasswdMapEntry(dict(name="foo", uid=10, gid=10)), + passwd.PasswdMapEntry(dict(name="bar", uid=11, gid=11)), + passwd.PasswdMapEntry(dict(name="quux", uid=12, gid=11)), ] pmap = passwd.PasswdMap(entries) cache.Write(pmap) cache.WriteIndex() - index_filename = cache.GetCacheFilename() + '.ixname' - self.assertTrue(os.path.exists(index_filename), - 'Index not created %s' % index_filename) + index_filename = cache.GetCacheFilename() + ".ixname" + self.assertTrue( + os.path.exists(index_filename), "Index not created %s" % index_filename + ) with open(index_filename) as f: - self.assertEqual('bar\x0015\x00\x00\n', f.readline()) - self.assertEqual('foo\x000\x00\x00\x00\n', f.readline()) - self.assertEqual('quux\x0030\x00\n', f.readline()) - - index_filename = cache.GetCacheFilename() + '.ixuid' - self.assertTrue(os.path.exists(index_filename), - 'Index not created %s' % index_filename) + self.assertEqual("bar\x0015\x00\x00\n", f.readline()) + self.assertEqual("foo\x000\x00\x00\x00\n", f.readline()) + self.assertEqual("quux\x0030\x00\n", f.readline()) + + index_filename = cache.GetCacheFilename() + ".ixuid" + self.assertTrue( + os.path.exists(index_filename), "Index not created %s" % index_filename + ) with open(index_filename) as f: - self.assertEqual('10\x000\x00\x00\n', f.readline()) - self.assertEqual('11\x0015\x00\n', f.readline()) - self.assertEqual('12\x0030\x00\n', f.readline()) + self.assertEqual("10\x000\x00\x00\n", f.readline()) + self.assertEqual("11\x0015\x00\n", f.readline()) + self.assertEqual("12\x0030\x00\n", f.readline()) def testWriteCacheAndIndex(self): cache = files.FilesPasswdMapHandler(self.config) entries = [ - passwd.PasswdMapEntry(dict(name='foo', uid=10, gid=10)), - passwd.PasswdMapEntry(dict(name='bar', uid=11, gid=11)), + passwd.PasswdMapEntry(dict(name="foo", uid=10, gid=10)), + passwd.PasswdMapEntry(dict(name="bar", uid=11, gid=11)), ] pmap = passwd.PasswdMap(entries) written = cache.Write(pmap) cache.WriteIndex() - self.assertTrue('foo' in written) - self.assertTrue('bar' in written) - index_filename = cache.GetCacheFilename() + '.ixname' - self.assertTrue(os.path.exists(index_filename), - 'Index not created %s' % index_filename) - index_filename = cache.GetCacheFilename() + '.ixuid' - self.assertTrue(os.path.exists(index_filename), - 'Index not created %s' % index_filename) + self.assertTrue("foo" in written) + self.assertTrue("bar" in written) + index_filename = cache.GetCacheFilename() + ".ixname" + self.assertTrue( + os.path.exists(index_filename), "Index not created %s" % index_filename + ) + index_filename = cache.GetCacheFilename() + ".ixuid" + self.assertTrue( + os.path.exists(index_filename), "Index not created %s" % index_filename + ) entries = [ - passwd.PasswdMapEntry(dict(name='foo', uid=10, gid=10)), - passwd.PasswdMapEntry(dict(name='bar', uid=11, gid=11)), - passwd.PasswdMapEntry(dict(name='quux', uid=12, gid=11)), + passwd.PasswdMapEntry(dict(name="foo", uid=10, gid=10)), + passwd.PasswdMapEntry(dict(name="bar", uid=11, gid=11)), + passwd.PasswdMapEntry(dict(name="quux", uid=12, gid=11)), ] pmap = passwd.PasswdMap(entries) written = cache.Write(pmap) - self.assertTrue('foo' in written) - self.assertTrue('bar' in written) - self.assertTrue('quux' in written) + self.assertTrue("foo" in written) + self.assertTrue("bar" in written) + self.assertTrue("quux" in written) - index_filename = cache.GetCacheFilename() + '.ixname' + index_filename = cache.GetCacheFilename() + ".ixname" with open(index_filename) as f: - self.assertEqual('bar\x0015\x00\n', f.readline()) - self.assertEqual('foo\x000\x00\x00\n', f.readline()) + self.assertEqual("bar\x0015\x00\n", f.readline()) + self.assertEqual("foo\x000\x00\x00\n", f.readline()) - index_filename = cache.GetCacheFilename() + '.ixuid' + index_filename = cache.GetCacheFilename() + ".ixuid" with open(index_filename) as f: - self.assertEqual('10\x000\x00\x00\n', f.readline()) - self.assertEqual('11\x0015\x00\n', f.readline()) + self.assertEqual("10\x000\x00\x00\n", f.readline()) + self.assertEqual("11\x0015\x00\n", f.readline()) cache.WriteIndex() - index_filename = cache.GetCacheFilename() + '.ixname' + index_filename = cache.GetCacheFilename() + ".ixname" with open(index_filename) as f: - self.assertEqual('bar\x0015\x00\x00\n', f.readline()) - self.assertEqual('foo\x000\x00\x00\x00\n', f.readline()) - self.assertEqual('quux\x0030\x00\n', f.readline()) + self.assertEqual("bar\x0015\x00\x00\n", f.readline()) + self.assertEqual("foo\x000\x00\x00\x00\n", f.readline()) + self.assertEqual("quux\x0030\x00\n", f.readline()) - index_filename = cache.GetCacheFilename() + '.ixuid' + index_filename = cache.GetCacheFilename() + ".ixuid" with open(index_filename) as f: - self.assertEqual('10\x000\x00\x00\n', f.readline()) - self.assertEqual('11\x0015\x00\n', f.readline()) - self.assertEqual('12\x0030\x00\n', f.readline()) + self.assertEqual("10\x000\x00\x00\n", f.readline()) + self.assertEqual("11\x0015\x00\n", f.readline()) + self.assertEqual("12\x0030\x00\n", f.readline()) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/nss_cache/command.py b/nss_cache/command.py index f1a33266..3b6e9cb1 100644 --- a/nss_cache/command.py +++ b/nss_cache/command.py @@ -13,8 +13,10 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Command objects.""" -__author__ = ('jaq@google.com (Jamie Wilkinson)', - 'vasilios@google.com (Vasilios Hoffman)') +__author__ = ( + "jaq@google.com (Jamie Wilkinson)", + "vasilios@google.com (Vasilios Hoffman)", +) import inspect from io import StringIO @@ -50,6 +52,7 @@ class Command(object): summary, then a complete description of the command. This is used as part of the help system. """ + # Well known exit codes. We reserve anything 30 and under for the # number of failed NSS maps (~15 defined under modern linux/glibc # implementations of named services. add fudge facter of 2 until I @@ -63,7 +66,7 @@ def __init__(self): # Setup logging. self.log = logging.getLogger(__name__) if self.__doc__ == Command.__doc__: - self.log.warning('No help message set for %r', self) + self.log.warning("No help message set for %r", self) # Setup command parser. self.parser = self._GetParser() # Attribute used to hold optional lock object. @@ -89,13 +92,14 @@ def _GetParser(self): parser.disable_interspersed_args() # commonly used options - parser.add_option('-m', - '--map', - action='append', - type='string', - dest='maps', - help='map to operate on, can be' - ' supplied multiple times') + parser.add_option( + "-m", + "--map", + action="append", + type="string", + dest="maps", + help="map to operate on, can be" " supplied multiple times", + ) return parser @@ -113,8 +117,9 @@ def Run(self, conf, args): 0 if the command was successful non-zero shell error code if not. """ - raise NotImplementedError('command %r not implemented' % - self.__class__.__name__) + raise NotImplementedError( + "command %r not implemented" % self.__class__.__name__ + ) def _Lock(self, path=None, force=False): """Grab a system-wide lock for this command. @@ -149,17 +154,17 @@ def Help(self, short=False): """Return the help message for this command.""" if self.__doc__ is Command.__doc__: return None - help_text = inspect.getdoc(self) + '\n' + help_text = inspect.getdoc(self) + "\n" if short: # only use the short summary first line - help_text = help_text.split('\n')[0] + help_text = help_text.split("\n")[0] else: # lose the short summary first line - help_text = '\n'.join(help_text.split('\n')[2:]) + help_text = "\n".join(help_text.split("\n")[2:]) help_buffer = StringIO() self.parser.print_help(file=help_buffer) # lose the first line, which is the usage line - help_text += '\n'.join(help_buffer.getvalue().split('\n')[1:]) + help_text += "\n".join(help_buffer.getvalue().split("\n")[1:]) return help_text @@ -173,35 +178,40 @@ class Update(Command): def __init__(self): """Initialize the argument parser for this command object.""" super(Update, self).__init__() - self.parser.add_option('-f', - '--full', - action='store_false', - help='force a full update from the data source', - dest='incremental', - default=True) - self.parser.add_option('-s', - '--sleep', - action='store', - type='int', - default=False, - dest='delay', - help='number of seconds to sleep before' - ' executing command') self.parser.add_option( - '--force-write', - action='store_true', + "-f", + "--full", + action="store_false", + help="force a full update from the data source", + dest="incremental", + default=True, + ) + self.parser.add_option( + "-s", + "--sleep", + action="store", + type="int", default=False, - dest='force_write', - help='force the update to write new maps, overriding' - ' safety checks, such as refusing to write empty' - 'maps.') + dest="delay", + help="number of seconds to sleep before" " executing command", + ) self.parser.add_option( - '--force-lock', - action='store_true', + "--force-write", + action="store_true", default=False, - dest='force_lock', - help='forcibly acquire the lock, and issue a SIGTERM' - 'to any nsscache process holding the lock.') + dest="force_write", + help="force the update to write new maps, overriding" + " safety checks, such as refusing to write empty" + "maps.", + ) + self.parser.add_option( + "--force-lock", + action="store_true", + default=False, + dest="force_lock", + help="forcibly acquire the lock, and issue a SIGTERM" + "to any nsscache process holding the lock.", + ) def Run(self, conf, args): """Run the Update command. @@ -221,28 +231,26 @@ def Run(self, conf, args): return e.code if options.maps: - self.log.info('Setting configured maps to %s', options.maps) + self.log.info("Setting configured maps to %s", options.maps) conf.maps = options.maps if not options.incremental: - self.log.debug('performing FULL update of caches') + self.log.debug("performing FULL update of caches") else: - self.log.debug('performing INCREMENTAL update of caches') + self.log.debug("performing INCREMENTAL update of caches") if options.delay: - self.log.info('Delaying %d seconds before executing', options.delay) + self.log.info("Delaying %d seconds before executing", options.delay) time.sleep(options.delay) - return self.UpdateMaps(conf, - incremental=options.incremental, - force_write=options.force_write, - force_lock=options.force_lock) + return self.UpdateMaps( + conf, + incremental=options.incremental, + force_write=options.force_write, + force_lock=options.force_lock, + ) - def UpdateMaps(self, - conf, - incremental, - force_write=False, - force_lock=False): + def UpdateMaps(self, conf, incremental, force_write=False, force_lock=False): """Update each configured map. For each configured map, create a source and cache object and @@ -259,20 +267,19 @@ def UpdateMaps(self, """ # Grab a lock before we continue! if not self._Lock(path=conf.lockfile, force=force_lock): - self.log.error('Failed to acquire lock, aborting!') + self.log.error("Failed to acquire lock, aborting!") return self.ERR_LOCK retval = 0 for map_name in conf.maps: if map_name not in conf.options: - self.log.error('No such map name defined in config: %s', - map_name) + self.log.error("No such map name defined in config: %s", map_name) return 1 if incremental: - self.log.info('Updating and verifying %s cache.', map_name) + self.log.info("Updating and verifying %s cache.", map_name) else: - self.log.info('Rebuilding and verifying %s cache.', map_name) + self.log.info("Rebuilding and verifying %s cache.", map_name) cache_options = conf.options[map_name].cache source_options = conf.options[map_name].source @@ -291,10 +298,11 @@ def UpdateMaps(self, # startup directory, but we convewrt them to absolute paths so that future # temp dirs do not mess with our output routines. old_cwd = os.getcwd() - tempdir = tempfile.mkdtemp(dir=cache_options['dir'], - prefix='nsscache-%s-' % map_name) - if not os.path.isabs(cache_options['dir']): - cache_options['dir'] = os.path.abspath(cache_options['dir']) + tempdir = tempfile.mkdtemp( + dir=cache_options["dir"], prefix="nsscache-%s-" % map_name + ) + if not os.path.isabs(cache_options["dir"]): + cache_options["dir"] = os.path.abspath(cache_options["dir"]) if not os.path.isabs(conf.timestamp_dir): conf.timestamp_dir = os.path.abspath(conf.timestamp_dir) if not os.path.isabs(tempdir): @@ -306,30 +314,29 @@ def UpdateMaps(self, try: source = source_factory.Create(source_options) - updater = self._Updater(map_name, source, cache_options, - conf) + updater = self._Updater(map_name, source, cache_options, conf) if incremental: - self.log.info('Updating and verifying %s cache.', - map_name) + self.log.info("Updating and verifying %s cache.", map_name) else: - self.log.info('Rebuilding and verifying %s cache.', - map_name) + self.log.info("Rebuilding and verifying %s cache.", map_name) - retval = updater.UpdateFromSource(source, - incremental=incremental, - force_write=force_write) + retval = updater.UpdateFromSource( + source, incremental=incremental, force_write=force_write + ) except error.PermissionDenied: self.log.error( - 'Permission denied: could not update map %r. Aborting', - map_name) + "Permission denied: could not update map %r. Aborting", + map_name, + ) retval += 1 except (error.EmptyMap, error.InvalidMap) as e: self.log.error(e) retval += 1 except error.InvalidMerge as e: - self.log.warning('Could not merge map %r: %s. Skipping.', - map_name, e) + self.log.warning( + "Could not merge map %r: %s. Skipping.", map_name, e + ) finally: # Start chdir cleanup os.chdir(old_cwd) @@ -343,25 +350,24 @@ def _Updater(self, map_name, source, cache_options, conf): # to determine which type of updater the source uses. At the moment # there's only two, so not a huge deal. If we add another we should # refactor though. - if hasattr(source, 'UPDATER') and source.UPDATER == config.UPDATER_FILE: + if hasattr(source, "UPDATER") and source.UPDATER == config.UPDATER_FILE: if map_name == config.MAP_AUTOMOUNT: return files_updater.FileAutomountUpdater( - map_name, conf.timestamp_dir, cache_options) + map_name, conf.timestamp_dir, cache_options + ) else: - return files_updater.FileMapUpdater(map_name, - conf.timestamp_dir, - cache_options, - can_do_incremental=True) + return files_updater.FileMapUpdater( + map_name, conf.timestamp_dir, cache_options, can_do_incremental=True + ) else: if map_name == config.MAP_AUTOMOUNT: - return map_updater.AutomountUpdater(map_name, - conf.timestamp_dir, - cache_options) + return map_updater.AutomountUpdater( + map_name, conf.timestamp_dir, cache_options + ) else: - return map_updater.MapUpdater(map_name, - conf.timestamp_dir, - cache_options, - can_do_incremental=True) + return map_updater.MapUpdater( + map_name, conf.timestamp_dir, cache_options, can_do_incremental=True + ) class Verify(Command): @@ -389,27 +395,26 @@ def Run(self, conf, args): return e.code if options.maps: - self.log.info('Setting configured maps to %s', options.maps) + self.log.info("Setting configured maps to %s", options.maps) conf.maps = options.maps (warnings, errors) = (0, 0) - self.log.info('Verifying program and system configuration.') + self.log.info("Verifying program and system configuration.") (config_warnings, config_errors) = config.VerifyConfiguration(conf) warnings += config_warnings errors += config_errors - self.log.info('Verifying data sources.') + self.log.info("Verifying data sources.") errors += self.VerifySources(conf) - self.log.info('Verifying data caches.') + self.log.info("Verifying data caches.") errors += self.VerifyMaps(conf) - self.log.info('Verification result: %d warnings, %d errors', warnings, - errors) + self.log.info("Verification result: %d warnings, %d errors", warnings, errors) if warnings + errors: - self.log.info('Verification failed!') + self.log.info("Verification failed!") else: - self.log.info('Verification passed!') + self.log.info("Verification passed!") return warnings + errors @@ -434,33 +439,33 @@ def VerifyMaps(self, conf): retval = 0 for map_name in conf.maps: - self.log.info('Verifying map: %s.', map_name) + self.log.info("Verifying map: %s.", map_name) # The netgroup map does not have an enumerator, # to test this we'd have to loop over the loaded cache map # and verify each entry is retrievable via getent directly. # TODO(blaed): apply fix from comment to allow for netgroup checking if map_name == config.MAP_NETGROUP: - self.log.info(('The netgroup map does not support enumeration, ' - 'skipping.')) + self.log.info( + ("The netgroup map does not support enumeration, " "skipping.") + ) continue # Automount maps do not support getent, we'll have to come up with # a good way to verify these. if map_name == config.MAP_AUTOMOUNT: self.log.info( - ('The automount map does not support enumeration, ' - 'skipping.')) + ("The automount map does not support enumeration, " "skipping.") + ) continue try: nss_map = nss.GetMap(map_name) except error.UnsupportedMap: - self.log.warning('Verification of %s map is unsupported!', - map_name) + self.log.warning("Verification of %s map is unsupported!", map_name) continue - self.log.debug('built NSS map of %d entries', len(nss_map)) + self.log.debug("built NSS map of %d entries", len(nss_map)) cache_options = conf.options[map_name].cache cache = cache_factory.Create(cache_options, map_name) @@ -468,11 +473,11 @@ def VerifyMaps(self, conf): try: cache_map = cache.GetMap() except error.CacheNotFound: - self.log.error('Cache missing!') + self.log.error("Cache missing!") retval += 1 continue - self.log.debug('built cache map of %d entries', len(cache_map)) + self.log.debug("built cache map of %d entries", len(cache_map)) # cache_map is a subset of nss_map due to possible other maps, # e.g. files, nis, ldap, etc. @@ -480,14 +485,17 @@ def VerifyMaps(self, conf): for map_entry in cache_map: if map_entry not in nss_map: self.log.info( - 'The following entry is present in the cache ' - 'but not availible via NSS! %s', map_entry.name) - self.log.debug('missing entry data: %s', map_entry) + "The following entry is present in the cache " + "but not availible via NSS! %s", + map_entry.name, + ) + self.log.debug("missing entry data: %s", map_entry) missing_entries += 1 if missing_entries > 0: - self.log.warning('Missing %d entries in %s map', - missing_entries, map_name) + self.log.warning( + "Missing %d entries in %s map", missing_entries, map_name + ) retval += 1 return retval @@ -506,13 +514,13 @@ def VerifySources(self, conf): try: source = source_factory.Create(source_options) except error.SourceUnavailable as e: - self.log.debug('map %s dumps source error %s', map_name, e) - self.log.error('Map %s is unvavailable!', map_name) + self.log.debug("map %s dumps source error %s", map_name, e) + self.log.error("Map %s is unvavailable!", map_name) retval += 1 continue retval += source.Verify() else: - self.log.error('No sources configured for any maps!') + self.log.error("No sources configured for any maps!") retval += 1 return retval @@ -543,15 +551,15 @@ def Run(self, conf, args): help_text = self.Help() else: help_command = args.pop() - print(('Usage: nsscache [global options] %s [options]' % - help_command)) + print(("Usage: nsscache [global options] %s [options]" % help_command)) print() try: - callable_action = getattr(inspect.getmodule(self), - help_command.capitalize()) + callable_action = getattr( + inspect.getmodule(self), help_command.capitalize() + ) help_text = callable_action().Help() except AttributeError: - print(('command %r is not implemented' % help_command)) + print(("command %r is not implemented" % help_command)) return 1 print(help_text) @@ -583,30 +591,31 @@ def Run(self, conf, args): return e.code if options.maps: - self.log.info('Setting configured maps to %s', options.maps) + self.log.info("Setting configured maps to %s", options.maps) conf.maps = options.maps (warnings, errors) = (0, 0) - self.log.info('Verifying program and system configuration.') + self.log.info("Verifying program and system configuration.") (config_warnings, config_errors) = config.VerifyConfiguration(conf) warnings += config_warnings errors += config_errors - self.log.info('Verifying data sources.') + self.log.info("Verifying data sources.") errors += Verify().VerifySources(conf) - self.log.info('verification: %d warnings, %d errors', warnings, errors) + self.log.info("verification: %d warnings, %d errors", warnings, errors) # Exit and report if config or source failed verification, because # we cannot reliably build a cache if either of these are faulty. if errors > 0: - self.log.error('Too many errors in verification tests failed;' - ' repair aborted!') + self.log.error( + "Too many errors in verification tests failed;" " repair aborted!" + ) return 1 # Rebuild local cache in full, which also verifies each cache. - self.log.info('Rebuilding and verifying caches: %s.', conf.maps) + self.log.info("Rebuilding and verifying caches: %s.", conf.maps) return Update().UpdateMaps(conf=conf, incremental=False) @@ -619,24 +628,31 @@ class Status(Command): def __init__(self): super(Status, self).__init__() - self.parser.add_option('--epoch', - action='store_true', - help='show timestamps in UNIX epoch time', - dest='epoch', - default=False) - self.parser.add_option('--template', - action='store', - help='Set format for output', - metavar='FORMAT', - dest='template', - default='NSS map: %(map)s\n%(key)s: %(value)s') - self.parser.add_option('--automount-template', - action='store', - help='Set format for automount output', - metavar='FORMAT', - dest='automount_template', - default=('NSS map: %(map)s\nAutomount map: ' - '%(automount)s\n%(key)s: %(value)s')) + self.parser.add_option( + "--epoch", + action="store_true", + help="show timestamps in UNIX epoch time", + dest="epoch", + default=False, + ) + self.parser.add_option( + "--template", + action="store", + help="Set format for output", + metavar="FORMAT", + dest="template", + default="NSS map: %(map)s\n%(key)s: %(value)s", + ) + self.parser.add_option( + "--automount-template", + action="store", + help="Set format for automount output", + metavar="FORMAT", + dest="automount_template", + default=( + "NSS map: %(map)s\nAutomount map: " "%(automount)s\n%(key)s: %(value)s" + ), + ) def Run(self, conf, args): """Run the Status command. @@ -657,33 +673,31 @@ def Run(self, conf, args): return e.code if options.maps: - self.log.info('Setting configured maps to %s', options.maps) + self.log.info("Setting configured maps to %s", options.maps) conf.maps = options.maps for map_name in conf.maps: # Hardcoded to support the two-tier structure of automount maps if map_name == config.MAP_AUTOMOUNT: - value_list = self.GetAutomountMapMetadata(conf, - epoch=options.epoch) - self.log.debug('Value list: %r', value_list) + value_list = self.GetAutomountMapMetadata(conf, epoch=options.epoch) + self.log.debug("Value list: %r", value_list) for value_dict in value_list: - self.log.debug('Value dict: %r', value_dict) + self.log.debug("Value dict: %r", value_dict) output = options.automount_template % value_dict print(output) else: for value_dict in self.GetSingleMapMetadata( - map_name, conf, epoch=options.epoch): - self.log.debug('Value dict: %r', value_dict) + map_name, conf, epoch=options.epoch + ): + self.log.debug("Value dict: %r", value_dict) output = options.template % value_dict print(output) return os.EX_OK - def GetSingleMapMetadata(self, - map_name, - conf, - automount_mountpoint=None, - epoch=False): + def GetSingleMapMetadata( + self, map_name, conf, automount_mountpoint=None, epoch=False + ): """Return metadata from map specified. Args: @@ -698,20 +712,22 @@ def GetSingleMapMetadata(self, """ cache_options = conf.options[map_name].cache - updater = map_updater.MapUpdater(map_name, conf.timestamp_dir, - cache_options, automount_mountpoint) + updater = map_updater.MapUpdater( + map_name, conf.timestamp_dir, cache_options, automount_mountpoint + ) - modify_dict = {'key': 'last-modify-timestamp', 'map': map_name} - update_dict = {'key': 'last-update-timestamp', 'map': map_name} + modify_dict = {"key": "last-modify-timestamp", "map": map_name} + update_dict = {"key": "last-update-timestamp", "map": map_name} if map_name == config.MAP_AUTOMOUNT: # have to find out *which* automount map from a cache object! cache = cache_factory.Create( cache_options, config.MAP_AUTOMOUNT, - automount_mountpoint=automount_mountpoint) + automount_mountpoint=automount_mountpoint, + ) automount = cache.GetMapLocation() - modify_dict['automount'] = automount - update_dict['automount'] = automount + modify_dict["automount"] = automount + update_dict["automount"] = automount last_modify_timestamp = updater.GetModifyTimestamp() or 0 last_update_timestamp = updater.GetUpdateTimestamp() or 0 @@ -721,17 +737,19 @@ def GetSingleMapMetadata(self, # the only place such a conversion is appropriate. if last_modify_timestamp: last_modify_timestamp = time.asctime( - time.localtime(last_modify_timestamp)) + time.localtime(last_modify_timestamp) + ) else: - last_modify_timestamp = 'Unknown' + last_modify_timestamp = "Unknown" if last_update_timestamp: last_update_timestamp = time.asctime( - time.localtime(last_update_timestamp)) + time.localtime(last_update_timestamp) + ) else: - last_update_timestamp = 'Unknown' + last_update_timestamp = "Unknown" - modify_dict['value'] = last_modify_timestamp - update_dict['value'] = last_update_timestamp + modify_dict["value"] = last_modify_timestamp + update_dict["value"] = last_update_timestamp return [modify_dict, update_dict] @@ -755,21 +773,21 @@ def GetAutomountMapMetadata(self, conf, epoch=False): # get the value_dict for the master map, note that automount_mountpoint=None # defaults to the master map! - values = self.GetSingleMapMetadata(map_name, - conf, - automount_mountpoint=None, - epoch=epoch) + values = self.GetSingleMapMetadata( + map_name, conf, automount_mountpoint=None, epoch=epoch + ) value_list.extend(values) # now get the contents of the master map, and get the status for each map # we find - cache = cache_factory.Create(cache_options, - config.MAP_AUTOMOUNT, - automount_mountpoint=None) + cache = cache_factory.Create( + cache_options, config.MAP_AUTOMOUNT, automount_mountpoint=None + ) master_map = cache.GetMap() for map_entry in master_map: values = self.GetSingleMapMetadata( - map_name, conf, automount_mountpoint=map_entry.key, epoch=epoch) + map_name, conf, automount_mountpoint=map_entry.key, epoch=epoch + ) value_list.extend(values) return value_list diff --git a/nss_cache/command_test.py b/nss_cache/command_test.py index f9044106..830bdba6 100644 --- a/nss_cache/command_test.py +++ b/nss_cache/command_test.py @@ -15,8 +15,10 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Unit tests for nss_cache/command.py.""" -__author__ = ('jaq@google.com (Jamie Wilkinson)', - 'vasilios@google.com (Vasilios Hoffman)') +__author__ = ( + "jaq@google.com (Jamie Wilkinson)", + "vasilios@google.com (Vasilios Hoffman)", +) import grp import os @@ -54,12 +56,12 @@ def testRunCommand(self): c = command.Command() self.assertRaises(NotImplementedError, c.Run, [], {}) - @unittest.skip('badly mocked') + @unittest.skip("badly mocked") def testLock(self): - self.mox.StubOutClassWithMocks(lock, 'PidFile') + self.mox.StubOutClassWithMocks(lock, "PidFile") mock_lock = lock.PidFile(filename=None) - mock_lock.Lock(force=False).AndReturn('LOCK') - mock_lock.Lock(force=False).AndReturn('MORLOCK') + mock_lock.Lock(force=False).AndReturn("LOCK") + mock_lock.Lock(force=False).AndReturn("MORLOCK") mock_lock.Locked().AndReturn(True) mock_lock.Unlock() @@ -68,26 +70,26 @@ def testLock(self): c = command.Command() # First test that we create a lock and lock it. - self.assertEqual('LOCK', c._Lock()) + self.assertEqual("LOCK", c._Lock()) # Then we test that we lock the existing one a second time. - self.assertEqual('MORLOCK', c._Lock()) + self.assertEqual("MORLOCK", c._Lock()) - @unittest.skip('badly mocked') + @unittest.skip("badly mocked") def testForceLock(self): - self.mox.StubOutClassWithMocks(lock, 'PidFile') + self.mox.StubOutClassWithMocks(lock, "PidFile") mock_lock = lock.PidFile(filename=None) - mock_lock.Lock(force=True).AndReturn('LOCK') + mock_lock.Lock(force=True).AndReturn("LOCK") mock_lock.Locked().AndReturn(True) mock_lock.Unlock() self.mox.ReplayAll() c = command.Command() - self.assertEqual('LOCK', c._Lock(force=True)) + self.assertEqual("LOCK", c._Lock(force=True)) - @unittest.skip('badly mocked') + @unittest.skip("badly mocked") def testUnlock(self): - self.mox.StubOutClassWithMocks(lock, 'PidFile') + self.mox.StubOutClassWithMocks(lock, "PidFile") mock_lock = lock.PidFile(filename=None) mock_lock.Lock(force=False).AndReturn(True) mock_lock.Locked().AndReturn(True) @@ -105,7 +107,6 @@ def testCommandHelp(self): self.assertEqual(None, c.Help()) def testDummyCommand(self): - class Dummy(command.Command): """Dummy docstring for dummy command.""" @@ -130,18 +131,18 @@ class DummyConfig(object): self.conf = DummyConfig() self.conf.options = { config.MAP_PASSWORD: config.MapOptions(), - config.MAP_AUTOMOUNT: config.MapOptions() + config.MAP_AUTOMOUNT: config.MapOptions(), } self.conf.options[config.MAP_PASSWORD].cache = { - 'name': 'dummy', - 'dir': self.workdir + "name": "dummy", + "dir": self.workdir, } - self.conf.options[config.MAP_PASSWORD].source = {'name': 'dummy'} + self.conf.options[config.MAP_PASSWORD].source = {"name": "dummy"} self.conf.options[config.MAP_AUTOMOUNT].cache = { - 'name': 'dummy', - 'dir': self.workdir + "name": "dummy", + "dir": self.workdir, } - self.conf.options[config.MAP_AUTOMOUNT].source = {'name': 'dummy'} + self.conf.options[config.MAP_AUTOMOUNT].source = {"name": "dummy"} self.conf.timestamp_dir = self.workdir self.conf.lockfile = None @@ -160,11 +161,10 @@ def testHelp(self): def testRunWithNoParameters(self): c = command.Update() - self.mox.StubOutWithMock(c, 'UpdateMaps') - c.UpdateMaps(self.conf, - incremental=True, - force_lock=False, - force_write=False).AndReturn(0) + self.mox.StubOutWithMock(c, "UpdateMaps") + c.UpdateMaps( + self.conf, incremental=True, force_lock=False, force_write=False + ).AndReturn(0) self.mox.ReplayAll() self.assertEqual(0, c.Run(self.conf, [])) @@ -176,146 +176,155 @@ def testRunWithBadParameters(self): dev_null = StringIO() stderr = sys.stderr sys.stderr = dev_null - self.assertEqual(2, c.Run(None, ['--invalid'])) + self.assertEqual(2, c.Run(None, ["--invalid"])) sys.stderr = stderr def testRunWithFlags(self): c = command.Update() - self.mox.StubOutWithMock(c, 'UpdateMaps') - c.UpdateMaps(self.conf, - incremental=False, - force_lock=True, - force_write=True).AndReturn(0) + self.mox.StubOutWithMock(c, "UpdateMaps") + c.UpdateMaps( + self.conf, incremental=False, force_lock=True, force_write=True + ).AndReturn(0) self.mox.ReplayAll() self.assertEqual( 0, - c.Run(self.conf, [ - '-m', config.MAP_PASSWORD, '-f', '--force-write', '--force-lock' - ])) - self.assertEqual(['passwd'], self.conf.maps) + c.Run( + self.conf, + ["-m", config.MAP_PASSWORD, "-f", "--force-write", "--force-lock"], + ), + ) + self.assertEqual(["passwd"], self.conf.maps) def testUpdateSingleMaps(self): - self.mox.StubOutClassWithMocks(lock, 'PidFile') + self.mox.StubOutClassWithMocks(lock, "PidFile") lock_mock = lock.PidFile(filename=None) lock_mock.Lock(force=False).AndReturn(True) lock_mock.Locked().AndReturn(True) lock_mock.Unlock() self.conf.maps = [config.MAP_PASSWORD] - self.conf.cache = 'dummy' + self.conf.cache = "dummy" modify_stamp = 1 - map_entry = passwd.PasswdMapEntry({'name': 'foo', 'uid': 10, 'gid': 10}) + map_entry = passwd.PasswdMapEntry({"name": "foo", "uid": 10, "gid": 10}) passwd_map = passwd.PasswdMap([map_entry]) passwd_map.SetModifyTimestamp(modify_stamp) source_mock = self.mox.CreateMock(source.Source) - source_mock.GetMap(config.MAP_PASSWORD, - location=None).AndReturn(passwd_map) + source_mock.GetMap(config.MAP_PASSWORD, location=None).AndReturn(passwd_map) - self.mox.StubOutWithMock(source_factory, 'Create') - source_factory.Create(self.conf.options[ - config.MAP_PASSWORD].source).AndReturn(source_mock) + self.mox.StubOutWithMock(source_factory, "Create") + source_factory.Create(self.conf.options[config.MAP_PASSWORD].source).AndReturn( + source_mock + ) cache_mock = self.mox.CreateMock(caches.Cache) cache_mock.WriteMap(map_data=passwd_map, force_write=False).AndReturn(0) - self.mox.StubOutWithMock(cache_factory, 'Create') - cache_factory.Create(self.conf.options[config.MAP_PASSWORD].cache, - config.MAP_PASSWORD).AndReturn(cache_mock) + self.mox.StubOutWithMock(cache_factory, "Create") + cache_factory.Create( + self.conf.options[config.MAP_PASSWORD].cache, config.MAP_PASSWORD + ).AndReturn(cache_mock) self.mox.ReplayAll() c = command.Update() self.assertEqual( - 0, c.UpdateMaps(self.conf, incremental=True, force_write=False)) + 0, c.UpdateMaps(self.conf, incremental=True, force_write=False) + ) def testUpdateAutomounts(self): - self.mox.StubOutClassWithMocks(lock, 'PidFile') + self.mox.StubOutClassWithMocks(lock, "PidFile") lock_mock = lock.PidFile(filename=None) lock_mock.Lock(force=False).AndReturn(True) lock_mock.Locked().AndReturn(True) lock_mock.Unlock() self.conf.maps = [config.MAP_AUTOMOUNT] - self.conf.cache = 'dummy' + self.conf.cache = "dummy" modify_stamp = 1 map_entry = automount.AutomountMapEntry() - map_entry.key = '/home' - map_entry.location = 'foo' + map_entry.key = "/home" + map_entry.location = "foo" automount_map = automount.AutomountMap([map_entry]) automount_map.SetModifyTimestamp(modify_stamp) source_mock = self.mox.CreateMock(source.Source) source_mock.GetAutomountMasterMap().AndReturn(automount_map) - source_mock.GetMap(config.MAP_AUTOMOUNT, - location='foo').AndReturn(automount_map) + source_mock.GetMap(config.MAP_AUTOMOUNT, location="foo").AndReturn( + automount_map + ) - self.mox.StubOutWithMock(source_factory, 'Create') - source_factory.Create(self.conf.options[ - config.MAP_PASSWORD].source).AndReturn(source_mock) + self.mox.StubOutWithMock(source_factory, "Create") + source_factory.Create(self.conf.options[config.MAP_PASSWORD].source).AndReturn( + source_mock + ) cache_mock = self.mox.CreateMock(caches.Cache) - cache_mock.GetMapLocation().AndReturn('home') - cache_mock.WriteMap(map_data=automount_map, - force_write=False).AndReturn(0) - cache_mock.WriteMap(map_data=automount_map, - force_write=False).AndReturn(0) - - self.mox.StubOutWithMock(cache_factory, 'Create') - cache_factory.Create(self.conf.options[config.MAP_AUTOMOUNT].cache, - config.MAP_AUTOMOUNT, - automount_mountpoint='/home').AndReturn(cache_mock) - cache_factory.Create(self.conf.options[config.MAP_AUTOMOUNT].cache, - config.MAP_AUTOMOUNT, - automount_mountpoint=None).AndReturn(cache_mock) + cache_mock.GetMapLocation().AndReturn("home") + cache_mock.WriteMap(map_data=automount_map, force_write=False).AndReturn(0) + cache_mock.WriteMap(map_data=automount_map, force_write=False).AndReturn(0) + + self.mox.StubOutWithMock(cache_factory, "Create") + cache_factory.Create( + self.conf.options[config.MAP_AUTOMOUNT].cache, + config.MAP_AUTOMOUNT, + automount_mountpoint="/home", + ).AndReturn(cache_mock) + cache_factory.Create( + self.conf.options[config.MAP_AUTOMOUNT].cache, + config.MAP_AUTOMOUNT, + automount_mountpoint=None, + ).AndReturn(cache_mock) self.mox.ReplayAll() c = command.Update() self.assertEqual( - 0, c.UpdateMaps(self.conf, incremental=True, force_write=False)) + 0, c.UpdateMaps(self.conf, incremental=True, force_write=False) + ) def testUpdateMapsTrapsPermissionDenied(self): - self.mox.StubOutWithMock(map_updater.MapUpdater, 'UpdateFromSource') - map_updater.MapUpdater.UpdateFromSource(mox.IgnoreArg(), - incremental=True, - force_write=False).AndRaise( - error.PermissionDenied) + self.mox.StubOutWithMock(map_updater.MapUpdater, "UpdateFromSource") + map_updater.MapUpdater.UpdateFromSource( + mox.IgnoreArg(), incremental=True, force_write=False + ).AndRaise(error.PermissionDenied) - self.mox.StubOutClassWithMocks(lock, 'PidFile') + self.mox.StubOutClassWithMocks(lock, "PidFile") lock_mock = lock.PidFile(filename=None) lock_mock.Lock(force=False).AndReturn(True) lock_mock.Locked().AndReturn(True) lock_mock.Unlock() self.conf.maps = [config.MAP_PASSWORD] - self.conf.cache = 'dummy' + self.conf.cache = "dummy" modify_stamp = 1 - map_entry = passwd.PasswdMapEntry({'name': 'foo', 'uid': 10, 'gid': 10}) + map_entry = passwd.PasswdMapEntry({"name": "foo", "uid": 10, "gid": 10}) passwd_map = passwd.PasswdMap([map_entry]) passwd_map.SetModifyTimestamp(modify_stamp) source_mock = self.mox.CreateMock(source.Source) - self.mox.StubOutWithMock(source_factory, 'Create') - source_factory.Create(self.conf.options[ - config.MAP_PASSWORD].source).AndReturn(source_mock) + self.mox.StubOutWithMock(source_factory, "Create") + source_factory.Create(self.conf.options[config.MAP_PASSWORD].source).AndReturn( + source_mock + ) cache_mock = self.mox.CreateMock(caches.Cache) - self.mox.StubOutWithMock(cache_factory, 'Create') + self.mox.StubOutWithMock(cache_factory, "Create") self.mox.ReplayAll() c = command.Update() self.assertEqual( - 1, c.UpdateMaps(self.conf, incremental=True, force_write=False)) + 1, c.UpdateMaps(self.conf, incremental=True, force_write=False) + ) def testUpdateMapsCanForceLock(self): - self.mox.StubOutClassWithMocks(lock, 'PidFile') + self.mox.StubOutClassWithMocks(lock, "PidFile") lock_mock = lock.PidFile(filename=None) lock_mock.Lock(force=True).AndReturn(False) lock_mock.Locked().AndReturn(True) @@ -324,73 +333,76 @@ def testUpdateMapsCanForceLock(self): self.mox.ReplayAll() c = command.Update() - self.assertEqual(c.UpdateMaps(self.conf, False, force_lock=True), - c.ERR_LOCK) + self.assertEqual(c.UpdateMaps(self.conf, False, force_lock=True), c.ERR_LOCK) def testSleep(self): - self.mox.StubOutWithMock(time, 'sleep') + self.mox.StubOutWithMock(time, "sleep") time.sleep(1) c = command.Update() - self.mox.StubOutWithMock(c, 'UpdateMaps') - c.UpdateMaps(self.conf, - incremental=True, - force_lock=mox.IgnoreArg(), - force_write=mox.IgnoreArg()).AndReturn(0) + self.mox.StubOutWithMock(c, "UpdateMaps") + c.UpdateMaps( + self.conf, + incremental=True, + force_lock=mox.IgnoreArg(), + force_write=mox.IgnoreArg(), + ).AndReturn(0) self.mox.ReplayAll() - c.Run(self.conf, ['-s', '1']) + c.Run(self.conf, ["-s", "1"]) def testForceWriteFlag(self): c = command.Update() (options, _) = c.parser.parse_args([]) self.assertEqual(False, options.force_write) - (options, _) = c.parser.parse_args(['--force-write']) + (options, _) = c.parser.parse_args(["--force-write"]) self.assertEqual(True, options.force_write) def testForceLockFlag(self): c = command.Update() (options, _) = c.parser.parse_args([]) self.assertEqual(False, options.force_lock) - (options, _) = c.parser.parse_args(['--force-lock']) + (options, _) = c.parser.parse_args(["--force-lock"]) self.assertEqual(True, options.force_lock) def testForceWriteFlagCallsUpdateMapsWithForceWriteTrue(self): c = command.Update() - self.mox.StubOutWithMock(c, 'UpdateMaps') - c.UpdateMaps(self.conf, - incremental=mox.IgnoreArg(), - force_lock=mox.IgnoreArg(), - force_write=True).AndReturn(0) + self.mox.StubOutWithMock(c, "UpdateMaps") + c.UpdateMaps( + self.conf, + incremental=mox.IgnoreArg(), + force_lock=mox.IgnoreArg(), + force_write=True, + ).AndReturn(0) self.mox.ReplayAll() - self.assertEqual(0, c.Run(self.conf, ['--force-write'])) + self.assertEqual(0, c.Run(self.conf, ["--force-write"])) def testForceLockFlagCallsUpdateMapsWithForceLockTrue(self): c = command.Update() - self.mox.StubOutWithMock(c, 'UpdateMaps') - c.UpdateMaps(self.conf, - incremental=mox.IgnoreArg(), - force_lock=True, - force_write=mox.IgnoreArg()).AndReturn(0) + self.mox.StubOutWithMock(c, "UpdateMaps") + c.UpdateMaps( + self.conf, + incremental=mox.IgnoreArg(), + force_lock=True, + force_write=mox.IgnoreArg(), + ).AndReturn(0) self.mox.ReplayAll() - self.assertEqual(0, c.Run(self.conf, ['--force-lock'])) + self.assertEqual(0, c.Run(self.conf, ["--force-lock"])) def testUpdateMapsWithBadMapName(self): c = command.Update() - self.mox.StubOutWithMock(c, '_Lock') + self.mox.StubOutWithMock(c, "_Lock") c._Lock(force=False, path=None).AndReturn(True) self.mox.ReplayAll() # Create an invalid map name. - self.assertEqual( - 1, c.Run(self.conf, ['-m', config.MAP_PASSWORD + 'invalid'])) + self.assertEqual(1, c.Run(self.conf, ["-m", config.MAP_PASSWORD + "invalid"])) class TestVerifyCommand(mox.MoxTestBase): - def setUp(self): super(TestVerifyCommand, self).setUp() @@ -398,7 +410,7 @@ class DummyConfig(object): pass class DummySource(source.Source): - name = 'dummy' + name = "dummy" def Verify(self): return 0 @@ -414,8 +426,8 @@ def Verify(self): # Create a config with a section for a passwd map. self.conf = DummyConfig() self.conf.options = {config.MAP_PASSWORD: config.MapOptions()} - self.conf.options[config.MAP_PASSWORD].cache = {'name': 'dummy'} - self.conf.options[config.MAP_PASSWORD].source = {'name': 'dummy'} + self.conf.options[config.MAP_PASSWORD].cache = {"name": "dummy"} + self.conf.options[config.MAP_PASSWORD].source = {"name": "dummy"} self.original_verify_configuration = config.VerifyConfiguration self.original_getmap = nss.GetMap @@ -425,12 +437,12 @@ def Verify(self): # Setup maps used by VerifyMap testing. big_map = passwd.PasswdMap() map_entry1 = passwd.PasswdMapEntry() - map_entry1.name = 'foo' + map_entry1.name = "foo" map_entry1.uid = 10 map_entry1.gid = 10 big_map.Add(map_entry1) map_entry2 = passwd.PasswdMapEntry() - map_entry2.name = 'bar' + map_entry2.name = "bar" map_entry2.uid = 20 map_entry2.gid = 20 big_map.Add(map_entry2) @@ -459,7 +471,6 @@ def testHelp(self): self.assertNotEqual(None, c.Help()) def testRunWithNoParameters(self): - def FakeVerifyConfiguration(conf): """Assert that we call VerifyConfiguration correctly.""" self.assertEqual(conf, self.conf) @@ -486,11 +497,10 @@ def testRunWithBadParameters(self): dev_null = StringIO() stderr = sys.stderr sys.stderr = dev_null - self.assertEqual(2, c.Run(None, ['--invalid'])) + self.assertEqual(2, c.Run(None, ["--invalid"])) sys.stderr = stderr def testRunWithParameters(self): - def FakeVerifyConfiguration(conf): """Assert that we call VerifyConfiguration correctly.""" self.assertEqual(conf, self.conf) @@ -506,19 +516,20 @@ def FakeVerifyMaps(conf): c = command.Verify() c.VerifyMaps = FakeVerifyMaps - self.assertEqual(0, c.Run(self.conf, ['-m', config.MAP_PASSWORD])) + self.assertEqual(0, c.Run(self.conf, ["-m", config.MAP_PASSWORD])) def testVerifyMapsSucceedsOnGoodMaps(self): cache_mock = self.mox.CreateMock(caches.Cache) cache_mock.GetMap().AndReturn(self.small_map) - self.mox.StubOutWithMock(cache_factory, 'Create') - cache_factory.Create(self.conf.options[config.MAP_PASSWORD].cache, - config.MAP_PASSWORD).AndReturn(cache_mock) + self.mox.StubOutWithMock(cache_factory, "Create") + cache_factory.Create( + self.conf.options[config.MAP_PASSWORD].cache, config.MAP_PASSWORD + ).AndReturn(cache_mock) self.conf.maps = [config.MAP_PASSWORD] - self.mox.StubOutWithMock(nss, 'GetMap') + self.mox.StubOutWithMock(nss, "GetMap") nss.GetMap(config.MAP_PASSWORD).AndReturn(self.big_map) self.mox.ReplayAll() @@ -531,13 +542,14 @@ def testVerifyMapsBad(self): cache_mock = self.mox.CreateMock(caches.Cache) cache_mock.GetMap().AndReturn(self.big_map) - self.mox.StubOutWithMock(cache_factory, 'Create') - cache_factory.Create(self.conf.options[config.MAP_PASSWORD].cache, - config.MAP_PASSWORD).AndReturn(cache_mock) + self.mox.StubOutWithMock(cache_factory, "Create") + cache_factory.Create( + self.conf.options[config.MAP_PASSWORD].cache, config.MAP_PASSWORD + ).AndReturn(cache_mock) self.conf.maps = [config.MAP_PASSWORD] - self.mox.StubOutWithMock(nss, 'GetMap') + self.mox.StubOutWithMock(nss, "GetMap") nss.GetMap(config.MAP_PASSWORD).AndReturn(self.small_map) self.mox.ReplayAll() @@ -550,13 +562,14 @@ def testVerifyMapsException(self): cache_mock = self.mox.CreateMock(caches.Cache) cache_mock.GetMap().AndRaise(error.CacheNotFound) - self.mox.StubOutWithMock(cache_factory, 'Create') - cache_factory.Create(self.conf.options[config.MAP_PASSWORD].cache, - config.MAP_PASSWORD).AndReturn(cache_mock) + self.mox.StubOutWithMock(cache_factory, "Create") + cache_factory.Create( + self.conf.options[config.MAP_PASSWORD].cache, config.MAP_PASSWORD + ).AndReturn(cache_mock) self.conf.maps = [config.MAP_PASSWORD] - self.mox.StubOutWithMock(nss, 'GetMap') + self.mox.StubOutWithMock(nss, "GetMap") nss.GetMap(config.MAP_PASSWORD).AndReturn(self.small_map) self.mox.ReplayAll() @@ -566,11 +579,11 @@ def testVerifyMapsException(self): self.assertEqual(1, c.VerifyMaps(self.conf)) def testVerifyMapsSkipsNetgroups(self): - self.mox.StubOutWithMock(cache_factory, 'Create') + self.mox.StubOutWithMock(cache_factory, "Create") self.conf.maps = [config.MAP_NETGROUP] - self.mox.StubOutWithMock(nss, 'GetMap') + self.mox.StubOutWithMock(nss, "GetMap") self.mox.ReplayAll() @@ -582,7 +595,7 @@ def testVerifySourcesGood(self): source_mock = self.mox.CreateMock(source.Source) source_mock.Verify().AndReturn(0) - self.mox.StubOutWithMock(source_factory, 'Create') + self.mox.StubOutWithMock(source_factory, "Create") source_factory.Create(mox.IgnoreArg()).AndReturn(source_mock) self.conf.maps = [config.MAP_PASSWORD] @@ -597,9 +610,10 @@ def testVerifySourcesBad(self): source_mock = self.mox.CreateMock(source.Source) source_mock.Verify().AndReturn(1) - self.mox.StubOutWithMock(source_factory, 'Create') - source_factory.Create( - self.conf.options[config.MAP_PASSWORD].cache).AndReturn(source_mock) + self.mox.StubOutWithMock(source_factory, "Create") + source_factory.Create(self.conf.options[config.MAP_PASSWORD].cache).AndReturn( + source_mock + ) self.conf.maps = [config.MAP_PASSWORD] @@ -613,8 +627,7 @@ def testVerifySourcesTrapsSourceUnavailable(self): def FakeCreate(conf): """Stub routine returning a pmock to test VerifySources.""" - self.assertEqual(conf, - self.conf.options[config.MAP_PASSWORD].source) + self.assertEqual(conf, self.conf.options[config.MAP_PASSWORD].source) raise error.SourceUnavailable old_source_base_create = source_factory.Create @@ -627,14 +640,12 @@ def FakeCreate(conf): class TestRepairCommand(unittest.TestCase): - def setUp(self): - class DummyConfig(object): pass class DummySource(source.Source): - name = 'dummy' + name = "dummy" def Verify(self): return 0 @@ -644,8 +655,8 @@ def Verify(self): self.conf = DummyConfig() self.conf.options = {config.MAP_PASSWORD: config.MapOptions()} - self.conf.options[config.MAP_PASSWORD].cache = {'name': 'dummy'} - self.conf.options[config.MAP_PASSWORD].source = {'name': 'dummy'} + self.conf.options[config.MAP_PASSWORD].cache = {"name": "dummy"} + self.conf.options[config.MAP_PASSWORD].source = {"name": "dummy"} self.original_verify_configuration = config.VerifyConfiguration @@ -681,11 +692,10 @@ def testRunWithBadParameters(self): dev_null = StringIO() stderr = sys.stderr sys.stderr = dev_null - self.assertEqual(2, c.Run(None, ['--invalid'])) + self.assertEqual(2, c.Run(None, ["--invalid"])) sys.stderr = stderr def testRunWithParameters(self): - def FakeVerifyConfiguration(conf): """Assert that we call VerifyConfiguration correctly.""" self.assertEqual(conf, self.conf) @@ -695,11 +705,10 @@ def FakeVerifyConfiguration(conf): c = command.Repair() - self.assertEqual(1, c.Run(self.conf, ['-m', config.MAP_PASSWORD])) + self.assertEqual(1, c.Run(self.conf, ["-m", config.MAP_PASSWORD])) class TestHelpCommand(unittest.TestCase): - def setUp(self): self.stdout = sys.stdout sys.stdout = StringIO() @@ -717,11 +726,10 @@ def testRunWithNoParameters(self): def testRunHelpHelp(self): c = command.Help() - self.assertEqual(0, c.Run(None, ['help'])) + self.assertEqual(0, c.Run(None, ["help"])) class TestStatusCommand(mox.MoxTestBase): - def setUp(self): super(TestStatusCommand, self).setUp() @@ -729,14 +737,13 @@ class DummyConfig(object): pass class DummySource(source.Source): - name = 'dummy' + name = "dummy" def Verify(self): return 0 # stub out parts of update.MapUpdater class DummyUpdater(map_updater.MapUpdater): - def GetModifyTimestamp(self): return 1 @@ -747,15 +754,15 @@ def GetUpdateTimestamp(self): source_factory.RegisterImplementation(DummySource) self.conf = DummyConfig() - self.conf.timestamp_dir = 'TEST_DIR' + self.conf.timestamp_dir = "TEST_DIR" self.conf.options = { config.MAP_PASSWORD: config.MapOptions(), - config.MAP_AUTOMOUNT: config.MapOptions() + config.MAP_AUTOMOUNT: config.MapOptions(), } - self.conf.options[config.MAP_PASSWORD].cache = {'name': 'dummy'} - self.conf.options[config.MAP_PASSWORD].source = {'name': 'dummy'} - self.conf.options[config.MAP_AUTOMOUNT].cache = {'name': 'dummy'} - self.conf.options[config.MAP_AUTOMOUNT].source = {'name': 'dummy'} + self.conf.options[config.MAP_PASSWORD].cache = {"name": "dummy"} + self.conf.options[config.MAP_PASSWORD].source = {"name": "dummy"} + self.conf.options[config.MAP_AUTOMOUNT].cache = {"name": "dummy"} + self.conf.options[config.MAP_AUTOMOUNT].source = {"name": "dummy"} self.original_verify_configuration = config.VerifyConfiguration self.original_create = cache_factory.Create @@ -786,7 +793,7 @@ def testRunWithBadParameters(self): dev_null = StringIO() stderr = sys.stderr sys.stderr = dev_null - self.assertEqual(2, c.Run(None, ['--invalid'])) + self.assertEqual(2, c.Run(None, ["--invalid"])) sys.stderr = stderr def testEpochFormatParameter(self): @@ -802,66 +809,62 @@ def testObeysMapsFlag(self): sys.stdout = stdout_buffer c = command.Status() - self.assertEqual(0, c.Run(self.conf, ['-m', 'passwd'])) + self.assertEqual(0, c.Run(self.conf, ["-m", "passwd"])) sys.stdout = old_stdout self.assertNotEqual(0, len(stdout_buffer.getvalue())) - self.assertFalse(stdout_buffer.getvalue().find('group') >= 0) + self.assertFalse(stdout_buffer.getvalue().find("group") >= 0) def testGetSingleMapMetadata(self): # test both automount and non-automount maps. # cache mock is returned by FakeCreate() for automount maps cache_mock = self.mox.CreateMock(caches.Cache) - cache_mock.GetMapLocation().AndReturn('/etc/auto.master') + cache_mock.GetMapLocation().AndReturn("/etc/auto.master") - self.mox.StubOutWithMock(cache_factory, 'Create') + self.mox.StubOutWithMock(cache_factory, "Create") cache_factory.Create( self.conf.options[config.MAP_AUTOMOUNT].cache, config.MAP_AUTOMOUNT, - automount_mountpoint='automount_mountpoint').AndReturn(cache_mock) + automount_mountpoint="automount_mountpoint", + ).AndReturn(cache_mock) self.mox.ReplayAll() c = command.Status() values = c.GetSingleMapMetadata(config.MAP_PASSWORD, self.conf) - self.assertTrue('map' in values[0]) - self.assertTrue('key' in values[0]) - self.assertTrue('value' in values[0]) + self.assertTrue("map" in values[0]) + self.assertTrue("key" in values[0]) + self.assertTrue("value" in values[0]) values = c.GetSingleMapMetadata( - config.MAP_AUTOMOUNT, - self.conf, - automount_mountpoint='automount_mountpoint') + config.MAP_AUTOMOUNT, self.conf, automount_mountpoint="automount_mountpoint" + ) - self.assertTrue('map' in values[0]) - self.assertTrue('key' in values[0]) - self.assertTrue('value' in values[0]) - self.assertTrue('automount' in values[0]) + self.assertTrue("map" in values[0]) + self.assertTrue("key" in values[0]) + self.assertTrue("value" in values[0]) + self.assertTrue("automount" in values[0]) def testGetSingleMapMetadataTimestampEpoch(self): c = command.Status() - values = c.GetSingleMapMetadata(config.MAP_PASSWORD, - self.conf, - epoch=True) - self.assertTrue('map' in values[0]) - self.assertTrue('key' in values[0]) - self.assertTrue('value' in values[0]) + values = c.GetSingleMapMetadata(config.MAP_PASSWORD, self.conf, epoch=True) + self.assertTrue("map" in values[0]) + self.assertTrue("key" in values[0]) + self.assertTrue("value" in values[0]) # values below are returned by dummyupdater - self.assertEqual(1, values[0]['value']) - self.assertEqual(2, values[1]['value']) + self.assertEqual(1, values[0]["value"]) + self.assertEqual(2, values[1]["value"]) def testGetSingleMapMetadataTimestampEpochFalse(self): # set the timezone so we get a consistent return value - os.environ['TZ'] = 'MST' + os.environ["TZ"] = "MST" time.tzset() c = command.Status() - values = c.GetSingleMapMetadata(config.MAP_PASSWORD, - self.conf, - epoch=False) - self.assertEqual('Wed Dec 31 17:00:02 1969', values[1]['value']) + values = c.GetSingleMapMetadata(config.MAP_PASSWORD, self.conf, epoch=False) + self.assertEqual("Wed Dec 31 17:00:02 1969", values[1]["value"]) def testGetAutomountMapMetadata(self): # need to stub out GetSingleMapMetadata (tested above) and then @@ -870,39 +873,38 @@ def testGetAutomountMapMetadata(self): # stub out GetSingleMapMetadata class DummyStatus(command.Status): - - def GetSingleMapMetadata(self, - unused_map_name, - unused_conf, - automount_mountpoint=None, - epoch=False): + def GetSingleMapMetadata( + self, + unused_map_name, + unused_conf, + automount_mountpoint=None, + epoch=False, + ): return { - 'map': 'map_name', - 'last-modify-timestamp': 'foo', - 'last-update-timestamp': 'bar' + "map": "map_name", + "last-modify-timestamp": "foo", + "last-update-timestamp": "bar", } # the master map to loop over master_map = automount.AutomountMap() master_map.Add( - automount.AutomountMapEntry({ - 'key': '/home', - 'location': '/etc/auto.home' - })) + automount.AutomountMapEntry({"key": "/home", "location": "/etc/auto.home"}) + ) master_map.Add( - automount.AutomountMapEntry({ - 'key': '/auto', - 'location': '/etc/auto.auto' - })) + automount.AutomountMapEntry({"key": "/auto", "location": "/etc/auto.auto"}) + ) # mock out a cache to return the master map cache_mock = self.mox.CreateMock(caches.Cache) cache_mock.GetMap().AndReturn(master_map) - self.mox.StubOutWithMock(cache_factory, 'Create') - cache_factory.Create(self.conf.options[config.MAP_AUTOMOUNT].cache, - config.MAP_AUTOMOUNT, - automount_mountpoint=None).AndReturn(cache_mock) + self.mox.StubOutWithMock(cache_factory, "Create") + cache_factory.Create( + self.conf.options[config.MAP_AUTOMOUNT].cache, + config.MAP_AUTOMOUNT, + automount_mountpoint=None, + ).AndReturn(cache_mock) self.mox.ReplayAll() @@ -912,5 +914,5 @@ def GetSingleMapMetadata(self, self.assertEqual(9, len(value_list)) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/nss_cache/config.py b/nss_cache/config.py index d53854c9..89c5b923 100644 --- a/nss_cache/config.py +++ b/nss_cache/config.py @@ -19,33 +19,33 @@ and parsing for the nss_cache module. """ -__author__ = 'vasilios@google.com (Vasilios Hoffman)' +__author__ = "vasilios@google.com (Vasilios Hoffman)" from configparser import ConfigParser import logging import re # known nss map types. -MAP_PASSWORD = 'passwd' -MAP_GROUP = 'group' -MAP_SHADOW = 'shadow' -MAP_NETGROUP = 'netgroup' -MAP_AUTOMOUNT = 'automount' -MAP_SSHKEY = 'sshkey' +MAP_PASSWORD = "passwd" +MAP_GROUP = "group" +MAP_SHADOW = "shadow" +MAP_NETGROUP = "netgroup" +MAP_AUTOMOUNT = "automount" +MAP_SSHKEY = "sshkey" # accepted commands. -CMD_HELP = 'help' -CMD_REPAIR = 'repair' -CMD_STATUS = 'status' -CMD_UPDATE = 'update' -CMD_VERIFY = 'verify' +CMD_HELP = "help" +CMD_REPAIR = "repair" +CMD_STATUS = "status" +CMD_UPDATE = "update" +CMD_VERIFY = "verify" # default file locations -FILE_NSSWITCH = '/etc/nsswitch.conf' +FILE_NSSWITCH = "/etc/nsswitch.conf" # update method types -UPDATER_FILE = 'file' -UPDATER_MAP = 'map' +UPDATER_FILE = "file" +UPDATER_MAP = "map" class Config(object): @@ -63,14 +63,14 @@ class Config(object): """ # default config file. - NSSCACHE_CONFIG = '/etc/nsscache.conf' + NSSCACHE_CONFIG = "/etc/nsscache.conf" # known config file option names - OPT_SOURCE = 'source' - OPT_CACHE = 'cache' - OPT_MAPS = 'maps' - OPT_LOCKFILE = 'lockfile' - OPT_TIMESTAMP_DIR = 'timestamp_dir' + OPT_SOURCE = "source" + OPT_CACHE = "cache" + OPT_MAPS = "maps" + OPT_LOCKFILE = "lockfile" + OPT_TIMESTAMP_DIR = "timestamp_dir" def __init__(self, env): """Initialize defaults for data we hold. @@ -79,8 +79,8 @@ def __init__(self, env): env: dictionary of environment variables (typically os.environ) """ # override constants based on ENV vars - if 'NSSCACHE_CONFIG' in env: - self.config_file = env['NSSCACHE_CONFIG'] + if "NSSCACHE_CONFIG" in env: + self.config_file = env["NSSCACHE_CONFIG"] else: self.config_file = self.NSSCACHE_CONFIG @@ -98,13 +98,19 @@ def __repr__(self): # self.options is of variable length so we are forced to do # some fugly concatenation here to print our config in a # readable fashion. - string = (('' % string + string = "%s\n\t%s=%r" % (string, key, self.options[key]) + return "%s\n>" % string class MapOptions(object): @@ -121,7 +127,7 @@ def __init__(self): def __repr__(self): """String representation of this object.""" - return '' % (self.cache, self.source) + return "" % (self.cache, self.source) # @@ -140,24 +146,25 @@ def LoadConfig(configuration): parser = ConfigParser() # load config file - configuration.log.debug('Attempting to parse configuration file: %s', - configuration.config_file) + configuration.log.debug( + "Attempting to parse configuration file: %s", configuration.config_file + ) parser.read(configuration.config_file) # these are required, and used as defaults for each section - default = 'DEFAULT' + default = "DEFAULT" default_source = FixValue(parser.get(default, Config.OPT_SOURCE)) default_cache = FixValue(parser.get(default, Config.OPT_CACHE)) # this is also required, but global only # TODO(v): make this default to /var/lib/nsscache before next release configuration.timestamp_dir = FixValue( - parser.get(default, Config.OPT_TIMESTAMP_DIR)) + parser.get(default, Config.OPT_TIMESTAMP_DIR) + ) # optional defaults if parser.has_option(default, Config.OPT_LOCKFILE): - configuration.lockfile = FixValue( - parser.get(default, Config.OPT_LOCKFILE)) + configuration.lockfile = FixValue(parser.get(default, Config.OPT_LOCKFILE)) if not configuration.maps: # command line did not override @@ -165,7 +172,7 @@ def LoadConfig(configuration): # special case for empty string, or split(',') will return a # non-empty list if maplist: - configuration.maps = [m.strip() for m in maplist.split(',')] + configuration.maps = [m.strip() for m in maplist.split(",")] else: configuration.maps = [] @@ -196,16 +203,15 @@ def LoadConfig(configuration): map_options.cache.update(options) # used to instantiate the specific cache/source - map_options.source['name'] = source - map_options.cache['name'] = cache + map_options.source["name"] = source + map_options.cache["name"] = cache # save final MapOptions() in the parent config object configuration.options[map_name] = map_options - configuration.log.info('Configured maps are: %s', - ', '.join(configuration.maps)) + configuration.log.info("Configured maps are: %s", ", ".join(configuration.maps)) - configuration.log.debug('loaded configuration: %r', configuration) + configuration.log.debug("loaded configuration: %r", configuration) def Options(items, name): @@ -222,7 +228,7 @@ def Options(items, name): dictionary of option:value pairs """ options = {} - option_re = re.compile(r'^%s_(.+)' % name) + option_re = re.compile(r"^%s_(.+)" % name) for item in items: match = option_re.match(item[0]) if match: @@ -244,8 +250,9 @@ def FixValue(value): fixed value """ # Strip quotes if necessary. - if ((value.startswith('"') and value.endswith('"')) or - (value.startswith('\'') and value.endswith('\''))): + if (value.startswith('"') and value.endswith('"')) or ( + value.startswith("'") and value.endswith("'") + ): value = value[1:-1] # Convert to float if necessary. Python converts between floats and ints @@ -276,11 +283,11 @@ def ParseNSSwitchConf(nsswitch_filename): a dictionary keyed by map names and containing a list of sources for each map. """ - with open(nsswitch_filename, 'r') as nsswitch_file: + with open(nsswitch_filename, "r") as nsswitch_file: nsswitch = {} - map_re = re.compile(r'^([a-z]+): *(.*)$') + map_re = re.compile(r"^([a-z]+): *(.*)$") for line in nsswitch_file: match = map_re.match(line) if match: @@ -306,31 +313,37 @@ def VerifyConfiguration(conf, nsswitch_filename=FILE_NSSWITCH): """ (warnings, errors) = (0, 0) if not conf.maps: - logging.error('No maps are configured.') + logging.error("No maps are configured.") errors += 1 # Verify that at least one supported module is configured in nsswitch.conf. nsswitch = ParseNSSwitchConf(nsswitch_filename) for configured_map in conf.maps: - if configured_map == 'sshkey': + if configured_map == "sshkey": continue - if conf.options[configured_map].cache['name'] == 'nssdb': - logging.error('nsscache no longer supports nssdb cache') + if conf.options[configured_map].cache["name"] == "nssdb": + logging.error("nsscache no longer supports nssdb cache") errors += 1 - if conf.options[configured_map].cache['name'] == 'files': - nss_module_name = 'files' - if ('cache_filename_suffix' in conf.options[configured_map].cache - and - conf.options[configured_map].cache['cache_filename_suffix'] - == 'cache'): + if conf.options[configured_map].cache["name"] == "files": + nss_module_name = "files" + if ( + "cache_filename_suffix" in conf.options[configured_map].cache + and conf.options[configured_map].cache["cache_filename_suffix"] + == "cache" + ): # We are configured for libnss-cache for this map. - nss_module_name = 'cache' + nss_module_name = "cache" else: - nss_module_name = 'cache' + nss_module_name = "cache" if nss_module_name not in nsswitch[configured_map]: - logging.warning(('nsscache is configured to build maps for %r, ' - 'but NSS is not configured (in %r) to use it'), - configured_map, nsswitch_filename) + logging.warning( + ( + "nsscache is configured to build maps for %r, " + "but NSS is not configured (in %r) to use it" + ), + configured_map, + nsswitch_filename, + ) warnings += 1 return (warnings, errors) diff --git a/nss_cache/config_test.py b/nss_cache/config_test.py index 7d4f2e02..2b76013d 100644 --- a/nss_cache/config_test.py +++ b/nss_cache/config_test.py @@ -15,7 +15,7 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Unit tests for nss_cache/config.py.""" -__author__ = 'vasilios@google.com (Vasilios Hoffman)' +__author__ = "vasilios@google.com (Vasilios Hoffman)" import os import shutil @@ -29,12 +29,14 @@ class TestConfig(unittest.TestCase): """Unit tests for config.Config().""" def testConfigInit(self): - env = {'NSSCACHE_CONFIG': 'test.conf'} + env = {"NSSCACHE_CONFIG": "test.conf"} conf = config.Config(env) - self.assertEqual(conf.config_file, - env['NSSCACHE_CONFIG'], - msg='Failed to override NSSCACHE_CONFIG.') + self.assertEqual( + conf.config_file, + env["NSSCACHE_CONFIG"], + msg="Failed to override NSSCACHE_CONFIG.", + ) class TestMapOptions(unittest.TestCase): @@ -53,12 +55,10 @@ def setUp(self): # create a directory with a writeable copy of nsscache.conf in it self.workdir = tempfile.mkdtemp() # nsscache.conf is in the parent dir of this test. - self.srcdir = os.path.normpath( - os.path.join(os.path.dirname(__file__), '..')) - conf_filename = 'nsscache.conf' + self.srcdir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..")) + conf_filename = "nsscache.conf" self.conf_filename = os.path.join(self.workdir, conf_filename) - shutil.copy(os.path.join(self.srcdir, conf_filename), - self.conf_filename) + shutil.copy(os.path.join(self.srcdir, conf_filename), self.conf_filename) os.chmod(self.conf_filename, 0o640) # prepare a config object with this config @@ -69,43 +69,49 @@ def tearDown(self): shutil.rmtree(self.workdir) def testLoadConfigSingleMap(self): - conf_file = open(self.conf_filename, 'w') - conf_file.write('[DEFAULT]\n' - 'source = foo\n' - 'cache = foo\n' - 'maps = foo\n' - 'timestamp_dir = foo\n') + conf_file = open(self.conf_filename, "w") + conf_file.write( + "[DEFAULT]\n" + "source = foo\n" + "cache = foo\n" + "maps = foo\n" + "timestamp_dir = foo\n" + ) conf_file.close() config.LoadConfig(self.conf) - self.assertEqual(['foo'], self.conf.maps) + self.assertEqual(["foo"], self.conf.maps) def testLoadConfigTwoMaps(self): - conf_file = open(self.conf_filename, 'w') - conf_file.write('[DEFAULT]\n' - 'source = foo\n' - 'cache = foo\n' - 'maps = foo, bar\n' - 'timestamp_dir = foo\n') + conf_file = open(self.conf_filename, "w") + conf_file.write( + "[DEFAULT]\n" + "source = foo\n" + "cache = foo\n" + "maps = foo, bar\n" + "timestamp_dir = foo\n" + ) conf_file.close() config.LoadConfig(self.conf) - self.assertEqual(['foo', 'bar'], self.conf.maps) + self.assertEqual(["foo", "bar"], self.conf.maps) def testLoadConfigMapsWhitespace(self): - conf_file = open(self.conf_filename, 'w') - conf_file.write('[DEFAULT]\n' - 'source = foo\n' - 'cache = foo\n' - 'maps = foo, bar , baz\n' - 'timestamp_dir = foo\n') + conf_file = open(self.conf_filename, "w") + conf_file.write( + "[DEFAULT]\n" + "source = foo\n" + "cache = foo\n" + "maps = foo, bar , baz\n" + "timestamp_dir = foo\n" + ) conf_file.close() config.LoadConfig(self.conf) - self.assertEqual(['foo', 'bar', 'baz'], self.conf.maps) + self.assertEqual(["foo", "bar", "baz"], self.conf.maps) def testLoadConfigExample(self): """Test that we parse and load the example config. @@ -118,243 +124,265 @@ def testLoadConfigExample(self): """ conf = self.conf config.LoadConfig(conf) - passwd = conf.options['passwd'] - group = conf.options['group'] - shadow = conf.options['shadow'] - automount = conf.options['automount'] + passwd = conf.options["passwd"] + group = conf.options["group"] + shadow = conf.options["shadow"] + automount = conf.options["automount"] self.assertTrue(isinstance(passwd, config.MapOptions)) self.assertTrue(isinstance(group, config.MapOptions)) self.assertTrue(isinstance(shadow, config.MapOptions)) self.assertTrue(isinstance(automount, config.MapOptions)) - self.assertEqual(passwd.source['name'], 'ldap') - self.assertEqual(group.source['name'], 'ldap') - self.assertEqual(shadow.source['name'], 'ldap') - self.assertEqual(automount.source['name'], 'ldap') + self.assertEqual(passwd.source["name"], "ldap") + self.assertEqual(group.source["name"], "ldap") + self.assertEqual(shadow.source["name"], "ldap") + self.assertEqual(automount.source["name"], "ldap") - self.assertEqual(passwd.cache['name'], 'files') - self.assertEqual(group.cache['name'], 'files') - self.assertEqual(shadow.cache['name'], 'files') - self.assertEqual(automount.cache['name'], 'files') + self.assertEqual(passwd.cache["name"], "files") + self.assertEqual(group.cache["name"], "files") + self.assertEqual(shadow.cache["name"], "files") + self.assertEqual(automount.cache["name"], "files") - self.assertEqual(passwd.source['base'], 'ou=people,dc=example,dc=com') - self.assertEqual(passwd.source['filter'], '(objectclass=posixAccount)') + self.assertEqual(passwd.source["base"], "ou=people,dc=example,dc=com") + self.assertEqual(passwd.source["filter"], "(objectclass=posixAccount)") - self.assertEqual(group.source['base'], 'ou=group,dc=example,dc=com') - self.assertEqual(group.source['filter'], '(objectclass=posixGroup)') + self.assertEqual(group.source["base"], "ou=group,dc=example,dc=com") + self.assertEqual(group.source["filter"], "(objectclass=posixGroup)") def testLoadConfigOptionalDefaults(self): - conf_file = open(self.conf_filename, 'w') - conf_file.write('[DEFAULT]\n' - 'source = foo\n' - 'cache = foo\n' - 'maps = foo, bar , baz\n' - 'lockfile = foo\n' - 'timestamp_dir = foo\n') + conf_file = open(self.conf_filename, "w") + conf_file.write( + "[DEFAULT]\n" + "source = foo\n" + "cache = foo\n" + "maps = foo, bar , baz\n" + "lockfile = foo\n" + "timestamp_dir = foo\n" + ) conf_file.close() config.LoadConfig(self.conf) - self.assertEqual(self.conf.lockfile, 'foo') + self.assertEqual(self.conf.lockfile, "foo") def testLoadConfigStripQuotesFromStrings(self): - conf_file = open(self.conf_filename, 'w') - conf_file.write('[DEFAULT]\n' - 'source = "ldap"\n' # needs to be ldap due to magic - 'cache = \'b\'ar\'\n' - 'maps = quux\n' - 'timestamp_dir = foo\n' - 'ldap_tls_require_cert = \'blah\'\n' - '[quux]\n' - 'ldap_klingon = "qep\'a\' wa\'maH loS\'DIch"\n') + conf_file = open(self.conf_filename, "w") + conf_file.write( + "[DEFAULT]\n" + 'source = "ldap"\n' # needs to be ldap due to magic + "cache = 'b'ar'\n" + "maps = quux\n" + "timestamp_dir = foo\n" + "ldap_tls_require_cert = 'blah'\n" + "[quux]\n" + "ldap_klingon = \"qep'a' wa'maH loS'DIch\"\n" + ) conf_file.close() config.LoadConfig(self.conf) - self.assertEqual('ldap', self.conf.options['quux'].source['name']) - self.assertEqual('b\'ar', self.conf.options['quux'].cache['name']) - self.assertEqual('blah', - self.conf.options['quux'].source['tls_require_cert']) - self.assertEqual('qep\'a\' wa\'maH loS\'DIch', - self.conf.options['quux'].source['klingon']) + self.assertEqual("ldap", self.conf.options["quux"].source["name"]) + self.assertEqual("b'ar", self.conf.options["quux"].cache["name"]) + self.assertEqual("blah", self.conf.options["quux"].source["tls_require_cert"]) + self.assertEqual( + "qep'a' wa'maH loS'DIch", self.conf.options["quux"].source["klingon"] + ) def testLoadConfigConvertsNumbers(self): - conf_file = open(self.conf_filename, 'w') - conf_file.write('[DEFAULT]\n' - 'source = foo\n' - 'cache = foo\n' - 'maps = foo\n' - 'timestamp_dir = foo\n' - 'foo_string = test\n' - 'foo_float = 1.23\n' - 'foo_int = 1\n') + conf_file = open(self.conf_filename, "w") + conf_file.write( + "[DEFAULT]\n" + "source = foo\n" + "cache = foo\n" + "maps = foo\n" + "timestamp_dir = foo\n" + "foo_string = test\n" + "foo_float = 1.23\n" + "foo_int = 1\n" + ) conf_file.close() config.LoadConfig(self.conf) - foo_dict = self.conf.options['foo'].source - self.assertTrue(isinstance(foo_dict['string'], str)) - self.assertTrue(isinstance(foo_dict['float'], float)) - self.assertTrue(isinstance(foo_dict['int'], int)) - self.assertEqual(foo_dict['string'], 'test') - self.assertEqual(foo_dict['float'], 1.23) - self.assertEqual(foo_dict['int'], 1) + foo_dict = self.conf.options["foo"].source + self.assertTrue(isinstance(foo_dict["string"], str)) + self.assertTrue(isinstance(foo_dict["float"], float)) + self.assertTrue(isinstance(foo_dict["int"], int)) + self.assertEqual(foo_dict["string"], "test") + self.assertEqual(foo_dict["float"], 1.23) + self.assertEqual(foo_dict["int"], 1) def testOptions(self): # check the empty case. - options = config.Options([], 'foo') + options = config.Options([], "foo") self.assertEqual(options, {}) # create a list like from ConfigParser.items() - items = [('maps', 'foo, bar, foobar'), ('ldap_uri', 'TEST_URI'), - ('source', 'foo'), ('cache', 'bar'), - ('ldap_base', 'TEST_BASE'), ('ldap_filter', 'TEST_FILTER')] + items = [ + ("maps", "foo, bar, foobar"), + ("ldap_uri", "TEST_URI"), + ("source", "foo"), + ("cache", "bar"), + ("ldap_base", "TEST_BASE"), + ("ldap_filter", "TEST_FILTER"), + ] - options = config.Options(items, 'ldap') + options = config.Options(items, "ldap") - self.assertTrue('uri' in options) - self.assertTrue('base' in options) - self.assertTrue('filter' in options) + self.assertTrue("uri" in options) + self.assertTrue("base" in options) + self.assertTrue("filter" in options) - self.assertEqual(options['uri'], 'TEST_URI') - self.assertEqual(options['base'], 'TEST_BASE') - self.assertEqual(options['filter'], 'TEST_FILTER') + self.assertEqual(options["uri"], "TEST_URI") + self.assertEqual(options["base"], "TEST_BASE") + self.assertEqual(options["filter"], "TEST_FILTER") def testParseNSSwitchConf(self): - nsswitch_filename = os.path.join(self.workdir, 'nsswitch.conf') - nsswitch_file = open(nsswitch_filename, 'w') - nsswitch_file.write('passwd: files cache\n') - nsswitch_file.write('group: files cache\n') - nsswitch_file.write('shadow: files cache\n') + nsswitch_filename = os.path.join(self.workdir, "nsswitch.conf") + nsswitch_file = open(nsswitch_filename, "w") + nsswitch_file.write("passwd: files cache\n") + nsswitch_file.write("group: files cache\n") + nsswitch_file.write("shadow: files cache\n") nsswitch_file.close() expected_switch = { - 'passwd': ['files', 'cache'], - 'group': ['files', 'cache'], - 'shadow': ['files', 'cache'] + "passwd": ["files", "cache"], + "group": ["files", "cache"], + "shadow": ["files", "cache"], } - self.assertEqual(expected_switch, - config.ParseNSSwitchConf(nsswitch_filename)) + self.assertEqual(expected_switch, config.ParseNSSwitchConf(nsswitch_filename)) os.unlink(nsswitch_filename) def testVerifyConfiguration(self): - conf_file = open(self.conf_filename, 'w') - conf_file.write('[DEFAULT]\n' - 'source = foo\n' - 'cache = foo\n' - 'maps = passwd, group, shadow\n' - 'timestamp_dir = foo\n') + conf_file = open(self.conf_filename, "w") + conf_file.write( + "[DEFAULT]\n" + "source = foo\n" + "cache = foo\n" + "maps = passwd, group, shadow\n" + "timestamp_dir = foo\n" + ) conf_file.close() config.LoadConfig(self.conf) - nsswitch_filename = os.path.join(self.workdir, 'nsswitch.conf') - nsswitch_file = open(nsswitch_filename, 'w') - nsswitch_file.write('passwd: files cache\n') - nsswitch_file.write('group: files cache\n') - nsswitch_file.write('shadow: files cache\n') + nsswitch_filename = os.path.join(self.workdir, "nsswitch.conf") + nsswitch_file = open(nsswitch_filename, "w") + nsswitch_file.write("passwd: files cache\n") + nsswitch_file.write("group: files cache\n") + nsswitch_file.write("shadow: files cache\n") nsswitch_file.close() - self.assertEqual((0, 0), - config.VerifyConfiguration(self.conf, - nsswitch_filename)) + self.assertEqual( + (0, 0), config.VerifyConfiguration(self.conf, nsswitch_filename) + ) os.unlink(nsswitch_filename) def testVerifyConfigurationWithCache(self): - conf_file = open(self.conf_filename, 'w') - conf_file.write('[DEFAULT]\n' - 'source = foo\n' - 'cache = files\n' - 'maps = passwd, group, shadow\n' - 'timestamp_dir = foo\n' - 'files_cache_filename_suffix = cache') + conf_file = open(self.conf_filename, "w") + conf_file.write( + "[DEFAULT]\n" + "source = foo\n" + "cache = files\n" + "maps = passwd, group, shadow\n" + "timestamp_dir = foo\n" + "files_cache_filename_suffix = cache" + ) conf_file.close() config.LoadConfig(self.conf) - nsswitch_filename = os.path.join(self.workdir, 'nsswitch.conf') - nsswitch_file = open(nsswitch_filename, 'w') - nsswitch_file.write('passwd: cache\n') - nsswitch_file.write('group: cache\n') - nsswitch_file.write('shadow: cache\n') + nsswitch_filename = os.path.join(self.workdir, "nsswitch.conf") + nsswitch_file = open(nsswitch_filename, "w") + nsswitch_file.write("passwd: cache\n") + nsswitch_file.write("group: cache\n") + nsswitch_file.write("shadow: cache\n") nsswitch_file.close() - self.assertEqual((0, 0), - config.VerifyConfiguration(self.conf, - nsswitch_filename)) + self.assertEqual( + (0, 0), config.VerifyConfiguration(self.conf, nsswitch_filename) + ) os.unlink(nsswitch_filename) def testVerifyConfigurationWithFiles(self): - conf_file = open(self.conf_filename, 'w') - conf_file.write('[DEFAULT]\n' - 'source = foo\n' - 'cache = files\n' - 'maps = passwd, group, shadow\n' - 'timestamp_dir = foo\n') + conf_file = open(self.conf_filename, "w") + conf_file.write( + "[DEFAULT]\n" + "source = foo\n" + "cache = files\n" + "maps = passwd, group, shadow\n" + "timestamp_dir = foo\n" + ) conf_file.close() config.LoadConfig(self.conf) - nsswitch_filename = os.path.join(self.workdir, 'nsswitch.conf') - nsswitch_file = open(nsswitch_filename, 'w') - nsswitch_file.write('passwd: files\n') - nsswitch_file.write('group: files\n') - nsswitch_file.write('shadow: files\n') + nsswitch_filename = os.path.join(self.workdir, "nsswitch.conf") + nsswitch_file = open(nsswitch_filename, "w") + nsswitch_file.write("passwd: files\n") + nsswitch_file.write("group: files\n") + nsswitch_file.write("shadow: files\n") nsswitch_file.close() - self.assertEqual((0, 0), - config.VerifyConfiguration(self.conf, - nsswitch_filename)) + self.assertEqual( + (0, 0), config.VerifyConfiguration(self.conf, nsswitch_filename) + ) os.unlink(nsswitch_filename) def testVerifyBadConfigurationWithCache(self): - conf_file = open(self.conf_filename, 'w') - conf_file.write('[DEFAULT]\n' - 'source = foo\n' - 'cache = files\n' - 'maps = passwd, group, shadow\n' - 'timestamp_dir = foo\n' - 'files_cache_filename_suffix = cache') + conf_file = open(self.conf_filename, "w") + conf_file.write( + "[DEFAULT]\n" + "source = foo\n" + "cache = files\n" + "maps = passwd, group, shadow\n" + "timestamp_dir = foo\n" + "files_cache_filename_suffix = cache" + ) conf_file.close() config.LoadConfig(self.conf) - nsswitch_filename = os.path.join(self.workdir, 'nsswitch.conf') - nsswitch_file = open(nsswitch_filename, 'w') - nsswitch_file.write('passwd: files\n') - nsswitch_file.write('group: files\n') - nsswitch_file.write('shadow: files\n') + nsswitch_filename = os.path.join(self.workdir, "nsswitch.conf") + nsswitch_file = open(nsswitch_filename, "w") + nsswitch_file.write("passwd: files\n") + nsswitch_file.write("group: files\n") + nsswitch_file.write("shadow: files\n") nsswitch_file.close() - self.assertEqual((3, 0), - config.VerifyConfiguration(self.conf, - nsswitch_filename)) + self.assertEqual( + (3, 0), config.VerifyConfiguration(self.conf, nsswitch_filename) + ) os.unlink(nsswitch_filename) def testVerifyBadConfigurationIncrementsWarningCount(self): - conf_file = open(self.conf_filename, 'w') - conf_file.write('[DEFAULT]\n' - 'source = foo\n' - 'cache = foo\n' - 'maps = passwd, group, shadow\n' - 'timestamp_dir = foo\n') + conf_file = open(self.conf_filename, "w") + conf_file.write( + "[DEFAULT]\n" + "source = foo\n" + "cache = foo\n" + "maps = passwd, group, shadow\n" + "timestamp_dir = foo\n" + ) conf_file.close() config.LoadConfig(self.conf) - nsswitch_filename = os.path.join(self.workdir, 'nsswitch.conf') - nsswitch_file = open(nsswitch_filename, 'w') - nsswitch_file.write('passwd: files ldap\n') - nsswitch_file.write('group: files cache\n') - nsswitch_file.write('shadow: files cache\n') + nsswitch_filename = os.path.join(self.workdir, "nsswitch.conf") + nsswitch_file = open(nsswitch_filename, "w") + nsswitch_file.write("passwd: files ldap\n") + nsswitch_file.write("group: files cache\n") + nsswitch_file.write("shadow: files cache\n") nsswitch_file.close() - self.assertEqual((1, 0), - config.VerifyConfiguration(self.conf, - nsswitch_filename)) + self.assertEqual( + (1, 0), config.VerifyConfiguration(self.conf, nsswitch_filename) + ) os.unlink(nsswitch_filename) def testVerifyNoMapConfigurationIsError(self): - conf_file = open(self.conf_filename, 'w') - conf_file.write('[DEFAULT]\n' - 'source = foo\n' - 'cache = foo\n' - 'maps = \n' - 'timestamp_dir = foo\n') + conf_file = open(self.conf_filename, "w") + conf_file.write( + "[DEFAULT]\n" + "source = foo\n" + "cache = foo\n" + "maps = \n" + "timestamp_dir = foo\n" + ) conf_file.close() config.LoadConfig(self.conf) - nsswitch_filename = os.path.join(self.workdir, 'nsswitch.conf') - nsswitch_file = open(nsswitch_filename, 'w') - nsswitch_file.write('passwd: files ldap\n') + nsswitch_filename = os.path.join(self.workdir, "nsswitch.conf") + nsswitch_file = open(nsswitch_filename, "w") + nsswitch_file.write("passwd: files ldap\n") nsswitch_file.close() - self.assertEqual((0, 1), - config.VerifyConfiguration(self.conf, - nsswitch_filename)) + self.assertEqual( + (0, 1), config.VerifyConfiguration(self.conf, nsswitch_filename) + ) os.unlink(nsswitch_filename) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/nss_cache/error.py b/nss_cache/error.py index 05c3043d..5e95a19e 100644 --- a/nss_cache/error.py +++ b/nss_cache/error.py @@ -15,61 +15,72 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Exception classes for nss_cache module.""" -__author__ = 'vasilios@google.com (Vasilios Hoffman)' +__author__ = "vasilios@google.com (Vasilios Hoffman)" class Error(Exception): """Base exception class for nss_cache.""" + pass class CacheNotFound(Error): """Raised when a local cache is missing.""" + pass class CacheInvalid(Error): """Raised when a cache is invalid.""" + pass class CommandParseError(Error): """Raised when the command line fails to parse correctly.""" + pass class ConfigurationError(Error): """Raised when there is a problem with configuration values.""" + pass class EmptyMap(Error): """Raised when an empty map is discovered and one is not expected.""" + pass class NoConfigFound(Error): """Raised when no configuration file is loaded.""" + pass class PermissionDenied(Error): """Raised when nss_cache cannot access a resource.""" + pass class UnsupportedMap(Error): """Raised when trying to use an unsupported map type.""" + pass class InvalidMap(Error): """Raised when an invalid map is encountered.""" + pass class SourceUnavailable(Error): """Raised when a source is unavailable.""" + pass diff --git a/nss_cache/error_test.py b/nss_cache/error_test.py index e27ec03a..9d878ab3 100644 --- a/nss_cache/error_test.py +++ b/nss_cache/error_test.py @@ -15,7 +15,7 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Unit tests for nss_cache/error.py.""" -__author__ = 'vasilios@google.com (Vasilios Hoffman)' +__author__ = "vasilios@google.com (Vasilios Hoffman)" import unittest @@ -119,5 +119,5 @@ def __init__(self): self.assertRaises(error.SourceUnavailable, Ooops) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/nss_cache/lock.py b/nss_cache/lock.py index 02b44c88..d766ca04 100644 --- a/nss_cache/lock.py +++ b/nss_cache/lock.py @@ -15,7 +15,7 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Lock management for nss_cache module.""" -__author__ = 'vasilios@google.com (Vasilios Hoffman)' +__author__ = "vasilios@google.com (Vasilios Hoffman)" import errno import fcntl @@ -56,9 +56,9 @@ class PidFile(object): be configured to work, but your mileage can and will vary. """ - STATE_DIR = '/var/run' - PROC_DIR = '/proc' - PROG_NAME = 'nsscache' + STATE_DIR = "/var/run" + PROC_DIR = "/proc" + PROG_NAME = "nsscache" def __init__(self, filename=None, pid=None): """Initialize the PidFile object.""" @@ -81,11 +81,11 @@ def __init__(self, filename=None, pid=None): # We were invoked from a python interpreter with # bad arguments, or otherwise loaded without sys.argv # being set. - self.log.critical('Can not determine lock file name!') - raise TypeError('missing required argument: filename') - self.filename = '%s/%s' % (self.STATE_DIR, basename) + self.log.critical("Can not determine lock file name!") + raise TypeError("missing required argument: filename") + self.filename = "%s/%s" % (self.STATE_DIR, basename) - self.log.debug('using %s for lock file', self.filename) + self.log.debug("using %s for lock file", self.filename) def __del__(self): """Release our pid file on object destruction.""" @@ -101,12 +101,11 @@ def _Open(self, filename=None): # will truncate, so we use 'a+' and seek. We don't truncate # the file because we haven't tested if it is locked by # another program yet, this is done later by fcntl module. - self._file = open(filename, 'a+') + self._file = open(filename, "a+") self._file.seek(0) # Set permissions. - os.chmod(filename, - stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IROTH) + os.chmod(filename, stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IROTH) def Lock(self, force=False): """Open our pid file and lock it. @@ -123,8 +122,9 @@ def Lock(self, force=False): self._Open() except IOError as e: if e.errno == errno.EACCES: - self.log.warning('Permission denied opening lock file: %s', - self.filename) + self.log.warning( + "Permission denied opening lock file: %s", self.filename + ) return False raise @@ -137,7 +137,7 @@ def Lock(self, force=False): if e.errno in [errno.EACCES, errno.EAGAIN]: # Catch the error raised when the file is locked. if not force: - self.log.debug('%s already locked!', self.filename) + self.log.debug("%s already locked!", self.filename) return False else: # Otherwise re-raise it. @@ -145,7 +145,7 @@ def Lock(self, force=False): # Check if we need to forcibly re-try the lock. if not return_val and force: - self.log.debug('retrying lock.') + self.log.debug("retrying lock.") # Try to kill the process with the lock. self.SendTerm() # Clear the lock. @@ -155,10 +155,10 @@ def Lock(self, force=False): # Store the pid. self._file.truncate() - self._file.write('%s\n' % self.pid) + self._file.write("%s\n" % self.pid) self._file.flush() - self.log.debug('successfully locked %s', self.filename) + self.log.debug("successfully locked %s", self.filename) self._locked = True return return_val @@ -175,11 +175,11 @@ def SendTerm(self): pid = int(pid_content.strip()) except (AttributeError, ValueError) as e: self.log.warning( - 'Not sending TERM, could not parse pid file content: %r', - pid_content) + "Not sending TERM, could not parse pid file content: %r", pid_content + ) return - self.log.debug('retrieved pid %d' % pid) + self.log.debug("retrieved pid %d" % pid) # Reset the filehandle just in case. self._file.seek(0) @@ -187,12 +187,12 @@ def SendTerm(self): # By reading cmdline out of /proc we establish: # a) if a process with that pid exists. # b) what the command line is, to see if it included 'nsscache'. - proc_path = '%s/%i/cmdline' % (self.PROC_DIR, pid) + proc_path = "%s/%i/cmdline" % (self.PROC_DIR, pid) try: - proc_file = open(proc_path, 'r') + proc_file = open(proc_path, "r") except IOError as e: if e.errno == errno.ENOENT: - self.log.debug('process does not exist, skipping signal.') + self.log.debug("process does not exist, skipping signal.") return raise @@ -200,14 +200,15 @@ def SendTerm(self): proc_file.close() # See if it matches our program name regex. - cmd_re = re.compile(r'.*%s' % self.PROG_NAME) + cmd_re = re.compile(r".*%s" % self.PROG_NAME) if not cmd_re.match(cmdline): - self.log.debug('process is running but not %s, skipping signal', - self.PROG_NAME) + self.log.debug( + "process is running but not %s, skipping signal", self.PROG_NAME + ) return # Send a SIGTERM. - self.log.debug('sending SIGTERM to %i', pid) + self.log.debug("sending SIGTERM to %i", pid) os.kill(pid, signal.SIGTERM) # We are not paranoid about success, so we're done! @@ -215,7 +216,7 @@ def SendTerm(self): def ClearLock(self): """Delete the pid file to remove any locks on it.""" - self.log.debug('clearing old pid file: %s', self.filename) + self.log.debug("clearing old pid file: %s", self.filename) self._file.close() self._file = None os.remove(self.filename) diff --git a/nss_cache/lock_test.py b/nss_cache/lock_test.py index c1585f90..1c4e4cf0 100644 --- a/nss_cache/lock_test.py +++ b/nss_cache/lock_test.py @@ -15,7 +15,7 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Unit tests for nss_cache/lock.py.""" -__author__ = 'vasilios@google.com (Vasilios Hoffman)' +__author__ = "vasilios@google.com (Vasilios Hoffman)" import builtins import errno @@ -48,7 +48,7 @@ class TestPidFile(unittest.TestCase): def setUp(self): super(TestPidFile, self).setUp() self.workdir = tempfile.mkdtemp() - self.filename = '%s/%s' % (self.workdir, 'pidfile') + self.filename = "%s/%s" % (self.workdir, "pidfile") def tearDown(self): shutil.rmtree(self.workdir) @@ -59,7 +59,7 @@ def testInit(self): pid = os.getpid() filename = os.path.basename(sys.argv[0]) - filename = '%s/%s' % (locker.STATE_DIR, filename) + filename = "%s/%s" % (locker.STATE_DIR, filename) self.assertTrue(isinstance(locker, lock.PidFile)) self.assertEqual(locker.pid, pid) @@ -69,12 +69,12 @@ def testInit(self): # also check the case where argv[0] is empty (interactively loaded) full_path = sys.argv[0] - sys.argv[0] = '' + sys.argv[0] = "" self.assertRaises(TypeError, lock.PidFile) sys.argv[0] = full_path def testHandleArgumentsProperly(self): - filename = 'TEST' + filename = "TEST" pid = 10 locker = lock.PidFile(filename=filename, pid=pid) self.assertEqual(locker.filename, filename) @@ -82,16 +82,18 @@ def testHandleArgumentsProperly(self): def testDestructorUnlocks(self): yes = lock.PidFile() - with mock.patch.object(yes, 'Locked') as locked, mock.patch.object( - yes, 'Unlock') as unlock: + with mock.patch.object(yes, "Locked") as locked, mock.patch.object( + yes, "Unlock" + ) as unlock: locked.return_value = True yes.__del__() # Destructor should unlock unlock.assert_called_once() no = lock.PidFile() - with mock.patch.object(no, 'Locked') as locked, mock.patch.object( - yes, 'Unlock') as unlock: + with mock.patch.object(no, "Locked") as locked, mock.patch.object( + yes, "Unlock" + ) as unlock: locked.return_value = False no.__del__() # No unlock needed if already not locked. @@ -104,36 +106,37 @@ def testOpenCreatesAppropriateFileWithPerms(self): self.assertTrue(os.path.exists(self.filename)) file_mode = os.stat(self.filename)[stat.ST_MODE] - correct_mode = (stat.S_IFREG | stat.S_IRUSR | stat.S_IWUSR | - stat.S_IRGRP | stat.S_IROTH) + correct_mode = ( + stat.S_IFREG | stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IROTH + ) self.assertEqual(file_mode, correct_mode) os.remove(self.filename) def testLockCreatesPidfiles(self): locker = lock.PidFile() - with mock.patch.object(locker, '_Open') as open: + with mock.patch.object(locker, "_Open") as open: open.side_effect = NotImplementedError() self.assertRaises(NotImplementedError, locker.Lock) # Note that testing when self._file is not None is covered below. - @mock.patch('fcntl.lockf') + @mock.patch("fcntl.lockf") def testLockLocksWithFcntl(self, lockf): - locker = lock.PidFile(pid='PID') + locker = lock.PidFile(pid="PID") - with mock.patch.object(locker, '_file') as f: + with mock.patch.object(locker, "_file") as f: locker.Lock() self.assertTrue(locker._locked) lockf.assert_called_once_with(f, fcntl.LOCK_EX | fcntl.LOCK_NB) def testLockStoresPid(self): - locker = lock.PidFile(filename=self.filename, pid='PID') + locker = lock.PidFile(filename=self.filename, pid="PID") locker.Lock() - pid_file = open(self.filename, 'r') + pid_file = open(self.filename, "r") - self.assertEqual(pid_file.read(), 'PID\n') + self.assertEqual(pid_file.read(), "PID\n") pid_file.close() @@ -141,50 +144,52 @@ def testLockStoresPid(self): def testLockTrapsPermissionDeniedOnly(self): locker = lock.PidFile() - with mock.patch.object(locker, '_Open') as open: - open.side_effect = [ - IOError(errno.EACCES, ''), - IOError(errno.EIO, '') - ] + with mock.patch.object(locker, "_Open") as open: + open.side_effect = [IOError(errno.EACCES, ""), IOError(errno.EIO, "")] self.assertEqual(False, locker.Lock()) self.assertRaises(IOError, locker.Lock) def testForceLockTerminatesAndClearsLock(self): - locker = lock.PidFile(pid='PID') - with mock.patch.object(locker, 'SendTerm'), mock.patch.object( - locker, 'ClearLock'), mock.patch.object(locker, '_file') as f: - with mock.patch('fcntl.lockf') as lockf: + locker = lock.PidFile(pid="PID") + with mock.patch.object(locker, "SendTerm"), mock.patch.object( + locker, "ClearLock" + ), mock.patch.object(locker, "_file") as f: + with mock.patch("fcntl.lockf") as lockf: # This is a little weird due to recursion. # The first time through lockf throws an error and we retry the lock. # The 2nd time through we should fail, because lockf will still throw # an error, so we expect False back and the above mock objects # invoked. lockf.side_effect = [ - IOError(errno.EAGAIN, ''), - IOError(errno.EAGAIN, '') + IOError(errno.EAGAIN, ""), + IOError(errno.EAGAIN, ""), ] self.assertFalse(locker.Lock(force=True)) lockf.assert_has_calls( - (mock.call(locker._file, fcntl.LOCK_EX | fcntl.LOCK_NB), - mock.call(locker._file, fcntl.LOCK_EX | fcntl.LOCK_NB))) + ( + mock.call(locker._file, fcntl.LOCK_EX | fcntl.LOCK_NB), + mock.call(locker._file, fcntl.LOCK_EX | fcntl.LOCK_NB), + ) + ) def testSendTermMatchesCommandAndSendsTerm(self): locker = lock.PidFile() # Program mocks mock_re = mock.create_autospec(re.Pattern) mock_re.match.return_value = True - with mock.patch('re.compile') as regexp, mock.patch( - 'os.kill') as kill, mock.patch.object(locker, '_file') as f: - f.read.return_value = '1234' + with mock.patch("re.compile") as regexp, mock.patch( + "os.kill" + ) as kill, mock.patch.object(locker, "_file") as f: + f.read.return_value = "1234" regexp.return_value = mock_re # Create a file we open() in SendTerm(). - proc_dir = '%s/1234' % self.workdir - proc_filename = '%s/cmdline' % proc_dir + proc_dir = "%s/1234" % self.workdir + proc_filename = "%s/cmdline" % proc_dir os.mkdir(proc_dir) - proc_file = open(proc_filename, 'w') - proc_file.write('TEST') + proc_file = open(proc_filename, "w") + proc_file.write("TEST") proc_file.flush() proc_file.close() locker.PROC_DIR = self.workdir @@ -193,7 +198,7 @@ def testSendTermMatchesCommandAndSendsTerm(self): locker.SendTerm() # Assert the mocks - regexp.assert_called_with(r'.*nsscache') + regexp.assert_called_with(r".*nsscache") kill.assert_called_once_with(1234, signal.SIGTERM) f.read.assert_called() f.seek.assert_called_with(0) @@ -202,9 +207,8 @@ def testSendTermMatchesCommandAndSendsTerm(self): def testSendTermNoPid(self): locker = lock.PidFile() - with mock.patch.object(locker, - '_file') as f, mock.patch('os.kill') as kill: - f.read.return_value = '\n' + with mock.patch.object(locker, "_file") as f, mock.patch("os.kill") as kill: + f.read.return_value = "\n" locker.PROC = self.workdir locker.SendTerm() f.read.assert_called() @@ -212,8 +216,7 @@ def testSendTermNoPid(self): def testSendTermNonePid(self): locker = lock.PidFile() - with mock.patch.object(locker, - '_file') as f, mock.patch('os.kill') as kill: + with mock.patch.object(locker, "_file") as f, mock.patch("os.kill") as kill: f.read.return_value = None locker.PROC = self.workdir locker.SendTerm() @@ -222,12 +225,13 @@ def testSendTermNonePid(self): def testSendTermTrapsENOENT(self): locker = lock.PidFile() - with mock.patch.object(locker, '_file') as f, mock.patch( - 'os.kill') as kill, mock.patch('builtins.open') as mock_open: - f.read.return_value = '1234\n' - mock_open.side_effect = IOError(errno.ENOENT, '') + with mock.patch.object(locker, "_file") as f, mock.patch( + "os.kill" + ) as kill, mock.patch("builtins.open") as mock_open: + f.read.return_value = "1234\n" + mock_open.side_effect = IOError(errno.ENOENT, "") # self.workdir/1234/cmdline should not exist :) - self.assertFalse(os.path.exists('%s/1234/cmdline' % self.workdir)) + self.assertFalse(os.path.exists("%s/1234/cmdline" % self.workdir)) locker.PROC = self.workdir locker.SendTerm() f.read.assert_called() @@ -235,8 +239,8 @@ def testSendTermTrapsENOENT(self): def testClearLockRemovesPidFile(self): # Create a pid file. - pidfile = open(self.filename, 'w') - pidfile.write('foo') + pidfile = open(self.filename, "w") + pidfile.write("foo") pidfile.flush() locker = lock.PidFile(filename=self.filename) @@ -259,12 +263,12 @@ def testLockedPredicate(self): def testUnlockReleasesFcntlLock(self): locker = lock.PidFile() - locker._file = 'FILE_OBJECT' - with mock.patch('fcntl.lockf') as lockf: + locker._file = "FILE_OBJECT" + with mock.patch("fcntl.lockf") as lockf: locker.Unlock() self.assertFalse(locker._locked) - lockf.assert_called_once_with('FILE_OBJECT', fcntl.LOCK_UN) + lockf.assert_called_once_with("FILE_OBJECT", fcntl.LOCK_UN) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/nss_cache/maps/automount.py b/nss_cache/maps/automount.py index 6b9252c8..977ebf88 100644 --- a/nss_cache/maps/automount.py +++ b/nss_cache/maps/automount.py @@ -21,7 +21,7 @@ AutomountMapEntry: A automount map entry based on the MapEntry class. """ -__author__ = 'vasilios@google.com (Vasilios Hoffman)' +__author__ = "vasilios@google.com (Vasilios Hoffman)" from nss_cache.maps import maps @@ -40,15 +40,16 @@ def __init__(self, iterable=None): def Add(self, entry): """Add a new object, verify it is a AutomountMapEntry object.""" if not isinstance(entry, AutomountMapEntry): - raise TypeError('Entry is not an AutomountMapEntry: %r' % entry) + raise TypeError("Entry is not an AutomountMapEntry: %r" % entry) return super(AutomountMap, self).Add(entry) class AutomountMapEntry(maps.MapEntry): """This class represents NSS automount map entries.""" - __slots__ = ('key', 'location', 'options') - _KEY = 'key' - _ATTRS = ('key', 'location', 'options') + + __slots__ = ("key", "location", "options") + _KEY = "key" + _ATTRS = ("key", "location", "options") def __init__(self, data=None): """Construct a AutomountMapEntry.""" diff --git a/nss_cache/maps/automount_test.py b/nss_cache/maps/automount_test.py index 40715ede..6995db9d 100644 --- a/nss_cache/maps/automount_test.py +++ b/nss_cache/maps/automount_test.py @@ -20,7 +20,7 @@ subclass is required to test the abstract class functionality. """ -__author__ = 'vasilios@google.com (Vasilios Hoffman)' +__author__ = "vasilios@google.com (Vasilios Hoffman)" import unittest @@ -35,34 +35,38 @@ def __init__(self, obj): """Set some default avalible data for testing.""" super(TestAutomountMap, self).__init__(obj) self._good_entry = automount.AutomountMapEntry() - self._good_entry.key = 'foo' - self._good_entry.options = '-tcp' - self._good_entry.location = 'nfsserver:/mah/stuff' + self._good_entry.key = "foo" + self._good_entry.options = "-tcp" + self._good_entry.location = "nfsserver:/mah/stuff" def testInit(self): """Construct an empty or seeded AutomountMap.""" - self.assertEqual(automount.AutomountMap, - type(automount.AutomountMap()), - msg='failed to create an empty AutomountMap') + self.assertEqual( + automount.AutomountMap, + type(automount.AutomountMap()), + msg="failed to create an empty AutomountMap", + ) amap = automount.AutomountMap([self._good_entry]) - self.assertEqual(self._good_entry, - amap.PopItem(), - msg='failed to seed AutomountMap with list') - self.assertRaises(TypeError, automount.AutomountMap, ['string']) + self.assertEqual( + self._good_entry, + amap.PopItem(), + msg="failed to seed AutomountMap with list", + ) + self.assertRaises(TypeError, automount.AutomountMap, ["string"]) def testAdd(self): """Add throws an error for objects it can't verify.""" amap = automount.AutomountMap() entry = self._good_entry - self.assertTrue(amap.Add(entry), msg='failed to append new entry.') + self.assertTrue(amap.Add(entry), msg="failed to append new entry.") - self.assertEqual(1, len(amap), msg='unexpected size for Map.') + self.assertEqual(1, len(amap), msg="unexpected size for Map.") ret_entry = amap.PopItem() - self.assertEqual(ret_entry, entry, msg='failed to pop correct entry.') + self.assertEqual(ret_entry, entry, msg="failed to pop correct entry.") pentry = passwd.PasswdMapEntry() - pentry.name = 'foo' + pentry.name = "foo" pentry.uid = 10 pentry.gid = 10 self.assertRaises(TypeError, amap.Add, pentry) @@ -73,35 +77,34 @@ class TestAutomountMapEntry(unittest.TestCase): def testInit(self): """Construct an empty and seeded AutomountMapEntry.""" - self.assertTrue(automount.AutomountMapEntry(), - msg='Could not create empty AutomountMapEntry') - seed = {'key': 'foo', 'location': '/dev/sda1'} + self.assertTrue( + automount.AutomountMapEntry(), + msg="Could not create empty AutomountMapEntry", + ) + seed = {"key": "foo", "location": "/dev/sda1"} entry = automount.AutomountMapEntry(seed) - self.assertTrue(entry.Verify(), - msg='Could not verify seeded AutomountMapEntry') - self.assertEqual(entry.key, - 'foo', - msg='Entry returned wrong value for name') - self.assertEqual(entry.options, - None, - msg='Entry returned wrong value for options') - self.assertEqual(entry.location, - '/dev/sda1', - msg='Entry returned wrong value for location') + self.assertTrue(entry.Verify(), msg="Could not verify seeded AutomountMapEntry") + self.assertEqual(entry.key, "foo", msg="Entry returned wrong value for name") + self.assertEqual( + entry.options, None, msg="Entry returned wrong value for options" + ) + self.assertEqual( + entry.location, "/dev/sda1", msg="Entry returned wrong value for location" + ) def testAttributes(self): """Test that we can get and set all expected attributes.""" entry = automount.AutomountMapEntry() - entry.key = 'foo' - self.assertEqual(entry.key, 'foo', msg='Could not set attribute: key') - entry.options = 'noatime' - self.assertEqual(entry.options, - 'noatime', - msg='Could not set attribute: options') - entry.location = '/dev/ipod' - self.assertEqual(entry.location, - '/dev/ipod', - msg='Could not set attribute: location') + entry.key = "foo" + self.assertEqual(entry.key, "foo", msg="Could not set attribute: key") + entry.options = "noatime" + self.assertEqual( + entry.options, "noatime", msg="Could not set attribute: options" + ) + entry.location = "/dev/ipod" + self.assertEqual( + entry.location, "/dev/ipod", msg="Could not set attribute: location" + ) def testVerify(self): """Test that the object can verify it's attributes and itself.""" @@ -113,9 +116,9 @@ def testVerify(self): def testKey(self): """Key() should return the value of the 'key' attribute.""" entry = automount.AutomountMapEntry() - entry.key = 'foo' + entry.key = "foo" self.assertEqual(entry.Key(), entry.key) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/nss_cache/maps/group.py b/nss_cache/maps/group.py index 42b0bf50..de403fe1 100644 --- a/nss_cache/maps/group.py +++ b/nss_cache/maps/group.py @@ -21,7 +21,7 @@ GroupMapEntry: A group map entry based on the MapEntry class. """ -__author__ = 'vasilios@google.com (Vasilios Hoffman)' +__author__ = "vasilios@google.com (Vasilios Hoffman)" from nss_cache.maps import maps @@ -46,10 +46,11 @@ def Add(self, entry): class GroupMapEntry(maps.MapEntry): """This class represents NSS group map entries.""" + # Using slots saves us over 2x memory on large maps. - __slots__ = ('name', 'passwd', 'gid', 'members', 'groupmembers') - _KEY = 'name' - _ATTRS = ('name', 'passwd', 'gid', 'members', 'groupmembers') + __slots__ = ("name", "passwd", "gid", "members", "groupmembers") + _KEY = "name" + _ATTRS = ("name", "passwd", "gid", "members", "groupmembers") def __init__(self, data=None): """Construct a GroupMapEntry, setting reasonable defaults.""" @@ -63,7 +64,7 @@ def __init__(self, data=None): # Seed data with defaults if needed if self.passwd is None: - self.passwd = 'x' + self.passwd = "x" if self.members is None: self.members = [] if self.groupmembers is None: diff --git a/nss_cache/maps/group_test.py b/nss_cache/maps/group_test.py index 5fc95329..b7eca9cc 100644 --- a/nss_cache/maps/group_test.py +++ b/nss_cache/maps/group_test.py @@ -20,7 +20,7 @@ subclass is required to test the abstract class functionality. """ -__author__ = 'vasilios@google.com (Vasilios Hoffman)' +__author__ = "vasilios@google.com (Vasilios Hoffman)" import unittest @@ -35,35 +35,37 @@ def __init__(self, obj): """Set some default avalible data for testing.""" super(TestGroupMap, self).__init__(obj) self._good_entry = group.GroupMapEntry() - self._good_entry.name = 'foo' - self._good_entry.passwd = 'x' + self._good_entry.name = "foo" + self._good_entry.passwd = "x" self._good_entry.gid = 10 - self._good_entry.members = ['foo', 'bar'] + self._good_entry.members = ["foo", "bar"] def testInit(self): """Construct an empty or seeded GroupMap.""" - self.assertEqual(group.GroupMap, - type(group.GroupMap()), - msg='failed to create an empty GroupMap') + self.assertEqual( + group.GroupMap, + type(group.GroupMap()), + msg="failed to create an empty GroupMap", + ) gmap = group.GroupMap([self._good_entry]) - self.assertEqual(self._good_entry, - gmap.PopItem(), - msg='failed to seed GroupMap with list') - self.assertRaises(TypeError, group.GroupMap, ['string']) + self.assertEqual( + self._good_entry, gmap.PopItem(), msg="failed to seed GroupMap with list" + ) + self.assertRaises(TypeError, group.GroupMap, ["string"]) def testAdd(self): """Add throws an error for objects it can't verify.""" gmap = group.GroupMap() entry = self._good_entry - self.assertTrue(gmap.Add(entry), msg='failed to append new entry.') + self.assertTrue(gmap.Add(entry), msg="failed to append new entry.") - self.assertEqual(1, len(gmap), msg='unexpected size for Map.') + self.assertEqual(1, len(gmap), msg="unexpected size for Map.") ret_entry = gmap.PopItem() - self.assertEqual(ret_entry, entry, msg='failed to pop correct entry.') + self.assertEqual(ret_entry, entry, msg="failed to pop correct entry.") pentry = passwd.PasswdMapEntry() - pentry.name = 'foo' + pentry.name = "foo" pentry.uid = 10 pentry.gid = 10 self.assertRaises(TypeError, gmap.Add, pentry) @@ -74,40 +76,31 @@ class TestGroupMapEntry(unittest.TestCase): def testInit(self): """Construct an empty and seeded GroupMapEntry.""" - self.assertTrue(group.GroupMapEntry(), - msg='Could not create empty GroupMapEntry') - seed = {'name': 'foo', 'gid': 10} + self.assertTrue( + group.GroupMapEntry(), msg="Could not create empty GroupMapEntry" + ) + seed = {"name": "foo", "gid": 10} entry = group.GroupMapEntry(seed) - self.assertTrue(entry.Verify(), - msg='Could not verify seeded PasswdMapEntry') - self.assertEqual(entry.name, - 'foo', - msg='Entry returned wrong value for name') - self.assertEqual(entry.passwd, - 'x', - msg='Entry returned wrong value for passwd') - self.assertEqual(entry.gid, - 10, - msg='Entry returned wrong value for gid') - self.assertEqual(entry.members, [], - msg='Entry returned wrong value for members') + self.assertTrue(entry.Verify(), msg="Could not verify seeded PasswdMapEntry") + self.assertEqual(entry.name, "foo", msg="Entry returned wrong value for name") + self.assertEqual(entry.passwd, "x", msg="Entry returned wrong value for passwd") + self.assertEqual(entry.gid, 10, msg="Entry returned wrong value for gid") + self.assertEqual( + entry.members, [], msg="Entry returned wrong value for members" + ) def testAttributes(self): """Test that we can get and set all expected attributes.""" entry = group.GroupMapEntry() - entry.name = 'foo' - self.assertEqual(entry.name, 'foo', msg='Could not set attribute: name') - entry.passwd = 'x' - self.assertEqual(entry.passwd, - 'x', - msg='Could not set attribute: passwd') + entry.name = "foo" + self.assertEqual(entry.name, "foo", msg="Could not set attribute: name") + entry.passwd = "x" + self.assertEqual(entry.passwd, "x", msg="Could not set attribute: passwd") entry.gid = 10 - self.assertEqual(entry.gid, 10, msg='Could not set attribute: gid') - members = ['foo', 'bar'] + self.assertEqual(entry.gid, 10, msg="Could not set attribute: gid") + members = ["foo", "bar"] entry.members = members - self.assertEqual(entry.members, - members, - msg='Could not set attribute: members') + self.assertEqual(entry.members, members, msg="Could not set attribute: members") def testVerify(self): """Test that the object can verify it's attributes and itself.""" @@ -119,9 +112,9 @@ def testVerify(self): def testKey(self): """Key() should return the value of the 'name' attribute.""" entry = group.GroupMapEntry() - entry.name = 'foo' + entry.name = "foo" self.assertEqual(entry.Key(), entry.name) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/nss_cache/maps/maps.py b/nss_cache/maps/maps.py index 990570cc..5dfe05b1 100644 --- a/nss_cache/maps/maps.py +++ b/nss_cache/maps/maps.py @@ -19,7 +19,7 @@ MapEntry: Abstract class representing an entry in a NSS map. """ -__author__ = 'vasilios@google.com (Vasilios Hoffman)' +__author__ = "vasilios@google.com (Vasilios Hoffman)" import logging @@ -69,7 +69,7 @@ def __init__(self, iterable=None, modify_time=None, update_time=None): TypeError: If the objects in the iterable are of the wrong type. """ if self.__class__ is Map: - raise TypeError('Map is an abstract class.') + raise TypeError("Map is an abstract class.") self._data = {} # The index preserves the order that entries are returned from the source # (e.g. the LDAP server.) It is not a set as sets are unordered. @@ -107,7 +107,7 @@ def __len__(self): return len(self._data) def __repr__(self): - return '<%s: %r>' % (self.__class__.__name__, self._data) + return "<%s: %r>" % (self.__class__.__name__, self._data) def Add(self, entry): """Add a MapEntry object to the Map and verify it (overwrites). @@ -124,11 +124,11 @@ def Add(self, entry): # Correct type? if not isinstance(entry, MapEntry): - raise TypeError('Not instance of MapEntry') + raise TypeError("Not instance of MapEntry") # Entry okay? if not entry.Verify(): - self.log.info('refusing to add entry, verify failed') + self.log.info("refusing to add entry, verify failed") return False # Add to index if not already there. @@ -136,8 +136,9 @@ def Add(self, entry): self._index.append(entry.Key()) else: self.log.warning( - 'duplicate key detected when adding to map: %r, overwritten', - entry.Key()) + "duplicate key detected when adding to map: %r, overwritten", + entry.Key(), + ) self._data[entry.Key()] = entry return True @@ -175,24 +176,27 @@ def Merge(self, other): """ if type(self) != type(other): raise TypeError( - 'Attempt to Merge() differently typed Maps: %r != %r' % - (type(self), type(other))) + "Attempt to Merge() differently typed Maps: %r != %r" + % (type(self), type(other)) + ) if other.GetModifyTimestamp() and self.GetModifyTimestamp(): if other.GetModifyTimestamp() < self.GetModifyTimestamp(): raise error.InvalidMerge( - 'Attempt to Merge a map with an older modify time into a newer one: ' - 'other: %s, self: %s' % - (other.GetModifyTimestamp(), self.GetModifyTimestamp())) + "Attempt to Merge a map with an older modify time into a newer one: " + "other: %s, self: %s" + % (other.GetModifyTimestamp(), self.GetModifyTimestamp()) + ) if other.GetUpdateTimestamp() and self.GetUpdateTimestamp(): if other.GetUpdateTimestamp() < self.GetUpdateTimestamp(): raise error.InvalidMerge( - 'Attempt to Merge a map with an older update time into a newer one: ' - 'other: %s, self: %s' % - (other.GetUpdateTimestamp(), self.GetUpdateTimestamp())) + "Attempt to Merge a map with an older update time into a newer one: " + "other: %s, self: %s" + % (other.GetUpdateTimestamp(), self.GetUpdateTimestamp()) + ) - self.log.info('merging from a map of %d entries', len(other)) + self.log.info("merging from a map of %d entries", len(other)) merge_count = 0 for their_entry in other: @@ -201,8 +205,7 @@ def Merge(self, other): if self.Add(their_entry): merge_count += 1 - self.log.info('%d of %d entries were new or modified', merge_count, - len(other)) + self.log.info("%d of %d entries were new or modified", merge_count, len(other)) if merge_count > 0: self.SetModifyTimestamp(other.GetModifyTimestamp()) @@ -240,7 +243,7 @@ def SetModifyTimestamp(self, value): if value is None or isinstance(value, int): self._last_modification_timestamp = value else: - raise TypeError('timestamp can only be int or None, not %r' % value) + raise TypeError("timestamp can only be int or None, not %r" % value) def GetModifyTimestamp(self): """Return last modification timestamp of this map. @@ -262,7 +265,7 @@ def SetUpdateTimestamp(self, value): if value is None or isinstance(value, int): self._last_update_timestamp = value else: - raise TypeError('timestamp can only be int or None, not %r', value) + raise TypeError("timestamp can only be int or None, not %r", value) def GetUpdateTimestamp(self): """Return last update timestamp of this map. @@ -283,8 +286,9 @@ class MapEntry(object): Attributes: log: A logging.Logger instance used for output. """ + # Using slots saves us over 2x memory on large maps. - __slots__ = ('_KEY', '_ATTRS', 'log') + __slots__ = ("_KEY", "_ATTRS", "log") # Overridden in the derived classes _KEY: str _ATTRS: set() @@ -300,7 +304,7 @@ def __init__(self, data=None, _KEY=None, _ATTRS=None): """ if self.__class__ is MapEntry: - raise TypeError('MapEntry is an abstract class.') + raise TypeError("MapEntry is an abstract class.") # Initialize from dict, if passed. if data is None: @@ -322,10 +326,10 @@ def __eq__(self, other): def __repr__(self): """String representation.""" - rep = '' + rep = "" for key in self._ATTRS: - rep = '%r:%r %s' % (key, getattr(self, key), rep) - return '<%s : %r>' % (self.__class__.__name__, rep.rstrip()) + rep = "%r:%r %s" % (key, getattr(self, key), rep) + return "<%s : %r>" % (self.__class__.__name__, rep.rstrip()) def Key(self): """Return unique identifier for this MapEntry object. diff --git a/nss_cache/maps/maps_test.py b/nss_cache/maps/maps_test.py index 640a5cb9..495083cc 100644 --- a/nss_cache/maps/maps_test.py +++ b/nss_cache/maps/maps_test.py @@ -19,8 +19,10 @@ base.py is specifically tested in passwd_test.py instead. """ -__author__ = ('jaq@google.com (Jamie Wilkinson)', - 'vasilios@google.com (Vasilios Hoffman)') +__author__ = ( + "jaq@google.com (Jamie Wilkinson)", + "vasilios@google.com (Vasilios Hoffman)", +) import time import unittest @@ -36,7 +38,6 @@ def testIsAbstract(self): self.assertRaises(TypeError, maps.Map) def testModifyTimestamp(self): - class StubMap(maps.Map): pass @@ -49,7 +50,6 @@ class StubMap(maps.Map): self.assertEqual(None, foo.GetModifyTimestamp()) def testUpdateTimestamp(self): - class StubMap(maps.Map): pass @@ -70,5 +70,5 @@ def testIsAbstract(self): self.assertRaises(TypeError, maps.MapEntry) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/nss_cache/maps/netgroup.py b/nss_cache/maps/netgroup.py index 9f6b8485..ed78bf07 100644 --- a/nss_cache/maps/netgroup.py +++ b/nss_cache/maps/netgroup.py @@ -36,7 +36,7 @@ and similar cases. """ -__author__ = 'vasilios@google.com (Vasilios Hoffman)' +__author__ = "vasilios@google.com (Vasilios Hoffman)" from nss_cache.maps import maps @@ -68,9 +68,10 @@ class NetgroupMapEntry(maps.MapEntry): tuple is the equivalent of a null pointer from getnetgrent(), specifically a wildcard. """ - __slots__ = ('name', 'entries') - _KEY = 'name' - _ATTRS = ('name', 'entries') + + __slots__ = ("name", "entries") + _KEY = "name" + _ATTRS = ("name", "entries") def __init__(self, data=None): """Construct a NetgroupMapEntry.""" @@ -81,4 +82,4 @@ def __init__(self, data=None): # Seed data with defaults if needed if self.entries is None: - self.entries = '' + self.entries = "" diff --git a/nss_cache/maps/netgroup_test.py b/nss_cache/maps/netgroup_test.py index 5710773e..fbce1556 100644 --- a/nss_cache/maps/netgroup_test.py +++ b/nss_cache/maps/netgroup_test.py @@ -20,7 +20,7 @@ subclass is required to test the abstract class functionality. """ -__author__ = 'vasilios@google.com (Vasilios Hoffman)' +__author__ = "vasilios@google.com (Vasilios Hoffman)" import unittest @@ -35,33 +35,35 @@ def __init__(self, obj): """Set some default avalible data for testing.""" super(TestNetgroupMap, self).__init__(obj) self._good_entry = netgroup.NetgroupMapEntry() - self._good_entry.name = 'foo' - self._good_entry.entries = [('-', 'bob', None), 'othernetgroup'] + self._good_entry.name = "foo" + self._good_entry.entries = [("-", "bob", None), "othernetgroup"] def testInit(self): """Construct an empty or seeded NetgroupMap.""" - self.assertEqual(netgroup.NetgroupMap, - type(netgroup.NetgroupMap()), - msg='failed to create an empty NetgroupMap') + self.assertEqual( + netgroup.NetgroupMap, + type(netgroup.NetgroupMap()), + msg="failed to create an empty NetgroupMap", + ) nmap = netgroup.NetgroupMap([self._good_entry]) - self.assertEqual(self._good_entry, - nmap.PopItem(), - msg='failed to seed NetgroupMap with list') - self.assertRaises(TypeError, netgroup.NetgroupMap, ['string']) + self.assertEqual( + self._good_entry, nmap.PopItem(), msg="failed to seed NetgroupMap with list" + ) + self.assertRaises(TypeError, netgroup.NetgroupMap, ["string"]) def testAdd(self): """Add throws an error for objects it can't verify.""" nmap = netgroup.NetgroupMap() entry = self._good_entry - self.assertTrue(nmap.Add(entry), msg='failed to append new entry.') + self.assertTrue(nmap.Add(entry), msg="failed to append new entry.") - self.assertEqual(1, len(nmap), msg='unexpected size for Map.') + self.assertEqual(1, len(nmap), msg="unexpected size for Map.") ret_entry = nmap.PopItem() - self.assertEqual(ret_entry, entry, msg='failed to pop correct entry.') + self.assertEqual(ret_entry, entry, msg="failed to pop correct entry.") pentry = passwd.PasswdMapEntry() - pentry.name = 'foo' + pentry.name = "foo" pentry.uid = 10 pentry.gid = 10 self.assertRaises(TypeError, nmap.Add, pentry) @@ -72,30 +74,26 @@ class TestNetgroupMapEntry(unittest.TestCase): def testInit(self): """Construct an empty and seeded NetgroupMapEntry.""" - self.assertTrue(netgroup.NetgroupMapEntry(), - msg='Could not create empty NetgroupMapEntry') - entries = ['bar', ('baz', '-', None)] - seed = {'name': 'foo', 'entries': entries} + self.assertTrue( + netgroup.NetgroupMapEntry(), msg="Could not create empty NetgroupMapEntry" + ) + entries = ["bar", ("baz", "-", None)] + seed = {"name": "foo", "entries": entries} entry = netgroup.NetgroupMapEntry(seed) - self.assertTrue(entry.Verify(), - msg='Could not verify seeded NetgroupMapEntry') - self.assertEqual(entry.name, - 'foo', - msg='Entry returned wrong value for name') - self.assertEqual(entry.entries, - entries, - msg='Entry returned wrong value for entries') + self.assertTrue(entry.Verify(), msg="Could not verify seeded NetgroupMapEntry") + self.assertEqual(entry.name, "foo", msg="Entry returned wrong value for name") + self.assertEqual( + entry.entries, entries, msg="Entry returned wrong value for entries" + ) def testAttributes(self): """Test that we can get and set all expected attributes.""" entry = netgroup.NetgroupMapEntry() - entry.name = 'foo' - self.assertEqual(entry.name, 'foo', msg='Could not set attribute: name') - entries = ['foo', '(-,bar,)'] + entry.name = "foo" + self.assertEqual(entry.name, "foo", msg="Could not set attribute: name") + entries = ["foo", "(-,bar,)"] entry.entries = entries - self.assertEqual(entry.entries, - entries, - msg='Could not set attribute: entries') + self.assertEqual(entry.entries, entries, msg="Could not set attribute: entries") def testVerify(self): """Test that the object can verify it's attributes and itself.""" @@ -107,9 +105,9 @@ def testVerify(self): def testKey(self): """Key() should return the value of the 'name' attribute.""" entry = netgroup.NetgroupMapEntry() - entry.name = 'foo' + entry.name = "foo" self.assertEqual(entry.Key(), entry.name) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/nss_cache/maps/passwd.py b/nss_cache/maps/passwd.py index 902f3a61..5f46396e 100644 --- a/nss_cache/maps/passwd.py +++ b/nss_cache/maps/passwd.py @@ -21,7 +21,7 @@ PasswdMapEntry: A passwd map entry based on the MapEntry class. """ -__author__ = 'vasilios@google.com (Vasilios Hoffman)' +__author__ = "vasilios@google.com (Vasilios Hoffman)" from nss_cache.maps import maps @@ -52,10 +52,11 @@ def Add(self, entry): class PasswdMapEntry(maps.MapEntry): """This class represents NSS passwd map entries.""" + # Using slots saves us over 2x memory on large maps. - __slots__ = ('name', 'uid', 'gid', 'passwd', 'gecos', 'dir', 'shell') - _KEY = 'name' - _ATTRS = ('name', 'uid', 'gid', 'passwd', 'gecos', 'dir', 'shell') + __slots__ = ("name", "uid", "gid", "passwd", "gecos", "dir", "shell") + _KEY = "name" + _ATTRS = ("name", "uid", "gid", "passwd", "gecos", "dir", "shell") def __init__(self, data=None): """Construct a PasswdMapEntry, setting reasonable defaults.""" @@ -71,10 +72,10 @@ def __init__(self, data=None): # Seed data with defaults if still empty if self.passwd is None: - self.passwd = 'x' + self.passwd = "x" if self.gecos is None: - self.gecos = '' + self.gecos = "" if self.dir is None: - self.dir = '' + self.dir = "" if self.shell is None: - self.shell = '' + self.shell = "" diff --git a/nss_cache/maps/passwd_test.py b/nss_cache/maps/passwd_test.py index d4b820f6..cba65fbd 100644 --- a/nss_cache/maps/passwd_test.py +++ b/nss_cache/maps/passwd_test.py @@ -15,7 +15,7 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Unit tests for passwd.py.""" -__author__ = 'vasilios@google.com (Vasilios Hoffman)' +__author__ = "vasilios@google.com (Vasilios Hoffman)" import time import unittest @@ -31,38 +31,40 @@ class TestPasswdMap(unittest.TestCase): def setUp(self): """Set some default avalible data for testing.""" self._good_entry = passwd.PasswdMapEntry() - self._good_entry.name = 'foo' - self._good_entry.passwd = 'x' + self._good_entry.name = "foo" + self._good_entry.passwd = "x" self._good_entry.uid = 10 self._good_entry.gid = 10 - self._good_entry.gecos = 'How Now Brown Cow' - self._good_entry.dir = '/home/foo' - self._good_entry.shell = '/bin/bash' + self._good_entry.gecos = "How Now Brown Cow" + self._good_entry.dir = "/home/foo" + self._good_entry.shell = "/bin/bash" def testInit(self): """Construct an empty or seeded PasswdMap.""" - self.assertEqual(passwd.PasswdMap, - type(passwd.PasswdMap()), - msg='failed to create emtpy PasswdMap') + self.assertEqual( + passwd.PasswdMap, + type(passwd.PasswdMap()), + msg="failed to create emtpy PasswdMap", + ) pmap = passwd.PasswdMap([self._good_entry]) - self.assertEqual(self._good_entry, - pmap.PopItem(), - msg='failed to seed PasswdMap with list') - self.assertRaises(TypeError, passwd.PasswdMap, ['string']) + self.assertEqual( + self._good_entry, pmap.PopItem(), msg="failed to seed PasswdMap with list" + ) + self.assertRaises(TypeError, passwd.PasswdMap, ["string"]) def testAdd(self): """Add raises exceptions for objects it can't add or verify.""" pmap = passwd.PasswdMap() entry = self._good_entry - self.assertTrue(pmap.Add(entry), msg='failed to add new entry.') + self.assertTrue(pmap.Add(entry), msg="failed to add new entry.") - self.assertEqual(1, len(pmap), msg='unexpected size for Map.') + self.assertEqual(1, len(pmap), msg="unexpected size for Map.") ret_entry = pmap.PopItem() - self.assertEqual(ret_entry, entry, msg='failed to pop existing entry.') + self.assertEqual(ret_entry, entry, msg="failed to pop existing entry.") gentry = group.GroupMapEntry() - gentry.name = 'foo' + gentry.name = "foo" gentry.gid = 10 self.assertRaises(TypeError, pmap.Add, gentry) @@ -70,17 +72,17 @@ def testContains(self): """Verify __contains__ works, and does a deep compare.""" pentry_good = self._good_entry pentry_like_good = passwd.PasswdMapEntry() - pentry_like_good.name = 'foo' # same Key(), but rest of attributes differ + pentry_like_good.name = "foo" # same Key(), but rest of attributes differ pentry_bad = passwd.PasswdMapEntry() - pentry_bad.name = 'bar' + pentry_bad.name = "bar" pmap = passwd.PasswdMap([pentry_good]) - self.assertTrue(pentry_good in pmap, msg='expected entry to be in map') - self.assertFalse(pentry_bad in pmap, - msg='did not expect entry to be in map') - self.assertFalse(pentry_like_good in pmap, - msg='__contains__ not doing a deep compare') + self.assertTrue(pentry_good in pmap, msg="expected entry to be in map") + self.assertFalse(pentry_bad in pmap, msg="did not expect entry to be in map") + self.assertFalse( + pentry_like_good in pmap, msg="__contains__ not doing a deep compare" + ) def testIterate(self): """Check that we can iterate over PasswdMap.""" @@ -89,17 +91,17 @@ def testIterate(self): ret_entries = [] for entry in pmap: ret_entries.append(entry) - self.assertEqual(len(ret_entries), 1, msg='iterated over wrong count') - self.assertEqual(ret_entries[0], - self._good_entry, - msg='got the wrong entry back') + self.assertEqual(len(ret_entries), 1, msg="iterated over wrong count") + self.assertEqual( + ret_entries[0], self._good_entry, msg="got the wrong entry back" + ) def testLen(self): """Verify we have correctly overridden __len__ in MapEntry.""" pmap = passwd.PasswdMap() - self.assertEqual(len(pmap), 0, msg='expected len(pmap) to be 0') + self.assertEqual(len(pmap), 0, msg="expected len(pmap) to be 0") pmap.Add(self._good_entry) - self.assertEqual(len(pmap), 1, msg='expected len(pmap) to be 1') + self.assertEqual(len(pmap), 1, msg="expected len(pmap) to be 1") def testExists(self): """Verify Exists() checks for presence of MapEntry objects.""" @@ -115,31 +117,32 @@ def testMerge(self): # Setup some MapEntry objects with distinct Key()s pentry1 = self._good_entry pentry2 = passwd.PasswdMapEntry() - pentry2.name = 'john' + pentry2.name = "john" pentry3 = passwd.PasswdMapEntry() - pentry3.name = 'jane' + pentry3.name = "jane" # Setup some Map objects pmap_big = passwd.PasswdMap([pentry1, pentry2]) pmap_small = passwd.PasswdMap([pentry3]) # Merge small into big - self.assertTrue(pmap_big.Merge(pmap_small), - msg='Merging small into big failed!') - self.assertTrue(pmap_big.Exists(pentry1), - msg='pentry1 not found in Map') - self.assertTrue(pmap_big.Exists(pentry2), - msg='pentry1 not found in Map') - self.assertTrue(pmap_big.Exists(pentry3), - msg='pentry1 not found in Map') + self.assertTrue( + pmap_big.Merge(pmap_small), msg="Merging small into big failed!" + ) + self.assertTrue(pmap_big.Exists(pentry1), msg="pentry1 not found in Map") + self.assertTrue(pmap_big.Exists(pentry2), msg="pentry1 not found in Map") + self.assertTrue(pmap_big.Exists(pentry3), msg="pentry1 not found in Map") # A second merge should do nothing - self.assertFalse(pmap_big.Merge(pmap_small), - msg='Re-merging small into big succeeded.') + self.assertFalse( + pmap_big.Merge(pmap_small), msg="Re-merging small into big succeeded." + ) # An empty merge should do nothing - self.assertFalse(pmap_big.Merge(passwd.PasswdMap()), - msg='Empty Merge should have done nothing.') + self.assertFalse( + pmap_big.Merge(passwd.PasswdMap()), + msg="Empty Merge should have done nothing.", + ) # Merge a GroupMap should throw TypeError gmap = group.GroupMap() @@ -177,98 +180,71 @@ class TestPasswdMapEntry(unittest.TestCase): def testInit(self): """Construct empty and seeded PasswdMapEntry.""" entry = passwd.PasswdMapEntry() - self.assertEqual(type(entry), - passwd.PasswdMapEntry, - msg='Could not create empty PasswdMapEntry') + self.assertEqual( + type(entry), + passwd.PasswdMapEntry, + msg="Could not create empty PasswdMapEntry", + ) seed = { - 'name': 'foo', - 'passwd': 'x', - 'uid': 10, - 'gid': 10, - 'gecos': '', - 'dir': '', - 'shell': '' + "name": "foo", + "passwd": "x", + "uid": 10, + "gid": 10, + "gecos": "", + "dir": "", + "shell": "", } entry = passwd.PasswdMapEntry(seed) - self.assertTrue(entry.Verify(), - msg='Could not verify seeded PasswdMapEntry') - self.assertEqual(entry.name, - 'foo', - msg='Entry returned wrong value for name') - self.assertEqual(entry.passwd, - 'x', - msg='Entry returned wrong value for passwd') - self.assertEqual(entry.uid, - 10, - msg='Entry returned wrong value for uid') - self.assertEqual(entry.gid, - 10, - msg='Entry returned wrong value for gid') - self.assertEqual(entry.gecos, - '', - msg='Entry returned wrong value for gecos') - self.assertEqual(entry.dir, - '', - msg='Entry returned wrong value for dir') - self.assertEqual(entry.shell, - '', - msg='Entry returned wrong value for shell') + self.assertTrue(entry.Verify(), msg="Could not verify seeded PasswdMapEntry") + self.assertEqual(entry.name, "foo", msg="Entry returned wrong value for name") + self.assertEqual(entry.passwd, "x", msg="Entry returned wrong value for passwd") + self.assertEqual(entry.uid, 10, msg="Entry returned wrong value for uid") + self.assertEqual(entry.gid, 10, msg="Entry returned wrong value for gid") + self.assertEqual(entry.gecos, "", msg="Entry returned wrong value for gecos") + self.assertEqual(entry.dir, "", msg="Entry returned wrong value for dir") + self.assertEqual(entry.shell, "", msg="Entry returned wrong value for shell") def testAttributes(self): """Test that we can get and set all expected attributes.""" entry = passwd.PasswdMapEntry() - entry.name = 'foo' - self.assertEqual(entry.name, 'foo', msg='Could not set attribute: name') - entry.passwd = 'x' - self.assertEqual(entry.passwd, - 'x', - msg='Could not set attribute: passwd') + entry.name = "foo" + self.assertEqual(entry.name, "foo", msg="Could not set attribute: name") + entry.passwd = "x" + self.assertEqual(entry.passwd, "x", msg="Could not set attribute: passwd") entry.uid = 10 - self.assertEqual(entry.uid, 10, msg='Could not set attribute: uid') + self.assertEqual(entry.uid, 10, msg="Could not set attribute: uid") entry.gid = 10 - self.assertEqual(entry.gid, 10, msg='Could not set attribute: gid') - entry.gecos = 'How Now Brown Cow' - self.assertEqual(entry.gecos, - 'How Now Brown Cow', - msg='Could not set attribute: gecos') - entry.dir = '/home/foo' - self.assertEqual(entry.dir, - '/home/foo', - msg='Could not set attribute: dir') - entry.shell = '/bin/bash' - self.assertEqual(entry.shell, - '/bin/bash', - msg='Could not set attribute: shell') + self.assertEqual(entry.gid, 10, msg="Could not set attribute: gid") + entry.gecos = "How Now Brown Cow" + self.assertEqual( + entry.gecos, "How Now Brown Cow", msg="Could not set attribute: gecos" + ) + entry.dir = "/home/foo" + self.assertEqual(entry.dir, "/home/foo", msg="Could not set attribute: dir") + entry.shell = "/bin/bash" + self.assertEqual(entry.shell, "/bin/bash", msg="Could not set attribute: shell") def testEq(self): """Verify we are doing a deep compare in __eq__.""" # Setup some things to compare - entry_good = passwd.PasswdMapEntry({ - 'name': 'foo', - 'uid': 10, - 'gid': 10 - }) - entry_same_as_good = passwd.PasswdMapEntry({ - 'name': 'foo', - 'uid': 10, - 'gid': 10 - }) + entry_good = passwd.PasswdMapEntry({"name": "foo", "uid": 10, "gid": 10}) + entry_same_as_good = passwd.PasswdMapEntry( + {"name": "foo", "uid": 10, "gid": 10} + ) entry_like_good = passwd.PasswdMapEntry() - entry_like_good.name = 'foo' # same Key(), but rest of attributes differ + entry_like_good.name = "foo" # same Key(), but rest of attributes differ entry_bad = passwd.PasswdMapEntry() - entry_bad.name = 'bar' - - self.assertEqual(entry_good, - entry_good, - msg='entry_good not equal to itself') - self.assertEqual(entry_good, - entry_same_as_good, - msg='__eq__ not doing deep compare') - self.assertNotEqual(entry_good, - entry_like_good, - msg='__eq__ not doing deep compare') - self.assertNotEqual(entry_good, entry_bad, msg='unexpected equality') + entry_bad.name = "bar" + + self.assertEqual(entry_good, entry_good, msg="entry_good not equal to itself") + self.assertEqual( + entry_good, entry_same_as_good, msg="__eq__ not doing deep compare" + ) + self.assertNotEqual( + entry_good, entry_like_good, msg="__eq__ not doing deep compare" + ) + self.assertNotEqual(entry_good, entry_bad, msg="unexpected equality") def testVerify(self): """Test that the object can verify it's attributes and itself.""" @@ -280,9 +256,9 @@ def testVerify(self): def testKey(self): """Key() should return the value of the 'name' attribute.""" entry = passwd.PasswdMapEntry() - entry.name = 'foo' + entry.name = "foo" self.assertEqual(entry.Key(), entry.name) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/nss_cache/maps/shadow.py b/nss_cache/maps/shadow.py index ebe1d834..360e7cbb 100644 --- a/nss_cache/maps/shadow.py +++ b/nss_cache/maps/shadow.py @@ -21,7 +21,7 @@ ShadowMapEntry: A shadow map entry based on the MapEntry class. """ -__author__ = 'vasilios@google.com (Vasilios Hoffman)' +__author__ = "vasilios@google.com (Vasilios Hoffman)" from nss_cache.maps import maps @@ -46,11 +46,30 @@ def Add(self, entry): class ShadowMapEntry(maps.MapEntry): """This class represents NSS shadow map entries.""" - __slots__ = ('name', 'passwd', 'lstchg', 'min', 'max', 'warn', 'inact', - 'expire', 'flag') - _KEY = 'name' - _ATTRS = ('name', 'passwd', 'lstchg', 'min', 'max', 'warn', 'inact', - 'expire', 'flag') + + __slots__ = ( + "name", + "passwd", + "lstchg", + "min", + "max", + "warn", + "inact", + "expire", + "flag", + ) + _KEY = "name" + _ATTRS = ( + "name", + "passwd", + "lstchg", + "min", + "max", + "warn", + "inact", + "expire", + "flag", + ) def __init__(self, data=None): """Construct a ShadowMapEntry, setting reasonable defaults.""" @@ -68,4 +87,4 @@ def __init__(self, data=None): # Seed data with defaults if needed if self.passwd is None: - self.passwd = '!!' + self.passwd = "!!" diff --git a/nss_cache/maps/shadow_test.py b/nss_cache/maps/shadow_test.py index acf7fc84..db04274e 100644 --- a/nss_cache/maps/shadow_test.py +++ b/nss_cache/maps/shadow_test.py @@ -20,7 +20,7 @@ subclass is required to test the abstract class functionality. """ -__author__ = 'vasilios@google.com (Vasilios Hoffman)' +__author__ = "vasilios@google.com (Vasilios Hoffman)" import unittest @@ -35,7 +35,7 @@ def __init__(self, obj): """Set some default avalible data for testing.""" super(TestShadowMap, self).__init__(obj) self._good_entry = shadow.ShadowMapEntry() - self._good_entry.name = 'foo' + self._good_entry.name = "foo" self._good_entry.lstchg = None self._good_entry.min = None self._good_entry.max = None @@ -46,28 +46,30 @@ def __init__(self, obj): def testInit(self): """Construct an empty or seeded ShadowMap.""" - self.assertEqual(shadow.ShadowMap, - type(shadow.ShadowMap()), - msg='failed to create emtpy ShadowMap') + self.assertEqual( + shadow.ShadowMap, + type(shadow.ShadowMap()), + msg="failed to create emtpy ShadowMap", + ) smap = shadow.ShadowMap([self._good_entry]) - self.assertEqual(self._good_entry, - smap.PopItem(), - msg='failed to seed ShadowMap with list') - self.assertRaises(TypeError, shadow.ShadowMap, ['string']) + self.assertEqual( + self._good_entry, smap.PopItem(), msg="failed to seed ShadowMap with list" + ) + self.assertRaises(TypeError, shadow.ShadowMap, ["string"]) def testAdd(self): """Add throws an error for objects it can't verify.""" smap = shadow.ShadowMap() entry = self._good_entry - self.assertTrue(smap.Add(entry), msg='failed to append new entry.') + self.assertTrue(smap.Add(entry), msg="failed to append new entry.") - self.assertEqual(1, len(smap), msg='unexpected size for Map.') + self.assertEqual(1, len(smap), msg="unexpected size for Map.") ret_entry = smap.PopItem() - self.assertEqual(ret_entry, entry, msg='failed to pop existing entry.') + self.assertEqual(ret_entry, entry, msg="failed to pop existing entry.") pentry = passwd.PasswdMapEntry() - pentry.name = 'foo' + pentry.name = "foo" pentry.uid = 10 pentry.gid = 10 self.assertRaises(TypeError, smap.Add, pentry) @@ -78,63 +80,49 @@ class TestShadowMapEntry(unittest.TestCase): def testInit(self): """Construct empty and seeded ShadowMapEntry.""" - self.assertTrue(shadow.ShadowMapEntry(), - msg='Could not create empty ShadowMapEntry') - seed = {'name': 'foo'} + self.assertTrue( + shadow.ShadowMapEntry(), msg="Could not create empty ShadowMapEntry" + ) + seed = {"name": "foo"} entry = shadow.ShadowMapEntry(seed) - self.assertTrue(entry.Verify(), - msg='Could not verify seeded ShadowMapEntry') - self.assertEqual(entry.name, - 'foo', - msg='Entry returned wrong value for name') - self.assertEqual(entry.passwd, - '!!', - msg='Entry returned wrong value for passwd') - self.assertEqual(entry.lstchg, - None, - msg='Entry returned wrong value for lstchg') - self.assertEqual(entry.min, - None, - msg='Entry returned wrong value for min') - self.assertEqual(entry.max, - None, - msg='Entry returned wrong value for max') - self.assertEqual(entry.warn, - None, - msg='Entry returned wrong value for warn') - self.assertEqual(entry.inact, - None, - msg='Entry returned wrong value for inact') - self.assertEqual(entry.expire, - None, - msg='Entry returned wrong value for expire') - self.assertEqual(entry.flag, - None, - msg='Entry returned wrong value for flag') + self.assertTrue(entry.Verify(), msg="Could not verify seeded ShadowMapEntry") + self.assertEqual(entry.name, "foo", msg="Entry returned wrong value for name") + self.assertEqual( + entry.passwd, "!!", msg="Entry returned wrong value for passwd" + ) + self.assertEqual( + entry.lstchg, None, msg="Entry returned wrong value for lstchg" + ) + self.assertEqual(entry.min, None, msg="Entry returned wrong value for min") + self.assertEqual(entry.max, None, msg="Entry returned wrong value for max") + self.assertEqual(entry.warn, None, msg="Entry returned wrong value for warn") + self.assertEqual(entry.inact, None, msg="Entry returned wrong value for inact") + self.assertEqual( + entry.expire, None, msg="Entry returned wrong value for expire" + ) + self.assertEqual(entry.flag, None, msg="Entry returned wrong value for flag") def testAttributes(self): """Test that we can get and set all expected attributes.""" entry = shadow.ShadowMapEntry() - entry.name = 'foo' - self.assertEqual(entry.name, 'foo', msg='Could not set attribute: name') - entry.passwd = 'seekret' - self.assertEqual(entry.passwd, - 'seekret', - msg='Could not set attribute: passwd') + entry.name = "foo" + self.assertEqual(entry.name, "foo", msg="Could not set attribute: name") + entry.passwd = "seekret" + self.assertEqual(entry.passwd, "seekret", msg="Could not set attribute: passwd") entry.lstchg = 0 - self.assertEqual(entry.lstchg, 0, msg='Could not set attribute: lstchg') + self.assertEqual(entry.lstchg, 0, msg="Could not set attribute: lstchg") entry.min = 0 - self.assertEqual(entry.min, 0, msg='Could not set attribute: min') + self.assertEqual(entry.min, 0, msg="Could not set attribute: min") entry.max = 0 - self.assertEqual(entry.max, 0, msg='Could not set attribute: max') + self.assertEqual(entry.max, 0, msg="Could not set attribute: max") entry.warn = 0 - self.assertEqual(entry.warn, 0, msg='Could not set attribute: warn') + self.assertEqual(entry.warn, 0, msg="Could not set attribute: warn") entry.inact = 0 - self.assertEqual(entry.inact, 0, msg='Could not set attribute: inact') + self.assertEqual(entry.inact, 0, msg="Could not set attribute: inact") entry.expire = 0 - self.assertEqual(entry.expire, 0, msg='Could not set attribute: expire') + self.assertEqual(entry.expire, 0, msg="Could not set attribute: expire") entry.flag = 0 - self.assertEqual(entry.flag, 0, msg='Could not set attribute: flag') + self.assertEqual(entry.flag, 0, msg="Could not set attribute: flag") def testVerify(self): """Test that the object can verify it's attributes and itself.""" @@ -146,9 +134,9 @@ def testVerify(self): def testKey(self): """Key() should return the value of the 'name' attribute.""" entry = shadow.ShadowMapEntry() - entry.name = 'foo' + entry.name = "foo" self.assertEqual(entry.Key(), entry.name) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/nss_cache/maps/sshkey.py b/nss_cache/maps/sshkey.py index d761e787..c743a906 100644 --- a/nss_cache/maps/sshkey.py +++ b/nss_cache/maps/sshkey.py @@ -21,7 +21,7 @@ SshkeyMapEntry: A sshkey map entry based on the MapEntry class. """ -__author__ = 'mimianddaniel@gmail.com' +__author__ = "mimianddaniel@gmail.com" from nss_cache.maps import maps @@ -52,10 +52,11 @@ def Add(self, entry): class SshkeyMapEntry(maps.MapEntry): """This class represents NSS sshkey map entries.""" + # Using slots saves us over 2x memory on large maps. - __slots__ = ('name', 'sshkey') - _KEY = 'name' - _ATTRS = ('name', 'sshkey') + __slots__ = ("name", "sshkey") + _KEY = "name" + _ATTRS = ("name", "sshkey") def __init__(self, data=None): """Construct a SshkeyMapEntry, setting reasonable defaults.""" @@ -65,4 +66,4 @@ def __init__(self, data=None): super(SshkeyMapEntry, self).__init__(data) # Seed data with defaults if still empty if self.sshkey is None: - self.sshkey = '' + self.sshkey = "" diff --git a/nss_cache/nss.py b/nss_cache/nss.py index 6883cbe6..2e70503c 100644 --- a/nss_cache/nss.py +++ b/nss_cache/nss.py @@ -15,7 +15,7 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """NSS utility library.""" -__author__ = 'vasilios@google.com (Vasilios Hoffman)' +__author__ = "vasilios@google.com (Vasilios Hoffman)" import pwd import grp @@ -30,7 +30,7 @@ # TODO(v): this should be a config option someday, but it's as standard # as libc so at the moment we'll leave it be for simplicity. -GETENT = '/usr/bin/getent' +GETENT = "/usr/bin/getent" def GetMap(map_name): @@ -75,7 +75,7 @@ def GetGroupMap(): map_entry.gid = nss_entry[2] map_entry.members = nss_entry[3] if not map_entry.members: - map_entry.members = [''] + map_entry.members = [""] group_map.Add(map_entry) return group_map @@ -91,42 +91,42 @@ def GetShadowMap(): shadow_map = shadow.ShadowMap() for line in getent_stdout.split(): - line = line.decode('utf-8') - nss_entry = line.strip().split(':') + line = line.decode("utf-8") + nss_entry = line.strip().split(":") map_entry = shadow.ShadowMapEntry() map_entry.name = nss_entry[0] map_entry.passwd = nss_entry[1] - if nss_entry[2] != '': + if nss_entry[2] != "": map_entry.lstchg = int(nss_entry[2]) - if nss_entry[3] != '': + if nss_entry[3] != "": map_entry.min = int(nss_entry[3]) - if nss_entry[4] != '': + if nss_entry[4] != "": map_entry.max = int(nss_entry[4]) - if nss_entry[5] != '': + if nss_entry[5] != "": map_entry.warn = int(nss_entry[5]) - if nss_entry[6] != '': + if nss_entry[6] != "": map_entry.inact = int(nss_entry[6]) - if nss_entry[7] != '': + if nss_entry[7] != "": map_entry.expire = int(nss_entry[7]) - if nss_entry[8] != '': + if nss_entry[8] != "": map_entry.flag = int(nss_entry[8]) shadow_map.Add(map_entry) if getent_stderr: - logging.debug('captured error %s', getent_stderr) + logging.debug("captured error %s", getent_stderr) retval = getent.returncode if retval != 0: - logging.warning('%s returned error code: %d', GETENT, retval) + logging.warning("%s returned error code: %d", GETENT, retval) return shadow_map def _SpawnGetent(map_name): """Run 'getent map' in a subprocess for reading NSS data.""" - getent = subprocess.Popen([GETENT, map_name], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + getent = subprocess.Popen( + [GETENT, map_name], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) return getent diff --git a/nss_cache/nss_test.py b/nss_cache/nss_test.py index 36a90cea..b72d4e6e 100644 --- a/nss_cache/nss_test.py +++ b/nss_cache/nss_test.py @@ -15,7 +15,7 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Unit tests for nss_cache/command.py.""" -__author__ = 'vasilios@google.com (Vasilios Hoffman)' +__author__ = "vasilios@google.com (Vasilios Hoffman)" import grp import pwd @@ -36,47 +36,47 @@ class TestNSS(mox.MoxTestBase): def testGetMap(self): """that GetMap is calling the right GetFooMap routines.""" - self.mox.StubOutWithMock(nss, 'GetPasswdMap') - nss.GetPasswdMap().AndReturn('TEST_PASSWORD') - self.mox.StubOutWithMock(nss, 'GetGroupMap') - nss.GetGroupMap().AndReturn('TEST_GROUP') - self.mox.StubOutWithMock(nss, 'GetShadowMap') - nss.GetShadowMap().AndReturn('TEST_SHADOW') + self.mox.StubOutWithMock(nss, "GetPasswdMap") + nss.GetPasswdMap().AndReturn("TEST_PASSWORD") + self.mox.StubOutWithMock(nss, "GetGroupMap") + nss.GetGroupMap().AndReturn("TEST_GROUP") + self.mox.StubOutWithMock(nss, "GetShadowMap") + nss.GetShadowMap().AndReturn("TEST_SHADOW") self.mox.ReplayAll() - self.assertEqual('TEST_PASSWORD', nss.GetMap(config.MAP_PASSWORD)) - self.assertEqual('TEST_GROUP', nss.GetMap(config.MAP_GROUP)) - self.assertEqual('TEST_SHADOW', nss.GetMap(config.MAP_SHADOW)) + self.assertEqual("TEST_PASSWORD", nss.GetMap(config.MAP_PASSWORD)) + self.assertEqual("TEST_GROUP", nss.GetMap(config.MAP_GROUP)) + self.assertEqual("TEST_SHADOW", nss.GetMap(config.MAP_SHADOW)) def testGetMapException(self): """GetMap throws error.UnsupportedMap for unsupported maps.""" - self.assertRaises(error.UnsupportedMap, nss.GetMap, 'ohio') + self.assertRaises(error.UnsupportedMap, nss.GetMap, "ohio") def testGetPasswdMap(self): """Verify we build a correct password map from nss calls.""" - foo = ('foo', 'x', 10, 10, 'foo bar', '/home/foo', '/bin/shell') - bar = ('bar', 'x', 20, 20, 'foo bar', '/home/monkeyboy', '/bin/shell') + foo = ("foo", "x", 10, 10, "foo bar", "/home/foo", "/bin/shell") + bar = ("bar", "x", 20, 20, "foo bar", "/home/monkeyboy", "/bin/shell") - self.mox.StubOutWithMock(pwd, 'getpwall') + self.mox.StubOutWithMock(pwd, "getpwall") pwd.getpwall().AndReturn([foo, bar]) entry1 = passwd.PasswdMapEntry() - entry1.name = 'foo' + entry1.name = "foo" entry1.uid = 10 entry1.gid = 10 - entry1.gecos = 'foo bar' - entry1.dir = '/home/foo' - entry1.shell = '/bin/shell' + entry1.gecos = "foo bar" + entry1.dir = "/home/foo" + entry1.shell = "/bin/shell" entry2 = passwd.PasswdMapEntry() - entry2.name = 'bar' + entry2.name = "bar" entry2.uid = 20 entry2.gid = 20 - entry2.gecos = 'foo bar' - entry2.dir = '/home/monkeyboy' - entry2.shell = '/bin/shell' + entry2.gecos = "foo bar" + entry2.dir = "/home/monkeyboy" + entry2.shell = "/bin/shell" self.mox.ReplayAll() @@ -90,23 +90,23 @@ def testGetPasswdMap(self): def testGetGroupMap(self): """Verify we build a correct group map from nss calls.""" - foo = ('foo', '*', 10, []) - bar = ('bar', '*', 20, ['foo', 'bar']) + foo = ("foo", "*", 10, []) + bar = ("bar", "*", 20, ["foo", "bar"]) - self.mox.StubOutWithMock(grp, 'getgrall') + self.mox.StubOutWithMock(grp, "getgrall") grp.getgrall().AndReturn([foo, bar]) entry1 = group.GroupMapEntry() - entry1.name = 'foo' - entry1.passwd = '*' + entry1.name = "foo" + entry1.passwd = "*" entry1.gid = 10 - entry1.members = [''] + entry1.members = [""] entry2 = group.GroupMapEntry() - entry2.name = 'bar' - entry2.passwd = '*' + entry2.name = "bar" + entry2.passwd = "*" entry2.gid = 20 - entry2.members = ['foo', 'bar'] + entry2.members = ["foo", "bar"] self.mox.ReplayAll() @@ -119,20 +119,20 @@ def testGetGroupMap(self): def testGetShadowMap(self): """Verify we build a correct shadow map from nss calls.""" - line1 = b'foo:!!::::::::' - line2 = b'bar:!!::::::::' + line1 = b"foo:!!::::::::" + line2 = b"bar:!!::::::::" lines = [line1, line2] mock_getent = self.mox.CreateMockAnything() - mock_getent.communicate().AndReturn([b'\n'.join(lines), b'']) + mock_getent.communicate().AndReturn([b"\n".join(lines), b""]) mock_getent.returncode = 0 entry1 = shadow.ShadowMapEntry() - entry1.name = 'foo' + entry1.name = "foo" entry2 = shadow.ShadowMapEntry() - entry2.name = 'bar' + entry2.name = "bar" - self.mox.StubOutWithMock(nss, '_SpawnGetent') + self.mox.StubOutWithMock(nss, "_SpawnGetent") nss._SpawnGetent(config.MAP_SHADOW).AndReturn(mock_getent) self.mox.ReplayAll() @@ -145,5 +145,5 @@ def testGetShadowMap(self): self.assertTrue(shadow_map.Exists(entry2)) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/nss_cache/sources/consulsource.py b/nss_cache/sources/consulsource.py index 100b9213..f277df77 100644 --- a/nss_cache/sources/consulsource.py +++ b/nss_cache/sources/consulsource.py @@ -1,6 +1,6 @@ """An implementation of a consul data source for nsscache.""" -__author__ = 'hexedpackets@gmail.com (William Huba)' +__author__ = "hexedpackets@gmail.com (William Huba)" import base64 import collections @@ -21,26 +21,26 @@ class ConsulFilesSource(httpsource.HttpFilesSource): """Source for data fetched via Consul.""" # Consul defaults - DATACENTER = 'dc1' - TOKEN = '' + DATACENTER = "dc1" + TOKEN = "" # for registration - name = 'consul' + name = "consul" def _SetDefaults(self, configuration): """Set defaults if necessary.""" super(ConsulFilesSource, self)._SetDefaults(configuration) - if 'token' not in configuration: - configuration['token'] = self.TOKEN - if 'datacenter' not in configuration: - configuration['datacenter'] = self.DATACENTER + if "token" not in configuration: + configuration["token"] = self.TOKEN + if "datacenter" not in configuration: + configuration["datacenter"] = self.DATACENTER - for url in ['passwd_url', 'group_url', 'shadow_url']: - configuration[url] = '{}?recurse&token={}&dc={}'.format( - configuration[url], configuration['token'], - configuration['datacenter']) + for url in ["passwd_url", "group_url", "shadow_url"]: + configuration[url] = "{}?recurse&token={}&dc={}".format( + configuration[url], configuration["token"], configuration["datacenter"] + ) def GetPasswdMap(self, since=None): """Return the passwd map from this source. @@ -52,8 +52,7 @@ def GetPasswdMap(self, since=None): Returns: instance of passwd.PasswdMap """ - return PasswdUpdateGetter().GetUpdates(self, self.conf['passwd_url'], - since) + return PasswdUpdateGetter().GetUpdates(self, self.conf["passwd_url"], since) def GetGroupMap(self, since=None): """Return the group map from this source. @@ -65,8 +64,7 @@ def GetGroupMap(self, since=None): Returns: instance of group.GroupMap """ - return GroupUpdateGetter().GetUpdates(self, self.conf['group_url'], - since) + return GroupUpdateGetter().GetUpdates(self, self.conf["group_url"], since) def GetShadowMap(self, since=None): """Return the shadow map from this source. @@ -78,8 +76,7 @@ def GetShadowMap(self, since=None): Returns: instance of shadow.ShadowMap """ - return ShadowUpdateGetter().GetUpdates(self, self.conf['shadow_url'], - since) + return ShadowUpdateGetter().GetUpdates(self, self.conf["shadow_url"], since) class PasswdUpdateGetter(httpsource.UpdateGetter): @@ -139,8 +136,8 @@ def GetMap(self, cache_info, data): entries = collections.defaultdict(dict) for line in json.loads(cache_info.read()): - key = line.get('Key', '').split('/') - value = line.get('Value', '') + key = line.get("Key", "").split("/") + value = line.get("Value", "") if not value or not key: continue value = base64.b64decode(value) @@ -152,13 +149,15 @@ def GetMap(self, cache_info, data): map_entry = self._ReadEntry(name, entry) if map_entry is None: self.log.warning( - 'Could not create entry from line %r in cache, skipping', - entry) + "Could not create entry from line %r in cache, skipping", entry + ) continue if not data.Add(map_entry): self.log.warning( - 'Could not add entry %r read from line %r in cache', - map_entry, entry) + "Could not add entry %r read from line %r in cache", + map_entry, + entry, + ) return data @@ -171,17 +170,17 @@ def _ReadEntry(self, name, entry): map_entry = passwd.PasswdMapEntry() # maps expect strict typing, so convert to int as appropriate. map_entry.name = name - map_entry.passwd = entry.get('passwd', 'x') + map_entry.passwd = entry.get("passwd", "x") try: - map_entry.uid = int(entry['uid']) - map_entry.gid = int(entry['gid']) + map_entry.uid = int(entry["uid"]) + map_entry.gid = int(entry["gid"]) except (ValueError, KeyError): return None - map_entry.gecos = entry.get('comment', '') - map_entry.dir = entry.get('home', '/home/{}'.format(name)) - map_entry.shell = entry.get('shell', '/bin/bash') + map_entry.gecos = entry.get("comment", "") + map_entry.dir = entry.get("home", "/home/{}".format(name)) + map_entry.shell = entry.get("shell", "/bin/bash") return map_entry @@ -195,17 +194,17 @@ def _ReadEntry(self, name, entry): map_entry = group.GroupMapEntry() # map entries expect strict typing, so convert as appropriate map_entry.name = name - map_entry.passwd = entry.get('passwd', 'x') + map_entry.passwd = entry.get("passwd", "x") try: - map_entry.gid = int(entry['gid']) + map_entry.gid = int(entry["gid"]) except (ValueError, KeyError): return None try: - members = entry.get('members', '').split('\n') + members = entry.get("members", "").split("\n") except (ValueError, TypeError): - members = [''] + members = [""] map_entry.members = members return map_entry @@ -219,11 +218,11 @@ def _ReadEntry(self, name, entry): map_entry = shadow.ShadowMapEntry() # maps expect strict typing, so convert to int as appropriate. map_entry.name = name - map_entry.passwd = entry.get('passwd', '*') + map_entry.passwd = entry.get("passwd", "*") if isinstance(map_entry.passwd, bytes): - map_entry.passwd = map_entry.passwd.decode('ascii') + map_entry.passwd = map_entry.passwd.decode("ascii") - for attr in ['lstchg', 'min', 'max', 'warn', 'inact', 'expire']: + for attr in ["lstchg", "min", "max", "warn", "inact", "expire"]: try: setattr(map_entry, attr, int(entry[attr])) except (ValueError, KeyError): diff --git a/nss_cache/sources/consulsource_test.py b/nss_cache/sources/consulsource_test.py index e934f77a..ebdbf5cc 100644 --- a/nss_cache/sources/consulsource_test.py +++ b/nss_cache/sources/consulsource_test.py @@ -1,6 +1,6 @@ """An implementation of a mock consul data source for nsscache.""" -__author__ = 'hexedpackets@gmail.com (William Huba)' +__author__ = "hexedpackets@gmail.com (William Huba)" import unittest from io import StringIO @@ -12,137 +12,139 @@ class TestConsulSource(unittest.TestCase): - def setUp(self): """Initialize a basic config dict.""" super(TestConsulSource, self).setUp() self.config = { - 'passwd_url': 'PASSWD_URL', - 'group_url': 'GROUP_URL', - 'datacenter': 'TEST_DATACENTER', - 'token': 'TEST_TOKEN', + "passwd_url": "PASSWD_URL", + "group_url": "GROUP_URL", + "datacenter": "TEST_DATACENTER", + "token": "TEST_TOKEN", } def testDefaultConfiguration(self): source = consulsource.ConsulFilesSource({}) - self.assertEqual(source.conf['datacenter'], - consulsource.ConsulFilesSource.DATACENTER) - self.assertEqual(source.conf['token'], - consulsource.ConsulFilesSource.TOKEN) + self.assertEqual( + source.conf["datacenter"], consulsource.ConsulFilesSource.DATACENTER + ) + self.assertEqual(source.conf["token"], consulsource.ConsulFilesSource.TOKEN) def testOverrideDefaultConfiguration(self): source = consulsource.ConsulFilesSource(self.config) - self.assertEqual(source.conf['datacenter'], 'TEST_DATACENTER') - self.assertEqual(source.conf['token'], 'TEST_TOKEN') + self.assertEqual(source.conf["datacenter"], "TEST_DATACENTER") + self.assertEqual(source.conf["token"], "TEST_TOKEN") self.assertEqual( - source.conf['passwd_url'], - 'PASSWD_URL?recurse&token=TEST_TOKEN&dc=TEST_DATACENTER') + source.conf["passwd_url"], + "PASSWD_URL?recurse&token=TEST_TOKEN&dc=TEST_DATACENTER", + ) self.assertEqual( - source.conf['group_url'], - 'GROUP_URL?recurse&token=TEST_TOKEN&dc=TEST_DATACENTER') + source.conf["group_url"], + "GROUP_URL?recurse&token=TEST_TOKEN&dc=TEST_DATACENTER", + ) class TestPasswdMapParser(unittest.TestCase): - def setUp(self): """Set some default avalible data for testing.""" self.good_entry = passwd.PasswdMapEntry() - self.good_entry.name = 'foo' - self.good_entry.passwd = 'x' + self.good_entry.name = "foo" + self.good_entry.passwd = "x" self.good_entry.uid = 10 self.good_entry.gid = 10 - self.good_entry.gecos = b'How Now Brown Cow' - self.good_entry.dir = b'/home/foo' - self.good_entry.shell = b'/bin/bash' + self.good_entry.gecos = b"How Now Brown Cow" + self.good_entry.dir = b"/home/foo" + self.good_entry.shell = b"/bin/bash" self.parser = consulsource.ConsulPasswdMapParser() def testGetMap(self): passwd_map = passwd.PasswdMap() - cache_info = StringIO('''[ + cache_info = StringIO( + """[ {"Key": "org/users/foo/uid", "Value": "MTA="}, {"Key": "org/users/foo/gid", "Value": "MTA="}, {"Key": "org/users/foo/home", "Value": "L2hvbWUvZm9v"}, {"Key": "org/users/foo/shell", "Value": "L2Jpbi9iYXNo"}, {"Key": "org/users/foo/comment", "Value": "SG93IE5vdyBCcm93biBDb3c="}, {"Key": "org/users/foo/subkey/irrelevant_key", "Value": "YmFjb24="} - ]''') + ]""" + ) self.parser.GetMap(cache_info, passwd_map) self.assertEqual(self.good_entry, passwd_map.PopItem()) def testReadEntry(self): data = { - 'uid': '10', - 'gid': '10', - 'comment': b'How Now Brown Cow', - 'shell': b'/bin/bash', - 'home': b'/home/foo', - 'passwd': 'x' + "uid": "10", + "gid": "10", + "comment": b"How Now Brown Cow", + "shell": b"/bin/bash", + "home": b"/home/foo", + "passwd": "x", } - entry = self.parser._ReadEntry('foo', data) + entry = self.parser._ReadEntry("foo", data) self.assertEqual(self.good_entry, entry) def testDefaultEntryValues(self): - data = {'uid': '10', 'gid': '10'} - entry = self.parser._ReadEntry('foo', data) - self.assertEqual(entry.shell, '/bin/bash') - self.assertEqual(entry.dir, '/home/foo') - self.assertEqual(entry.gecos, '') - self.assertEqual(entry.passwd, 'x') + data = {"uid": "10", "gid": "10"} + entry = self.parser._ReadEntry("foo", data) + self.assertEqual(entry.shell, "/bin/bash") + self.assertEqual(entry.dir, "/home/foo") + self.assertEqual(entry.gecos, "") + self.assertEqual(entry.passwd, "x") def testInvalidEntry(self): - data = {'irrelevant_key': 'bacon'} - entry = self.parser._ReadEntry('foo', data) + data = {"irrelevant_key": "bacon"} + entry = self.parser._ReadEntry("foo", data) self.assertEqual(entry, None) class TestConsulGroupMapParser(unittest.TestCase): - def setUp(self): self.good_entry = group.GroupMapEntry() - self.good_entry.name = 'foo' - self.good_entry.passwd = 'x' + self.good_entry.name = "foo" + self.good_entry.passwd = "x" self.good_entry.gid = 10 - self.good_entry.members = ['foo', 'bar'] + self.good_entry.members = ["foo", "bar"] self.parser = consulsource.ConsulGroupMapParser() - @unittest.skip('broken') + @unittest.skip("broken") def testGetMap(self): group_map = group.GroupMap() - cache_info = StringIO('''[ + cache_info = StringIO( + """[ {"Key": "org/groups/foo/gid", "Value": "MTA="}, {"Key": "org/groups/foo/members", "Value": "Zm9vCmJhcg=="}, {"Key": "org/groups/foo/subkey/irrelevant_key", "Value": "YmFjb24="} - ]''') + ]""" + ) self.parser.GetMap(cache_info, group_map) self.assertEqual(self.good_entry, group_map.PopItem()) def testReadEntry(self): - data = {'passwd': 'x', 'gid': '10', 'members': 'foo\nbar'} - entry = self.parser._ReadEntry('foo', data) + data = {"passwd": "x", "gid": "10", "members": "foo\nbar"} + entry = self.parser._ReadEntry("foo", data) self.assertEqual(self.good_entry, entry) def testDefaultPasswd(self): - data = {'gid': '10', 'members': 'foo\nbar'} - entry = self.parser._ReadEntry('foo', data) + data = {"gid": "10", "members": "foo\nbar"} + entry = self.parser._ReadEntry("foo", data) self.assertEqual(self.good_entry, entry) def testNoMembers(self): - data = {'gid': '10', 'members': ''} - entry = self.parser._ReadEntry('foo', data) - self.assertEqual(entry.members, ['']) + data = {"gid": "10", "members": ""} + entry = self.parser._ReadEntry("foo", data) + self.assertEqual(entry.members, [""]) def testInvalidEntry(self): - data = {'irrelevant_key': 'bacon'} - entry = self.parser._ReadEntry('foo', data) + data = {"irrelevant_key": "bacon"} + entry = self.parser._ReadEntry("foo", data) self.assertEqual(entry, None) class TestConsulShadowMapParser(unittest.TestCase): - def setUp(self): self.good_entry = shadow.ShadowMapEntry() - self.good_entry.name = 'foo' - self.good_entry.passwd = '*' + self.good_entry.name = "foo" + self.good_entry.passwd = "*" self.good_entry.lstchg = 17246 self.good_entry.min = 0 self.good_entry.max = 99999 @@ -151,32 +153,28 @@ def setUp(self): def testGetMap(self): shadow_map = shadow.ShadowMap() - cache_info = StringIO('''[ + cache_info = StringIO( + """[ {"Key": "org/groups/foo/passwd", "Value": "Kg=="}, {"Key": "org/groups/foo/lstchg", "Value": "MTcyNDY="}, {"Key": "org/groups/foo/min", "Value": "MA=="}, {"Key": "org/groups/foo/max", "Value": "OTk5OTk="}, {"Key": "org/groups/foo/warn", "Value": "Nw=="} - ]''') + ]""" + ) self.parser.GetMap(cache_info, shadow_map) self.assertEqual(self.good_entry, shadow_map.PopItem()) def testReadEntry(self): - data = { - 'passwd': '*', - 'lstchg': 17246, - 'min': 0, - 'max': 99999, - 'warn': 7 - } - entry = self.parser._ReadEntry('foo', data) + data = {"passwd": "*", "lstchg": 17246, "min": 0, "max": 99999, "warn": 7} + entry = self.parser._ReadEntry("foo", data) self.assertEqual(self.good_entry, entry) def testDefaultPasswd(self): - data = {'lstchg': 17246, 'min': 0, 'max': 99999, 'warn': 7} - entry = self.parser._ReadEntry('foo', data) + data = {"lstchg": 17246, "min": 0, "max": 99999, "warn": 7} + entry = self.parser._ReadEntry("foo", data) self.assertEqual(self.good_entry, entry) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/nss_cache/sources/gcssource.py b/nss_cache/sources/gcssource.py index be5621ca..41e80aa1 100644 --- a/nss_cache/sources/gcssource.py +++ b/nss_cache/sources/gcssource.py @@ -14,7 +14,8 @@ from nss_cache.util import timestamps warnings.filterwarnings( - "ignore", "Your application has authenticated using end user credentials") + "ignore", "Your application has authenticated using end user credentials" +) def RegisterImplementation(registration_callback): @@ -25,13 +26,13 @@ class GcsFilesSource(source.Source): """Source for data fetched from GCS.""" # GCS Defaults - BUCKET = '' - PASSWD_OBJECT = '' - GROUP_OBJECT = '' - SHADOW_OBJECT = '' + BUCKET = "" + PASSWD_OBJECT = "" + GROUP_OBJECT = "" + SHADOW_OBJECT = "" # for registration - name = 'gcs' + name = "gcs" def __init__(self, conf): """Initialize the GcsFilesSource object. @@ -54,14 +55,14 @@ def _GetClient(self): def _SetDefaults(self, configuration): """Set defaults if necessary.""" - if 'bucket' not in configuration: - configuration['bucket'] = self.BUCKET - if 'passwd_object' not in configuration: - configuration['passwd_object'] = self.PASSWD_OBJECT - if 'group_object' not in configuration: - configuration['group_object'] = self.GROUP_OBJECT - if 'shadow_object' not in configuration: - configuration['shadow_object'] = self.SHADOW_OBJECT + if "bucket" not in configuration: + configuration["bucket"] = self.BUCKET + if "passwd_object" not in configuration: + configuration["passwd_object"] = self.PASSWD_OBJECT + if "group_object" not in configuration: + configuration["group_object"] = self.GROUP_OBJECT + if "shadow_object" not in configuration: + configuration["shadow_object"] = self.SHADOW_OBJECT def GetPasswdMap(self, since=None): """Return the passwd map from this source. @@ -73,10 +74,9 @@ def GetPasswdMap(self, since=None): Returns: instance of passwd.PasswdMap """ - return PasswdUpdateGetter().GetUpdates(self._GetClient(), - self.conf['bucket'], - self.conf['passwd_object'], - since) + return PasswdUpdateGetter().GetUpdates( + self._GetClient(), self.conf["bucket"], self.conf["passwd_object"], since + ) def GetGroupMap(self, since=None): """Return the group map from this source. @@ -88,9 +88,9 @@ def GetGroupMap(self, since=None): Returns: instance of group.GroupMap """ - return GroupUpdateGetter().GetUpdates(self._GetClient(), - self.conf['bucket'], - self.conf['group_object'], since) + return GroupUpdateGetter().GetUpdates( + self._GetClient(), self.conf["bucket"], self.conf["group_object"], since + ) def GetShadowMap(self, since=None): """Return the shadow map from this source. @@ -102,10 +102,9 @@ def GetShadowMap(self, since=None): Returns: instance of shadow.ShadowMap """ - return ShadowUpdateGetter().GetUpdates(self._GetClient(), - self.conf['bucket'], - self.conf['shadow_object'], - since) + return ShadowUpdateGetter().GetUpdates( + self._GetClient(), self.conf["bucket"], self.conf["shadow_object"], since + ) class GcsUpdateGetter(object): @@ -117,47 +116,46 @@ def __init__(self): def GetUpdates(self, gcs_client, bucket_name, obj, since): """Gets updates from a source. - Args: - gcs_client: initialized gcs client - bucket_name: gcs bucket name - obj: object with the data - since: a timestamp representing the last change (None to force-get) + Args: + gcs_client: initialized gcs client + bucket_name: gcs bucket name + obj: object with the data + since: a timestamp representing the last change (None to force-get) - Returns: - A tuple containing the map of updates and a maximum timestamp - """ + Returns: + A tuple containing the map of updates and a maximum timestamp + """ bucket = gcs_client.bucket(bucket_name) blob = bucket.get_blob(obj) # get_blob captures NotFound error and returns None: if blob is None: - self.log.error('GCS object gs://%s/%s not found', bucket_name, obj) - raise error.SourceUnavailable('unable to download object from GCS.') + self.log.error("GCS object gs://%s/%s not found", bucket_name, obj) + raise error.SourceUnavailable("unable to download object from GCS.") # GCS doesn't return HTTP 304 like HTTP or S3 sources, # so return if updated timestamp is before 'since': if since and timestamps.FromDateTimeToTimestamp(blob.updated) < since: return [] data_map = self.GetMap(cache_info=blob.open()) - data_map.SetModifyTimestamp( - timestamps.FromDateTimeToTimestamp(blob.updated)) + data_map.SetModifyTimestamp(timestamps.FromDateTimeToTimestamp(blob.updated)) return data_map def GetParser(self): """Return the approriate parser. - Must be implemented by child class. - """ + Must be implemented by child class. + """ raise NotImplementedError def GetMap(self, cache_info): """Creates a Map from the cache_info data. - Args: - cache_info: file-like object containing the data to parse + Args: + cache_info: file-like object containing the data to parse - Returns: - A child of Map containing the cache data. - """ + Returns: + A child of Map containing the cache data. + """ return self.GetParser().GetMap(cache_info, self.CreateMap()) diff --git a/nss_cache/sources/gcssource_test.py b/nss_cache/sources/gcssource_test.py index 02449ce0..7401bf50 100644 --- a/nss_cache/sources/gcssource_test.py +++ b/nss_cache/sources/gcssource_test.py @@ -15,51 +15,51 @@ class TestGcsSource(unittest.TestCase): - def setUp(self): super(TestGcsSource, self).setUp() self.config = { - 'passwd_object': 'PASSWD_OBJ', - 'group_object': 'GROUP_OBJ', - 'bucket': 'TEST_BUCKET', + "passwd_object": "PASSWD_OBJ", + "group_object": "GROUP_OBJ", + "bucket": "TEST_BUCKET", } def testDefaultConfiguration(self): source = gcssource.GcsFilesSource({}) - self.assertEqual(source.conf['bucket'], gcssource.GcsFilesSource.BUCKET) - self.assertEqual(source.conf['passwd_object'], - gcssource.GcsFilesSource.PASSWD_OBJECT) + self.assertEqual(source.conf["bucket"], gcssource.GcsFilesSource.BUCKET) + self.assertEqual( + source.conf["passwd_object"], gcssource.GcsFilesSource.PASSWD_OBJECT + ) def testOverrideDefaultConfiguration(self): source = gcssource.GcsFilesSource(self.config) - self.assertEqual(source.conf['bucket'], 'TEST_BUCKET') - self.assertEqual(source.conf['passwd_object'], 'PASSWD_OBJ') - self.assertEqual(source.conf['group_object'], 'GROUP_OBJ') + self.assertEqual(source.conf["bucket"], "TEST_BUCKET") + self.assertEqual(source.conf["passwd_object"], "PASSWD_OBJ") + self.assertEqual(source.conf["group_object"], "GROUP_OBJ") class TestPasswdUpdateGetter(unittest.TestCase): - def setUp(self): super(TestPasswdUpdateGetter, self).setUp() self.updater = gcssource.PasswdUpdateGetter() def testGetParser(self): - self.assertIsInstance(self.updater.GetParser(), - file_formats.FilesPasswdMapParser) + self.assertIsInstance( + self.updater.GetParser(), file_formats.FilesPasswdMapParser + ) def testCreateMap(self): self.assertIsInstance(self.updater.CreateMap(), passwd.PasswdMap) class TestShadowUpdateGetter(mox.MoxTestBase, unittest.TestCase): - def setUp(self): super(TestShadowUpdateGetter, self).setUp() self.updater = gcssource.ShadowUpdateGetter() def testGetParser(self): - self.assertIsInstance(self.updater.GetParser(), - file_formats.FilesShadowMapParser) + self.assertIsInstance( + self.updater.GetParser(), file_formats.FilesShadowMapParser + ) def testCreateMap(self): self.assertIsInstance(self.updater.CreateMap(), shadow.ShadowMap) @@ -67,21 +67,23 @@ def testCreateMap(self): def testShadowGetUpdatesWithContent(self): mock_blob = self.mox.CreateMockAnything() mock_blob.open().AndReturn( - io.StringIO("""usera:x::::::: + io.StringIO( + """usera:x::::::: userb:x::::::: -""")) +""" + ) + ) mock_blob.updated = datetime.datetime.now() mock_bucket = self.mox.CreateMockAnything() - mock_bucket.get_blob('passwd').AndReturn(mock_blob) + mock_bucket.get_blob("passwd").AndReturn(mock_blob) mock_client = self.mox.CreateMockAnything() - mock_client.bucket('test-bucket').AndReturn(mock_bucket) + mock_client.bucket("test-bucket").AndReturn(mock_bucket) self.mox.ReplayAll() - result = self.updater.GetUpdates(mock_client, 'test-bucket', 'passwd', - None) + result = self.updater.GetUpdates(mock_client, "test-bucket", "passwd", None) self.assertEqual(len(result), 2) def testShadowGetUpdatesSinceAfterUpdatedTime(self): @@ -89,33 +91,37 @@ def testShadowGetUpdatesSinceAfterUpdatedTime(self): mock_blob.updated = datetime.datetime.now() mock_bucket = self.mox.CreateMockAnything() - mock_bucket.get_blob('passwd').AndReturn(mock_blob) + mock_bucket.get_blob("passwd").AndReturn(mock_blob) mock_client = self.mox.CreateMockAnything() - mock_client.bucket('test-bucket').AndReturn(mock_bucket) + mock_client.bucket("test-bucket").AndReturn(mock_bucket) self.mox.ReplayAll() result = self.updater.GetUpdates( - mock_client, 'test-bucket', 'passwd', - timestamps.FromDateTimeToTimestamp(mock_blob.updated + - datetime.timedelta(days=1))) + mock_client, + "test-bucket", + "passwd", + timestamps.FromDateTimeToTimestamp( + mock_blob.updated + datetime.timedelta(days=1) + ), + ) self.assertEqual(len(result), 0) class TestGroupUpdateGetter(unittest.TestCase): - def setUp(self): super(TestGroupUpdateGetter, self).setUp() self.updater = gcssource.GroupUpdateGetter() def testGetParser(self): - self.assertIsInstance(self.updater.GetParser(), - file_formats.FilesGroupMapParser) + self.assertIsInstance( + self.updater.GetParser(), file_formats.FilesGroupMapParser + ) def testCreateMap(self): self.assertIsInstance(self.updater.CreateMap(), group.GroupMap) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/nss_cache/sources/httpsource.py b/nss_cache/sources/httpsource.py index 186587f9..c7be405a 100644 --- a/nss_cache/sources/httpsource.py +++ b/nss_cache/sources/httpsource.py @@ -15,7 +15,7 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """An implementation of an http data source for nsscache.""" -__author__ = ('blaedd@google.com (David MacKinnon',) +__author__ = ("blaedd@google.com (David MacKinnon",) import bz2 import calendar @@ -44,19 +44,20 @@ def RegisterImplementation(registration_callback): class HttpFilesSource(source.Source): """Source for data fetched via HTTP.""" + # HTTP defaults - PASSWD_URL = '' - SHADOW_URL = '' - GROUP_URL = '' - AUTOMOUNT_BASE_URL = '' - NETGROUP_URL = '' - SSHKEY_URL = '' + PASSWD_URL = "" + SHADOW_URL = "" + GROUP_URL = "" + AUTOMOUNT_BASE_URL = "" + NETGROUP_URL = "" + SSHKEY_URL = "" RETRY_DELAY = 5 RETRY_MAX = 3 - TLS_CACERTFILE = '/etc/ssl/certs/ca-certificates.crt' + TLS_CACERTFILE = "/etc/ssl/certs/ca-certificates.crt" # for registration - name = 'http' + name = "http" def __init__(self, conf, conn=None): """Initialise the HTTP Data Source. @@ -73,34 +74,34 @@ def __init__(self, conf, conn=None): conn.setopt(pycurl.NOSIGNAL, 1) # Don't hang on to connections from broken servers indefinitely. conn.setopt(pycurl.TIMEOUT, 60) - conn.setopt(pycurl.USERAGENT, 'nsscache') - if self.conf['http_proxy']: - conn.setopt(pycurl.PROXY, self.conf['http_proxy']) + conn.setopt(pycurl.USERAGENT, "nsscache") + if self.conf["http_proxy"]: + conn.setopt(pycurl.PROXY, self.conf["http_proxy"]) self.conn = conn def _SetDefaults(self, configuration): """Set defaults if necessary.""" - if 'automount_base_url' not in configuration: - configuration['automount_base_url'] = self.AUTOMOUNT_BASE_URL - if 'passwd_url' not in configuration: - configuration['passwd_url'] = self.PASSWD_URL - if 'shadow_url' not in configuration: - configuration['shadow_url'] = self.SHADOW_URL - if 'group_url' not in configuration: - configuration['group_url'] = self.GROUP_URL - if 'netgroup_url' not in configuration: - configuration['netgroup_url'] = self.GROUP_URL - if 'sshkey_url' not in configuration: - configuration['sshkey_url'] = self.SSHKEY_URL - if 'retry_delay' not in configuration: - configuration['retry_delay'] = self.RETRY_DELAY - if 'retry_max' not in configuration: - configuration['retry_max'] = self.RETRY_MAX - if 'tls_cacertfile' not in configuration: - configuration['tls_cacertfile'] = self.TLS_CACERTFILE - if 'http_proxy' not in configuration: - configuration['http_proxy'] = None + if "automount_base_url" not in configuration: + configuration["automount_base_url"] = self.AUTOMOUNT_BASE_URL + if "passwd_url" not in configuration: + configuration["passwd_url"] = self.PASSWD_URL + if "shadow_url" not in configuration: + configuration["shadow_url"] = self.SHADOW_URL + if "group_url" not in configuration: + configuration["group_url"] = self.GROUP_URL + if "netgroup_url" not in configuration: + configuration["netgroup_url"] = self.GROUP_URL + if "sshkey_url" not in configuration: + configuration["sshkey_url"] = self.SSHKEY_URL + if "retry_delay" not in configuration: + configuration["retry_delay"] = self.RETRY_DELAY + if "retry_max" not in configuration: + configuration["retry_max"] = self.RETRY_MAX + if "tls_cacertfile" not in configuration: + configuration["tls_cacertfile"] = self.TLS_CACERTFILE + if "http_proxy" not in configuration: + configuration["http_proxy"] = None def GetPasswdMap(self, since=None): """Return the passwd map from this source. @@ -112,8 +113,7 @@ def GetPasswdMap(self, since=None): Returns: instance of passwd.PasswdMap """ - return PasswdUpdateGetter().GetUpdates(self, self.conf['passwd_url'], - since) + return PasswdUpdateGetter().GetUpdates(self, self.conf["passwd_url"], since) def GetShadowMap(self, since=None): """Return the shadow map from this source. @@ -125,8 +125,7 @@ def GetShadowMap(self, since=None): Returns: instance of shadow.ShadowMap """ - return ShadowUpdateGetter().GetUpdates(self, self.conf['shadow_url'], - since) + return ShadowUpdateGetter().GetUpdates(self, self.conf["shadow_url"], since) def GetGroupMap(self, since=None): """Return the group map from this source. @@ -138,8 +137,7 @@ def GetGroupMap(self, since=None): Returns: instance of group.GroupMap """ - return GroupUpdateGetter().GetUpdates(self, self.conf['group_url'], - since) + return GroupUpdateGetter().GetUpdates(self, self.conf["group_url"], since) def GetNetgroupMap(self, since=None): """Return the netgroup map from this source. @@ -151,9 +149,7 @@ def GetNetgroupMap(self, since=None): Returns: instance of netgroup.NetgroupMap """ - return NetgroupUpdateGetter().GetUpdates(self, - self.conf['netgroup_url'], - since) + return NetgroupUpdateGetter().GetUpdates(self, self.conf["netgroup_url"], since) def GetAutomountMap(self, since=None, location=None): """Return an automount map from this source. @@ -175,10 +171,9 @@ def GetAutomountMap(self, since=None, location=None): EmptyMap: """ if location is None: - self.log.error( - 'A location is required to retrieve an automount map!') + self.log.error("A location is required to retrieve an automount map!") raise error.EmptyMap - automount_url = urljoin(self.conf['automount_base_url'], location) + automount_url = urljoin(self.conf["automount_base_url"], location) return AutomountUpdateGetter().GetUpdates(self, automount_url, since) def GetAutomountMasterMap(self): @@ -187,10 +182,10 @@ def GetAutomountMasterMap(self): Returns: an instance of automount.AutomountMap """ - master_map = self.GetAutomountMap(location='auto.master') + master_map = self.GetAutomountMap(location="auto.master") for map_entry in master_map: map_entry.location = os.path.split(map_entry.location)[1] - self.log.debug('master map has: %s' % map_entry.location) + self.log.debug("master map has: %s" % map_entry.location) return master_map def GetSshkeyMap(self, since=None): @@ -203,8 +198,7 @@ def GetSshkeyMap(self, since=None): Returns: instance of sshkey.SshkeyMap """ - return SshkeyUpdateGetter().GetUpdates(self, self.conf['sshkey_url'], - since) + return SshkeyUpdateGetter().GetUpdates(self, self.conf["sshkey_url"], since) class UpdateGetter(object): @@ -222,7 +216,7 @@ def FromTimestampToHttp(self, ts): HTTP format timestamp string """ ts = time.gmtime(ts) - return time.strftime('%a, %d %b %Y %H:%M:%S GMT', ts) + return time.strftime("%a, %d %b %Y %H:%M:%S GMT", ts) def FromHttpToTimestamp(self, http_ts_string): """Converts HTTP timestamp string to internal nss_cache timestamp. @@ -232,7 +226,7 @@ def FromHttpToTimestamp(self, http_ts_string): Returns: number of seconds since epoch """ - t = time.strptime(http_ts_string, '%a, %d %b %Y %H:%M:%S GMT') + t = time.strptime(http_ts_string, "%a, %d %b %Y %H:%M:%S GMT") return int(calendar.timegm(t)) def GetUpdates(self, source, url, since): @@ -250,27 +244,26 @@ def GetUpdates(self, source, url, since): ValueError: an object in the source map is malformed ConfigurationError: """ - proto = url.split(':')[0] + proto = url.split(":")[0] # Newer libcurl allow you to disable protocols there. Unfortunately # it's not in dapper or hardy. - if proto not in ('http', 'https'): - raise error.ConfigurationError('Unsupported protocol %s' % proto) + if proto not in ("http", "https"): + raise error.ConfigurationError("Unsupported protocol %s" % proto) conn = source.conn conn.setopt(pycurl.OPT_FILETIME, 1) - conn.setopt(pycurl.ENCODING, 'bzip2, gzip') + conn.setopt(pycurl.ENCODING, "bzip2, gzip") if since is not None: conn.setopt(pycurl.TIMEVALUE, int(since)) conn.setopt(pycurl.TIMECONDITION, pycurl.TIMECONDITION_IFMODSINCE) retry_count = 0 resp_code = 500 - while retry_count < source.conf['retry_max']: + while retry_count < source.conf["retry_max"]: try: - source.log.debug('fetching %s', url) - (resp_code, headers, - body_bytes) = curl.CurlFetch(url, conn, self.log) - self.log.debug('response code: %s', resp_code) + source.log.debug("fetching %s", url) + (resp_code, headers, body_bytes) = curl.CurlFetch(url, conn, self.log) + self.log.debug("response code: %s", resp_code) finally: if resp_code < 400: # Not modified-since @@ -279,43 +272,43 @@ def GetUpdates(self, source, url, since): if resp_code == 200: break retry_count += 1 - self.log.warning('Failed connection: attempt #%s.', retry_count) - if retry_count == source.conf['retry_max']: - self.log.debug('max retries hit') - raise error.SourceUnavailable('Max retries exceeded.') - time.sleep(source.conf['retry_delay']) + self.log.warning("Failed connection: attempt #%s.", retry_count) + if retry_count == source.conf["retry_max"]: + self.log.debug("max retries hit") + raise error.SourceUnavailable("Max retries exceeded.") + time.sleep(source.conf["retry_delay"]) - headers = headers.split('\r\n') + headers = headers.split("\r\n") last_modified = conn.getinfo(pycurl.INFO_FILETIME) - self.log.debug('last modified: %s', last_modified) + self.log.debug("last modified: %s", last_modified) if last_modified == -1: for header in headers: - if header.lower().startswith('last-modified'): - self.log.debug('%s', header) - http_ts_string = header[header.find(':') + 1:].strip() + if header.lower().startswith("last-modified"): + self.log.debug("%s", header) + http_ts_string = header[header.find(":") + 1 :].strip() last_modified = self.FromHttpToTimestamp(http_ts_string) break else: - http_ts_string = '' + http_ts_string = "" else: http_ts_string = self.FromTimestampToHttp(last_modified) - self.log.debug('Last-modified is: %s', http_ts_string) + self.log.debug("Last-modified is: %s", http_ts_string) # curl (on Ubuntu hardy at least) will handle gzip, but not bzip2 try: body_bytes = bz2.decompress(body_bytes) - self.log.debug('bzip encoding found') + self.log.debug("bzip encoding found") except IOError: - self.log.debug('bzip encoding not found') + self.log.debug("bzip encoding not found") # Wrap in a stringIO so that it can be looped on by newlines in the parser - response = StringIO(body_bytes.decode('utf-8')) + response = StringIO(body_bytes.decode("utf-8")) data_map = self.GetMap(cache_info=response) if http_ts_string: http_ts = self.FromHttpToTimestamp(http_ts_string) - self.log.debug('setting last modified to: %s', http_ts) + self.log.debug("setting last modified to: %s", http_ts) data_map.SetModifyTimestamp(http_ts) return data_map diff --git a/nss_cache/sources/httpsource_test.py b/nss_cache/sources/httpsource_test.py index e1f81e62..95e86986 100644 --- a/nss_cache/sources/httpsource_test.py +++ b/nss_cache/sources/httpsource_test.py @@ -15,7 +15,7 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """An implementation of a mock http data source for nsscache.""" -__author__ = 'blaedd@google.com (David MacKinnon)' +__author__ = "blaedd@google.com (David MacKinnon)" import base64 import time @@ -38,64 +38,67 @@ class TestHttpSource(unittest.TestCase): - def setUp(self): """Initialize a basic config dict.""" super(TestHttpSource, self).setUp() self.config = { - 'passwd_url': 'PASSWD_URL', - 'shadow_url': 'SHADOW_URL', - 'group_url': 'GROUP_URL', - 'sshkey_url': 'SSHKEY_URL', - 'retry_delay': 'TEST_RETRY_DELAY', - 'retry_max': 'TEST_RETRY_MAX', - 'tls_cacertfile': 'TEST_TLS_CACERTFILE', - 'http_proxy': 'HTTP_PROXY', + "passwd_url": "PASSWD_URL", + "shadow_url": "SHADOW_URL", + "group_url": "GROUP_URL", + "sshkey_url": "SSHKEY_URL", + "retry_delay": "TEST_RETRY_DELAY", + "retry_max": "TEST_RETRY_MAX", + "tls_cacertfile": "TEST_TLS_CACERTFILE", + "http_proxy": "HTTP_PROXY", } def testDefaultConfiguration(self): source = httpsource.HttpFilesSource({}) - self.assertEqual(source.conf['passwd_url'], - httpsource.HttpFilesSource.PASSWD_URL) - self.assertEqual(source.conf['shadow_url'], - httpsource.HttpFilesSource.SHADOW_URL) - self.assertEqual(source.conf['group_url'], - httpsource.HttpFilesSource.GROUP_URL) - self.assertEqual(source.conf['sshkey_url'], - httpsource.HttpFilesSource.SSHKEY_URL) - self.assertEqual(source.conf['retry_max'], - httpsource.HttpFilesSource.RETRY_MAX) - self.assertEqual(source.conf['retry_delay'], - httpsource.HttpFilesSource.RETRY_DELAY) - self.assertEqual(source.conf['tls_cacertfile'], - httpsource.HttpFilesSource.TLS_CACERTFILE) - self.assertEqual(source.conf['http_proxy'], None) + self.assertEqual( + source.conf["passwd_url"], httpsource.HttpFilesSource.PASSWD_URL + ) + self.assertEqual( + source.conf["shadow_url"], httpsource.HttpFilesSource.SHADOW_URL + ) + self.assertEqual(source.conf["group_url"], httpsource.HttpFilesSource.GROUP_URL) + self.assertEqual( + source.conf["sshkey_url"], httpsource.HttpFilesSource.SSHKEY_URL + ) + self.assertEqual(source.conf["retry_max"], httpsource.HttpFilesSource.RETRY_MAX) + self.assertEqual( + source.conf["retry_delay"], httpsource.HttpFilesSource.RETRY_DELAY + ) + self.assertEqual( + source.conf["tls_cacertfile"], httpsource.HttpFilesSource.TLS_CACERTFILE + ) + self.assertEqual(source.conf["http_proxy"], None) def testOverrideDefaultConfiguration(self): source = httpsource.HttpFilesSource(self.config) - self.assertEqual(source.conf['passwd_url'], 'PASSWD_URL') - self.assertEqual(source.conf['group_url'], 'GROUP_URL') - self.assertEqual(source.conf['shadow_url'], 'SHADOW_URL') - self.assertEqual(source.conf['sshkey_url'], 'SSHKEY_URL') - self.assertEqual(source.conf['retry_delay'], 'TEST_RETRY_DELAY') - self.assertEqual(source.conf['retry_max'], 'TEST_RETRY_MAX') - self.assertEqual(source.conf['tls_cacertfile'], 'TEST_TLS_CACERTFILE') - self.assertEqual(source.conf['http_proxy'], 'HTTP_PROXY') + self.assertEqual(source.conf["passwd_url"], "PASSWD_URL") + self.assertEqual(source.conf["group_url"], "GROUP_URL") + self.assertEqual(source.conf["shadow_url"], "SHADOW_URL") + self.assertEqual(source.conf["sshkey_url"], "SSHKEY_URL") + self.assertEqual(source.conf["retry_delay"], "TEST_RETRY_DELAY") + self.assertEqual(source.conf["retry_max"], "TEST_RETRY_MAX") + self.assertEqual(source.conf["tls_cacertfile"], "TEST_TLS_CACERTFILE") + self.assertEqual(source.conf["http_proxy"], "HTTP_PROXY") class TestHttpUpdateGetter(mox.MoxTestBase): - def testFromTimestampToHttp(self): ts = 1259641025 - expected_http_ts = 'Tue, 01 Dec 2009 04:17:05 GMT' - self.assertEqual(expected_http_ts, - httpsource.UpdateGetter().FromTimestampToHttp(ts)) + expected_http_ts = "Tue, 01 Dec 2009 04:17:05 GMT" + self.assertEqual( + expected_http_ts, httpsource.UpdateGetter().FromTimestampToHttp(ts) + ) def testFromHttpToTimestamp(self): expected_ts = 1259641025 - http_ts = 'Tue, 01 Dec 2009 04:17:05 GMT' - self.assertEqual(expected_ts, - httpsource.UpdateGetter().FromHttpToTimestamp(http_ts)) + http_ts = "Tue, 01 Dec 2009 04:17:05 GMT" + self.assertEqual( + expected_ts, httpsource.UpdateGetter().FromHttpToTimestamp(http_ts) + ) def testAcceptHttpProtocol(self): mock_conn = self.mox.CreateMockAnything() @@ -104,14 +107,13 @@ def testAcceptHttpProtocol(self): # We use code 304 since it basically shortcuts to the end of the method. mock_conn.getinfo(pycurl.RESPONSE_CODE).AndReturn(304) - self.mox.StubOutWithMock(pycurl, 'Curl') + self.mox.StubOutWithMock(pycurl, "Curl") pycurl.Curl().AndReturn(mock_conn) self.mox.ReplayAll() config = {} source = httpsource.HttpFilesSource(config) - result = httpsource.UpdateGetter().GetUpdates(source, 'http://TEST_URL', - None) + result = httpsource.UpdateGetter().GetUpdates(source, "http://TEST_URL", None) self.assertEqual([], result) def testAcceptHttpsProtocol(self): @@ -121,14 +123,13 @@ def testAcceptHttpsProtocol(self): # We use code 304 since it basically shortcuts to the end of the method. mock_conn.getinfo(pycurl.RESPONSE_CODE).AndReturn(304) - self.mox.StubOutWithMock(pycurl, 'Curl') + self.mox.StubOutWithMock(pycurl, "Curl") pycurl.Curl().AndReturn(mock_conn) self.mox.ReplayAll() config = {} source = httpsource.HttpFilesSource(config) - result = httpsource.UpdateGetter().GetUpdates(source, - 'https://TEST_URL', None) + result = httpsource.UpdateGetter().GetUpdates(source, "https://TEST_URL", None) self.assertEqual([], result) def testRaiseConfigurationErrorOnUnsupportedProtocol(self): @@ -136,14 +137,18 @@ def testRaiseConfigurationErrorOnUnsupportedProtocol(self): mock_conn = self.mox.CreateMockAnything() mock_conn.setopt(mox.IgnoreArg(), mox.IgnoreArg()).MultipleTimes() - self.mox.StubOutWithMock(pycurl, 'Curl') + self.mox.StubOutWithMock(pycurl, "Curl") pycurl.Curl().AndReturn(mock_conn) self.mox.ReplayAll() source = httpsource.HttpFilesSource({}) - self.assertRaises(error.ConfigurationError, - httpsource.UpdateGetter().GetUpdates, source, - 'ftp://test_url', None) + self.assertRaises( + error.ConfigurationError, + httpsource.UpdateGetter().GetUpdates, + source, + "ftp://test_url", + None, + ) def testNoUpdatesForTemporaryFailure(self): mock_conn = self.mox.CreateMockAnything() @@ -151,14 +156,13 @@ def testNoUpdatesForTemporaryFailure(self): mock_conn.perform() mock_conn.getinfo(pycurl.RESPONSE_CODE).AndReturn(304) - self.mox.StubOutWithMock(pycurl, 'Curl') + self.mox.StubOutWithMock(pycurl, "Curl") pycurl.Curl().AndReturn(mock_conn) self.mox.ReplayAll() config = {} source = httpsource.HttpFilesSource(config) - result = httpsource.UpdateGetter().GetUpdates(source, - 'https://TEST_URL', 37) + result = httpsource.UpdateGetter().GetUpdates(source, "https://TEST_URL", 37) self.assertEqual(result, []) def testGetUpdatesIfTimestampNotMatch(self): @@ -170,20 +174,20 @@ def testGetUpdatesIfTimestampNotMatch(self): mock_conn.getinfo(pycurl.RESPONSE_CODE).AndReturn(200) mock_conn.getinfo(pycurl.INFO_FILETIME).AndReturn(ts) - self.mox.StubOutWithMock(pycurl, 'Curl') + self.mox.StubOutWithMock(pycurl, "Curl") pycurl.Curl().AndReturn(mock_conn) mock_map = self.mox.CreateMockAnything() mock_map.SetModifyTimestamp(ts) getter = httpsource.UpdateGetter() - self.mox.StubOutWithMock(getter, 'GetMap') + self.mox.StubOutWithMock(getter, "GetMap") getter.GetMap(cache_info=mox.IgnoreArg()).AndReturn(mock_map) self.mox.ReplayAll() config = {} source = httpsource.HttpFilesSource(config) - result = getter.GetUpdates(source, 'https://TEST_URL', 1) + result = getter.GetUpdates(source, "https://TEST_URL", 1) self.assertEqual(mock_map, result) def testGetUpdatesWithoutTimestamp(self): @@ -193,47 +197,48 @@ def testGetUpdatesWithoutTimestamp(self): mock_conn.getinfo(pycurl.RESPONSE_CODE).AndReturn(200) mock_conn.getinfo(pycurl.INFO_FILETIME).AndReturn(-1) - self.mox.StubOutWithMock(pycurl, 'Curl') + self.mox.StubOutWithMock(pycurl, "Curl") pycurl.Curl().AndReturn(mock_conn) mock_map = self.mox.CreateMockAnything() getter = httpsource.UpdateGetter() - self.mox.StubOutWithMock(getter, 'GetMap') + self.mox.StubOutWithMock(getter, "GetMap") getter.GetMap(cache_info=mox.IgnoreArg()).AndReturn(mock_map) self.mox.ReplayAll() config = {} source = httpsource.HttpFilesSource(config) - result = getter.GetUpdates(source, 'https://TEST_URL', 1) + result = getter.GetUpdates(source, "https://TEST_URL", 1) self.assertEqual(mock_map, result) def testRetryOnErrorCodeResponse(self): - config = {'retry_delay': 5, 'retry_max': 3} + config = {"retry_delay": 5, "retry_max": 3} mock_conn = self.mox.CreateMockAnything() mock_conn.setopt(mox.IgnoreArg(), mox.IgnoreArg()).MultipleTimes() mock_conn.perform().MultipleTimes() mock_conn.getinfo(pycurl.RESPONSE_CODE).MultipleTimes().AndReturn(400) - self.mox.StubOutWithMock(time, 'sleep') + self.mox.StubOutWithMock(time, "sleep") time.sleep(5) time.sleep(5) - self.mox.StubOutWithMock(pycurl, 'Curl') + self.mox.StubOutWithMock(pycurl, "Curl") pycurl.Curl().AndReturn(mock_conn) self.mox.ReplayAll() source = httpsource.HttpFilesSource(config) - self.assertRaises(error.SourceUnavailable, - httpsource.UpdateGetter().GetUpdates, - source, - url='https://TEST_URL', - since=None) + self.assertRaises( + error.SourceUnavailable, + httpsource.UpdateGetter().GetUpdates, + source, + url="https://TEST_URL", + since=None, + ) class TestPasswdUpdateGetter(unittest.TestCase): - def setUp(self): super(TestPasswdUpdateGetter, self).setUp() self.updater = httpsource.PasswdUpdateGetter() @@ -241,15 +246,14 @@ def setUp(self): def testGetParser(self): parser = self.updater.GetParser() self.assertTrue( - isinstance(self.updater.GetParser(), - file_formats.FilesPasswdMapParser)) + isinstance(self.updater.GetParser(), file_formats.FilesPasswdMapParser) + ) def testCreateMap(self): self.assertTrue(isinstance(self.updater.CreateMap(), passwd.PasswdMap)) class TestShadowUpdateGetter(mox.MoxTestBase): - def setUp(self): super(TestShadowUpdateGetter, self).setUp() self.updater = httpsource.ShadowUpdateGetter() @@ -257,8 +261,8 @@ def setUp(self): def testGetParser(self): parser = self.updater.GetParser() self.assertTrue( - isinstance(self.updater.GetParser(), - file_formats.FilesShadowMapParser)) + isinstance(self.updater.GetParser(), file_formats.FilesShadowMapParser) + ) def testCreateMap(self): self.assertTrue(isinstance(self.updater.CreateMap(), shadow.ShadowMap)) @@ -268,23 +272,27 @@ def testShadowGetUpdatesWithContent(self): mock_conn.setopt(mox.IgnoreArg(), mox.IgnoreArg()).MultipleTimes() mock_conn.getinfo(pycurl.INFO_FILETIME).AndReturn(-1) - self.mox.StubOutWithMock(pycurl, 'Curl') + self.mox.StubOutWithMock(pycurl, "Curl") pycurl.Curl().AndReturn(mock_conn) - self.mox.StubOutWithMock(curl, 'CurlFetch') + self.mox.StubOutWithMock(curl, "CurlFetch") - curl.CurlFetch('https://TEST_URL', mock_conn, - self.updater.log).AndReturn([ - 200, "", - BytesIO(b"""usera:x::::::: + curl.CurlFetch("https://TEST_URL", mock_conn, self.updater.log).AndReturn( + [ + 200, + "", + BytesIO( + b"""usera:x::::::: userb:x::::::: -""").getvalue() - ]) +""" + ).getvalue(), + ] + ) self.mox.ReplayAll() config = {} source = httpsource.HttpFilesSource(config) - result = self.updater.GetUpdates(source, 'https://TEST_URL', 1) + result = self.updater.GetUpdates(source, "https://TEST_URL", 1) print(result) self.assertEqual(len(result), 2) @@ -293,31 +301,32 @@ def testShadowGetUpdatesWithBz2Content(self): mock_conn.setopt(mox.IgnoreArg(), mox.IgnoreArg()).MultipleTimes() mock_conn.getinfo(pycurl.INFO_FILETIME).AndReturn(-1) - self.mox.StubOutWithMock(pycurl, 'Curl') + self.mox.StubOutWithMock(pycurl, "Curl") pycurl.Curl().AndReturn(mock_conn) - self.mox.StubOutWithMock(curl, 'CurlFetch') + self.mox.StubOutWithMock(curl, "CurlFetch") - curl.CurlFetch( - 'https://TEST_URL', mock_conn, self.updater.log - ).AndReturn([ - 200, "", - BytesIO( - base64.b64decode( - "QlpoOTFBWSZTWfm+rXYAAAvJgAgQABAyABpAIAAhKm1GMoQAwRSpHIXejGQgz4u5IpwoSHzfVrsA" - )).getvalue() - ]) + curl.CurlFetch("https://TEST_URL", mock_conn, self.updater.log).AndReturn( + [ + 200, + "", + BytesIO( + base64.b64decode( + "QlpoOTFBWSZTWfm+rXYAAAvJgAgQABAyABpAIAAhKm1GMoQAwRSpHIXejGQgz4u5IpwoSHzfVrsA" + ) + ).getvalue(), + ] + ) self.mox.ReplayAll() config = {} source = httpsource.HttpFilesSource(config) - result = self.updater.GetUpdates(source, 'https://TEST_URL', 1) + result = self.updater.GetUpdates(source, "https://TEST_URL", 1) print(result) self.assertEqual(len(result), 2) class TestGroupUpdateGetter(unittest.TestCase): - def setUp(self): super(TestGroupUpdateGetter, self).setUp() self.updater = httpsource.GroupUpdateGetter() @@ -325,15 +334,14 @@ def setUp(self): def testGetParser(self): parser = self.updater.GetParser() self.assertTrue( - isinstance(self.updater.GetParser(), - file_formats.FilesGroupMapParser)) + isinstance(self.updater.GetParser(), file_formats.FilesGroupMapParser) + ) def testCreateMap(self): self.assertTrue(isinstance(self.updater.CreateMap(), group.GroupMap)) class TestNetgroupUpdateGetter(unittest.TestCase): - def setUp(self): super(TestNetgroupUpdateGetter, self).setUp() self.updater = httpsource.NetgroupUpdateGetter() @@ -341,16 +349,14 @@ def setUp(self): def testGetParser(self): parser = self.updater.GetParser() self.assertTrue( - isinstance(self.updater.GetParser(), - file_formats.FilesNetgroupMapParser)) + isinstance(self.updater.GetParser(), file_formats.FilesNetgroupMapParser) + ) def testCreateMap(self): - self.assertTrue( - isinstance(self.updater.CreateMap(), netgroup.NetgroupMap)) + self.assertTrue(isinstance(self.updater.CreateMap(), netgroup.NetgroupMap)) class TestAutomountUpdateGetter(unittest.TestCase): - def setUp(self): super(TestAutomountUpdateGetter, self).setUp() self.updater = httpsource.AutomountUpdateGetter() @@ -358,16 +364,14 @@ def setUp(self): def testGetParser(self): parser = self.updater.GetParser() self.assertTrue( - isinstance(self.updater.GetParser(), - file_formats.FilesAutomountMapParser)) + isinstance(self.updater.GetParser(), file_formats.FilesAutomountMapParser) + ) def testCreateMap(self): - self.assertTrue( - isinstance(self.updater.CreateMap(), automount.AutomountMap)) + self.assertTrue(isinstance(self.updater.CreateMap(), automount.AutomountMap)) class TestSshkeyUpdateGetter(unittest.TestCase): - def setUp(self): super(TestSshkeyUpdateGetter, self).setUp() self.updater = httpsource.SshkeyUpdateGetter() @@ -375,12 +379,12 @@ def setUp(self): def testGetParser(self): parser = self.updater.GetParser() self.assertTrue( - isinstance(self.updater.GetParser(), - file_formats.FilesSshkeyMapParser)) + isinstance(self.updater.GetParser(), file_formats.FilesSshkeyMapParser) + ) def testCreateMap(self): self.assertTrue(isinstance(self.updater.CreateMap(), sshkey.SshkeyMap)) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/nss_cache/sources/ldapsource.py b/nss_cache/sources/ldapsource.py index d689451b..8a722dc1 100644 --- a/nss_cache/sources/ldapsource.py +++ b/nss_cache/sources/ldapsource.py @@ -15,8 +15,10 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """An implementation of an ldap data source for nsscache.""" -__author__ = ('jaq@google.com (Jamie Wilkinson)', - 'vasilios@google.com (Vasilios Hoffman)') +__author__ = ( + "jaq@google.com (Jamie Wilkinson)", + "vasilios@google.com (Vasilios Hoffman)", +) import calendar import logging @@ -37,10 +39,10 @@ from nss_cache.maps import sshkey from nss_cache.sources import source -IS_LDAP24_OR_NEWER = version.parse(ldap.__version__) >= version.parse('2.4') +IS_LDAP24_OR_NEWER = version.parse(ldap.__version__) >= version.parse("2.4") # ldap.LDAP_CONTROL_PAGE_OID is unavailable on some systems, so we define it here -LDAP_CONTROL_PAGE_OID = '1.2.840.113556.1.4.319' +LDAP_CONTROL_PAGE_OID = "1.2.840.113556.1.4.319" def RegisterImplementation(registration_callback): @@ -51,12 +53,11 @@ def makeSimplePagedResultsControl(page_size): # The API for this is different on older versions of python-ldap, so we need # to handle this case. if IS_LDAP24_OR_NEWER: - return ldap.controls.SimplePagedResultsControl(True, - size=page_size, - cookie='') + return ldap.controls.SimplePagedResultsControl(True, size=page_size, cookie="") else: - return ldap.controls.SimplePagedResultsControl(LDAP_CONTROL_PAGE_OID, - True, (page_size, '')) + return ldap.controls.SimplePagedResultsControl( + LDAP_CONTROL_PAGE_OID, True, (page_size, "") + ) def getCookieFromControl(pctrl): @@ -87,7 +88,7 @@ def sidToStr(sid): https://ldap3.readthedocs.io/_modules/ldap3/protocol/formatters/formatters.html#format_sid """ try: - if sid.startswith(b'S-1') or sid.startswith('S-1'): + if sid.startswith(b"S-1") or sid.startswith("S-1"): return sid except Exception: pass @@ -95,16 +96,20 @@ def sidToStr(sid): if str is not bytes: revision = int(sid[0]) sub_authorities = int(sid[1]) - identifier_authority = int.from_bytes(sid[2:8], byteorder='big') + identifier_authority = int.from_bytes(sid[2:8], byteorder="big") if identifier_authority >= 2**32: identifier_authority = hex(identifier_authority) - sub_authority = '-' + '-'.join([ - str( - int.from_bytes(sid[8 + (i * 4):12 + (i * 4)], - byteorder='little')) - for i in range(sub_authorities) - ]) + sub_authority = "-" + "-".join( + [ + str( + int.from_bytes( + sid[8 + (i * 4) : 12 + (i * 4)], byteorder="little" + ) + ) + for i in range(sub_authorities) + ] + ) else: revision = int(b2a_hex(sid[0])) sub_authorities = int(b2a_hex(sid[1])) @@ -112,12 +117,15 @@ def sidToStr(sid): if identifier_authority >= 2**32: identifier_authority = hex(identifier_authority) - sub_authority = '-' + '-'.join([ - str(int(b2a_hex(sid[11 + (i * 4):7 + (i * 4):-1]), 16)) - for i in range(sub_authorities) - ]) - objectSid = 'S-' + str(revision) + '-' + str( - identifier_authority) + sub_authority + sub_authority = "-" + "-".join( + [ + str(int(b2a_hex(sid[11 + (i * 4) : 7 + (i * 4) : -1]), 16)) + for i in range(sub_authorities) + ] + ) + objectSid = ( + "S-" + str(revision) + "-" + str(identifier_authority) + sub_authority + ) return objectSid except Exception: @@ -136,17 +144,18 @@ class LdapSource(source.Source): 'objects' in this sense means some structured blob of data, not a Python object. """ + # ldap defaults - BIND_DN = '' - BIND_PASSWORD = '' + BIND_DN = "" + BIND_PASSWORD = "" RETRY_DELAY = 5 RETRY_MAX = 3 - SCOPE = 'one' + SCOPE = "one" TIMELIMIT = -1 - TLS_REQUIRE_CERT = 'demand' # one of never, hard, demand, allow, try + TLS_REQUIRE_CERT = "demand" # one of never, hard, demand, allow, try # for registration - name = 'ldap' + name = "ldap" # Page size for paged LDAP requests # Value chosen based on default Active Directory MaxPageSize @@ -173,13 +182,15 @@ def __init__(self, conf, conn=None): # ReconnectLDAPObject should handle interrupted ldap transactions. # also, ugh rlo = ldap.ldapobject.ReconnectLDAPObject - self.conn = rlo(uri=conf['uri'], - retry_max=conf['retry_max'], - retry_delay=conf['retry_delay']) - if conf['tls_starttls'] == 1: + self.conn = rlo( + uri=conf["uri"], + retry_max=conf["retry_max"], + retry_delay=conf["retry_delay"], + ) + if conf["tls_starttls"] == 1: self.conn.start_tls_s() - if 'ldap_debug' in conf: - self.conn.set_option(ldap.OPT_DEBUG_LEVEL, conf['ldap_debug']) + if "ldap_debug" in conf: + self.conn.set_option(ldap.OPT_DEBUG_LEVEL, conf["ldap_debug"]) else: self.conn = conn @@ -190,69 +201,63 @@ def __init__(self, conf, conn=None): def _SetDefaults(self, configuration): """Set defaults if necessary.""" # LDAPI URLs must be url escaped socket filenames; rewrite if necessary. - if 'uri' in configuration: - if configuration['uri'].startswith('ldapi://'): - configuration['uri'] = 'ldapi://' + quote( - configuration['uri'][8:], '') - if 'bind_dn' not in configuration: - configuration['bind_dn'] = self.BIND_DN - if 'bind_password' not in configuration: - configuration['bind_password'] = self.BIND_PASSWORD - if 'retry_delay' not in configuration: - configuration['retry_delay'] = self.RETRY_DELAY - if 'retry_max' not in configuration: - configuration['retry_max'] = self.RETRY_MAX - if 'scope' not in configuration: - configuration['scope'] = self.SCOPE - if 'timelimit' not in configuration: - configuration['timelimit'] = self.TIMELIMIT + if "uri" in configuration: + if configuration["uri"].startswith("ldapi://"): + configuration["uri"] = "ldapi://" + quote(configuration["uri"][8:], "") + if "bind_dn" not in configuration: + configuration["bind_dn"] = self.BIND_DN + if "bind_password" not in configuration: + configuration["bind_password"] = self.BIND_PASSWORD + if "retry_delay" not in configuration: + configuration["retry_delay"] = self.RETRY_DELAY + if "retry_max" not in configuration: + configuration["retry_max"] = self.RETRY_MAX + if "scope" not in configuration: + configuration["scope"] = self.SCOPE + if "timelimit" not in configuration: + configuration["timelimit"] = self.TIMELIMIT # TODO(jaq): XXX EVIL. ldap client libraries change behaviour if we use # polling, and it's nasty. So don't let the user poll. - if configuration['timelimit'] == 0: - configuration['timelimit'] = -1 - if 'tls_require_cert' not in configuration: - configuration['tls_require_cert'] = self.TLS_REQUIRE_CERT - if 'tls_starttls' not in configuration: - configuration['tls_starttls'] = 0 + if configuration["timelimit"] == 0: + configuration["timelimit"] = -1 + if "tls_require_cert" not in configuration: + configuration["tls_require_cert"] = self.TLS_REQUIRE_CERT + if "tls_starttls" not in configuration: + configuration["tls_starttls"] = 0 # Translate tls_require into appropriate constant, if necessary. - if configuration['tls_require_cert'] == 'never': - configuration['tls_require_cert'] = ldap.OPT_X_TLS_NEVER - elif configuration['tls_require_cert'] == 'hard': - configuration['tls_require_cert'] = ldap.OPT_X_TLS_HARD - elif configuration['tls_require_cert'] == 'demand': - configuration['tls_require_cert'] = ldap.OPT_X_TLS_DEMAND - elif configuration['tls_require_cert'] == 'allow': - configuration['tls_require_cert'] = ldap.OPT_X_TLS_ALLOW - elif configuration['tls_require_cert'] == 'try': - configuration['tls_require_cert'] = ldap.OPT_X_TLS_TRY - - if 'sasl_authzid' not in configuration: - configuration['sasl_authzid'] = '' + if configuration["tls_require_cert"] == "never": + configuration["tls_require_cert"] = ldap.OPT_X_TLS_NEVER + elif configuration["tls_require_cert"] == "hard": + configuration["tls_require_cert"] = ldap.OPT_X_TLS_HARD + elif configuration["tls_require_cert"] == "demand": + configuration["tls_require_cert"] = ldap.OPT_X_TLS_DEMAND + elif configuration["tls_require_cert"] == "allow": + configuration["tls_require_cert"] = ldap.OPT_X_TLS_ALLOW + elif configuration["tls_require_cert"] == "try": + configuration["tls_require_cert"] = ldap.OPT_X_TLS_TRY + + if "sasl_authzid" not in configuration: + configuration["sasl_authzid"] = "" # Should we issue STARTTLS? - if configuration['tls_starttls'] in (1, '1', 'on', 'yes', 'true'): - configuration['tls_starttls'] = 1 + if configuration["tls_starttls"] in (1, "1", "on", "yes", "true"): + configuration["tls_starttls"] = 1 # if not configuration['tls_starttls']: else: - configuration['tls_starttls'] = 0 + configuration["tls_starttls"] = 0 # Setting global ldap defaults. - ldap.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, - configuration['tls_require_cert']) + ldap.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, configuration["tls_require_cert"]) ldap.set_option(ldap.OPT_REFERRALS, 0) - if 'tls_cacertdir' in configuration: - ldap.set_option(ldap.OPT_X_TLS_CACERTDIR, - configuration['tls_cacertdir']) - if 'tls_cacertfile' in configuration: - ldap.set_option(ldap.OPT_X_TLS_CACERTFILE, - configuration['tls_cacertfile']) - if 'tls_certfile' in configuration: - ldap.set_option(ldap.OPT_X_TLS_CERTFILE, - configuration['tls_certfile']) - if 'tls_keyfile' in configuration: - ldap.set_option(ldap.OPT_X_TLS_KEYFILE, - configuration['tls_keyfile']) + if "tls_cacertdir" in configuration: + ldap.set_option(ldap.OPT_X_TLS_CACERTDIR, configuration["tls_cacertdir"]) + if "tls_cacertfile" in configuration: + ldap.set_option(ldap.OPT_X_TLS_CACERTFILE, configuration["tls_cacertfile"]) + if "tls_certfile" in configuration: + ldap.set_option(ldap.OPT_X_TLS_CERTFILE, configuration["tls_certfile"]) + if "tls_keyfile" in configuration: + ldap.set_option(ldap.OPT_X_TLS_KEYFILE, configuration["tls_keyfile"]) ldap.version = ldap.VERSION3 # this is hard-coded, we only support V3 def _SetCookie(self, cookie): @@ -263,37 +268,38 @@ def Bind(self, configuration): # If the server is unavailable, we are going to find out now, as this # actually initiates the network connection. retry_count = 0 - while retry_count < configuration['retry_max']: - self.log.debug('opening ldap connection and binding to %s', - configuration['uri']) + while retry_count < configuration["retry_max"]: + self.log.debug( + "opening ldap connection and binding to %s", configuration["uri"] + ) try: - if 'use_sasl' in configuration and configuration['use_sasl']: - if ('sasl_mech' in configuration and - configuration['sasl_mech'] and - configuration['sasl_mech'].lower() == 'gssapi'): - sasl = ldap.sasl.gssapi(configuration['sasl_authzid']) + if "use_sasl" in configuration and configuration["use_sasl"]: + if ( + "sasl_mech" in configuration + and configuration["sasl_mech"] + and configuration["sasl_mech"].lower() == "gssapi" + ): + sasl = ldap.sasl.gssapi(configuration["sasl_authzid"]) # TODO: Add other sasl mechs else: - raise error.ConfigurationError( - 'SASL mechanism not supported') + raise error.ConfigurationError("SASL mechanism not supported") - self.conn.sasl_interactive_bind_s('', sasl) + self.conn.sasl_interactive_bind_s("", sasl) else: - self.conn.simple_bind_s(who=configuration['bind_dn'], - cred=str( - configuration['bind_password'])) + self.conn.simple_bind_s( + who=configuration["bind_dn"], + cred=str(configuration["bind_password"]), + ) break except ldap.SERVER_DOWN as e: retry_count += 1 - self.log.warning('Failed LDAP connection: attempt #%s.', - retry_count) - self.log.debug('ldap error is %r', e) - if retry_count == configuration['retry_max']: - self.log.debug('max retries hit') + self.log.warning("Failed LDAP connection: attempt #%s.", retry_count) + self.log.debug("ldap error is %r", e) + if retry_count == configuration["retry_max"]: + self.log.debug("max retries hit") raise error.SourceUnavailable(e) - self.log.debug('sleeping %d seconds', - configuration['retry_delay']) - time.sleep(configuration['retry_delay']) + self.log.debug("sleeping %d seconds", configuration["retry_delay"]) + time.sleep(configuration["retry_delay"]) def _ReSearch(self): """Performs self.Search again with the previously used parameters. @@ -318,18 +324,24 @@ def Search(self, search_base, search_filter, search_scope, attrs): Returns: nothing. """ - self._last_search_params = (search_base, search_filter, search_scope, - attrs) - - self.log.debug('searching for base=%r, filter=%r, scope=%r, attrs=%r', - search_base, search_filter, search_scope, attrs) - if 'dn' in attrs: # special cased attribute + self._last_search_params = (search_base, search_filter, search_scope, attrs) + + self.log.debug( + "searching for base=%r, filter=%r, scope=%r, attrs=%r", + search_base, + search_filter, + search_scope, + attrs, + ) + if "dn" in attrs: # special cased attribute self._dn_requested = True - self.message_id = self.conn.search_ext(base=search_base, - filterstr=search_filter, - scope=search_scope, - attrlist=attrs, - serverctrls=[self.ldap_controls]) + self.message_id = self.conn.search_ext( + base=search_base, + filterstr=search_filter, + scope=search_scope, + attrlist=attrs, + serverctrls=[self.ldap_controls], + ) def __iter__(self): """Iterate over the data from the last search. @@ -344,10 +356,11 @@ def __iter__(self): result_type, data = None, None timeout_retries = 0 - while timeout_retries < int(self._conf['retry_max']): + while timeout_retries < int(self._conf["retry_max"]): try: result_type, data, _, serverctrls = self.conn.result3( - self.message_id, all=0, timeout=self.conf['timelimit']) + self.message_id, all=0, timeout=self.conf["timelimit"] + ) # we need to filter out AD referrals if data and not data[0][0]: continue @@ -357,13 +370,15 @@ def __iter__(self): if len(serverctrls) > 0: # Search for appropriate control simple_paged_results_controls = [ - control for control in serverctrls + control + for control in serverctrls if control.controlType == LDAP_CONTROL_PAGE_OID ] if simple_paged_results_controls: # We only expect one control; just take the first in the list. cookie = getCookieFromControl( - simple_paged_results_controls[0]) + simple_paged_results_controls[0] + ) if len(cookie) > 0: # If cookie is non-empty, call search_ext and result3 again @@ -372,40 +387,43 @@ def __iter__(self): result_type, data, _, serverctrls = self.conn.result3( self.message_id, all=0, - timeout=self.conf['timelimit']) + timeout=self.conf["timelimit"], + ) # else: An empty cookie means we are done. # break loop once result3 doesn't time out and reset cookie - setCookieOnControl(self.ldap_controls, '', self.PAGE_SIZE) + setCookieOnControl(self.ldap_controls, "", self.PAGE_SIZE) break except ldap.SIZELIMIT_EXCEEDED: self.log.warning( - 'LDAP server size limit exceeded; using page size {0}.'. - format(self.PAGE_SIZE)) + "LDAP server size limit exceeded; using page size {0}.".format( + self.PAGE_SIZE + ) + ) return except ldap.NO_SUCH_OBJECT: - self.log.debug('Returning due to ldap.NO_SUCH_OBJECT') + self.log.debug("Returning due to ldap.NO_SUCH_OBJECT") return except ldap.TIMELIMIT_EXCEEDED: timeout_retries += 1 - self.log.warning('Timeout on LDAP results, attempt #%s.', - timeout_retries) - if timeout_retries >= self._conf['retry_max']: - self.log.debug('max retries hit, returning') + self.log.warning( + "Timeout on LDAP results, attempt #%s.", timeout_retries + ) + if timeout_retries >= self._conf["retry_max"]: + self.log.debug("max retries hit, returning") return - self.log.debug('sleeping %d seconds', - self._conf['retry_delay']) - time.sleep(self.conf['retry_delay']) + self.log.debug("sleeping %d seconds", self._conf["retry_delay"]) + time.sleep(self.conf["retry_delay"]) if result_type == ldap.RES_SEARCH_RESULT: - self.log.debug('Returning due to RES_SEARCH_RESULT') + self.log.debug("Returning due to RES_SEARCH_RESULT") return if result_type != ldap.RES_SEARCH_ENTRY: - self.log.info('Unknown result type %r, ignoring.', result_type) + self.log.info("Unknown result type %r, ignoring.", result_type) if not data: - self.log.debug('Returning due to len(data) == 0') + self.log.debug("Returning due to len(data) == 0") return for record in data: @@ -413,12 +431,11 @@ def __iter__(self): # otherwise ignore it. for key in record[1]: for i in range(len(record[1][key])): - if isinstance(record[1][key][i], - bytes) and key != 'objectSid': - value = record[1][key][i].decode('utf-8') + if isinstance(record[1][key][i], bytes) and key != "objectSid": + value = record[1][key][i].decode("utf-8") record[1][key][i] = value if self._dn_requested: - merged_records = {'dn': record[0]} + merged_records = {"dn": record[0]} merged_records.update(record[1]) yield merged_records else: @@ -436,10 +453,11 @@ def GetSshkeyMap(self, since=None): """ return SshkeyUpdateGetter(self.conf).GetUpdates( source=self, - search_base=self.conf['base'], - search_filter=self.conf['filter'], - search_scope=self.conf['scope'], - since=since) + search_base=self.conf["base"], + search_filter=self.conf["filter"], + search_scope=self.conf["scope"], + since=since, + ) def GetPasswdMap(self, since=None): """Return the passwd map from this source. @@ -453,10 +471,11 @@ def GetPasswdMap(self, since=None): """ return PasswdUpdateGetter(self.conf).GetUpdates( source=self, - search_base=self.conf['base'], - search_filter=self.conf['filter'], - search_scope=self.conf['scope'], - since=since) + search_base=self.conf["base"], + search_filter=self.conf["filter"], + search_scope=self.conf["scope"], + since=since, + ) def GetGroupMap(self, since=None): """Return the group map from this source. @@ -470,10 +489,11 @@ def GetGroupMap(self, since=None): """ return GroupUpdateGetter(self.conf).GetUpdates( source=self, - search_base=self.conf['base'], - search_filter=self.conf['filter'], - search_scope=self.conf['scope'], - since=since) + search_base=self.conf["base"], + search_filter=self.conf["filter"], + search_scope=self.conf["scope"], + since=since, + ) def GetShadowMap(self, since=None): """Return the shadow map from this source. @@ -487,10 +507,11 @@ def GetShadowMap(self, since=None): """ return ShadowUpdateGetter(self.conf).GetUpdates( source=self, - search_base=self.conf['base'], - search_filter=self.conf['filter'], - search_scope=self.conf['scope'], - since=since) + search_base=self.conf["base"], + search_filter=self.conf["filter"], + search_scope=self.conf["scope"], + since=since, + ) def GetNetgroupMap(self, since=None): """Return the netgroup map from this source. @@ -504,10 +525,11 @@ def GetNetgroupMap(self, since=None): """ return NetgroupUpdateGetter(self.conf).GetUpdates( source=self, - search_base=self.conf['base'], - search_filter=self.conf['filter'], - search_scope=self.conf['scope'], - since=since) + search_base=self.conf["base"], + search_filter=self.conf["filter"], + search_scope=self.conf["scope"], + since=since, + ) def GetAutomountMap(self, since=None, location=None): """Return an automount map from this source. @@ -526,17 +548,17 @@ def GetAutomountMap(self, since=None, location=None): instance of AutomountMap """ if location is None: - self.log.error( - 'A location is required to retrieve an automount map!') + self.log.error("A location is required to retrieve an automount map!") raise error.EmptyMap - autofs_filter = '(objectclass=automount)' + autofs_filter = "(objectclass=automount)" return AutomountUpdateGetter(self.conf).GetUpdates( source=self, search_base=location, search_filter=autofs_filter, - search_scope='one', - since=since) + search_scope="one", + since=since, + ) def GetAutomountMasterMap(self): """Return the autmount master map from this source. @@ -549,27 +571,29 @@ def GetAutomountMasterMap(self): Returns: an instance of maps.AutomountMap """ - search_base = self.conf['base'] + search_base = self.conf["base"] search_scope = ldap.SCOPE_SUBTREE # auto.master is stored under ou=auto.master with objectclass=automountMap - search_filter = '(&(objectclass=automountMap)(ou=auto.master))' - self.log.debug('retrieving automount master map.') - self.Search(search_base=search_base, - search_filter=search_filter, - search_scope=search_scope, - attrs=['dn']) + search_filter = "(&(objectclass=automountMap)(ou=auto.master))" + self.log.debug("retrieving automount master map.") + self.Search( + search_base=search_base, + search_filter=search_filter, + search_scope=search_scope, + attrs=["dn"], + ) search_base = None for obj in self: # the dn of the matched object is our search base - search_base = obj['dn'] + search_base = obj["dn"] if search_base is None: - self.log.critical('Could not find automount master map!') + self.log.critical("Could not find automount master map!") raise error.EmptyMap - self.log.debug('found ou=auto.master at %s', search_base) + self.log.debug("found ou=auto.master at %s", search_base) master_map = self.GetAutomountMap(location=search_base) # fix our location attribute to contain the data we @@ -577,10 +601,10 @@ def GetAutomountMasterMap(self): for map_entry in master_map: # we currently ignore hostname and just look for the dn which will # be the search_base for this map. third field, colon delimited. - map_entry.location = map_entry.location.split(':')[2] + map_entry.location = map_entry.location.split(":")[2] # and strip the space seperated options - map_entry.location = map_entry.location.split(' ')[0] - self.log.debug('master map has: %s' % map_entry.location) + map_entry.location = map_entry.location.split(" ")[0] + self.log.debug("master map has: %s" % map_entry.location) return master_map @@ -615,22 +639,22 @@ def FromLdapToTimestamp(self, ldap_ts_string): number of seconds since epoch. """ if isinstance(ldap_ts_string, bytes): - ldap_ts_string = ldap_ts_string.decode('utf-8') + ldap_ts_string = ldap_ts_string.decode("utf-8") try: - if self.conf.get('ad'): + if self.conf.get("ad"): # AD timestamp has different format - t = time.strptime(ldap_ts_string, '%Y%m%d%H%M%S.0Z') + t = time.strptime(ldap_ts_string, "%Y%m%d%H%M%S.0Z") else: - t = time.strptime(ldap_ts_string, '%Y%m%d%H%M%SZ') + t = time.strptime(ldap_ts_string, "%Y%m%d%H%M%SZ") except ValueError: # Some systems add a decimal component; try to filter it: - m = re.match('([0-9]*)(\.[0-9]*)?(Z)', ldap_ts_string) + m = re.match(r"([0-9]*)(\.[0-9]*)?(Z)", ldap_ts_string) if m: ldap_ts_string = m.group(1) + m.group(3) - if self.conf.get('ad'): - t = time.strptime(ldap_ts_string, '%Y%m%d%H%M%S.0Z') + if self.conf.get("ad"): + t = time.strptime(ldap_ts_string, "%Y%m%d%H%M%S.0Z") else: - t = time.strptime(ldap_ts_string, '%Y%m%d%H%M%SZ') + t = time.strptime(ldap_ts_string, "%Y%m%d%H%M%SZ") return int(calendar.timegm(t)) def FromTimestampToLdap(self, ts): @@ -642,14 +666,13 @@ def FromTimestampToLdap(self, ts): Returns: LDAP format timestamp string. """ - if self.conf.get('ad'): - t = time.strftime('%Y%m%d%H%M%S.0Z', time.gmtime(ts)) + if self.conf.get("ad"): + t = time.strftime("%Y%m%d%H%M%S.0Z", time.gmtime(ts)) else: - t = time.strftime('%Y%m%d%H%M%SZ', time.gmtime(ts)) + t = time.strftime("%Y%m%d%H%M%SZ", time.gmtime(ts)) return t - def GetUpdates(self, source, search_base, search_filter, search_scope, - since): + def GetUpdates(self, source, search_base, search_filter, search_scope, since): """Get updates from a source. Args: @@ -666,39 +689,40 @@ def GetUpdates(self, source, search_base, search_filter, search_scope, error.ConfigurationError: scope is invalid ValueError: an object in the source map is malformed """ - if self.conf.get('ad'): + if self.conf.get("ad"): # AD attribute for modifyTimestamp is whenChanged - self.attrs.append('whenChanged') + self.attrs.append("whenChanged") else: - self.attrs.append('modifyTimestamp') + self.attrs.append("modifyTimestamp") if since is not None: ts = self.FromTimestampToLdap(since) # since openldap disallows modifyTimestamp "greater than" we have to # increment by one second. - if self.conf.get('ad'): - ts = int(ts.rstrip('.0Z')) + 1 - ts = '%s.0Z' % ts - search_filter = ('(&%s(whenChanged>=%s))' % (search_filter, ts)) + if self.conf.get("ad"): + ts = int(ts.rstrip(".0Z")) + 1 + ts = "%s.0Z" % ts + search_filter = "(&%s(whenChanged>=%s))" % (search_filter, ts) else: - ts = int(ts.rstrip('Z')) + 1 - ts = '%sZ' % ts - search_filter = ('(&%s(modifyTimestamp>=%s))' % - (search_filter, ts)) + ts = int(ts.rstrip("Z")) + 1 + ts = "%sZ" % ts + search_filter = "(&%s(modifyTimestamp>=%s))" % (search_filter, ts) - if search_scope == 'base': + if search_scope == "base": search_scope = ldap.SCOPE_BASE - elif search_scope in ['one', 'onelevel']: + elif search_scope in ["one", "onelevel"]: search_scope = ldap.SCOPE_ONELEVEL - elif search_scope in ['sub', 'subtree']: + elif search_scope in ["sub", "subtree"]: search_scope = ldap.SCOPE_SUBTREE else: - raise error.ConfigurationError('Invalid scope: %s' % search_scope) + raise error.ConfigurationError("Invalid scope: %s" % search_scope) - source.Search(search_base=search_base, - search_filter=search_filter, - search_scope=search_scope, - attrs=self.attrs) + source.Search( + search_base=search_base, + search_filter=search_filter, + search_scope=search_scope, + attrs=self.attrs, + ) # Don't initialize with since, because we really want to get the # latest timestamp read, and if somehow a larger 'since' slips through @@ -710,27 +734,25 @@ def GetUpdates(self, source, search_base, search_filter, search_scope, for obj in source: for field in self.essential_fields: if field not in obj: - logging.warn('invalid object passed: %r not in %r', field, - obj) - raise ValueError('Invalid object passed: %r', obj) + logging.warn("invalid object passed: %r not in %r", field, obj) + raise ValueError("Invalid object passed: %r", obj) - if self.conf.get('ad'): - obj_ts = self.FromLdapToTimestamp(obj['whenChanged'][0]) + if self.conf.get("ad"): + obj_ts = self.FromLdapToTimestamp(obj["whenChanged"][0]) else: try: - obj_ts = self.FromLdapToTimestamp(obj['modifyTimestamp'][0]) + obj_ts = self.FromLdapToTimestamp(obj["modifyTimestamp"][0]) except KeyError: - obj_ts = self.FromLdapToTimestamp(obj['modifyTimeStamp'][0]) + obj_ts = self.FromLdapToTimestamp(obj["modifyTimeStamp"][0]) if max_ts is None or obj_ts > max_ts: max_ts = obj_ts try: if not data_map.Add(self.Transform(obj)): - logging.info('could not add obj: %r', obj) + logging.info("could not add obj: %r", obj) except AttributeError as e: - logging.warning('error %r, discarding malformed obj: %r', - str(e), obj) + logging.warning("error %r, discarding malformed obj: %r", str(e), obj) # Perform some post processing on the data_map. self.PostProcess(data_map, source, search_filter, search_scope) @@ -748,26 +770,36 @@ class PasswdUpdateGetter(UpdateGetter): def __init__(self, conf): super(PasswdUpdateGetter, self).__init__(conf) - if self.conf.get('ad'): + if self.conf.get("ad"): # attributes of AD user to be returned self.attrs = [ - 'sAMAccountName', 'objectSid', 'displayName', - 'unixHomeDirectory', 'pwdLastSet', 'loginShell' + "sAMAccountName", + "objectSid", + "displayName", + "unixHomeDirectory", + "pwdLastSet", + "loginShell", ] - self.essential_fields = ['sAMAccountName', 'objectSid'] + self.essential_fields = ["sAMAccountName", "objectSid"] else: self.attrs = [ - 'uid', 'uidNumber', 'gidNumber', 'gecos', 'cn', 'homeDirectory', - 'loginShell', 'fullName' + "uid", + "uidNumber", + "gidNumber", + "gecos", + "cn", + "homeDirectory", + "loginShell", + "fullName", ] - if 'uidattr' in self.conf: - self.attrs.append(self.conf['uidattr']) - if 'uidregex' in self.conf: - self.uidregex = re.compile(self.conf['uidregex']) - self.essential_fields = ['uid', 'uidNumber', 'gidNumber'] - if self.conf.get('use_rid'): - self.attrs.append('sambaSID') - self.essential_fields.append('sambaSID') + if "uidattr" in self.conf: + self.attrs.append(self.conf["uidattr"]) + if "uidregex" in self.conf: + self.uidregex = re.compile(self.conf["uidregex"]) + self.essential_fields = ["uid", "uidNumber", "gidNumber"] + if self.conf.get("use_rid"): + self.attrs.append("sambaSID") + self.essential_fields.append("sambaSID") self.log = logging.getLogger(self.__class__.__name__) def CreateMap(self): @@ -781,68 +813,68 @@ def Transform(self, obj): pw = passwd.PasswdMapEntry() - if self.conf.get('ad'): - if 'displayName' in obj: - pw.gecos = obj['displayName'][0] - elif 'gecos' in obj: - pw.gecos = obj['gecos'][0] - elif 'cn' in obj: - pw.gecos = obj['cn'][0] - elif 'fullName' in obj: - pw.gecos = obj['fullName'][0] + if self.conf.get("ad"): + if "displayName" in obj: + pw.gecos = obj["displayName"][0] + elif "gecos" in obj: + pw.gecos = obj["gecos"][0] + elif "cn" in obj: + pw.gecos = obj["cn"][0] + elif "fullName" in obj: + pw.gecos = obj["fullName"][0] else: - raise ValueError('Neither gecos nor cn found') + raise ValueError("Neither gecos nor cn found") - pw.gecos = pw.gecos.replace('\n', '') + pw.gecos = pw.gecos.replace("\n", "") - if self.conf.get('ad'): - pw.name = obj['sAMAccountName'][0] - elif 'uidattr' in self.conf: - pw.name = obj[self.conf['uidattr']][0] + if self.conf.get("ad"): + pw.name = obj["sAMAccountName"][0] + elif "uidattr" in self.conf: + pw.name = obj[self.conf["uidattr"]][0] else: - pw.name = obj['uid'][0] + pw.name = obj["uid"][0] - if hasattr(self, 'uidregex'): - pw.name = ''.join([x for x in self.uidregex.findall(pw.name)]) + if hasattr(self, "uidregex"): + pw.name = "".join([x for x in self.uidregex.findall(pw.name)]) - if 'override_shell' in self.conf: - pw.shell = self.conf['override_shell'] - elif 'loginShell' in obj: - pw.shell = obj['loginShell'][0] + if "override_shell" in self.conf: + pw.shell = self.conf["override_shell"] + elif "loginShell" in obj: + pw.shell = obj["loginShell"][0] else: - pw.shell = '' + pw.shell = "" - if self.conf.get('ad'): + if self.conf.get("ad"): # use the user's RID for uid and gid to have # the correspondant group with the same name - pw.uid = int(sidToStr(obj['objectSid'][0]).split('-')[-1]) - pw.gid = int(sidToStr(obj['objectSid'][0]).split('-')[-1]) - elif self.conf.get('use_rid'): + pw.uid = int(sidToStr(obj["objectSid"][0]).split("-")[-1]) + pw.gid = int(sidToStr(obj["objectSid"][0]).split("-")[-1]) + elif self.conf.get("use_rid"): # use the user's RID for uid and gid to have # the correspondant group with the same name - pw.uid = int(sidToStr(obj['sambaSID'][0]).split('-')[-1]) - pw.gid = int(sidToStr(obj['sambaSID'][0]).split('-')[-1]) + pw.uid = int(sidToStr(obj["sambaSID"][0]).split("-")[-1]) + pw.gid = int(sidToStr(obj["sambaSID"][0]).split("-")[-1]) else: - pw.uid = int(obj['uidNumber'][0]) - pw.gid = int(obj['gidNumber'][0]) + pw.uid = int(obj["uidNumber"][0]) + pw.gid = int(obj["gidNumber"][0]) - if 'offset' in self.conf: + if "offset" in self.conf: # map uid and gid to higher number # to avoid conflict with local accounts - pw.uid = int(pw.uid + self.conf['offset']) - pw.gid = int(pw.gid + self.conf['offset']) - - if self.conf.get('home_dir'): - pw.dir = '/home/%s' % pw.name - elif 'unixHomeDirectory' in obj: - pw.dir = obj['unixHomeDirectory'][0] - elif 'homeDirectory' in obj: - pw.dir = obj['homeDirectory'][0] + pw.uid = int(pw.uid + self.conf["offset"]) + pw.gid = int(pw.gid + self.conf["offset"]) + + if self.conf.get("home_dir"): + pw.dir = "/home/%s" % pw.name + elif "unixHomeDirectory" in obj: + pw.dir = obj["unixHomeDirectory"][0] + elif "homeDirectory" in obj: + pw.dir = obj["homeDirectory"][0] else: - pw.dir = '' + pw.dir = "" # hack - pw.passwd = 'x' + pw.passwd = "x" return pw @@ -853,23 +885,23 @@ class GroupUpdateGetter(UpdateGetter): def __init__(self, conf): super(GroupUpdateGetter, self).__init__(conf) # TODO: Merge multiple rcf2307bis[_alt] options into a single option. - if self.conf.get('ad'): + if self.conf.get("ad"): # attributes of AD group to be returned - self.attrs = ['sAMAccountName', 'member', 'objectSid'] - self.essential_fields = ['sAMAccountName', 'objectSid'] + self.attrs = ["sAMAccountName", "member", "objectSid"] + self.essential_fields = ["sAMAccountName", "objectSid"] else: - if conf.get('rfc2307bis'): - self.attrs = ['cn', 'gidNumber', 'member', 'uid'] - elif conf.get('rfc2307bis_alt'): - self.attrs = ['cn', 'gidNumber', 'uniqueMember', 'uid'] + if conf.get("rfc2307bis"): + self.attrs = ["cn", "gidNumber", "member", "uid"] + elif conf.get("rfc2307bis_alt"): + self.attrs = ["cn", "gidNumber", "uniqueMember", "uid"] else: - self.attrs = ['cn', 'gidNumber', 'memberUid', 'uid'] - if 'groupregex' in conf: - self.groupregex = re.compile(self.conf['groupregex']) - self.essential_fields = ['cn'] - if conf.get('use_rid'): - self.attrs.append('sambaSID') - self.essential_fields.append('sambaSID') + self.attrs = ["cn", "gidNumber", "memberUid", "uid"] + if "groupregex" in conf: + self.groupregex = re.compile(self.conf["groupregex"]) + self.essential_fields = ["cn"] + if conf.get("use_rid"): + self.attrs.append("sambaSID") + self.essential_fields.append("sambaSID") self.log = logging.getLogger(__name__) @@ -882,52 +914,54 @@ def Transform(self, obj): gr = group.GroupMapEntry() - if self.conf.get('ad'): - gr.name = obj['sAMAccountName'][0] + if self.conf.get("ad"): + gr.name = obj["sAMAccountName"][0] # hack to map the users as the corresponding group with the same name - elif 'uidattr' in self.conf and self.conf['uidattr'] in obj: - gr.name = obj[self.conf['uidattr']][0] - elif 'uid' in obj: - gr.name = obj['uid'][0] + elif "uidattr" in self.conf and self.conf["uidattr"] in obj: + gr.name = obj[self.conf["uidattr"]][0] + elif "uid" in obj: + gr.name = obj["uid"][0] else: - gr.name = obj['cn'][0] + gr.name = obj["cn"][0] # group passwords are deferred to gshadow - gr.passwd = '*' + gr.passwd = "*" base = self.conf.get("base") members = [] group_members = [] - if 'memberUid' in obj: - if hasattr(self, 'groupregex'): - members.extend(''.join( - [x for x in self.groupregex.findall(obj['memberUid'])])) + if "memberUid" in obj: + if hasattr(self, "groupregex"): + members.extend( + "".join([x for x in self.groupregex.findall(obj["memberUid"])]) + ) else: - members.extend(obj['memberUid']) - elif 'member' in obj: - for member_dn in obj['member']: - member_uid = member_dn.split(',')[0].split('=')[1] + members.extend(obj["memberUid"]) + elif "member" in obj: + for member_dn in obj["member"]: + member_uid = member_dn.split(",")[0].split("=")[1] # Note that there is not currently a way to consistently distinguish # a group from a person group_members.append(member_uid) - if hasattr(self, 'groupregex'): - members.append(''.join( - [x for x in self.groupregex.findall(member_uid)])) + if hasattr(self, "groupregex"): + members.append( + "".join([x for x in self.groupregex.findall(member_uid)]) + ) else: members.append(member_uid) - elif 'uniqueMember' in obj: + elif "uniqueMember" in obj: """This contains a DN and is processed in PostProcess in GetUpdates.""" - members.extend(obj['uniqueMember']) + members.extend(obj["uniqueMember"]) members.sort() - if self.conf.get('ad'): - gr.gid = int(sidToStr(obj['objectSid'][0]).split('-')[-1]) - elif self.conf.get('use_rid'): - gr.gid = int(sidToStr(obj['sambaSID'][0]).split('-')[-1]) + if self.conf.get("ad"): + gr.gid = int(sidToStr(obj["objectSid"][0]).split("-")[-1]) + elif self.conf.get("use_rid"): + gr.gid = int(sidToStr(obj["sambaSID"][0]).split("-")[-1]) else: - gr.gid = int(obj['gidNumber'][0]) + gr.gid = int(obj["gidNumber"][0]) - if 'offset' in self.conf: - gr.gid = int(gr.gid + self.conf['offset']) + if "offset" in self.conf: + gr.gid = int(gr.gid + self.conf["offset"]) gr.members = members gr.groupmembers = group_members @@ -936,17 +970,19 @@ def Transform(self, obj): def PostProcess(self, data_map, source, search_filter, search_scope): """Perform some post-process of the data.""" - if 'uniqueMember' in self.attrs: + if "uniqueMember" in self.attrs: for gr in data_map: uidmembers = [] for member in gr.members: - source.Search(search_base=member, - search_filter='(objectClass=*)', - search_scope=ldap.SCOPE_BASE, - attrs=['uid']) + source.Search( + search_base=member, + search_filter="(objectClass=*)", + search_scope=ldap.SCOPE_BASE, + attrs=["uid"], + ) for obj in source: - if 'uid' in obj: - uidmembers.extend(obj['uid']) + if "uid" in obj: + uidmembers.extend(obj["uid"]) del gr.members[:] gr.members.extend(uidmembers) @@ -961,7 +997,10 @@ def _expand_members(obj, visited=None): if member not in obj.members: obj.members.append(member) for submember_name in gmember.groupmembers: - if submember_name in _group_map and submember_name not in visited: + if ( + submember_name in _group_map + and submember_name not in visited + ): visited.append(submember_name) _expand_members(_group_map[submember_name], visited) @@ -977,20 +1016,26 @@ class ShadowUpdateGetter(UpdateGetter): def __init__(self, conf): super(ShadowUpdateGetter, self).__init__(conf) self.attrs = [ - 'uid', 'shadowLastChange', 'shadowMin', 'shadowMax', - 'shadowWarning', 'shadowInactive', 'shadowExpire', 'shadowFlag', - 'userPassword' + "uid", + "shadowLastChange", + "shadowMin", + "shadowMax", + "shadowWarning", + "shadowInactive", + "shadowExpire", + "shadowFlag", + "userPassword", ] - if self.conf.get('ad'): + if self.conf.get("ad"): # attributes of AD user to be returned for shadow - self.attrs.extend(('sAMAccountName', 'pwdLastSet')) - self.essential_fields = ['sAMAccountName', 'pwdLastSet'] + self.attrs.extend(("sAMAccountName", "pwdLastSet")) + self.essential_fields = ["sAMAccountName", "pwdLastSet"] else: - if 'uidattr' in self.conf: - self.attrs.append(self.conf['uidattr']) - if 'uidregex' in self.conf: - self.uidregex = re.compile(self.conf['uidregex']) - self.essential_fields = ['uid'] + if "uidattr" in self.conf: + self.attrs.append(self.conf["uidattr"]) + if "uidregex" in self.conf: + self.uidregex = re.compile(self.conf["uidregex"]) + self.essential_fields = ["uid"] self.log = logging.getLogger(self.__class__.__name__) def CreateMap(self): @@ -1002,50 +1047,52 @@ def Transform(self, obj): shadow_ent = shadow.ShadowMapEntry() - if self.conf.get('ad'): - shadow_ent.name = obj['sAMAccountName'][0] - elif 'uidattr' in self.conf: - shadow_ent.name = obj[self.conf['uidattr']][0] + if self.conf.get("ad"): + shadow_ent.name = obj["sAMAccountName"][0] + elif "uidattr" in self.conf: + shadow_ent.name = obj[self.conf["uidattr"]][0] else: - shadow_ent.name = obj['uid'][0] + shadow_ent.name = obj["uid"][0] - if hasattr(self, 'uidregex'): - shadow_ent.name = ''.join( - [x for x in self.uidregex.findall(shadow_end.name)]) + if hasattr(self, "uidregex"): + shadow_ent.name = "".join( + [x for x in self.uidregex.findall(shadow_end.name)] + ) # TODO(jaq): does nss_ldap check the contents of the userPassword # attribute? - shadow_ent.passwd = '*' - if self.conf.get('ad'): + shadow_ent.passwd = "*" + if self.conf.get("ad"): # Time attributes of AD objects use interval date/time format with a value # that represents the number of 100-nanosecond intervals since January 1, 1601. # We need to calculate the difference between 1970-01-01 and 1601-01-01 in seconds wich is 11644473600 # then abstract it from the pwdLastChange value in seconds, then devide it by 86400 to get the # days since Jan 1, 1970 the password wa changed. shadow_ent.lstchg = int( - (int(obj['pwdLastSet'][0]) / 10000000 - 11644473600) / 86400) - elif 'shadowLastChange' in obj: - shadow_ent.lstchg = int(obj['shadowLastChange'][0]) - if 'shadowMin' in obj: - shadow_ent.min = int(obj['shadowMin'][0]) - if 'shadowMax' in obj: - shadow_ent.max = int(obj['shadowMax'][0]) - if 'shadowWarning' in obj: - shadow_ent.warn = int(obj['shadowWarning'][0]) - if 'shadowInactive' in obj: - shadow_ent.inact = int(obj['shadowInactive'][0]) - if 'shadowExpire' in obj: - shadow_ent.expire = int(obj['shadowExpire'][0]) - if 'shadowFlag' in obj: - shadow_ent.flag = int(obj['shadowFlag'][0]) + (int(obj["pwdLastSet"][0]) / 10000000 - 11644473600) / 86400 + ) + elif "shadowLastChange" in obj: + shadow_ent.lstchg = int(obj["shadowLastChange"][0]) + if "shadowMin" in obj: + shadow_ent.min = int(obj["shadowMin"][0]) + if "shadowMax" in obj: + shadow_ent.max = int(obj["shadowMax"][0]) + if "shadowWarning" in obj: + shadow_ent.warn = int(obj["shadowWarning"][0]) + if "shadowInactive" in obj: + shadow_ent.inact = int(obj["shadowInactive"][0]) + if "shadowExpire" in obj: + shadow_ent.expire = int(obj["shadowExpire"][0]) + if "shadowFlag" in obj: + shadow_ent.flag = int(obj["shadowFlag"][0]) if shadow_ent.flag is None: shadow_ent.flag = 0 - if 'userPassword' in obj: - passwd = obj['userPassword'][0] - if passwd[:7].lower() == '{crypt}': + if "userPassword" in obj: + passwd = obj["userPassword"][0] + if passwd[:7].lower() == "{crypt}": shadow_ent.passwd = passwd[7:] else: - logging.info('Ignored password that was not in crypt format') + logging.info("Ignored password that was not in crypt format") return shadow_ent @@ -1054,8 +1101,8 @@ class NetgroupUpdateGetter(UpdateGetter): def __init__(self, conf): super(NetgroupUpdateGetter, self).__init__(conf) - self.attrs = ['cn', 'memberNisNetgroup', 'nisNetgroupTriple'] - self.essential_fields = ['cn'] + self.attrs = ["cn", "memberNisNetgroup", "nisNetgroupTriple"] + self.essential_fields = ["cn"] def CreateMap(self): """Return a NetgroupMap instance.""" @@ -1064,16 +1111,16 @@ def CreateMap(self): def Transform(self, obj): """Transforms an LDAP nisNetgroup object into a netgroup(5) entry.""" netgroup_ent = netgroup.NetgroupMapEntry() - netgroup_ent.name = obj['cn'][0] + netgroup_ent.name = obj["cn"][0] entries = set() - if 'memberNisNetgroup' in obj: - entries.update(obj['memberNisNetgroup']) - if 'nisNetgroupTriple' in obj: - entries.update(obj['nisNetgroupTriple']) + if "memberNisNetgroup" in obj: + entries.update(obj["memberNisNetgroup"]) + if "nisNetgroupTriple" in obj: + entries.update(obj["nisNetgroupTriple"]) # final data is stored as a string in the object - netgroup_ent.entries = ' '.join(sorted(entries)) + netgroup_ent.entries = " ".join(sorted(entries)) return netgroup_ent @@ -1083,8 +1130,8 @@ class AutomountUpdateGetter(UpdateGetter): def __init__(self, conf): super(AutomountUpdateGetter, self).__init__(conf) - self.attrs = ['cn', 'automountInformation'] - self.essential_fields = ['cn'] + self.attrs = ["cn", "automountInformation"] + self.essential_fields = ["cn"] def CreateMap(self): """Return a AutomountMap instance.""" @@ -1093,17 +1140,17 @@ def CreateMap(self): def Transform(self, obj): """Transforms an LDAP automount object into an autofs(5) entry.""" automount_ent = automount.AutomountMapEntry() - automount_ent.key = obj['cn'][0] + automount_ent.key = obj["cn"][0] - automount_information = obj['automountInformation'][0] + automount_information = obj["automountInformation"][0] - if automount_information.startswith('ldap'): + if automount_information.startswith("ldap"): # we are creating an autmount master map, pointing to other maps in LDAP automount_ent.location = automount_information else: # we are creating normal automount maps, with filesystems and options - automount_ent.options = automount_information.split(' ')[0] - automount_ent.location = automount_information.split(' ')[1] + automount_ent.options = automount_information.split(" ")[0] + automount_ent.location = automount_information.split(" ")[1] return automount_ent @@ -1113,12 +1160,12 @@ class SshkeyUpdateGetter(UpdateGetter): def __init__(self, conf): super(SshkeyUpdateGetter, self).__init__(conf) - self.attrs = ['uid', 'sshPublicKey'] - if 'uidattr' in self.conf: - self.attrs.append(self.conf['uidattr']) - if 'uidregex' in self.conf: - self.uidregex = re.compile(self.conf['uidregex']) - self.essential_fields = ['uid'] + self.attrs = ["uid", "sshPublicKey"] + if "uidattr" in self.conf: + self.attrs.append(self.conf["uidattr"]) + if "uidregex" in self.conf: + self.uidregex = re.compile(self.conf["uidregex"]) + self.essential_fields = ["uid"] def CreateMap(self): """Returns a new SshkeyMap instance to have SshkeyMapEntries added to @@ -1131,17 +1178,17 @@ def Transform(self, obj): skey = sshkey.SshkeyMapEntry() - if 'uidattr' in self.conf: - skey.name = obj[self.conf['uidattr']][0] + if "uidattr" in self.conf: + skey.name = obj[self.conf["uidattr"]][0] else: - skey.name = obj['uid'][0] + skey.name = obj["uid"][0] - if hasattr(self, 'uidregex'): - skey.name = ''.join([x for x in self.uidregex.findall(pw.name)]) + if hasattr(self, "uidregex"): + skey.name = "".join([x for x in self.uidregex.findall(pw.name)]) - if 'sshPublicKey' in obj: - skey.sshkey = obj['sshPublicKey'] + if "sshPublicKey" in obj: + skey.sshkey = obj["sshPublicKey"] else: - skey.sshkey = '' + skey.sshkey = "" return skey diff --git a/nss_cache/sources/ldapsource_test.py b/nss_cache/sources/ldapsource_test.py index 85f7c162..376cc68d 100644 --- a/nss_cache/sources/ldapsource_test.py +++ b/nss_cache/sources/ldapsource_test.py @@ -15,8 +15,10 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """An implementation of a mock ldap data source for nsscache.""" -__author__ = ('jaq@google.com (Jamie Wilkinson)', - 'vasilios@google.com (Vasilios Hoffman)') +__author__ = ( + "jaq@google.com (Jamie Wilkinson)", + "vasilios@google.com (Vasilios Hoffman)", +) import time import unittest @@ -32,30 +34,28 @@ TEST_RETRY_MAX = 1 TEST_RETRY_DELAY = 0 -TEST_URI = 'TEST_URI' +TEST_URI = "TEST_URI" class TestLdapSource(mox.MoxTestBase): - def setUp(self): """Initialize a basic config dict.""" super(TestLdapSource, self).setUp() self.config = { - 'uri': 'TEST_URI', - 'base': 'TEST_BASE', - 'filter': 'TEST_FILTER', - 'bind_dn': 'TEST_BIND_DN', - 'bind_password': 'TEST_BIND_PASSWORD', - 'retry_delay': TEST_RETRY_DELAY, - 'retry_max': TEST_RETRY_MAX, - 'timelimit': 'TEST_TIMELIMIT', - 'tls_require_cert': 0, - 'tls_cacertdir': 'TEST_TLS_CACERTDIR', - 'tls_cacertfile': 'TEST_TLS_CACERTFILE', + "uri": "TEST_URI", + "base": "TEST_BASE", + "filter": "TEST_FILTER", + "bind_dn": "TEST_BIND_DN", + "bind_password": "TEST_BIND_PASSWORD", + "retry_delay": TEST_RETRY_DELAY, + "retry_max": TEST_RETRY_MAX, + "timelimit": "TEST_TIMELIMIT", + "tls_require_cert": 0, + "tls_cacertdir": "TEST_TLS_CACERTDIR", + "tls_cacertfile": "TEST_TLS_CACERTFILE", } - def compareSPRC(self, expected_value=''): - + def compareSPRC(self, expected_value=""): def comparator(param): if not isinstance(param, list): return False @@ -70,120 +70,122 @@ def comparator(param): return comparator def testDefaultConfiguration(self): - config = {'uri': 'ldap://foo'} + config = {"uri": "ldap://foo"} mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='', who='') - self.mox.StubOutWithMock(ldap, 'ldapobject') - ldap.ldapobject.ReconnectLDAPObject(uri='ldap://foo', - retry_max=3, - retry_delay=5).AndReturn(mock_rlo) + mock_rlo.simple_bind_s(cred="", who="") + self.mox.StubOutWithMock(ldap, "ldapobject") + ldap.ldapobject.ReconnectLDAPObject( + uri="ldap://foo", retry_max=3, retry_delay=5 + ).AndReturn(mock_rlo) self.mox.ReplayAll() source = ldapsource.LdapSource(config) - self.assertEqual(source.conf['bind_dn'], ldapsource.LdapSource.BIND_DN) - self.assertEqual(source.conf['bind_password'], - ldapsource.LdapSource.BIND_PASSWORD) - self.assertEqual(source.conf['retry_max'], - ldapsource.LdapSource.RETRY_MAX) - self.assertEqual(source.conf['retry_delay'], - ldapsource.LdapSource.RETRY_DELAY) - self.assertEqual(source.conf['scope'], ldapsource.LdapSource.SCOPE) - self.assertEqual(source.conf['timelimit'], - ldapsource.LdapSource.TIMELIMIT) - self.assertEqual(source.conf['tls_require_cert'], ldap.OPT_X_TLS_DEMAND) + self.assertEqual(source.conf["bind_dn"], ldapsource.LdapSource.BIND_DN) + self.assertEqual( + source.conf["bind_password"], ldapsource.LdapSource.BIND_PASSWORD + ) + self.assertEqual(source.conf["retry_max"], ldapsource.LdapSource.RETRY_MAX) + self.assertEqual(source.conf["retry_delay"], ldapsource.LdapSource.RETRY_DELAY) + self.assertEqual(source.conf["scope"], ldapsource.LdapSource.SCOPE) + self.assertEqual(source.conf["timelimit"], ldapsource.LdapSource.TIMELIMIT) + self.assertEqual(source.conf["tls_require_cert"], ldap.OPT_X_TLS_DEMAND) def testOverrideDefaultConfiguration(self): config = dict(self.config) - config['scope'] = ldap.SCOPE_BASE + config["scope"] = ldap.SCOPE_BASE mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='TEST_BIND_PASSWORD', who='TEST_BIND_DN') - self.mox.StubOutWithMock(ldap, 'ldapobject') - ldap.ldapobject.ReconnectLDAPObject(retry_max=TEST_RETRY_MAX, - retry_delay=TEST_RETRY_DELAY, - uri='TEST_URI').AndReturn(mock_rlo) + mock_rlo.simple_bind_s(cred="TEST_BIND_PASSWORD", who="TEST_BIND_DN") + self.mox.StubOutWithMock(ldap, "ldapobject") + ldap.ldapobject.ReconnectLDAPObject( + retry_max=TEST_RETRY_MAX, retry_delay=TEST_RETRY_DELAY, uri="TEST_URI" + ).AndReturn(mock_rlo) self.mox.ReplayAll() source = ldapsource.LdapSource(config) - self.assertEqual(source.conf['scope'], ldap.SCOPE_BASE) - self.assertEqual(source.conf['bind_dn'], 'TEST_BIND_DN') - self.assertEqual(source.conf['bind_password'], 'TEST_BIND_PASSWORD') - self.assertEqual(source.conf['retry_delay'], TEST_RETRY_DELAY) - self.assertEqual(source.conf['retry_max'], TEST_RETRY_MAX) - self.assertEqual(source.conf['timelimit'], 'TEST_TIMELIMIT') - self.assertEqual(source.conf['tls_require_cert'], 0) - self.assertEqual(source.conf['tls_cacertdir'], 'TEST_TLS_CACERTDIR') - self.assertEqual(source.conf['tls_cacertfile'], 'TEST_TLS_CACERTFILE') + self.assertEqual(source.conf["scope"], ldap.SCOPE_BASE) + self.assertEqual(source.conf["bind_dn"], "TEST_BIND_DN") + self.assertEqual(source.conf["bind_password"], "TEST_BIND_PASSWORD") + self.assertEqual(source.conf["retry_delay"], TEST_RETRY_DELAY) + self.assertEqual(source.conf["retry_max"], TEST_RETRY_MAX) + self.assertEqual(source.conf["timelimit"], "TEST_TIMELIMIT") + self.assertEqual(source.conf["tls_require_cert"], 0) + self.assertEqual(source.conf["tls_cacertdir"], "TEST_TLS_CACERTDIR") + self.assertEqual(source.conf["tls_cacertfile"], "TEST_TLS_CACERTFILE") def testDebugLevelSet(self): config = dict(self.config) - config['ldap_debug'] = 3 + config["ldap_debug"] = 3 mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) mock_rlo.set_option(ldap.OPT_DEBUG_LEVEL, 3) - mock_rlo.simple_bind_s(cred='TEST_BIND_PASSWORD', who='TEST_BIND_DN') - self.mox.StubOutWithMock(ldap, 'ldapobject') - ldap.ldapobject.ReconnectLDAPObject(retry_max=TEST_RETRY_MAX, - retry_delay=TEST_RETRY_DELAY, - uri='TEST_URI').AndReturn(mock_rlo) + mock_rlo.simple_bind_s(cred="TEST_BIND_PASSWORD", who="TEST_BIND_DN") + self.mox.StubOutWithMock(ldap, "ldapobject") + ldap.ldapobject.ReconnectLDAPObject( + retry_max=TEST_RETRY_MAX, retry_delay=TEST_RETRY_DELAY, uri="TEST_URI" + ).AndReturn(mock_rlo) self.mox.ReplayAll() source = ldapsource.LdapSource(config) def testTrapServerDownAndRetry(self): config = dict(self.config) - config['bind_dn'] = '' - config['bind_password'] = '' - config['retry_delay'] = 5 - config['retry_max'] = 3 + config["bind_dn"] = "" + config["bind_password"] = "" + config["retry_delay"] = 5 + config["retry_max"] = 3 mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='', who='').MultipleTimes().AndRaise( - ldap.SERVER_DOWN) + mock_rlo.simple_bind_s(cred="", who="").MultipleTimes().AndRaise( + ldap.SERVER_DOWN + ) - self.mox.StubOutWithMock(ldap, 'ldapobject') + self.mox.StubOutWithMock(ldap, "ldapobject") ldap.ldapobject.ReconnectLDAPObject( - uri='TEST_URI', retry_max=3, - retry_delay=5).MultipleTimes().AndReturn(mock_rlo) + uri="TEST_URI", retry_max=3, retry_delay=5 + ).MultipleTimes().AndReturn(mock_rlo) - self.mox.StubOutWithMock(time, 'sleep') + self.mox.StubOutWithMock(time, "sleep") time.sleep(5) time.sleep(5) self.mox.ReplayAll() - self.assertRaises(error.SourceUnavailable, ldapsource.LdapSource, - config) + self.assertRaises(error.SourceUnavailable, ldapsource.LdapSource, config) def testIterationOverLdapDataSource(self): config = dict(self.config) mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='TEST_BIND_PASSWORD', who='TEST_BIND_DN') - mock_rlo.search_ext(base=config['base'], - filterstr='TEST_FILTER', - scope='TEST_SCOPE', - attrlist='TEST_ATTRLIST', - serverctrls=mox.Func( - self.compareSPRC())).AndReturn('TEST_RES') - - dataset = [('dn', {'uid': [0]})] - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_ENTRY, dataset, None, [])) - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_RESULT, None, None, [])) - - self.mox.StubOutWithMock(ldap, 'ldapobject') + mock_rlo.simple_bind_s(cred="TEST_BIND_PASSWORD", who="TEST_BIND_DN") + mock_rlo.search_ext( + base=config["base"], + filterstr="TEST_FILTER", + scope="TEST_SCOPE", + attrlist="TEST_ATTRLIST", + serverctrls=mox.Func(self.compareSPRC()), + ).AndReturn("TEST_RES") + + dataset = [("dn", {"uid": [0]})] + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_ENTRY, dataset, None, []) + ) + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_RESULT, None, None, []) + ) + + self.mox.StubOutWithMock(ldap, "ldapobject") ldap.ldapobject.ReconnectLDAPObject( - uri='TEST_URI', - retry_max=TEST_RETRY_MAX, - retry_delay=TEST_RETRY_DELAY).AndReturn(mock_rlo) + uri="TEST_URI", retry_max=TEST_RETRY_MAX, retry_delay=TEST_RETRY_DELAY + ).AndReturn(mock_rlo) self.mox.ReplayAll() source = ldapsource.LdapSource(config) - source.Search(search_base=config['base'], - search_filter='TEST_FILTER', - search_scope='TEST_SCOPE', - attrs='TEST_ATTRLIST') + source.Search( + search_base=config["base"], + search_filter="TEST_FILTER", + search_scope="TEST_SCOPE", + attrs="TEST_ATTRLIST", + ) count = 0 for r in source: @@ -194,39 +196,42 @@ def testIterationOverLdapDataSource(self): def testIterationTimeout(self): config = dict(self.config) - config['retry_delay'] = 5 - config['retry_max'] = 3 + config["retry_delay"] = 5 + config["retry_max"] = 3 mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='TEST_BIND_PASSWORD', who='TEST_BIND_DN') - mock_rlo.search_ext(base=config['base'], - filterstr='TEST_FILTER', - scope='TEST_SCOPE', - attrlist='TEST_ATTRLIST', - serverctrls=mox.Func( - self.compareSPRC())).AndReturn('TEST_RES') - - dataset = [('dn', {'uid': [0]})] - mock_rlo.result3('TEST_RES', all=0, - timeout='TEST_TIMELIMIT').MultipleTimes().AndRaise( - ldap.TIMELIMIT_EXCEEDED) - - self.mox.StubOutWithMock(time, 'sleep') + mock_rlo.simple_bind_s(cred="TEST_BIND_PASSWORD", who="TEST_BIND_DN") + mock_rlo.search_ext( + base=config["base"], + filterstr="TEST_FILTER", + scope="TEST_SCOPE", + attrlist="TEST_ATTRLIST", + serverctrls=mox.Func(self.compareSPRC()), + ).AndReturn("TEST_RES") + + dataset = [("dn", {"uid": [0]})] + mock_rlo.result3( + "TEST_RES", all=0, timeout="TEST_TIMELIMIT" + ).MultipleTimes().AndRaise(ldap.TIMELIMIT_EXCEEDED) + + self.mox.StubOutWithMock(time, "sleep") time.sleep(5) time.sleep(5) - self.mox.StubOutWithMock(ldap, 'ldapobject') - ldap.ldapobject.ReconnectLDAPObject(uri='TEST_URI', - retry_max=3, - retry_delay=5).AndReturn(mock_rlo) + self.mox.StubOutWithMock(ldap, "ldapobject") + ldap.ldapobject.ReconnectLDAPObject( + uri="TEST_URI", retry_max=3, retry_delay=5 + ).AndReturn(mock_rlo) self.mox.ReplayAll() source = ldapsource.LdapSource(config) - source.Search(search_base=config['base'], - search_filter='TEST_FILTER', - search_scope='TEST_SCOPE', - attrs='TEST_ATTRLIST') + source.Search( + search_base=config["base"], + search_filter="TEST_FILTER", + search_scope="TEST_SCOPE", + attrs="TEST_ATTRLIST", + ) count = 0 for r in source: @@ -235,43 +240,55 @@ def testIterationTimeout(self): self.assertEqual(0, count) def testGetPasswdMap(self): - test_posix_account = ('cn=test,ou=People,dc=example,dc=com', { - 'sambaSID': ['S-1-5-21-2127521184-1604012920-1887927527-72713'], - 'uidNumber': ['1000'], - 'gidNumber': ['1000'], - 'uid': ['test'], - 'cn': ['Testguy McTest'], - 'homeDirectory': ['/home/test'], - 'loginShell': ['/bin/sh'], - 'userPassword': ['p4ssw0rd'], - 'modifyTimestamp': ['20070227012807Z'] - }) + test_posix_account = ( + "cn=test,ou=People,dc=example,dc=com", + { + "sambaSID": ["S-1-5-21-2127521184-1604012920-1887927527-72713"], + "uidNumber": ["1000"], + "gidNumber": ["1000"], + "uid": ["test"], + "cn": ["Testguy McTest"], + "homeDirectory": ["/home/test"], + "loginShell": ["/bin/sh"], + "userPassword": ["p4ssw0rd"], + "modifyTimestamp": ["20070227012807Z"], + }, + ) config = dict(self.config) attrlist = [ - 'uid', 'uidNumber', 'gidNumber', 'gecos', 'cn', 'homeDirectory', - 'fullName', 'loginShell', 'modifyTimestamp' + "uid", + "uidNumber", + "gidNumber", + "gecos", + "cn", + "homeDirectory", + "fullName", + "loginShell", + "modifyTimestamp", ] mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='TEST_BIND_PASSWORD', who='TEST_BIND_DN') - mock_rlo.search_ext(base='TEST_BASE', - filterstr='TEST_FILTER', - scope=ldap.SCOPE_ONELEVEL, - attrlist=mox.SameElementsAs(attrlist), - serverctrls=mox.Func( - self.compareSPRC())).AndReturn('TEST_RES') - - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_ENTRY, [test_posix_account], None, [])) - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_RESULT, None, None, [])) - - self.mox.StubOutWithMock(ldap, 'ldapobject') + mock_rlo.simple_bind_s(cred="TEST_BIND_PASSWORD", who="TEST_BIND_DN") + mock_rlo.search_ext( + base="TEST_BASE", + filterstr="TEST_FILTER", + scope=ldap.SCOPE_ONELEVEL, + attrlist=mox.SameElementsAs(attrlist), + serverctrls=mox.Func(self.compareSPRC()), + ).AndReturn("TEST_RES") + + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_ENTRY, [test_posix_account], None, []) + ) + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_RESULT, None, None, []) + ) + + self.mox.StubOutWithMock(ldap, "ldapobject") ldap.ldapobject.ReconnectLDAPObject( - uri='TEST_URI', - retry_max=TEST_RETRY_MAX, - retry_delay=TEST_RETRY_DELAY).AndReturn(mock_rlo) + uri="TEST_URI", retry_max=TEST_RETRY_MAX, retry_delay=TEST_RETRY_DELAY + ).AndReturn(mock_rlo) self.mox.ReplayAll() @@ -282,48 +299,61 @@ def testGetPasswdMap(self): first = data.PopItem() - self.assertEqual('test', first.name) + self.assertEqual("test", first.name) def testGetPasswdMapWithUidAttr(self): - test_posix_account = ('cn=test,ou=People,dc=example,dc=com', { - 'sambaSID': ['S-1-5-21-2127521184-1604012920-1887927527-72713'], - 'uidNumber': [1000], - 'gidNumber': [1000], - 'uid': ['test'], - 'name': ['test'], - 'cn': ['Testguy McTest'], - 'homeDirectory': ['/home/test'], - 'loginShell': ['/bin/sh'], - 'userPassword': ['p4ssw0rd'], - 'modifyTimestamp': ['20070227012807Z'] - }) + test_posix_account = ( + "cn=test,ou=People,dc=example,dc=com", + { + "sambaSID": ["S-1-5-21-2127521184-1604012920-1887927527-72713"], + "uidNumber": [1000], + "gidNumber": [1000], + "uid": ["test"], + "name": ["test"], + "cn": ["Testguy McTest"], + "homeDirectory": ["/home/test"], + "loginShell": ["/bin/sh"], + "userPassword": ["p4ssw0rd"], + "modifyTimestamp": ["20070227012807Z"], + }, + ) config = dict(self.config) - config['uidattr'] = 'name' + config["uidattr"] = "name" attrlist = [ - 'uid', 'uidNumber', 'gidNumber', 'gecos', 'cn', 'homeDirectory', - 'fullName', 'name', 'loginShell', 'modifyTimestamp' + "uid", + "uidNumber", + "gidNumber", + "gecos", + "cn", + "homeDirectory", + "fullName", + "name", + "loginShell", + "modifyTimestamp", ] mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='TEST_BIND_PASSWORD', who='TEST_BIND_DN') - mock_rlo.search_ext(base='TEST_BASE', - filterstr='TEST_FILTER', - scope=ldap.SCOPE_ONELEVEL, - attrlist=mox.SameElementsAs(attrlist), - serverctrls=mox.Func( - self.compareSPRC())).AndReturn('TEST_RES') - - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_ENTRY, [test_posix_account], None, [])) - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_RESULT, None, None, [])) - - self.mox.StubOutWithMock(ldap, 'ldapobject') + mock_rlo.simple_bind_s(cred="TEST_BIND_PASSWORD", who="TEST_BIND_DN") + mock_rlo.search_ext( + base="TEST_BASE", + filterstr="TEST_FILTER", + scope=ldap.SCOPE_ONELEVEL, + attrlist=mox.SameElementsAs(attrlist), + serverctrls=mox.Func(self.compareSPRC()), + ).AndReturn("TEST_RES") + + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_ENTRY, [test_posix_account], None, []) + ) + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_RESULT, None, None, []) + ) + + self.mox.StubOutWithMock(ldap, "ldapobject") ldap.ldapobject.ReconnectLDAPObject( - uri='TEST_URI', - retry_max=TEST_RETRY_MAX, - retry_delay=TEST_RETRY_DELAY).AndReturn(mock_rlo) + uri="TEST_URI", retry_max=TEST_RETRY_MAX, retry_delay=TEST_RETRY_DELAY + ).AndReturn(mock_rlo) self.mox.ReplayAll() @@ -334,47 +364,59 @@ def testGetPasswdMapWithUidAttr(self): first = data.PopItem() - self.assertEqual('test', first.name) + self.assertEqual("test", first.name) def testGetPasswdMapWithShellOverride(self): - test_posix_account = ('cn=test,ou=People,dc=example,dc=com', { - 'sambaSID': ['S-1-5-21-2127521184-1604012920-1887927527-72713'], - 'uidNumber': [1000], - 'gidNumber': [1000], - 'uid': ['test'], - 'cn': ['Testguy McTest'], - 'homeDirectory': ['/home/test'], - 'loginShell': ['/bin/sh'], - 'userPassword': ['p4ssw0rd'], - 'modifyTimestamp': ['20070227012807Z'] - }) + test_posix_account = ( + "cn=test,ou=People,dc=example,dc=com", + { + "sambaSID": ["S-1-5-21-2127521184-1604012920-1887927527-72713"], + "uidNumber": [1000], + "gidNumber": [1000], + "uid": ["test"], + "cn": ["Testguy McTest"], + "homeDirectory": ["/home/test"], + "loginShell": ["/bin/sh"], + "userPassword": ["p4ssw0rd"], + "modifyTimestamp": ["20070227012807Z"], + }, + ) config = dict(self.config) - config['override_shell'] = '/bin/false' + config["override_shell"] = "/bin/false" attrlist = [ - 'uid', 'uidNumber', 'gidNumber', 'gecos', 'cn', 'homeDirectory', - 'fullName', 'loginShell', 'modifyTimestamp' + "uid", + "uidNumber", + "gidNumber", + "gecos", + "cn", + "homeDirectory", + "fullName", + "loginShell", + "modifyTimestamp", ] mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='TEST_BIND_PASSWORD', who='TEST_BIND_DN') - mock_rlo.search_ext(base='TEST_BASE', - filterstr='TEST_FILTER', - scope=ldap.SCOPE_ONELEVEL, - attrlist=mox.SameElementsAs(attrlist), - serverctrls=mox.Func( - self.compareSPRC())).AndReturn('TEST_RES') - - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_ENTRY, [test_posix_account], None, [])) - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_RESULT, None, None, [])) - - self.mox.StubOutWithMock(ldap, 'ldapobject') + mock_rlo.simple_bind_s(cred="TEST_BIND_PASSWORD", who="TEST_BIND_DN") + mock_rlo.search_ext( + base="TEST_BASE", + filterstr="TEST_FILTER", + scope=ldap.SCOPE_ONELEVEL, + attrlist=mox.SameElementsAs(attrlist), + serverctrls=mox.Func(self.compareSPRC()), + ).AndReturn("TEST_RES") + + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_ENTRY, [test_posix_account], None, []) + ) + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_RESULT, None, None, []) + ) + + self.mox.StubOutWithMock(ldap, "ldapobject") ldap.ldapobject.ReconnectLDAPObject( - uri='TEST_URI', - retry_max=TEST_RETRY_MAX, - retry_delay=TEST_RETRY_DELAY).AndReturn(mock_rlo) + uri="TEST_URI", retry_max=TEST_RETRY_MAX, retry_delay=TEST_RETRY_DELAY + ).AndReturn(mock_rlo) self.mox.ReplayAll() @@ -385,47 +427,60 @@ def testGetPasswdMapWithShellOverride(self): first = data.PopItem() - self.assertEqual('/bin/false', first.shell) + self.assertEqual("/bin/false", first.shell) def testGetPasswdMapWithUseRid(self): - test_posix_account = ('cn=test,ou=People,dc=example,dc=com', { - 'sambaSID': ['S-1-5-21-2127521184-1604012920-1887927527-72713'], - 'uidNumber': [1000], - 'gidNumber': [1000], - 'uid': ['test'], - 'cn': ['Testguy McTest'], - 'homeDirectory': ['/home/test'], - 'loginShell': ['/bin/sh'], - 'userPassword': ['p4ssw0rd'], - 'modifyTimestamp': ['20070227012807Z'] - }) + test_posix_account = ( + "cn=test,ou=People,dc=example,dc=com", + { + "sambaSID": ["S-1-5-21-2127521184-1604012920-1887927527-72713"], + "uidNumber": [1000], + "gidNumber": [1000], + "uid": ["test"], + "cn": ["Testguy McTest"], + "homeDirectory": ["/home/test"], + "loginShell": ["/bin/sh"], + "userPassword": ["p4ssw0rd"], + "modifyTimestamp": ["20070227012807Z"], + }, + ) config = dict(self.config) - config['use_rid'] = '1' + config["use_rid"] = "1" attrlist = [ - 'uid', 'uidNumber', 'gidNumber', 'gecos', 'cn', 'homeDirectory', - 'fullName', 'sambaSID', 'loginShell', 'modifyTimestamp' + "uid", + "uidNumber", + "gidNumber", + "gecos", + "cn", + "homeDirectory", + "fullName", + "sambaSID", + "loginShell", + "modifyTimestamp", ] mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='TEST_BIND_PASSWORD', who='TEST_BIND_DN') - mock_rlo.search_ext(base='TEST_BASE', - filterstr='TEST_FILTER', - scope=ldap.SCOPE_ONELEVEL, - attrlist=mox.SameElementsAs(attrlist), - serverctrls=mox.Func( - self.compareSPRC())).AndReturn('TEST_RES') - - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_ENTRY, [test_posix_account], None, [])) - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_RESULT, None, None, [])) - - self.mox.StubOutWithMock(ldap, 'ldapobject') + mock_rlo.simple_bind_s(cred="TEST_BIND_PASSWORD", who="TEST_BIND_DN") + mock_rlo.search_ext( + base="TEST_BASE", + filterstr="TEST_FILTER", + scope=ldap.SCOPE_ONELEVEL, + attrlist=mox.SameElementsAs(attrlist), + serverctrls=mox.Func(self.compareSPRC()), + ).AndReturn("TEST_RES") + + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_ENTRY, [test_posix_account], None, []) + ) + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_RESULT, None, None, []) + ) + + self.mox.StubOutWithMock(ldap, "ldapobject") ldap.ldapobject.ReconnectLDAPObject( - uri='TEST_URI', - retry_max=TEST_RETRY_MAX, - retry_delay=TEST_RETRY_DELAY).AndReturn(mock_rlo) + uri="TEST_URI", retry_max=TEST_RETRY_MAX, retry_delay=TEST_RETRY_DELAY + ).AndReturn(mock_rlo) self.mox.ReplayAll() @@ -436,47 +491,57 @@ def testGetPasswdMapWithUseRid(self): first = data.PopItem() - self.assertEqual('test', first.name) + self.assertEqual("test", first.name) def testGetPasswdMapAD(self): - test_posix_account = ('cn=test,ou=People,dc=example,dc=com', { - 'objectSid': [ - b'\x01\x05\x00\x00\x00\x00\x00\x05\x15\x00\x00\x00\xa0e\xcf~xK\x9b_\xe7|\x87p\t\x1c\x01\x00' - ], - 'sAMAccountName': ['test'], - 'displayName': ['Testguy McTest'], - 'unixHomeDirectory': ['/home/test'], - 'loginShell': ['/bin/sh'], - 'pwdLastSet': ['132161071270000000'], - 'whenChanged': ['20070227012807.0Z'] - }) + test_posix_account = ( + "cn=test,ou=People,dc=example,dc=com", + { + "objectSid": [ + b"\x01\x05\x00\x00\x00\x00\x00\x05\x15\x00\x00\x00\xa0e\xcf~xK\x9b_\xe7|\x87p\t\x1c\x01\x00" + ], + "sAMAccountName": ["test"], + "displayName": ["Testguy McTest"], + "unixHomeDirectory": ["/home/test"], + "loginShell": ["/bin/sh"], + "pwdLastSet": ["132161071270000000"], + "whenChanged": ["20070227012807.0Z"], + }, + ) config = dict(self.config) - config['ad'] = '1' + config["ad"] = "1" attrlist = [ - 'sAMAccountName', 'pwdLastSet', 'loginShell', 'objectSid', - 'displayName', 'whenChanged', 'unixHomeDirectory' + "sAMAccountName", + "pwdLastSet", + "loginShell", + "objectSid", + "displayName", + "whenChanged", + "unixHomeDirectory", ] mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='TEST_BIND_PASSWORD', who='TEST_BIND_DN') - mock_rlo.search_ext(base='TEST_BASE', - filterstr='TEST_FILTER', - scope=ldap.SCOPE_ONELEVEL, - attrlist=mox.SameElementsAs(attrlist), - serverctrls=mox.Func( - self.compareSPRC())).AndReturn('TEST_RES') - - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_ENTRY, [test_posix_account], None, [])) - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_RESULT, None, None, [])) - - self.mox.StubOutWithMock(ldap, 'ldapobject') + mock_rlo.simple_bind_s(cred="TEST_BIND_PASSWORD", who="TEST_BIND_DN") + mock_rlo.search_ext( + base="TEST_BASE", + filterstr="TEST_FILTER", + scope=ldap.SCOPE_ONELEVEL, + attrlist=mox.SameElementsAs(attrlist), + serverctrls=mox.Func(self.compareSPRC()), + ).AndReturn("TEST_RES") + + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_ENTRY, [test_posix_account], None, []) + ) + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_RESULT, None, None, []) + ) + + self.mox.StubOutWithMock(ldap, "ldapobject") ldap.ldapobject.ReconnectLDAPObject( - uri='TEST_URI', - retry_max=TEST_RETRY_MAX, - retry_delay=TEST_RETRY_DELAY).AndReturn(mock_rlo) + uri="TEST_URI", retry_max=TEST_RETRY_MAX, retry_delay=TEST_RETRY_DELAY + ).AndReturn(mock_rlo) self.mox.ReplayAll() @@ -487,48 +552,58 @@ def testGetPasswdMapAD(self): first = data.PopItem() - self.assertEqual('test', first.name) + self.assertEqual("test", first.name) def testGetPasswdMapADWithOffeset(self): - test_posix_account = ('cn=test,ou=People,dc=example,dc=com', { - 'objectSid': [ - b'\x01\x05\x00\x00\x00\x00\x00\x05\x15\x00\x00\x00\xa0e\xcf~xK\x9b_\xe7|\x87p\t\x1c\x01\x00' - ], - 'sAMAccountName': ['test'], - 'displayName': ['Testguy McTest'], - 'unixHomeDirectory': ['/home/test'], - 'loginShell': ['/bin/sh'], - 'pwdLastSet': ['132161071270000000'], - 'whenChanged': ['20070227012807.0Z'] - }) + test_posix_account = ( + "cn=test,ou=People,dc=example,dc=com", + { + "objectSid": [ + b"\x01\x05\x00\x00\x00\x00\x00\x05\x15\x00\x00\x00\xa0e\xcf~xK\x9b_\xe7|\x87p\t\x1c\x01\x00" + ], + "sAMAccountName": ["test"], + "displayName": ["Testguy McTest"], + "unixHomeDirectory": ["/home/test"], + "loginShell": ["/bin/sh"], + "pwdLastSet": ["132161071270000000"], + "whenChanged": ["20070227012807.0Z"], + }, + ) config = dict(self.config) - config['ad'] = '1' - config['offset'] = 10000 + config["ad"] = "1" + config["offset"] = 10000 attrlist = [ - 'sAMAccountName', 'pwdLastSet', 'loginShell', 'objectSid', - 'displayName', 'whenChanged', 'unixHomeDirectory' + "sAMAccountName", + "pwdLastSet", + "loginShell", + "objectSid", + "displayName", + "whenChanged", + "unixHomeDirectory", ] mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='TEST_BIND_PASSWORD', who='TEST_BIND_DN') - mock_rlo.search_ext(base='TEST_BASE', - filterstr='TEST_FILTER', - scope=ldap.SCOPE_ONELEVEL, - attrlist=mox.SameElementsAs(attrlist), - serverctrls=mox.Func( - self.compareSPRC())).AndReturn('TEST_RES') - - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_ENTRY, [test_posix_account], None, [])) - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_RESULT, None, None, [])) - - self.mox.StubOutWithMock(ldap, 'ldapobject') + mock_rlo.simple_bind_s(cred="TEST_BIND_PASSWORD", who="TEST_BIND_DN") + mock_rlo.search_ext( + base="TEST_BASE", + filterstr="TEST_FILTER", + scope=ldap.SCOPE_ONELEVEL, + attrlist=mox.SameElementsAs(attrlist), + serverctrls=mox.Func(self.compareSPRC()), + ).AndReturn("TEST_RES") + + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_ENTRY, [test_posix_account], None, []) + ) + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_RESULT, None, None, []) + ) + + self.mox.StubOutWithMock(ldap, "ldapobject") ldap.ldapobject.ReconnectLDAPObject( - uri='TEST_URI', - retry_max=TEST_RETRY_MAX, - retry_delay=TEST_RETRY_DELAY).AndReturn(mock_rlo) + uri="TEST_URI", retry_max=TEST_RETRY_MAX, retry_delay=TEST_RETRY_DELAY + ).AndReturn(mock_rlo) self.mox.ReplayAll() @@ -539,39 +614,44 @@ def testGetPasswdMapADWithOffeset(self): first = data.PopItem() - self.assertEqual('test', first.name) + self.assertEqual("test", first.name) def testGetGroupMap(self): - test_posix_group = ('cn=test,ou=Group,dc=example,dc=com', { - 'sambaSID': ['S-1-5-21-2127521184-1604012920-1887927527-72713'], - 'gidNumber': [1000], - 'cn': ['testgroup'], - 'memberUid': ['testguy', 'fooguy', 'barguy'], - 'modifyTimestamp': ['20070227012807Z'] - }) + test_posix_group = ( + "cn=test,ou=Group,dc=example,dc=com", + { + "sambaSID": ["S-1-5-21-2127521184-1604012920-1887927527-72713"], + "gidNumber": [1000], + "cn": ["testgroup"], + "memberUid": ["testguy", "fooguy", "barguy"], + "modifyTimestamp": ["20070227012807Z"], + }, + ) config = dict(self.config) - attrlist = ['cn', 'uid', 'gidNumber', 'memberUid', 'modifyTimestamp'] + attrlist = ["cn", "uid", "gidNumber", "memberUid", "modifyTimestamp"] mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='TEST_BIND_PASSWORD', who='TEST_BIND_DN') - mock_rlo.search_ext(base='TEST_BASE', - filterstr='TEST_FILTER', - scope=ldap.SCOPE_ONELEVEL, - attrlist=mox.SameElementsAs(attrlist), - serverctrls=mox.Func( - self.compareSPRC())).AndReturn('TEST_RES') - - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_ENTRY, [test_posix_group], None, [])) - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_RESULT, None, None, [])) - - self.mox.StubOutWithMock(ldap, 'ldapobject') + mock_rlo.simple_bind_s(cred="TEST_BIND_PASSWORD", who="TEST_BIND_DN") + mock_rlo.search_ext( + base="TEST_BASE", + filterstr="TEST_FILTER", + scope=ldap.SCOPE_ONELEVEL, + attrlist=mox.SameElementsAs(attrlist), + serverctrls=mox.Func(self.compareSPRC()), + ).AndReturn("TEST_RES") + + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_ENTRY, [test_posix_group], None, []) + ) + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_RESULT, None, None, []) + ) + + self.mox.StubOutWithMock(ldap, "ldapobject") ldap.ldapobject.ReconnectLDAPObject( - uri='TEST_URI', - retry_max=TEST_RETRY_MAX, - retry_delay=TEST_RETRY_DELAY).AndReturn(mock_rlo) + uri="TEST_URI", retry_max=TEST_RETRY_MAX, retry_delay=TEST_RETRY_DELAY + ).AndReturn(mock_rlo) self.mox.ReplayAll() @@ -582,42 +662,52 @@ def testGetGroupMap(self): ent = data.PopItem() - self.assertEqual('testgroup', ent.name) + self.assertEqual("testgroup", ent.name) def testGetGroupMapWithUseRid(self): - test_posix_group = ('cn=test,ou=Group,dc=example,dc=com', { - 'sambaSID': ['S-1-5-21-2127521184-1604012920-1887927527-72713'], - 'gidNumber': [1000], - 'cn': ['testgroup'], - 'memberUid': ['testguy', 'fooguy', 'barguy'], - 'modifyTimestamp': ['20070227012807Z'] - }) + test_posix_group = ( + "cn=test,ou=Group,dc=example,dc=com", + { + "sambaSID": ["S-1-5-21-2127521184-1604012920-1887927527-72713"], + "gidNumber": [1000], + "cn": ["testgroup"], + "memberUid": ["testguy", "fooguy", "barguy"], + "modifyTimestamp": ["20070227012807Z"], + }, + ) config = dict(self.config) - config['use_rid'] = '1' + config["use_rid"] = "1" attrlist = [ - 'cn', 'uid', 'gidNumber', 'memberUid', 'sambaSID', 'modifyTimestamp' + "cn", + "uid", + "gidNumber", + "memberUid", + "sambaSID", + "modifyTimestamp", ] mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='TEST_BIND_PASSWORD', who='TEST_BIND_DN') - mock_rlo.search_ext(base='TEST_BASE', - filterstr='TEST_FILTER', - scope=ldap.SCOPE_ONELEVEL, - attrlist=mox.SameElementsAs(attrlist), - serverctrls=mox.Func( - self.compareSPRC())).AndReturn('TEST_RES') - - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_ENTRY, [test_posix_group], None, [])) - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_RESULT, None, None, [])) - - self.mox.StubOutWithMock(ldap, 'ldapobject') + mock_rlo.simple_bind_s(cred="TEST_BIND_PASSWORD", who="TEST_BIND_DN") + mock_rlo.search_ext( + base="TEST_BASE", + filterstr="TEST_FILTER", + scope=ldap.SCOPE_ONELEVEL, + attrlist=mox.SameElementsAs(attrlist), + serverctrls=mox.Func(self.compareSPRC()), + ).AndReturn("TEST_RES") + + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_ENTRY, [test_posix_group], None, []) + ) + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_RESULT, None, None, []) + ) + + self.mox.StubOutWithMock(ldap, "ldapobject") ldap.ldapobject.ReconnectLDAPObject( - uri='TEST_URI', - retry_max=TEST_RETRY_MAX, - retry_delay=TEST_RETRY_DELAY).AndReturn(mock_rlo) + uri="TEST_URI", retry_max=TEST_RETRY_MAX, retry_delay=TEST_RETRY_DELAY + ).AndReturn(mock_rlo) self.mox.ReplayAll() @@ -628,46 +718,56 @@ def testGetGroupMapWithUseRid(self): ent = data.PopItem() - self.assertEqual('testgroup', ent.name) + self.assertEqual("testgroup", ent.name) def testGetGroupMapAsUser(self): - test_posix_group = ('cn=test,ou=People,dc=example,dc=com', { - 'sambaSID': ['S-1-5-21-2127521184-1604012920-1887927527-72713'], - 'uidNumber': [1000], - 'gidNumber': [1000], - 'uid': ['test'], - 'cn': ['Testguy McTest'], - 'homeDirectory': ['/home/test'], - 'loginShell': ['/bin/sh'], - 'userPassword': ['p4ssw0rd'], - 'modifyTimestamp': ['20070227012807Z'] - }) + test_posix_group = ( + "cn=test,ou=People,dc=example,dc=com", + { + "sambaSID": ["S-1-5-21-2127521184-1604012920-1887927527-72713"], + "uidNumber": [1000], + "gidNumber": [1000], + "uid": ["test"], + "cn": ["Testguy McTest"], + "homeDirectory": ["/home/test"], + "loginShell": ["/bin/sh"], + "userPassword": ["p4ssw0rd"], + "modifyTimestamp": ["20070227012807Z"], + }, + ) config = dict(self.config) - config['use_rid'] = '1' + config["use_rid"] = "1" attrlist = [ - 'cn', 'uid', 'gidNumber', 'memberUid', 'sambaSID', 'modifyTimestamp' + "cn", + "uid", + "gidNumber", + "memberUid", + "sambaSID", + "modifyTimestamp", ] mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='TEST_BIND_PASSWORD', who='TEST_BIND_DN') - mock_rlo.search_ext(base='TEST_BASE', - filterstr='TEST_FILTER', - scope=ldap.SCOPE_ONELEVEL, - attrlist=mox.SameElementsAs(attrlist), - serverctrls=mox.Func( - self.compareSPRC())).AndReturn('TEST_RES') - - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_ENTRY, [test_posix_group], None, [])) - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_RESULT, None, None, [])) - - self.mox.StubOutWithMock(ldap, 'ldapobject') + mock_rlo.simple_bind_s(cred="TEST_BIND_PASSWORD", who="TEST_BIND_DN") + mock_rlo.search_ext( + base="TEST_BASE", + filterstr="TEST_FILTER", + scope=ldap.SCOPE_ONELEVEL, + attrlist=mox.SameElementsAs(attrlist), + serverctrls=mox.Func(self.compareSPRC()), + ).AndReturn("TEST_RES") + + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_ENTRY, [test_posix_group], None, []) + ) + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_RESULT, None, None, []) + ) + + self.mox.StubOutWithMock(ldap, "ldapobject") ldap.ldapobject.ReconnectLDAPObject( - uri='TEST_URI', - retry_max=TEST_RETRY_MAX, - retry_delay=TEST_RETRY_DELAY).AndReturn(mock_rlo) + uri="TEST_URI", retry_max=TEST_RETRY_MAX, retry_delay=TEST_RETRY_DELAY + ).AndReturn(mock_rlo) self.mox.ReplayAll() @@ -678,46 +778,51 @@ def testGetGroupMapAsUser(self): ent = data.PopItem() - self.assertEqual('test', ent.name) + self.assertEqual("test", ent.name) def testGetGroupMapAD(self): - test_posix_group = ('cn=test,ou=Group,dc=example,dc=com', { - 'objectSid': [ - b'\x01\x05\x00\x00\x00\x00\x00\x05\x15\x00\x00\x00\xa0e\xcf~xK\x9b_\xe7|\x87p\t\x1c\x01\x00' - ], - 'sAMAccountName': ['testgroup'], - 'cn': ['testgroup'], - 'member': [ - 'cn=testguy,ou=People,dc=example,dc=com', - 'cn=fooguy,ou=People,dc=example,dc=com', - 'cn=barguy,ou=People,dc=example,dc=com' - ], - 'whenChanged': ['20070227012807.0Z'] - }) + test_posix_group = ( + "cn=test,ou=Group,dc=example,dc=com", + { + "objectSid": [ + b"\x01\x05\x00\x00\x00\x00\x00\x05\x15\x00\x00\x00\xa0e\xcf~xK\x9b_\xe7|\x87p\t\x1c\x01\x00" + ], + "sAMAccountName": ["testgroup"], + "cn": ["testgroup"], + "member": [ + "cn=testguy,ou=People,dc=example,dc=com", + "cn=fooguy,ou=People,dc=example,dc=com", + "cn=barguy,ou=People,dc=example,dc=com", + ], + "whenChanged": ["20070227012807.0Z"], + }, + ) config = dict(self.config) - config['ad'] = '1' - attrlist = ['sAMAccountName', 'objectSid', 'member', 'whenChanged'] + config["ad"] = "1" + attrlist = ["sAMAccountName", "objectSid", "member", "whenChanged"] mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='TEST_BIND_PASSWORD', who='TEST_BIND_DN') - mock_rlo.search_ext(base='TEST_BASE', - filterstr='TEST_FILTER', - scope=ldap.SCOPE_ONELEVEL, - attrlist=mox.SameElementsAs(attrlist), - serverctrls=mox.Func( - self.compareSPRC())).AndReturn('TEST_RES') - - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_ENTRY, [test_posix_group], None, [])) - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_RESULT, None, None, [])) - - self.mox.StubOutWithMock(ldap, 'ldapobject') + mock_rlo.simple_bind_s(cred="TEST_BIND_PASSWORD", who="TEST_BIND_DN") + mock_rlo.search_ext( + base="TEST_BASE", + filterstr="TEST_FILTER", + scope=ldap.SCOPE_ONELEVEL, + attrlist=mox.SameElementsAs(attrlist), + serverctrls=mox.Func(self.compareSPRC()), + ).AndReturn("TEST_RES") + + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_ENTRY, [test_posix_group], None, []) + ) + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_RESULT, None, None, []) + ) + + self.mox.StubOutWithMock(ldap, "ldapobject") ldap.ldapobject.ReconnectLDAPObject( - uri='TEST_URI', - retry_max=TEST_RETRY_MAX, - retry_delay=TEST_RETRY_DELAY).AndReturn(mock_rlo) + uri="TEST_URI", retry_max=TEST_RETRY_MAX, retry_delay=TEST_RETRY_DELAY + ).AndReturn(mock_rlo) self.mox.ReplayAll() @@ -728,44 +833,49 @@ def testGetGroupMapAD(self): ent = data.PopItem() - self.assertEqual('testgroup', ent.name) + self.assertEqual("testgroup", ent.name) def testGetGroupMapBis(self): - test_posix_group = ('cn=test,ou=Group,dc=example,dc=com', { - 'sambaSID': ['S-1-5-21-2127521184-1604012920-1887927527-72713'], - 'gidNumber': [1000], - 'cn': ['testgroup'], - 'member': [ - 'cn=testguy,ou=People,dc=example,dc=com', - 'cn=fooguy,ou=People,dc=example,dc=com', - 'cn=barguy,ou=People,dc=example,dc=com' - ], - 'modifyTimestamp': ['20070227012807Z'] - }) + test_posix_group = ( + "cn=test,ou=Group,dc=example,dc=com", + { + "sambaSID": ["S-1-5-21-2127521184-1604012920-1887927527-72713"], + "gidNumber": [1000], + "cn": ["testgroup"], + "member": [ + "cn=testguy,ou=People,dc=example,dc=com", + "cn=fooguy,ou=People,dc=example,dc=com", + "cn=barguy,ou=People,dc=example,dc=com", + ], + "modifyTimestamp": ["20070227012807Z"], + }, + ) config = dict(self.config) - config['rfc2307bis'] = 1 - attrlist = ['cn', 'uid', 'gidNumber', 'member', 'modifyTimestamp'] + config["rfc2307bis"] = 1 + attrlist = ["cn", "uid", "gidNumber", "member", "modifyTimestamp"] mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='TEST_BIND_PASSWORD', who='TEST_BIND_DN') - mock_rlo.search_ext(base='TEST_BASE', - filterstr='TEST_FILTER', - scope=ldap.SCOPE_ONELEVEL, - attrlist=mox.SameElementsAs(attrlist), - serverctrls=mox.Func( - self.compareSPRC())).AndReturn('TEST_RES') - - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_ENTRY, [test_posix_group], None, [])) - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_RESULT, None, None, [])) - - self.mox.StubOutWithMock(ldap, 'ldapobject') + mock_rlo.simple_bind_s(cred="TEST_BIND_PASSWORD", who="TEST_BIND_DN") + mock_rlo.search_ext( + base="TEST_BASE", + filterstr="TEST_FILTER", + scope=ldap.SCOPE_ONELEVEL, + attrlist=mox.SameElementsAs(attrlist), + serverctrls=mox.Func(self.compareSPRC()), + ).AndReturn("TEST_RES") + + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_ENTRY, [test_posix_group], None, []) + ) + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_RESULT, None, None, []) + ) + + self.mox.StubOutWithMock(ldap, "ldapobject") ldap.ldapobject.ReconnectLDAPObject( - uri='TEST_URI', - retry_max=TEST_RETRY_MAX, - retry_delay=TEST_RETRY_DELAY).AndReturn(mock_rlo) + uri="TEST_URI", retry_max=TEST_RETRY_MAX, retry_delay=TEST_RETRY_DELAY + ).AndReturn(mock_rlo) self.mox.ReplayAll() @@ -776,58 +886,65 @@ def testGetGroupMapBis(self): ent = data.PopItem() - self.assertEqual('testgroup', ent.name) + self.assertEqual("testgroup", ent.name) self.assertEqual(3, len(ent.members)) def testGetGroupNestedNotConfigured(self): - test_posix_group = ('cn=test,ou=Group,dc=example,dc=com', { - 'sambaSID': ['S-1-5-21-2127521184-1604012920-1887927527-72713'], - 'gidNumber': [1000], - 'cn': ['testgroup'], - 'member': [ - 'cn=testguy,ou=People,dc=example,dc=com', - 'cn=fooguy,ou=People,dc=example,dc=com', - 'cn=barguy,ou=People,dc=example,dc=com', - 'cn=child,ou=Group,dc=example,dc=com' - ], - 'modifyTimestamp': ['20070227012807Z'] - }) - test_child_group = ('cn=child,ou=Group,dc=example,dc=com', { - 'sambaSID': ['S-1-5-21-2127521184-1604012920-1887927527-72714'], - 'gidNumber': [1001], - 'cn': ['child'], - 'member': [ - 'cn=newperson,ou=People,dc=example,dc=com', - 'cn=fooperson,ou=People,dc=example,dc=com', - 'cn=barperson,ou=People,dc=example,dc=com' - ], - 'modifyTimestamp': ['20070227012807Z'] - }) + test_posix_group = ( + "cn=test,ou=Group,dc=example,dc=com", + { + "sambaSID": ["S-1-5-21-2127521184-1604012920-1887927527-72713"], + "gidNumber": [1000], + "cn": ["testgroup"], + "member": [ + "cn=testguy,ou=People,dc=example,dc=com", + "cn=fooguy,ou=People,dc=example,dc=com", + "cn=barguy,ou=People,dc=example,dc=com", + "cn=child,ou=Group,dc=example,dc=com", + ], + "modifyTimestamp": ["20070227012807Z"], + }, + ) + test_child_group = ( + "cn=child,ou=Group,dc=example,dc=com", + { + "sambaSID": ["S-1-5-21-2127521184-1604012920-1887927527-72714"], + "gidNumber": [1001], + "cn": ["child"], + "member": [ + "cn=newperson,ou=People,dc=example,dc=com", + "cn=fooperson,ou=People,dc=example,dc=com", + "cn=barperson,ou=People,dc=example,dc=com", + ], + "modifyTimestamp": ["20070227012807Z"], + }, + ) config = dict(self.config) - config['rfc2307bis'] = 1 - attrlist = ['cn', 'uid', 'gidNumber', 'member', 'modifyTimestamp'] + config["rfc2307bis"] = 1 + attrlist = ["cn", "uid", "gidNumber", "member", "modifyTimestamp"] mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='TEST_BIND_PASSWORD', who='TEST_BIND_DN') - mock_rlo.search_ext(base='TEST_BASE', - filterstr='TEST_FILTER', - scope=ldap.SCOPE_ONELEVEL, - attrlist=mox.SameElementsAs(attrlist), - serverctrls=mox.Func( - self.compareSPRC())).AndReturn('TEST_RES') - - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_ENTRY, [test_posix_group, - test_child_group], None, [])) - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_RESULT, None, None, [])) - - self.mox.StubOutWithMock(ldap, 'ldapobject') + mock_rlo.simple_bind_s(cred="TEST_BIND_PASSWORD", who="TEST_BIND_DN") + mock_rlo.search_ext( + base="TEST_BASE", + filterstr="TEST_FILTER", + scope=ldap.SCOPE_ONELEVEL, + attrlist=mox.SameElementsAs(attrlist), + serverctrls=mox.Func(self.compareSPRC()), + ).AndReturn("TEST_RES") + + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_ENTRY, [test_posix_group, test_child_group], None, []) + ) + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_RESULT, None, None, []) + ) + + self.mox.StubOutWithMock(ldap, "ldapobject") ldap.ldapobject.ReconnectLDAPObject( - uri='TEST_URI', - retry_max=TEST_RETRY_MAX, - retry_delay=TEST_RETRY_DELAY).AndReturn(mock_rlo) + uri="TEST_URI", retry_max=TEST_RETRY_MAX, retry_delay=TEST_RETRY_DELAY + ).AndReturn(mock_rlo) self.mox.ReplayAll() source = ldapsource.LdapSource(config) @@ -842,58 +959,63 @@ def testGetGroupNestedNotConfigured(self): self.assertNotIn("newperson", datadict["testgroup"].members) def testGetGroupNested(self): - test_posix_group = ('cn=test,ou=Group,dc=example,dc=com', { - 'sambaSID': ['S-1-5-21-2127521184-1604012920-1887927527-72713'], - 'gidNumber': [1000], - 'cn': ['testgroup'], - 'member': [ - 'cn=testguy,ou=People,dc=example,dc=com', - 'cn=fooguy,ou=People,dc=example,dc=com', - 'cn=barguy,ou=People,dc=example,dc=com', - 'cn=child,ou=Group,dc=example,dc=com' - ], - 'modifyTimestamp': ['20070227012807Z'] - }) - test_child_group = ('cn=child,ou=Group,dc=example,dc=com', { - 'sambaSID': ['S-1-5-21-2127521184-1604012920-1887927527-72714'], - 'gidNumber': [1001], - 'cn': ['child'], - 'member': [ - 'cn=newperson,ou=People,dc=example,dc=com', - 'cn=fooperson,ou=People,dc=example,dc=com', - 'cn=barperson,ou=People,dc=example,dc=com' - ], - 'modifyTimestamp': ['20070227012807Z'] - }) + test_posix_group = ( + "cn=test,ou=Group,dc=example,dc=com", + { + "sambaSID": ["S-1-5-21-2127521184-1604012920-1887927527-72713"], + "gidNumber": [1000], + "cn": ["testgroup"], + "member": [ + "cn=testguy,ou=People,dc=example,dc=com", + "cn=fooguy,ou=People,dc=example,dc=com", + "cn=barguy,ou=People,dc=example,dc=com", + "cn=child,ou=Group,dc=example,dc=com", + ], + "modifyTimestamp": ["20070227012807Z"], + }, + ) + test_child_group = ( + "cn=child,ou=Group,dc=example,dc=com", + { + "sambaSID": ["S-1-5-21-2127521184-1604012920-1887927527-72714"], + "gidNumber": [1001], + "cn": ["child"], + "member": [ + "cn=newperson,ou=People,dc=example,dc=com", + "cn=fooperson,ou=People,dc=example,dc=com", + "cn=barperson,ou=People,dc=example,dc=com", + ], + "modifyTimestamp": ["20070227012807Z"], + }, + ) config = dict(self.config) - config['rfc2307bis'] = 1 + config["rfc2307bis"] = 1 config["nested_groups"] = 1 - config['use_rid'] = 1 - attrlist = [ - 'cn', 'uid', 'gidNumber', 'member', 'sambaSID', 'modifyTimestamp' - ] + config["use_rid"] = 1 + attrlist = ["cn", "uid", "gidNumber", "member", "sambaSID", "modifyTimestamp"] mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='TEST_BIND_PASSWORD', who='TEST_BIND_DN') - mock_rlo.search_ext(base='TEST_BASE', - filterstr='TEST_FILTER', - scope=ldap.SCOPE_ONELEVEL, - attrlist=mox.SameElementsAs(attrlist), - serverctrls=mox.Func( - self.compareSPRC())).AndReturn('TEST_RES') - - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_ENTRY, [test_posix_group, - test_child_group], None, [])) - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_RESULT, None, None, [])) - - self.mox.StubOutWithMock(ldap, 'ldapobject') + mock_rlo.simple_bind_s(cred="TEST_BIND_PASSWORD", who="TEST_BIND_DN") + mock_rlo.search_ext( + base="TEST_BASE", + filterstr="TEST_FILTER", + scope=ldap.SCOPE_ONELEVEL, + attrlist=mox.SameElementsAs(attrlist), + serverctrls=mox.Func(self.compareSPRC()), + ).AndReturn("TEST_RES") + + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_ENTRY, [test_posix_group, test_child_group], None, []) + ) + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_RESULT, None, None, []) + ) + + self.mox.StubOutWithMock(ldap, "ldapobject") ldap.ldapobject.ReconnectLDAPObject( - uri='TEST_URI', - retry_max=TEST_RETRY_MAX, - retry_delay=TEST_RETRY_DELAY).AndReturn(mock_rlo) + uri="TEST_URI", retry_max=TEST_RETRY_MAX, retry_delay=TEST_RETRY_DELAY + ).AndReturn(mock_rlo) self.mox.ReplayAll() source = ldapsource.LdapSource(config) @@ -908,68 +1030,81 @@ def testGetGroupNested(self): self.assertIn("newperson", datadict["testgroup"].members) def testGetGroupLoop(self): - test_posix_group = ('cn=test,ou=Group,dc=example,dc=com', { - 'sambaSID': ['S-1-5-21-2127521184-1604012920-1887927527-72713'], - 'gidNumber': [1000], - 'cn': ['testgroup'], - 'member': [ - 'cn=testguy,ou=People,dc=example,dc=com', - 'cn=fooguy,ou=People,dc=example,dc=com', - 'cn=barguy,ou=People,dc=example,dc=com', - 'cn=child,ou=Group,dc=example,dc=com' - ], - 'modifyTimestamp': ['20070227012807Z'] - }) - test_child_group = ('cn=child,ou=Group,dc=example,dc=com', { - 'sambaSID': ['S-1-5-21-2127521184-1604012920-1887927527-72714'], - 'gidNumber': [1001], - 'cn': ['child'], - 'member': [ - 'cn=newperson,ou=People,dc=example,dc=com', - 'cn=fooperson,ou=People,dc=example,dc=com', - 'cn=barperson,ou=People,dc=example,dc=com' - ], - 'modifyTimestamp': ['20070227012807Z'] - }) - test_loop_group = ('cn=loop,ou=Group,dc=example,dc=com', { - 'sambaSID': ['S-1-5-21-2127521184-1604012920-1887927527-72715'], - 'gidNumber': [1002], - 'cn': ['loop'], - 'member': [ - 'cn=loopperson,ou=People,dc=example,dc=com', - 'cn=testgroup,ou=Group,dc=example,dc=com' - ], - 'modifyTimestamp': ['20070227012807Z'] - }) + test_posix_group = ( + "cn=test,ou=Group,dc=example,dc=com", + { + "sambaSID": ["S-1-5-21-2127521184-1604012920-1887927527-72713"], + "gidNumber": [1000], + "cn": ["testgroup"], + "member": [ + "cn=testguy,ou=People,dc=example,dc=com", + "cn=fooguy,ou=People,dc=example,dc=com", + "cn=barguy,ou=People,dc=example,dc=com", + "cn=child,ou=Group,dc=example,dc=com", + ], + "modifyTimestamp": ["20070227012807Z"], + }, + ) + test_child_group = ( + "cn=child,ou=Group,dc=example,dc=com", + { + "sambaSID": ["S-1-5-21-2127521184-1604012920-1887927527-72714"], + "gidNumber": [1001], + "cn": ["child"], + "member": [ + "cn=newperson,ou=People,dc=example,dc=com", + "cn=fooperson,ou=People,dc=example,dc=com", + "cn=barperson,ou=People,dc=example,dc=com", + ], + "modifyTimestamp": ["20070227012807Z"], + }, + ) + test_loop_group = ( + "cn=loop,ou=Group,dc=example,dc=com", + { + "sambaSID": ["S-1-5-21-2127521184-1604012920-1887927527-72715"], + "gidNumber": [1002], + "cn": ["loop"], + "member": [ + "cn=loopperson,ou=People,dc=example,dc=com", + "cn=testgroup,ou=Group,dc=example,dc=com", + ], + "modifyTimestamp": ["20070227012807Z"], + }, + ) config = dict(self.config) - config['rfc2307bis'] = 1 + config["rfc2307bis"] = 1 config["nested_groups"] = 1 - config['use_rid'] = 1 - attrlist = [ - 'cn', 'uid', 'gidNumber', 'member', 'sambaSID', 'modifyTimestamp' - ] + config["use_rid"] = 1 + attrlist = ["cn", "uid", "gidNumber", "member", "sambaSID", "modifyTimestamp"] mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='TEST_BIND_PASSWORD', who='TEST_BIND_DN') - mock_rlo.search_ext(base='TEST_BASE', - filterstr='TEST_FILTER', - scope=ldap.SCOPE_ONELEVEL, - attrlist=mox.SameElementsAs(attrlist), - serverctrls=mox.Func( - self.compareSPRC())).AndReturn('TEST_RES') - - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_ENTRY, - [test_posix_group, test_child_group, test_loop_group], None, [])) - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_RESULT, None, None, [])) - - self.mox.StubOutWithMock(ldap, 'ldapobject') + mock_rlo.simple_bind_s(cred="TEST_BIND_PASSWORD", who="TEST_BIND_DN") + mock_rlo.search_ext( + base="TEST_BASE", + filterstr="TEST_FILTER", + scope=ldap.SCOPE_ONELEVEL, + attrlist=mox.SameElementsAs(attrlist), + serverctrls=mox.Func(self.compareSPRC()), + ).AndReturn("TEST_RES") + + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + ( + ldap.RES_SEARCH_ENTRY, + [test_posix_group, test_child_group, test_loop_group], + None, + [], + ) + ) + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_RESULT, None, None, []) + ) + + self.mox.StubOutWithMock(ldap, "ldapobject") ldap.ldapobject.ReconnectLDAPObject( - uri='TEST_URI', - retry_max=TEST_RETRY_MAX, - retry_delay=TEST_RETRY_DELAY).AndReturn(mock_rlo) + uri="TEST_URI", retry_max=TEST_RETRY_MAX, retry_delay=TEST_RETRY_DELAY + ).AndReturn(mock_rlo) self.mox.ReplayAll() source = ldapsource.LdapSource(config) @@ -984,64 +1119,79 @@ def testGetGroupLoop(self): self.assertIn("newperson", datadict["testgroup"].members) def testGetGroupMapBisAlt(self): - test_posix_group = ('cn=test,ou=Group,dc=example,dc=com', { - 'sambaSID': ['S-1-5-21-2127521184-1604012920-1887927527-72713'], - 'gidNumber': [1000], - 'cn': ['testgroup'], - 'uniqueMember': ['cn=testguy,ou=People,dc=example,dc=com'], - 'modifyTimestamp': ['20070227012807Z'] - }) - dn_user = 'cn=testguy,ou=People,dc=example,dc=com' - test_posix_account = (dn_user, { - 'sambaSID': ['S-1-5-21-2562418665-3218585558-1813906818-1576'], - 'uidNumber': [1000], - 'gidNumber': [1000], - 'uid': ['test'], - 'cn': ['testguy'], - 'homeDirectory': ['/home/test'], - 'loginShell': ['/bin/sh'], - 'userPassword': ['p4ssw0rd'], - 'modifyTimestamp': ['20070227012807Z'] - }) + test_posix_group = ( + "cn=test,ou=Group,dc=example,dc=com", + { + "sambaSID": ["S-1-5-21-2127521184-1604012920-1887927527-72713"], + "gidNumber": [1000], + "cn": ["testgroup"], + "uniqueMember": ["cn=testguy,ou=People,dc=example,dc=com"], + "modifyTimestamp": ["20070227012807Z"], + }, + ) + dn_user = "cn=testguy,ou=People,dc=example,dc=com" + test_posix_account = ( + dn_user, + { + "sambaSID": ["S-1-5-21-2562418665-3218585558-1813906818-1576"], + "uidNumber": [1000], + "gidNumber": [1000], + "uid": ["test"], + "cn": ["testguy"], + "homeDirectory": ["/home/test"], + "loginShell": ["/bin/sh"], + "userPassword": ["p4ssw0rd"], + "modifyTimestamp": ["20070227012807Z"], + }, + ) config = dict(self.config) - config['rfc2307bis_alt'] = 1 - config['use_rid'] = 1 + config["rfc2307bis_alt"] = 1 + config["use_rid"] = 1 attrlist = [ - 'cn', 'gidNumber', 'uniqueMember', 'uid', 'sambaSID', - 'modifyTimestamp' + "cn", + "gidNumber", + "uniqueMember", + "uid", + "sambaSID", + "modifyTimestamp", ] - uidattr = ['uid'] + uidattr = ["uid"] mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='TEST_BIND_PASSWORD', who='TEST_BIND_DN') - mock_rlo.search_ext(base='TEST_BASE', - filterstr='TEST_FILTER', - scope=ldap.SCOPE_ONELEVEL, - attrlist=mox.SameElementsAs(attrlist), - serverctrls=mox.Func( - self.compareSPRC())).AndReturn('TEST_RES') - - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_ENTRY, [test_posix_group], None, [])) - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_RESULT, None, None, [])) - mock_rlo.search_ext(base=dn_user, - filterstr='(objectClass=*)', - scope=ldap.SCOPE_BASE, - attrlist=mox.SameElementsAs(uidattr), - serverctrls=mox.Func( - self.compareSPRC())).AndReturn('TEST_RES') - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_ENTRY, [test_posix_account], None, [])) - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_RESULT, None, None, [])) - - self.mox.StubOutWithMock(ldap, 'ldapobject') + mock_rlo.simple_bind_s(cred="TEST_BIND_PASSWORD", who="TEST_BIND_DN") + mock_rlo.search_ext( + base="TEST_BASE", + filterstr="TEST_FILTER", + scope=ldap.SCOPE_ONELEVEL, + attrlist=mox.SameElementsAs(attrlist), + serverctrls=mox.Func(self.compareSPRC()), + ).AndReturn("TEST_RES") + + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_ENTRY, [test_posix_group], None, []) + ) + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_RESULT, None, None, []) + ) + mock_rlo.search_ext( + base=dn_user, + filterstr="(objectClass=*)", + scope=ldap.SCOPE_BASE, + attrlist=mox.SameElementsAs(uidattr), + serverctrls=mox.Func(self.compareSPRC()), + ).AndReturn("TEST_RES") + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_ENTRY, [test_posix_account], None, []) + ) + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_RESULT, None, None, []) + ) + + self.mox.StubOutWithMock(ldap, "ldapobject") ldap.ldapobject.ReconnectLDAPObject( - uri='TEST_URI', - retry_max=TEST_RETRY_MAX, - retry_delay=TEST_RETRY_DELAY).AndReturn(mock_rlo) + uri="TEST_URI", retry_max=TEST_RETRY_MAX, retry_delay=TEST_RETRY_DELAY + ).AndReturn(mock_rlo) self.mox.ReplayAll() @@ -1052,47 +1202,59 @@ def testGetGroupMapBisAlt(self): ent = data.PopItem() - self.assertEqual('testgroup', ent.name) + self.assertEqual("testgroup", ent.name) self.assertEqual(1, len(ent.members)) def testGetShadowMap(self): - test_shadow = ('cn=test,ou=People,dc=example,dc=com', { - 'uid': ['test'], - 'shadowLastChange': ['11296'], - 'shadowMax': ['99999'], - 'shadowWarning': ['7'], - 'shadowInactive': ['-1'], - 'shadowExpire': ['-1'], - 'shadowFlag': ['134537556'], - 'modifyTimestamp': ['20070227012807Z'], - 'userPassword': ['{CRYPT}p4ssw0rd'] - }) + test_shadow = ( + "cn=test,ou=People,dc=example,dc=com", + { + "uid": ["test"], + "shadowLastChange": ["11296"], + "shadowMax": ["99999"], + "shadowWarning": ["7"], + "shadowInactive": ["-1"], + "shadowExpire": ["-1"], + "shadowFlag": ["134537556"], + "modifyTimestamp": ["20070227012807Z"], + "userPassword": ["{CRYPT}p4ssw0rd"], + }, + ) config = dict(self.config) attrlist = [ - 'uid', 'shadowLastChange', 'shadowMin', 'shadowMax', - 'shadowWarning', 'shadowInactive', 'shadowExpire', 'shadowFlag', - 'userPassword', 'modifyTimestamp' + "uid", + "shadowLastChange", + "shadowMin", + "shadowMax", + "shadowWarning", + "shadowInactive", + "shadowExpire", + "shadowFlag", + "userPassword", + "modifyTimestamp", ] mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='TEST_BIND_PASSWORD', who='TEST_BIND_DN') - mock_rlo.search_ext(base='TEST_BASE', - filterstr='TEST_FILTER', - scope=ldap.SCOPE_ONELEVEL, - attrlist=mox.SameElementsAs(attrlist), - serverctrls=mox.Func( - self.compareSPRC())).AndReturn('TEST_RES') - - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_ENTRY, [test_shadow], None, [])) - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_RESULT, None, None, [])) - - self.mox.StubOutWithMock(ldap, 'ldapobject') + mock_rlo.simple_bind_s(cred="TEST_BIND_PASSWORD", who="TEST_BIND_DN") + mock_rlo.search_ext( + base="TEST_BASE", + filterstr="TEST_FILTER", + scope=ldap.SCOPE_ONELEVEL, + attrlist=mox.SameElementsAs(attrlist), + serverctrls=mox.Func(self.compareSPRC()), + ).AndReturn("TEST_RES") + + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_ENTRY, [test_shadow], None, []) + ) + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_RESULT, None, None, []) + ) + + self.mox.StubOutWithMock(ldap, "ldapobject") ldap.ldapobject.ReconnectLDAPObject( - uri='TEST_URI', - retry_max=TEST_RETRY_MAX, - retry_delay=TEST_RETRY_DELAY).AndReturn(mock_rlo) + uri="TEST_URI", retry_max=TEST_RETRY_MAX, retry_delay=TEST_RETRY_DELAY + ).AndReturn(mock_rlo) self.mox.ReplayAll() @@ -1103,49 +1265,62 @@ def testGetShadowMap(self): ent = data.PopItem() - self.assertEqual('test', ent.name) - self.assertEqual('p4ssw0rd', ent.passwd) + self.assertEqual("test", ent.name) + self.assertEqual("p4ssw0rd", ent.passwd) def testGetShadowMapWithUidAttr(self): - test_shadow = ('cn=test,ou=People,dc=example,dc=com', { - 'uid': ['test'], - 'name': ['test'], - 'shadowLastChange': ['11296'], - 'shadowMax': ['99999'], - 'shadowWarning': ['7'], - 'shadowInactive': ['-1'], - 'shadowExpire': ['-1'], - 'shadowFlag': ['134537556'], - 'modifyTimestamp': ['20070227012807Z'], - 'userPassword': ['{CRYPT}p4ssw0rd'] - }) + test_shadow = ( + "cn=test,ou=People,dc=example,dc=com", + { + "uid": ["test"], + "name": ["test"], + "shadowLastChange": ["11296"], + "shadowMax": ["99999"], + "shadowWarning": ["7"], + "shadowInactive": ["-1"], + "shadowExpire": ["-1"], + "shadowFlag": ["134537556"], + "modifyTimestamp": ["20070227012807Z"], + "userPassword": ["{CRYPT}p4ssw0rd"], + }, + ) config = dict(self.config) - config['uidattr'] = 'name' + config["uidattr"] = "name" attrlist = [ - 'uid', 'shadowLastChange', 'shadowMin', 'shadowMax', 'name', - 'shadowWarning', 'shadowInactive', 'shadowExpire', 'shadowFlag', - 'userPassword', 'modifyTimestamp' + "uid", + "shadowLastChange", + "shadowMin", + "shadowMax", + "name", + "shadowWarning", + "shadowInactive", + "shadowExpire", + "shadowFlag", + "userPassword", + "modifyTimestamp", ] mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='TEST_BIND_PASSWORD', who='TEST_BIND_DN') - mock_rlo.search_ext(base='TEST_BASE', - filterstr='TEST_FILTER', - scope=ldap.SCOPE_ONELEVEL, - attrlist=mox.SameElementsAs(attrlist), - serverctrls=mox.Func( - self.compareSPRC())).AndReturn('TEST_RES') - - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_ENTRY, [test_shadow], None, [])) - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_RESULT, None, None, [])) - - self.mox.StubOutWithMock(ldap, 'ldapobject') + mock_rlo.simple_bind_s(cred="TEST_BIND_PASSWORD", who="TEST_BIND_DN") + mock_rlo.search_ext( + base="TEST_BASE", + filterstr="TEST_FILTER", + scope=ldap.SCOPE_ONELEVEL, + attrlist=mox.SameElementsAs(attrlist), + serverctrls=mox.Func(self.compareSPRC()), + ).AndReturn("TEST_RES") + + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_ENTRY, [test_shadow], None, []) + ) + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_RESULT, None, None, []) + ) + + self.mox.StubOutWithMock(ldap, "ldapobject") ldap.ldapobject.ReconnectLDAPObject( - uri='TEST_URI', - retry_max=TEST_RETRY_MAX, - retry_delay=TEST_RETRY_DELAY).AndReturn(mock_rlo) + uri="TEST_URI", retry_max=TEST_RETRY_MAX, retry_delay=TEST_RETRY_DELAY + ).AndReturn(mock_rlo) self.mox.ReplayAll() @@ -1156,40 +1331,43 @@ def testGetShadowMapWithUidAttr(self): ent = data.PopItem() - self.assertEqual('test', ent.name) - self.assertEqual('p4ssw0rd', ent.passwd) + self.assertEqual("test", ent.name) + self.assertEqual("p4ssw0rd", ent.passwd) def testGetNetgroupMap(self): - test_posix_netgroup = ('cn=test,ou=netgroup,dc=example,dc=com', { - 'cn': ['test'], - 'memberNisNetgroup': ['admins'], - 'nisNetgroupTriple': ['(-,hax0r,)'], - 'modifyTimestamp': ['20070227012807Z'], - }) + test_posix_netgroup = ( + "cn=test,ou=netgroup,dc=example,dc=com", + { + "cn": ["test"], + "memberNisNetgroup": ["admins"], + "nisNetgroupTriple": ["(-,hax0r,)"], + "modifyTimestamp": ["20070227012807Z"], + }, + ) config = dict(self.config) - attrlist = [ - 'cn', 'memberNisNetgroup', 'nisNetgroupTriple', 'modifyTimestamp' - ] + attrlist = ["cn", "memberNisNetgroup", "nisNetgroupTriple", "modifyTimestamp"] mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='TEST_BIND_PASSWORD', who='TEST_BIND_DN') - mock_rlo.search_ext(base='TEST_BASE', - filterstr='TEST_FILTER', - scope=ldap.SCOPE_ONELEVEL, - attrlist=mox.SameElementsAs(attrlist), - serverctrls=mox.Func( - self.compareSPRC())).AndReturn('TEST_RES') - - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_ENTRY, [test_posix_netgroup], None, [])) - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_RESULT, None, None, [])) - - self.mox.StubOutWithMock(ldap, 'ldapobject') + mock_rlo.simple_bind_s(cred="TEST_BIND_PASSWORD", who="TEST_BIND_DN") + mock_rlo.search_ext( + base="TEST_BASE", + filterstr="TEST_FILTER", + scope=ldap.SCOPE_ONELEVEL, + attrlist=mox.SameElementsAs(attrlist), + serverctrls=mox.Func(self.compareSPRC()), + ).AndReturn("TEST_RES") + + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_ENTRY, [test_posix_netgroup], None, []) + ) + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_RESULT, None, None, []) + ) + + self.mox.StubOutWithMock(ldap, "ldapobject") ldap.ldapobject.ReconnectLDAPObject( - uri='TEST_URI', - retry_max=TEST_RETRY_MAX, - retry_delay=TEST_RETRY_DELAY).AndReturn(mock_rlo) + uri="TEST_URI", retry_max=TEST_RETRY_MAX, retry_delay=TEST_RETRY_DELAY + ).AndReturn(mock_rlo) self.mox.ReplayAll() @@ -1200,40 +1378,43 @@ def testGetNetgroupMap(self): ent = data.PopItem() - self.assertEqual('test', ent.name) - self.assertEqual('(-,hax0r,) admins', ent.entries) + self.assertEqual("test", ent.name) + self.assertEqual("(-,hax0r,) admins", ent.entries) def testGetNetgroupMapWithDupes(self): - test_posix_netgroup = ('cn=test,ou=netgroup,dc=example,dc=com', { - 'cn': ['test'], - 'memberNisNetgroup': ['(-,hax0r,)'], - 'nisNetgroupTriple': ['(-,hax0r,)'], - 'modifyTimestamp': ['20070227012807Z'], - }) + test_posix_netgroup = ( + "cn=test,ou=netgroup,dc=example,dc=com", + { + "cn": ["test"], + "memberNisNetgroup": ["(-,hax0r,)"], + "nisNetgroupTriple": ["(-,hax0r,)"], + "modifyTimestamp": ["20070227012807Z"], + }, + ) config = dict(self.config) - attrlist = [ - 'cn', 'memberNisNetgroup', 'nisNetgroupTriple', 'modifyTimestamp' - ] + attrlist = ["cn", "memberNisNetgroup", "nisNetgroupTriple", "modifyTimestamp"] mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='TEST_BIND_PASSWORD', who='TEST_BIND_DN') - mock_rlo.search_ext(base='TEST_BASE', - filterstr='TEST_FILTER', - scope=ldap.SCOPE_ONELEVEL, - attrlist=mox.SameElementsAs(attrlist), - serverctrls=mox.Func( - self.compareSPRC())).AndReturn('TEST_RES') - - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_ENTRY, [test_posix_netgroup], None, [])) - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_RESULT, None, None, [])) - - self.mox.StubOutWithMock(ldap, 'ldapobject') + mock_rlo.simple_bind_s(cred="TEST_BIND_PASSWORD", who="TEST_BIND_DN") + mock_rlo.search_ext( + base="TEST_BASE", + filterstr="TEST_FILTER", + scope=ldap.SCOPE_ONELEVEL, + attrlist=mox.SameElementsAs(attrlist), + serverctrls=mox.Func(self.compareSPRC()), + ).AndReturn("TEST_RES") + + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_ENTRY, [test_posix_netgroup], None, []) + ) + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_RESULT, None, None, []) + ) + + self.mox.StubOutWithMock(ldap, "ldapobject") ldap.ldapobject.ReconnectLDAPObject( - uri='TEST_URI', - retry_max=TEST_RETRY_MAX, - retry_delay=TEST_RETRY_DELAY).AndReturn(mock_rlo) + uri="TEST_URI", retry_max=TEST_RETRY_MAX, retry_delay=TEST_RETRY_DELAY + ).AndReturn(mock_rlo) self.mox.ReplayAll() @@ -1244,109 +1425,119 @@ def testGetNetgroupMapWithDupes(self): ent = data.PopItem() - self.assertEqual('test', ent.name) - self.assertEqual('(-,hax0r,)', ent.entries) + self.assertEqual("test", ent.name) + self.assertEqual("(-,hax0r,)", ent.entries) def testGetAutomountMap(self): test_automount = ( - 'cn=user,ou=auto.home,ou=automounts,dc=example,dc=com', { - 'cn': ['user'], - 'automountInformation': ['-tcp,rw home:/home/user'], - 'modifyTimestamp': ['20070227012807Z'], - }) + "cn=user,ou=auto.home,ou=automounts,dc=example,dc=com", + { + "cn": ["user"], + "automountInformation": ["-tcp,rw home:/home/user"], + "modifyTimestamp": ["20070227012807Z"], + }, + ) config = dict(self.config) - attrlist = ['cn', 'automountInformation', 'modifyTimestamp'] - filterstr = '(objectclass=automount)' + attrlist = ["cn", "automountInformation", "modifyTimestamp"] + filterstr = "(objectclass=automount)" mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='TEST_BIND_PASSWORD', who='TEST_BIND_DN') - mock_rlo.search_ext(base='TEST_BASE', - filterstr=filterstr, - scope=ldap.SCOPE_ONELEVEL, - attrlist=mox.SameElementsAs(attrlist), - serverctrls=mox.Func( - self.compareSPRC())).AndReturn('TEST_RES') - - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_ENTRY, [test_automount], None, [])) - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_RESULT, None, None, [])) - - self.mox.StubOutWithMock(ldap, 'ldapobject') + mock_rlo.simple_bind_s(cred="TEST_BIND_PASSWORD", who="TEST_BIND_DN") + mock_rlo.search_ext( + base="TEST_BASE", + filterstr=filterstr, + scope=ldap.SCOPE_ONELEVEL, + attrlist=mox.SameElementsAs(attrlist), + serverctrls=mox.Func(self.compareSPRC()), + ).AndReturn("TEST_RES") + + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_ENTRY, [test_automount], None, []) + ) + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_RESULT, None, None, []) + ) + + self.mox.StubOutWithMock(ldap, "ldapobject") ldap.ldapobject.ReconnectLDAPObject( - uri='TEST_URI', - retry_max=TEST_RETRY_MAX, - retry_delay=TEST_RETRY_DELAY).AndReturn(mock_rlo) + uri="TEST_URI", retry_max=TEST_RETRY_MAX, retry_delay=TEST_RETRY_DELAY + ).AndReturn(mock_rlo) self.mox.ReplayAll() source = ldapsource.LdapSource(config) - data = source.GetAutomountMap(location='TEST_BASE') + data = source.GetAutomountMap(location="TEST_BASE") self.assertEqual(1, len(data)) ent = data.PopItem() - self.assertEqual('user', ent.key) - self.assertEqual('-tcp,rw', ent.options) - self.assertEqual('home:/home/user', ent.location) + self.assertEqual("user", ent.key) + self.assertEqual("-tcp,rw", ent.options) + self.assertEqual("home:/home/user", ent.location) def testGetAutomountMasterMap(self): - test_master_ou = ('ou=auto.master,ou=automounts,dc=example,dc=com', { - 'ou': ['auto.master'] - }) + test_master_ou = ( + "ou=auto.master,ou=automounts,dc=example,dc=com", + {"ou": ["auto.master"]}, + ) test_automount = ( - 'cn=/home,ou=auto.master,ou=automounts,dc=example,dc=com', { - 'cn': ['/home'], - 'automountInformation': [ - 'ldap:ldap:ou=auto.home,' - 'ou=automounts,dc=example,' - 'dc=com' + "cn=/home,ou=auto.master,ou=automounts,dc=example,dc=com", + { + "cn": ["/home"], + "automountInformation": [ + "ldap:ldap:ou=auto.home," "ou=automounts,dc=example," "dc=com" ], - 'modifyTimestamp': ['20070227012807Z'] - }) + "modifyTimestamp": ["20070227012807Z"], + }, + ) config = dict(self.config) # first search for the dn of ou=auto.master - attrlist = ['dn'] - filterstr = '(&(objectclass=automountMap)(ou=auto.master))' + attrlist = ["dn"] + filterstr = "(&(objectclass=automountMap)(ou=auto.master))" scope = ldap.SCOPE_SUBTREE mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='TEST_BIND_PASSWORD', who='TEST_BIND_DN') - mock_rlo.search_ext(base='TEST_BASE', - filterstr=filterstr, - scope=scope, - attrlist=mox.SameElementsAs(attrlist), - serverctrls=mox.Func( - self.compareSPRC())).AndReturn('TEST_RES') - - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_ENTRY, [test_master_ou], None, [])) - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_RESULT, None, None, [])) + mock_rlo.simple_bind_s(cred="TEST_BIND_PASSWORD", who="TEST_BIND_DN") + mock_rlo.search_ext( + base="TEST_BASE", + filterstr=filterstr, + scope=scope, + attrlist=mox.SameElementsAs(attrlist), + serverctrls=mox.Func(self.compareSPRC()), + ).AndReturn("TEST_RES") + + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_ENTRY, [test_master_ou], None, []) + ) + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_RESULT, None, None, []) + ) # then search for the entries under ou=auto.master - attrlist = ['cn', 'automountInformation', 'modifyTimestamp'] - filterstr = '(objectclass=automount)' + attrlist = ["cn", "automountInformation", "modifyTimestamp"] + filterstr = "(objectclass=automount)" scope = ldap.SCOPE_ONELEVEL - base = 'ou=auto.master,ou=automounts,dc=example,dc=com' - - mock_rlo.search_ext(base=base, - filterstr=filterstr, - scope=scope, - attrlist=mox.SameElementsAs(attrlist), - serverctrls=mox.Func( - self.compareSPRC())).AndReturn('TEST_RES') - - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_ENTRY, [test_automount], None, [])) - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_RESULT, None, None, [])) - - self.mox.StubOutWithMock(ldap, 'ldapobject') + base = "ou=auto.master,ou=automounts,dc=example,dc=com" + + mock_rlo.search_ext( + base=base, + filterstr=filterstr, + scope=scope, + attrlist=mox.SameElementsAs(attrlist), + serverctrls=mox.Func(self.compareSPRC()), + ).AndReturn("TEST_RES") + + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_ENTRY, [test_automount], None, []) + ) + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_RESULT, None, None, []) + ) + + self.mox.StubOutWithMock(ldap, "ldapobject") ldap.ldapobject.ReconnectLDAPObject( - uri='TEST_URI', - retry_max=TEST_RETRY_MAX, - retry_delay=TEST_RETRY_DELAY).AndReturn(mock_rlo) + uri="TEST_URI", retry_max=TEST_RETRY_MAX, retry_delay=TEST_RETRY_DELAY + ).AndReturn(mock_rlo) self.mox.ReplayAll() @@ -1355,34 +1546,41 @@ def testGetAutomountMasterMap(self): self.assertEqual(1, len(data)) ent = data.PopItem() - self.assertEqual('/home', ent.key) - self.assertEqual('ou=auto.home,ou=automounts,dc=example,dc=com', - ent.location) + self.assertEqual("/home", ent.key) + self.assertEqual("ou=auto.home,ou=automounts,dc=example,dc=com", ent.location) self.assertEqual(None, ent.options) def testVerify(self): attrlist = [ - 'uid', 'uidNumber', 'gidNumber', 'gecos', 'cn', 'homeDirectory', - 'fullName', 'loginShell', 'modifyTimestamp' + "uid", + "uidNumber", + "gidNumber", + "gecos", + "cn", + "homeDirectory", + "fullName", + "loginShell", + "modifyTimestamp", ] - filterstr = '(&TEST_FILTER(modifyTimestamp>=19700101000001Z))' + filterstr = "(&TEST_FILTER(modifyTimestamp>=19700101000001Z))" mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='TEST_BIND_PASSWORD', who='TEST_BIND_DN') - mock_rlo.search_ext(base='TEST_BASE', - filterstr=filterstr, - scope=ldap.SCOPE_ONELEVEL, - attrlist=mox.SameElementsAs(attrlist), - serverctrls=mox.Func( - self.compareSPRC())).AndReturn('TEST_RES') - - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_RESULT, None, None, [])) - self.mox.StubOutWithMock(ldap, 'ldapobject') + mock_rlo.simple_bind_s(cred="TEST_BIND_PASSWORD", who="TEST_BIND_DN") + mock_rlo.search_ext( + base="TEST_BASE", + filterstr=filterstr, + scope=ldap.SCOPE_ONELEVEL, + attrlist=mox.SameElementsAs(attrlist), + serverctrls=mox.Func(self.compareSPRC()), + ).AndReturn("TEST_RES") + + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_RESULT, None, None, []) + ) + self.mox.StubOutWithMock(ldap, "ldapobject") ldap.ldapobject.ReconnectLDAPObject( - uri='TEST_URI', - retry_max=TEST_RETRY_MAX, - retry_delay=TEST_RETRY_DELAY).AndReturn(mock_rlo) + uri="TEST_URI", retry_max=TEST_RETRY_MAX, retry_delay=TEST_RETRY_DELAY + ).AndReturn(mock_rlo) self.mox.ReplayAll() source = ldapsource.LdapSource(self.config) @@ -1390,37 +1588,45 @@ def testVerify(self): def testVerifyRID(self): attrlist = [ - 'uid', 'uidNumber', 'gidNumber', 'gecos', 'cn', 'homeDirectory', - 'fullName', 'loginShell', 'modifyTimestamp', 'sambaSID' + "uid", + "uidNumber", + "gidNumber", + "gecos", + "cn", + "homeDirectory", + "fullName", + "loginShell", + "modifyTimestamp", + "sambaSID", ] - filterstr = '(&TEST_FILTER(modifyTimestamp>=19700101000001Z))' + filterstr = "(&TEST_FILTER(modifyTimestamp>=19700101000001Z))" mock_rlo = self.mox.CreateMock(ldap.ldapobject.ReconnectLDAPObject) - mock_rlo.simple_bind_s(cred='TEST_BIND_PASSWORD', who='TEST_BIND_DN') - mock_rlo.search_ext(base='TEST_BASE', - filterstr=filterstr, - scope=ldap.SCOPE_ONELEVEL, - attrlist=mox.SameElementsAs(attrlist), - serverctrls=mox.Func( - self.compareSPRC())).AndReturn('TEST_RES') - - mock_rlo.result3('TEST_RES', all=0, timeout='TEST_TIMELIMIT').AndReturn( - (ldap.RES_SEARCH_RESULT, None, None, [])) - self.mox.StubOutWithMock(ldap, 'ldapobject') + mock_rlo.simple_bind_s(cred="TEST_BIND_PASSWORD", who="TEST_BIND_DN") + mock_rlo.search_ext( + base="TEST_BASE", + filterstr=filterstr, + scope=ldap.SCOPE_ONELEVEL, + attrlist=mox.SameElementsAs(attrlist), + serverctrls=mox.Func(self.compareSPRC()), + ).AndReturn("TEST_RES") + + mock_rlo.result3("TEST_RES", all=0, timeout="TEST_TIMELIMIT").AndReturn( + (ldap.RES_SEARCH_RESULT, None, None, []) + ) + self.mox.StubOutWithMock(ldap, "ldapobject") ldap.ldapobject.ReconnectLDAPObject( - uri='TEST_URI', - retry_max=TEST_RETRY_MAX, - retry_delay=TEST_RETRY_DELAY).AndReturn(mock_rlo) + uri="TEST_URI", retry_max=TEST_RETRY_MAX, retry_delay=TEST_RETRY_DELAY + ).AndReturn(mock_rlo) config = dict(self.config) - config['use_rid'] = 1 + config["use_rid"] = 1 self.mox.ReplayAll() source = ldapsource.LdapSource(config) self.assertEqual(0, source.Verify(0)) class TestUpdateGetter(unittest.TestCase): - def setUp(self): """Create a dummy source object.""" super(TestUpdateGetter, self).setUp() @@ -1438,23 +1644,23 @@ def Search(self, search_base, search_filter, search_scope, attrs): def testFromTimestampToLdap(self): ts = 1259641025 - expected_ldap_ts = '20091201041705Z' - self.assertEqual(expected_ldap_ts, - ldapsource.UpdateGetter({}).FromTimestampToLdap(ts)) + expected_ldap_ts = "20091201041705Z" + self.assertEqual( + expected_ldap_ts, ldapsource.UpdateGetter({}).FromTimestampToLdap(ts) + ) def testFromLdapToTimestamp(self): expected_ts = 1259641025 - ldap_ts = '20091201041705Z' + ldap_ts = "20091201041705Z" self.assertEqual( - expected_ts, - ldapsource.UpdateGetter({}).FromLdapToTimestamp(ldap_ts)) + expected_ts, ldapsource.UpdateGetter({}).FromLdapToTimestamp(ldap_ts) + ) def testPasswdEmptySourceGetUpdates(self): """Test that getUpdates on the PasswdUpdateGetter works.""" getter = ldapsource.PasswdUpdateGetter({}) - data = getter.GetUpdates(self.source, 'TEST_BASE', 'TEST_FILTER', - 'base', None) + data = getter.GetUpdates(self.source, "TEST_BASE", "TEST_FILTER", "base", None) self.assertEqual(passwd.PasswdMap, type(data)) @@ -1462,8 +1668,7 @@ def testGroupEmptySourceGetUpdates(self): """Test that getUpdates on the GroupUpdateGetter works.""" getter = ldapsource.GroupUpdateGetter({}) - data = getter.GetUpdates(self.source, 'TEST_BASE', 'TEST_FILTER', - 'base', None) + data = getter.GetUpdates(self.source, "TEST_BASE", "TEST_FILTER", "base", None) self.assertEqual(group.GroupMap, type(data)) @@ -1471,8 +1676,7 @@ def testShadowEmptySourceGetUpdates(self): """Test that getUpdates on the ShadowUpdateGetter works.""" getter = ldapsource.ShadowUpdateGetter({}) - data = getter.GetUpdates(self.source, 'TEST_BASE', 'TEST_FILTER', - 'base', None) + data = getter.GetUpdates(self.source, "TEST_BASE", "TEST_FILTER", "base", None) self.assertEqual(shadow.ShadowMap, type(data)) @@ -1480,8 +1684,7 @@ def testAutomountEmptySourceGetsUpdates(self): """Test that getUpdates on the AutomountUpdateGetter works.""" getter = ldapsource.AutomountUpdateGetter({}) - data = getter.GetUpdates(self.source, 'TEST_BASE', 'TEST_FILTER', - 'base', None) + data = getter.GetUpdates(self.source, "TEST_BASE", "TEST_FILTER", "base", None) self.assertEqual(automount.AutomountMap, type(data)) @@ -1491,10 +1694,16 @@ def testBadScopeException(self): # exception-raising code. getter = ldapsource.PasswdUpdateGetter({}) - self.assertRaises(error.ConfigurationError, getter.GetUpdates, - self.source, 'TEST_BASE', 'TEST_FILTER', 'BAD_SCOPE', - None) + self.assertRaises( + error.ConfigurationError, + getter.GetUpdates, + self.source, + "TEST_BASE", + "TEST_FILTER", + "BAD_SCOPE", + None, + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/nss_cache/sources/s3source.py b/nss_cache/sources/s3source.py index 4dc3cb06..dfff76e2 100644 --- a/nss_cache/sources/s3source.py +++ b/nss_cache/sources/s3source.py @@ -1,6 +1,6 @@ """An implementation of a S3 data source for nsscache.""" -__author__ = 'alexey.pikin@gmail.com' +__author__ = "alexey.pikin@gmail.com" import base64 import collections @@ -27,14 +27,14 @@ class S3FilesSource(source.Source): """Source for data fetched from S3.""" # S3 defaults - BUCKET = '' - PASSWD_OBJECT = '' - GROUP_OBJECT = '' - SHADOW_OBJECT = '' - SSH_OBJECT = '' + BUCKET = "" + PASSWD_OBJECT = "" + GROUP_OBJECT = "" + SHADOW_OBJECT = "" + SSH_OBJECT = "" # for registration - name = 's3' + name = "s3" def __init__(self, conf): """Initialise the S3FilesSource object. @@ -51,22 +51,22 @@ def __init__(self, conf): def _GetClient(self): if self.s3_client is None: - self.s3_client = boto3.client('s3') + self.s3_client = boto3.client("s3") return self.s3_client def _SetDefaults(self, configuration): """Set defaults if necessary.""" - if 'bucket' not in configuration: - configuration['bucket'] = self.BUCKET - if 'passwd_object' not in configuration: - configuration['passwd_object'] = self.PASSWD_OBJECT - if 'group_object' not in configuration: - configuration['group_object'] = self.GROUP_OBJECT - if 'shadow_object' not in configuration: - configuration['shadow_object'] = self.SHADOW_OBJECT - if 'sshkey_object' not in configuration: - configuration['sshkey_object'] = self.SSH_OBJECT + if "bucket" not in configuration: + configuration["bucket"] = self.BUCKET + if "passwd_object" not in configuration: + configuration["passwd_object"] = self.PASSWD_OBJECT + if "group_object" not in configuration: + configuration["group_object"] = self.GROUP_OBJECT + if "shadow_object" not in configuration: + configuration["shadow_object"] = self.SHADOW_OBJECT + if "sshkey_object" not in configuration: + configuration["sshkey_object"] = self.SSH_OBJECT def GetPasswdMap(self, since=None): """Return the passwd map from this source. @@ -78,10 +78,9 @@ def GetPasswdMap(self, since=None): Returns: instance of passwd.PasswdMap """ - return PasswdUpdateGetter().GetUpdates(self._GetClient(), - self.conf['bucket'], - self.conf['passwd_object'], - since) + return PasswdUpdateGetter().GetUpdates( + self._GetClient(), self.conf["bucket"], self.conf["passwd_object"], since + ) def GetGroupMap(self, since=None): """Return the group map from this source. @@ -93,9 +92,9 @@ def GetGroupMap(self, since=None): Returns: instance of group.GroupMap """ - return GroupUpdateGetter().GetUpdates(self._GetClient(), - self.conf['bucket'], - self.conf['group_object'], since) + return GroupUpdateGetter().GetUpdates( + self._GetClient(), self.conf["bucket"], self.conf["group_object"], since + ) def GetShadowMap(self, since=None): """Return the shadow map from this source. @@ -107,10 +106,9 @@ def GetShadowMap(self, since=None): Returns: instance of shadow.ShadowMap """ - return ShadowUpdateGetter().GetUpdates(self._GetClient(), - self.conf['bucket'], - self.conf['shadow_object'], - since) + return ShadowUpdateGetter().GetUpdates( + self._GetClient(), self.conf["bucket"], self.conf["shadow_object"], since + ) def GetSshkeyMap(self, since=None): """Return the ssh map from this source. @@ -122,10 +120,9 @@ def GetSshkeyMap(self, since=None): Returns: instance of shadow.SSHMap """ - return SshkeyUpdateGetter().GetUpdates(self._GetClient(), - self.conf['bucket'], - self.conf['sshkey_object'], - since) + return SshkeyUpdateGetter().GetUpdates( + self._GetClient(), self.conf["bucket"], self.conf["sshkey_object"], since + ) class S3UpdateGetter(object): @@ -155,18 +152,20 @@ def GetUpdates(self, s3_client, bucket, obj, since): response = s3_client.get_object( Bucket=bucket, IfModifiedSince=timestamps.FromTimestampToDateTime(since), - Key=obj) + Key=obj, + ) else: response = s3_client.get_object(Bucket=bucket, Key=obj) - body = response['Body'] + body = response["Body"] last_modified_ts = timestamps.FromDateTimeToTimestamp( - response['LastModified']) + response["LastModified"] + ) except ClientError as e: - error_code = int(e.response['Error']['Code']) + error_code = int(e.response["Error"]["Code"]) if error_code == 304: return [] - self.log.error('error getting S3 object ({}): {}'.format(obj, e)) - raise error.SourceUnavailable('unable to download object from S3') + self.log.error("error getting S3 object ({}): {}".format(obj, e)) + raise error.SourceUnavailable("unable to download object from S3") data_map = self.GetMap(cache_info=body) data_map.SetModifyTimestamp(last_modified_ts) @@ -259,20 +258,22 @@ def GetMap(self, cache_info, data): A child of Map containing the cache data. """ for obj in json.loads(cache_info.read()): - key = obj.get('Key', '') - value = obj.get('Value', '') + key = obj.get("Key", "") + value = obj.get("Value", "") if not value or not key: continue map_entry = self._ReadEntry(key, value) if map_entry is None: self.log.warning( - 'Could not create entry from line %r in cache, skipping', - value) + "Could not create entry from line %r in cache, skipping", value + ) continue if not data.Add(map_entry): self.log.warning( - 'Could not add entry %r read from line %r in cache', - map_entry, value) + "Could not add entry %r read from line %r in cache", + map_entry, + value, + ) return data @@ -285,17 +286,17 @@ def _ReadEntry(self, name, entry): map_entry = passwd.PasswdMapEntry() # maps expect strict typing, so convert to int as appropriate. map_entry.name = name - map_entry.passwd = entry.get('passwd', 'x') + map_entry.passwd = entry.get("passwd", "x") try: - map_entry.uid = int(entry['uid']) - map_entry.gid = int(entry['gid']) + map_entry.uid = int(entry["uid"]) + map_entry.gid = int(entry["gid"]) except (ValueError, KeyError): return None - map_entry.gecos = entry.get('comment', '') - map_entry.dir = entry.get('home', '/home/{}'.format(name)) - map_entry.shell = entry.get('shell', '/bin/bash') + map_entry.gecos = entry.get("comment", "") + map_entry.dir = entry.get("home", "/home/{}".format(name)) + map_entry.shell = entry.get("shell", "/bin/bash") return map_entry @@ -309,7 +310,7 @@ def _ReadEntry(self, name, entry): map_entry = sshkey.SshkeyMapEntry() # maps expect strict typing, so convert to int as appropriate. map_entry.name = name - map_entry.sshkey = entry.get('sshPublicKey', '') + map_entry.sshkey = entry.get("sshPublicKey", "") return map_entry @@ -323,17 +324,17 @@ def _ReadEntry(self, name, entry): map_entry = group.GroupMapEntry() # map entries expect strict typing, so convert as appropriate map_entry.name = name - map_entry.passwd = entry.get('passwd', 'x') + map_entry.passwd = entry.get("passwd", "x") try: - map_entry.gid = int(entry['gid']) + map_entry.gid = int(entry["gid"]) except (ValueError, KeyError): return None try: - members = entry.get('members', '').split('\n') + members = entry.get("members", "").split("\n") except (ValueError, TypeError): - members = [''] + members = [""] map_entry.members = members return map_entry @@ -347,9 +348,9 @@ def _ReadEntry(self, name, entry): map_entry = shadow.ShadowMapEntry() # maps expect strict typing, so convert to int as appropriate. map_entry.name = name - map_entry.passwd = entry.get('passwd', '*') + map_entry.passwd = entry.get("passwd", "*") - for attr in ['lstchg', 'min', 'max', 'warn', 'inact', 'expire']: + for attr in ["lstchg", "min", "max", "warn", "inact", "expire"]: try: setattr(map_entry, attr, int(entry[attr])) except (ValueError, KeyError): diff --git a/nss_cache/sources/s3source_test.py b/nss_cache/sources/s3source_test.py index 77197f3d..6c678f11 100644 --- a/nss_cache/sources/s3source_test.py +++ b/nss_cache/sources/s3source_test.py @@ -1,6 +1,6 @@ """An implementation of a mock S3 data source for nsscache.""" -__author__ = 'alexey.pikin@gmail.com' +__author__ = "alexey.pikin@gmail.com" import unittest from io import StringIO @@ -12,46 +12,46 @@ class TestS3Source(unittest.TestCase): - def setUp(self): """Initialize a basic config dict.""" super(TestS3Source, self).setUp() self.config = { - 'passwd_object': 'PASSWD_OBJ', - 'group_object': 'GROUP_OBJ', - 'bucket': 'TEST_BUCKET' + "passwd_object": "PASSWD_OBJ", + "group_object": "GROUP_OBJ", + "bucket": "TEST_BUCKET", } def testDefaultConfiguration(self): source = s3source.S3FilesSource({}) - self.assertEqual(source.conf['bucket'], s3source.S3FilesSource.BUCKET) - self.assertEqual(source.conf['passwd_object'], - s3source.S3FilesSource.PASSWD_OBJECT) + self.assertEqual(source.conf["bucket"], s3source.S3FilesSource.BUCKET) + self.assertEqual( + source.conf["passwd_object"], s3source.S3FilesSource.PASSWD_OBJECT + ) def testOverrideDefaultConfiguration(self): source = s3source.S3FilesSource(self.config) - self.assertEqual(source.conf['bucket'], 'TEST_BUCKET') - self.assertEqual(source.conf['passwd_object'], 'PASSWD_OBJ') - self.assertEqual(source.conf['group_object'], 'GROUP_OBJ') + self.assertEqual(source.conf["bucket"], "TEST_BUCKET") + self.assertEqual(source.conf["passwd_object"], "PASSWD_OBJ") + self.assertEqual(source.conf["group_object"], "GROUP_OBJ") class TestPasswdMapParser(unittest.TestCase): - def setUp(self): """Set some default avalible data for testing.""" self.good_entry = passwd.PasswdMapEntry() - self.good_entry.name = 'foo' - self.good_entry.passwd = 'x' + self.good_entry.name = "foo" + self.good_entry.passwd = "x" self.good_entry.uid = 10 self.good_entry.gid = 10 - self.good_entry.gecos = 'How Now Brown Cow' - self.good_entry.dir = '/home/foo' - self.good_entry.shell = '/bin/bash' + self.good_entry.gecos = "How Now Brown Cow" + self.good_entry.dir = "/home/foo" + self.good_entry.shell = "/bin/bash" self.parser = s3source.S3PasswdMapParser() def testGetMap(self): passwd_map = passwd.PasswdMap() - cache_info = StringIO('''[ + cache_info = StringIO( + """[ { "Key": "foo", "Value": { "uid": 10, "gid": 10, "home": "/home/foo", @@ -59,49 +59,50 @@ def testGetMap(self): "irrelevant_key":"bacon" } } - ]''') + ]""" + ) self.parser.GetMap(cache_info, passwd_map) self.assertEqual(self.good_entry, passwd_map.PopItem()) def testReadEntry(self): data = { - 'uid': '10', - 'gid': '10', - 'comment': 'How Now Brown Cow', - 'shell': '/bin/bash', - 'home': '/home/foo', - 'passwd': 'x' + "uid": "10", + "gid": "10", + "comment": "How Now Brown Cow", + "shell": "/bin/bash", + "home": "/home/foo", + "passwd": "x", } - entry = self.parser._ReadEntry('foo', data) + entry = self.parser._ReadEntry("foo", data) self.assertEqual(self.good_entry, entry) def testDefaultEntryValues(self): - data = {'uid': '10', 'gid': '10'} - entry = self.parser._ReadEntry('foo', data) - self.assertEqual(entry.shell, '/bin/bash') - self.assertEqual(entry.dir, '/home/foo') - self.assertEqual(entry.gecos, '') - self.assertEqual(entry.passwd, 'x') + data = {"uid": "10", "gid": "10"} + entry = self.parser._ReadEntry("foo", data) + self.assertEqual(entry.shell, "/bin/bash") + self.assertEqual(entry.dir, "/home/foo") + self.assertEqual(entry.gecos, "") + self.assertEqual(entry.passwd, "x") def testInvalidEntry(self): - data = {'irrelevant_key': 'bacon'} - entry = self.parser._ReadEntry('foo', data) + data = {"irrelevant_key": "bacon"} + entry = self.parser._ReadEntry("foo", data) self.assertEqual(entry, None) class TestS3GroupMapParser(unittest.TestCase): - def setUp(self): self.good_entry = group.GroupMapEntry() - self.good_entry.name = 'foo' - self.good_entry.passwd = 'x' + self.good_entry.name = "foo" + self.good_entry.passwd = "x" self.good_entry.gid = 10 - self.good_entry.members = ['foo', 'bar'] + self.good_entry.members = ["foo", "bar"] self.parser = s3source.S3GroupMapParser() def testGetMap(self): group_map = group.GroupMap() - cache_info = StringIO('''[ + cache_info = StringIO( + """[ { "Key": "foo", "Value": { "gid": 10, @@ -109,37 +110,37 @@ def testGetMap(self): "irrelevant_key": "bacon" } } - ]''') + ]""" + ) self.parser.GetMap(cache_info, group_map) self.assertEqual(self.good_entry, group_map.PopItem()) def testReadEntry(self): - data = {'passwd': 'x', 'gid': '10', 'members': 'foo\nbar'} - entry = self.parser._ReadEntry('foo', data) + data = {"passwd": "x", "gid": "10", "members": "foo\nbar"} + entry = self.parser._ReadEntry("foo", data) self.assertEqual(self.good_entry, entry) def testDefaultPasswd(self): - data = {'gid': '10', 'members': 'foo\nbar'} - entry = self.parser._ReadEntry('foo', data) + data = {"gid": "10", "members": "foo\nbar"} + entry = self.parser._ReadEntry("foo", data) self.assertEqual(self.good_entry, entry) def testNoMembers(self): - data = {'gid': '10', 'members': ''} - entry = self.parser._ReadEntry('foo', data) - self.assertEqual(entry.members, ['']) + data = {"gid": "10", "members": ""} + entry = self.parser._ReadEntry("foo", data) + self.assertEqual(entry.members, [""]) def testInvalidEntry(self): - data = {'irrelevant_key': 'bacon'} - entry = self.parser._ReadEntry('foo', data) + data = {"irrelevant_key": "bacon"} + entry = self.parser._ReadEntry("foo", data) self.assertEqual(entry, None) class TestS3ShadowMapParser(unittest.TestCase): - def setUp(self): self.good_entry = shadow.ShadowMapEntry() - self.good_entry.name = 'foo' - self.good_entry.passwd = '*' + self.good_entry.name = "foo" + self.good_entry.passwd = "*" self.good_entry.lstchg = 17246 self.good_entry.min = 0 self.good_entry.max = 99999 @@ -148,33 +149,29 @@ def setUp(self): def testGetMap(self): shadow_map = shadow.ShadowMap() - cache_info = StringIO('''[ + cache_info = StringIO( + """[ { "Key": "foo", "Value": { "passwd": "*", "lstchg": 17246, "min": 0, "max": 99999, "warn": 7 } } - ]''') + ]""" + ) self.parser.GetMap(cache_info, shadow_map) self.assertEqual(self.good_entry, shadow_map.PopItem()) def testReadEntry(self): - data = { - 'passwd': '*', - 'lstchg': 17246, - 'min': 0, - 'max': 99999, - 'warn': 7 - } - entry = self.parser._ReadEntry('foo', data) + data = {"passwd": "*", "lstchg": 17246, "min": 0, "max": 99999, "warn": 7} + entry = self.parser._ReadEntry("foo", data) self.assertEqual(self.good_entry, entry) def testDefaultPasswd(self): - data = {'lstchg': 17246, 'min': 0, 'max': 99999, 'warn': 7} - entry = self.parser._ReadEntry('foo', data) + data = {"lstchg": 17246, "min": 0, "max": 99999, "warn": 7} + entry = self.parser._ReadEntry("foo", data) self.assertEqual(self.good_entry, entry) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/nss_cache/sources/source.py b/nss_cache/sources/source.py index 88228ca2..5775e72b 100644 --- a/nss_cache/sources/source.py +++ b/nss_cache/sources/source.py @@ -15,8 +15,10 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Base class of data source object for nss_cache.""" -__author__ = ('jaq@google.com (Jamie Wilkinson)', - 'vasilios@google.com (Vasilios Hoffman)') +__author__ = ( + "jaq@google.com (Jamie Wilkinson)", + "vasilios@google.com (Vasilios Hoffman)", +) import logging @@ -39,7 +41,7 @@ def __init__(self, conf): RuntimeError: object wasn't initialised with a dict """ if not isinstance(conf, dict): - raise RuntimeError('Source constructor not passed a dictionary') + raise RuntimeError("Source constructor not passed a dictionary") self.conf = conf @@ -73,7 +75,7 @@ def GetMap(self, map_name, since=None, location=None): elif map_name == config.MAP_AUTOMOUNT: return self.GetAutomountMap(since, location=location) - raise error.UnsupportedMap('Source can not fetch %s' % map_name) + raise error.UnsupportedMap("Source can not fetch %s" % map_name) def GetAutomountMap(self, since=None, location=None): """Get an automount map from this source.""" @@ -105,7 +107,7 @@ def __init__(self, conf): RuntimeError: object wasn't initialised with a dict """ if not isinstance(conf, dict): - raise RuntimeError('Source constructor not passed a dictionary') + raise RuntimeError("Source constructor not passed a dictionary") self.conf = conf @@ -136,8 +138,6 @@ def GetFile(self, map_name, dst_file, current_file, location=None): elif map_name == config.MAP_NETGROUP: return self.GetNetgroupFile(dst_file, current_file) elif map_name == config.MAP_AUTOMOUNT: - return self.GetAutomountFile(dst_file, - current_file, - location=location) + return self.GetAutomountFile(dst_file, current_file, location=location) - raise error.UnsupportedMap('Source can not fetch %s' % map_name) + raise error.UnsupportedMap("Source can not fetch %s" % map_name) diff --git a/nss_cache/sources/source_factory.py b/nss_cache/sources/source_factory.py index 3b3547b4..b114f434 100644 --- a/nss_cache/sources/source_factory.py +++ b/nss_cache/sources/source_factory.py @@ -15,8 +15,10 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Factory for data source implementations.""" -__author__ = ('jaq@google.com (Jamie Wilkinson)', - 'vasilios@google.com (Vasilios Hoffman)') +__author__ = ( + "jaq@google.com (Jamie Wilkinson)", + "vasilios@google.com (Vasilios Hoffman)", +) _source_implementations = {} @@ -40,7 +42,7 @@ def RegisterImplementation(source): RuntimeError: no 'name' entry in this source. """ global _source_implementations - if 'name' not in source.__dict__: + if "name" not in source.__dict__: raise RuntimeError("'name' not defined in Source %r" % (source,)) _source_implementations[source.name] = source @@ -49,30 +51,35 @@ def RegisterImplementation(source): # Discover all the known implementations of sources. try: from nss_cache.sources import httpsource + httpsource.RegisterImplementation(RegisterImplementation) except ImportError: pass try: from nss_cache.sources import ldapsource + ldapsource.RegisterImplementation(RegisterImplementation) except ImportError: pass try: from nss_cache.sources import consulsource + consulsource.RegisterImplementation(RegisterImplementation) except ImportError: pass try: from nss_cache.sources import s3source + s3source.RegisterImplementation(RegisterImplementation) except ImportError: pass try: from nss_cache.sources import gcssource + gcssource.RegisterImplementation(RegisterImplementation) except ImportError: pass @@ -93,11 +100,11 @@ def Create(conf): """ global _source_implementations if not _source_implementations: - raise RuntimeError('no source implementations exist') + raise RuntimeError("no source implementations exist") - source_name = conf['name'] + source_name = conf["name"] if source_name not in list(_source_implementations.keys()): - raise RuntimeError('source not implemented: %r' % (source_name,)) + raise RuntimeError("source not implemented: %r" % (source_name,)) return _source_implementations[source_name](conf) diff --git a/nss_cache/sources/source_factory_test.py b/nss_cache/sources/source_factory_test.py index b611ec43..6080bdfe 100644 --- a/nss_cache/sources/source_factory_test.py +++ b/nss_cache/sources/source_factory_test.py @@ -15,7 +15,7 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Unit tests for sources/source.py.""" -__author__ = 'jaq@google.com (Jamie Wilkinson)' +__author__ = "jaq@google.com (Jamie Wilkinson)" import unittest @@ -31,39 +31,38 @@ def testRegister(self): number_of_sources = len(source_factory._source_implementations) class DummySource(source.Source): - name = 'dummy' + name = "dummy" source_factory.RegisterImplementation(DummySource) - self.assertEqual(number_of_sources + 1, - len(source_factory._source_implementations)) - self.assertEqual(DummySource, - source_factory._source_implementations['dummy']) + self.assertEqual( + number_of_sources + 1, len(source_factory._source_implementations) + ) + self.assertEqual(DummySource, source_factory._source_implementations["dummy"]) def testRegisterWithoutName(self): - class DummySource(source.Source): pass - self.assertRaises(RuntimeError, source_factory.RegisterImplementation, - DummySource) + self.assertRaises( + RuntimeError, source_factory.RegisterImplementation, DummySource + ) def testCreateWithNoImplementations(self): source_factory._source_implementations = {} self.assertRaises(RuntimeError, source_factory.Create, {}) def testCreate(self): - class DummySource(source.Source): - name = 'dummy' + name = "dummy" source_factory.RegisterImplementation(DummySource) - dummy_config = {'name': 'dummy'} + dummy_config = {"name": "dummy"} dummy_source = source_factory.Create(dummy_config) self.assertEqual(DummySource, type(dummy_source)) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/nss_cache/sources/source_test.py b/nss_cache/sources/source_test.py index 8c071362..8a47c6c3 100644 --- a/nss_cache/sources/source_test.py +++ b/nss_cache/sources/source_test.py @@ -15,7 +15,7 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Unit tests for sources/source.py.""" -__author__ = 'jaq@google.com (Jamie Wilkinson)' +__author__ = "jaq@google.com (Jamie Wilkinson)" import unittest @@ -33,7 +33,7 @@ def testCreateNoConfig(self): self.assertRaises(RuntimeError, source.Source, None) - config = 'foo' + config = "foo" self.assertRaises(RuntimeError, source.Source, config) @@ -42,5 +42,5 @@ def testVerify(self): self.assertRaises(NotImplementedError, s.Verify) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/nss_cache/update/files_updater.py b/nss_cache/update/files_updater.py index 8b2e938e..704d9a65 100644 --- a/nss_cache/update/files_updater.py +++ b/nss_cache/update/files_updater.py @@ -23,9 +23,9 @@ """ __author__ = ( - 'jaq@google.com (Jamie Wilkinson)', - 'vasilios@google.com (V Hoffman)', - 'blaedd@google.com (David MacKinnon)', + "jaq@google.com (Jamie Wilkinson)", + "vasilios@google.com (V Hoffman)", + "blaedd@google.com (David MacKinnon)", ) import errno @@ -41,12 +41,9 @@ class FileMapUpdater(updater.Updater): """Updates simple map files like passwd, group, shadow, and netgroup.""" - def UpdateCacheFromSource(self, - cache, - source, - incremental=False, - force_write=False, - location=None): + def UpdateCacheFromSource( + self, cache, source, incremental=False, force_write=False, location=None + ): """Update a single cache file, from a given source. Args: @@ -68,29 +65,30 @@ def UpdateCacheFromSource(self, new_file_fd, new_file = tempfile.mkstemp( dir=os.path.dirname(cache_filename), prefix=os.path.basename(cache_filename), - suffix='.nsscache.tmp') + suffix=".nsscache.tmp", + ) else: - raise error.CacheInvalid('Cache has no filename.') + raise error.CacheInvalid("Cache has no filename.") - self.log.debug('temp source filename: %s', new_file) + self.log.debug("temp source filename: %s", new_file) try: # Writes the source to new_file. # Current file is passed in to allow the source to do partial diffs. # TODO(jaq): refactor this to pass in the whole cache, so that the source # can decide how to reduce downloads, c.f. last-modify-timestamp for ldap. - source.GetFile(self.map_name, - new_file, - current_file=cache.GetCacheFilename(), - location=location) + source.GetFile( + self.map_name, + new_file, + current_file=cache.GetCacheFilename(), + location=location, + ) os.lseek(new_file_fd, 0, os.SEEK_SET) # TODO(jaq): this sucks. - source_cache = cache_factory.Create(self.cache_options, - self.map_name) + source_cache = cache_factory.Create(self.cache_options, self.map_name) source_map = source_cache.GetMap(new_file) # Update the cache from the new file. - return_val += self._FullUpdateFromFile(cache, source_map, - force_write) + return_val += self._FullUpdateFromFile(cache, source_map, force_write) finally: try: os.unlink(new_file) @@ -122,21 +120,22 @@ def _FullUpdateFromFile(self, cache, source_map, force_write=False): for entry in source_map: if not entry.Verify(): - raise error.InvalidMap('Map is not valid. Aborting') + raise error.InvalidMap("Map is not valid. Aborting") if len(source_map) == 0 and not force_write: raise error.EmptyMap( - 'Source map empty during full update, aborting. ' - 'Use --force-write to override.') + "Source map empty during full update, aborting. " + "Use --force-write to override." + ) - return_val += cache.WriteMap(map_data=source_map, - force_write=force_write) + return_val += cache.WriteMap(map_data=source_map, force_write=force_write) # We did an update, write our timestamps unless there is an error. if return_val == 0: mtime = os.stat(cache.GetCacheFilename()).st_mtime - self.log.debug('Cache filename %s has mtime %d', - cache.GetCacheFilename(), mtime) + self.log.debug( + "Cache filename %s has mtime %d", cache.GetCacheFilename(), mtime + ) self.WriteModifyTimestamp(mtime) self.WriteUpdateTimestamp() @@ -155,13 +154,11 @@ class FileAutomountUpdater(updater.Updater): """ # automount-specific options - OPT_LOCAL_MASTER = 'local_automount_master' + OPT_LOCAL_MASTER = "local_automount_master" - def __init__(self, - map_name, - timestamp_dir, - cache_options, - automount_mountpoint=None): + def __init__( + self, map_name, timestamp_dir, cache_options, automount_mountpoint=None + ): """Initialize automount-specific updater options. Args: @@ -170,11 +167,12 @@ def __init__(self, cache_options: A dict containing the options for any caches we create. automount_mountpoint: An optional string containing automount path info. """ - updater.Updater.__init__(self, map_name, timestamp_dir, cache_options, - automount_mountpoint) + updater.Updater.__init__( + self, map_name, timestamp_dir, cache_options, automount_mountpoint + ) self.local_master = False if self.OPT_LOCAL_MASTER in cache_options: - if cache_options[self.OPT_LOCAL_MASTER] == 'yes': + if cache_options[self.OPT_LOCAL_MASTER] == "yes": self.local_master = True def UpdateFromSource(self, source, incremental=False, force_write=False): @@ -217,66 +215,73 @@ def UpdateFromSource(self, source, incremental=False, force_write=False): try: if not self.local_master: - self.log.info('Retrieving automount master map.') + self.log.info("Retrieving automount master map.") master_file = source.GetAutomountMasterFile( - os.path.join(self.cache_options['dir'], 'auto.master')) - master_cache = cache_factory.Create(self.cache_options, - self.map_name, None) + os.path.join(self.cache_options["dir"], "auto.master") + ) + master_cache = cache_factory.Create(self.cache_options, self.map_name, None) master_map = master_cache.GetMap() except error.CacheNotFound: return 1 if self.local_master: - self.log.info('Using local master map to determine maps to update.') + self.log.info("Using local master map to determine maps to update.") # we need the local map to determine which of the other maps to update - cache = cache_factory.Create(self.cache_options, - self.map_name, - automount_mountpoint=None) + cache = cache_factory.Create( + self.cache_options, self.map_name, automount_mountpoint=None + ) try: local_master = cache.GetMap() except error.CacheNotFound: - self.log.warning('Local master map specified but no map found! ' - 'No maps will update.') + self.log.warning( + "Local master map specified but no map found! " + "No maps will update." + ) return return_val + 1 # update specific maps, e.g. auto.home and auto.auto for map_entry in master_map: source_location = os.path.basename(map_entry.location) mountpoint = map_entry.key # e.g. /auto mountpoint - self.log.debug('Looking at mountpoint %s', mountpoint) + self.log.debug("Looking at mountpoint %s", mountpoint) # create the cache to update - cache = cache_factory.Create(self.cache_options, - self.map_name, - automount_mountpoint=mountpoint) + cache = cache_factory.Create( + self.cache_options, self.map_name, automount_mountpoint=mountpoint + ) # update the master map with the location of the map in the cache # e.g. /etc/auto.auto replaces ou=auto.auto map_entry.location = cache.GetMapLocation() - self.log.debug('Map location: %s', map_entry.location) + self.log.debug("Map location: %s", map_entry.location) # if configured to use the local master map, skip any not defined there if self.local_master: if map_entry not in local_master: - self.log.info('Skipping entry %s, not in map %s', map_entry, - local_master) + self.log.info( + "Skipping entry %s, not in map %s", map_entry, local_master + ) continue - self.log.info('Updating mountpoint %s', map_entry.key) + self.log.info("Updating mountpoint %s", map_entry.key) # update this map (e.g. /etc/auto.auto) - update_obj = FileMapUpdater(self.map_name, - self.timestamp_dir, - self.cache_options, - automount_mountpoint=mountpoint) + update_obj = FileMapUpdater( + self.map_name, + self.timestamp_dir, + self.cache_options, + automount_mountpoint=mountpoint, + ) return_val += update_obj.UpdateCacheFromSource( - cache, source, False, force_write, source_location) + cache, source, False, force_write, source_location + ) # with sub-maps updated, write modified master map to disk if configured to if not self.local_master: # automount_mountpoint=None defaults to master - cache = cache_factory.Create(self.cache_options, - self.map_name, - automount_mountpoint=None) - update_obj = FileMapUpdater(self.map_name, self.timestamp_dir, - self.cache_options) + cache = cache_factory.Create( + self.cache_options, self.map_name, automount_mountpoint=None + ) + update_obj = FileMapUpdater( + self.map_name, self.timestamp_dir, self.cache_options + ) return_val += update_obj.FullUpdateFromMap(cache, master_file) return return_val diff --git a/nss_cache/update/files_updater_test.py b/nss_cache/update/files_updater_test.py index 9e025013..88d6d569 100644 --- a/nss_cache/update/files_updater_test.py +++ b/nss_cache/update/files_updater_test.py @@ -15,9 +15,11 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Unit tests for nss_cache/files_updater.py.""" -__author__ = ('vasilios@google.com (V Hoffman)', - 'jaq@google.com (Jamie Wilkinson)', - 'blaedd@google.com (David MacKinnon)') +__author__ = ( + "vasilios@google.com (V Hoffman)", + "jaq@google.com (Jamie Wilkinson)", + "blaedd@google.com (David MacKinnon)", +) import os import shutil @@ -50,7 +52,7 @@ def tearDown(self): shutil.rmtree(self.workdir) shutil.rmtree(self.workdir2) - @unittest.skip('timestamp isnt propagaged correctly') + @unittest.skip("timestamp isnt propagaged correctly") def testFullUpdate(self): original_modify_stamp = 1 new_modify_stamp = 2 @@ -58,8 +60,8 @@ def testFullUpdate(self): # Construct a fake source. def GetFile(map_name, dst_file, current_file, location): print(("GetFile: %s" % dst_file)) - f = open(dst_file, 'w') - f.write('root:x:0:0:root:/root:/bin/bash\n') + f = open(dst_file, "w") + f.write("root:x:0:0:root:/root:/bin/bash\n") f.close() os.utime(dst_file, (1, 2)) os.system("ls -al %s" % dst_file) @@ -67,75 +69,71 @@ def GetFile(map_name, dst_file, current_file, location): dst_file = mox.Value() source_mock = self.mox.CreateMock(source.FileSource) - source_mock.GetFile(config.MAP_PASSWORD, - mox.Remember(dst_file), - current_file=mox.IgnoreArg(), - location=mox.IgnoreArg()).WithSideEffects( - GetFile).AndReturn(dst_file) + source_mock.GetFile( + config.MAP_PASSWORD, + mox.Remember(dst_file), + current_file=mox.IgnoreArg(), + location=mox.IgnoreArg(), + ).WithSideEffects(GetFile).AndReturn(dst_file) # Construct the cache. - cache = files.FilesPasswdMapHandler({'dir': self.workdir2}) - map_entry = passwd.PasswdMapEntry({'name': 'foo', 'uid': 10, 'gid': 10}) + cache = files.FilesPasswdMapHandler({"dir": self.workdir2}) + map_entry = passwd.PasswdMapEntry({"name": "foo", "uid": 10, "gid": 10}) password_map = passwd.PasswdMap() password_map.SetModifyTimestamp(new_modify_stamp) password_map.Add(map_entry) cache.Write(password_map) - updater = files_updater.FileMapUpdater(config.MAP_PASSWORD, - self.workdir, { - 'name': 'files', - 'dir': self.workdir2 - }) + updater = files_updater.FileMapUpdater( + config.MAP_PASSWORD, self.workdir, {"name": "files", "dir": self.workdir2} + ) updater.WriteModifyTimestamp(original_modify_stamp) self.mox.ReplayAll() self.assertEqual( 0, - updater.UpdateCacheFromSource(cache, - source_mock, - force_write=False, - location=None)) + updater.UpdateCacheFromSource( + cache, source_mock, force_write=False, location=None + ), + ) self.assertEqual(new_modify_stamp, updater.GetModifyTimestamp()) self.assertNotEqual(None, updater.GetUpdateTimestamp()) - @unittest.skip('source map empty during full update') + @unittest.skip("source map empty during full update") def testFullUpdateOnEmptyCache(self): """A full update as above, but the initial cache is empty.""" original_modify_stamp = 1 new_modify_stamp = 2 # Construct an updater - self.updater = files_updater.FileMapUpdater(config.MAP_PASSWORD, - self.workdir, { - 'name': 'files', - 'dir': self.workdir2 - }) + self.updater = files_updater.FileMapUpdater( + config.MAP_PASSWORD, self.workdir, {"name": "files", "dir": self.workdir2} + ) self.updater.WriteModifyTimestamp(original_modify_stamp) # Construct a cache - cache = files.FilesPasswdMapHandler({'dir': self.workdir2}) + cache = files.FilesPasswdMapHandler({"dir": self.workdir2}) def GetFileEffects(map_name, dst_file, current_file, location): - f = open(dst_file, 'w') - f.write('root:x:0:0:root:/root:/bin/bash\n') + f = open(dst_file, "w") + f.write("root:x:0:0:root:/root:/bin/bash\n") f.close() os.utime(dst_file, (1, 2)) return dst_file source_mock = self.mox.CreateMock(source.FileSource) - source_mock.GetFile(config.MAP_PASSWORD, - mox.IgnoreArg(), - mox.IgnoreArg(), - location=None).WithSideEffects(GetFileEffects) + source_mock.GetFile( + config.MAP_PASSWORD, mox.IgnoreArg(), mox.IgnoreArg(), location=None + ).WithSideEffects(GetFileEffects) - #source_mock = MockSource() + # source_mock = MockSource() self.assertEqual( 0, - self.updater.UpdateCacheFromSource(cache, - source_mock, - force_write=False, - location=None)) + self.updater.UpdateCacheFromSource( + cache, source_mock, force_write=False, location=None + ), + ) self.assertEqual(new_modify_stamp, self.updater.GetModifyTimestamp()) self.assertNotEqual(None, self.updater.GetUpdateTimestamp()) @@ -144,33 +142,35 @@ def testFullUpdateOnEmptySource(self): original_modify_stamp = 1 new_modify_stamp = 2 # Construct an updater - self.updater = files_updater.FileMapUpdater(config.MAP_PASSWORD, - self.workdir, { - 'name': 'files', - 'dir': self.workdir2 - }) + self.updater = files_updater.FileMapUpdater( + config.MAP_PASSWORD, self.workdir, {"name": "files", "dir": self.workdir2} + ) self.updater.WriteModifyTimestamp(original_modify_stamp) # Construct a cache - cache = files.FilesPasswdMapHandler({'dir': self.workdir2}) - map_entry = passwd.PasswdMapEntry({'name': 'foo', 'uid': 10, 'gid': 10}) + cache = files.FilesPasswdMapHandler({"dir": self.workdir2}) + map_entry = passwd.PasswdMapEntry({"name": "foo", "uid": 10, "gid": 10}) password_map = passwd.PasswdMap() password_map.SetModifyTimestamp(new_modify_stamp) password_map.Add(map_entry) cache.Write(password_map) source_mock = self.mox.CreateMock(source.FileSource) - source_mock.GetFile(config.MAP_PASSWORD, - mox.IgnoreArg(), - current_file=mox.IgnoreArg(), - location=None).AndReturn(None) + source_mock.GetFile( + config.MAP_PASSWORD, + mox.IgnoreArg(), + current_file=mox.IgnoreArg(), + location=None, + ).AndReturn(None) self.mox.ReplayAll() - self.assertRaises(error.EmptyMap, - self.updater.UpdateCacheFromSource, - cache, - source_mock, - force_write=False, - location=None) + self.assertRaises( + error.EmptyMap, + self.updater.UpdateCacheFromSource, + cache, + source_mock, + force_write=False, + location=None, + ) self.assertNotEqual(new_modify_stamp, self.updater.GetModifyTimestamp()) self.assertEqual(None, self.updater.GetUpdateTimestamp()) @@ -179,39 +179,38 @@ def testFullUpdateOnEmptySourceForceWrite(self): original_modify_stamp = 1 new_modify_stamp = 2 # Construct an updater - self.updater = files_updater.FileMapUpdater(config.MAP_PASSWORD, - self.workdir, { - 'name': 'files', - 'dir': self.workdir2 - }) + self.updater = files_updater.FileMapUpdater( + config.MAP_PASSWORD, self.workdir, {"name": "files", "dir": self.workdir2} + ) self.updater.WriteModifyTimestamp(original_modify_stamp) # Construct a cache - cache = files.FilesPasswdMapHandler({'dir': self.workdir2}) - map_entry = passwd.PasswdMapEntry({'name': 'foo', 'uid': 10, 'gid': 10}) + cache = files.FilesPasswdMapHandler({"dir": self.workdir2}) + map_entry = passwd.PasswdMapEntry({"name": "foo", "uid": 10, "gid": 10}) password_map = passwd.PasswdMap() password_map.SetModifyTimestamp(new_modify_stamp) password_map.Add(map_entry) cache.Write(password_map) source_mock = self.mox.CreateMock(source.FileSource) - source_mock.GetFile(config.MAP_PASSWORD, - mox.IgnoreArg(), - current_file=mox.IgnoreArg(), - location=None).AndReturn(None) + source_mock.GetFile( + config.MAP_PASSWORD, + mox.IgnoreArg(), + current_file=mox.IgnoreArg(), + location=None, + ).AndReturn(None) self.mox.ReplayAll() self.assertEqual( 0, - self.updater.UpdateCacheFromSource(cache, - source_mock, - force_write=True, - location=None)) - self.assertNotEqual(original_modify_stamp, - self.updater.GetModifyTimestamp()) + self.updater.UpdateCacheFromSource( + cache, source_mock, force_write=True, location=None + ), + ) + self.assertNotEqual(original_modify_stamp, self.updater.GetModifyTimestamp()) self.assertNotEqual(None, self.updater.GetUpdateTimestamp()) -@unittest.skip('disabled') +@unittest.skip("disabled") class AutomountUpdaterTest(mox.MoxTestBase): """Unit tests for FileAutomountUpdater class.""" @@ -225,84 +224,88 @@ def tearDown(self): def testInit(self): """An automount object correctly sets map-specific attributes.""" - updater = files_updater.FileAutomountUpdater(config.MAP_AUTOMOUNT, - self.workdir, {}) + updater = files_updater.FileAutomountUpdater( + config.MAP_AUTOMOUNT, self.workdir, {} + ) self.assertEqual(updater.local_master, False) - conf = {files_updater.FileAutomountUpdater.OPT_LOCAL_MASTER: 'yes'} - updater = files_updater.FileAutomountUpdater(config.MAP_AUTOMOUNT, - self.workdir, conf) + conf = {files_updater.FileAutomountUpdater.OPT_LOCAL_MASTER: "yes"} + updater = files_updater.FileAutomountUpdater( + config.MAP_AUTOMOUNT, self.workdir, conf + ) self.assertEqual(updater.local_master, True) - conf = {files_updater.FileAutomountUpdater.OPT_LOCAL_MASTER: 'no'} - updater = files_updater.FileAutomountUpdater(config.MAP_AUTOMOUNT, - self.workdir, conf) + conf = {files_updater.FileAutomountUpdater.OPT_LOCAL_MASTER: "no"} + updater = files_updater.FileAutomountUpdater( + config.MAP_AUTOMOUNT, self.workdir, conf + ) self.assertEqual(updater.local_master, False) def testUpdate(self): """An update gets a master map and updates each entry.""" map_entry1 = automount.AutomountMapEntry() map_entry2 = automount.AutomountMapEntry() - map_entry1.key = '/home' - map_entry2.key = '/auto' - map_entry1.location = 'ou=auto.home,ou=automounts' - map_entry2.location = 'ou=auto.auto,ou=automounts' + map_entry1.key = "/home" + map_entry2.key = "/auto" + map_entry1.location = "ou=auto.home,ou=automounts" + map_entry2.location = "ou=auto.auto,ou=automounts" master_map = automount.AutomountMap([map_entry1, map_entry2]) source_mock = self.mox.CreateMock(zsyncsource.ZSyncSource) - source_mock.GetAutomountMasterFile( - mox.IgnoreArg()).AndReturn(master_map) + source_mock.GetAutomountMasterFile(mox.IgnoreArg()).AndReturn(master_map) # the auto.home cache cache_mock1 = self.mox.CreateMock(files.FilesCache) cache_mock1.GetCacheFilename().AndReturn(None) - cache_mock1.GetMapLocation().AndReturn('/etc/auto.home') + cache_mock1.GetMapLocation().AndReturn("/etc/auto.home") # the auto.auto cache cache_mock2 = self.mox.CreateMock(files.FilesCache) - cache_mock2.GetMapLocation().AndReturn('/etc/auto.auto') + cache_mock2.GetMapLocation().AndReturn("/etc/auto.auto") cache_mock2.GetCacheFilename().AndReturn(None) # the auto.master cache cache_mock3 = self.mox.CreateMock(files.FilesCache) cache_mock3.GetMap().AndReturn(master_map) - self.mox.StubOutWithMock(cache_factory, 'Create') - cache_factory.Create(mox.IgnoreArg(), mox.IgnoreArg(), - None).AndReturn(cache_mock3) + self.mox.StubOutWithMock(cache_factory, "Create") + cache_factory.Create(mox.IgnoreArg(), mox.IgnoreArg(), None).AndReturn( + cache_mock3 + ) cache_factory.Create( - mox.IgnoreArg(), mox.IgnoreArg(), - automount_mountpoint='/auto').AndReturn(cache_mock2) + mox.IgnoreArg(), mox.IgnoreArg(), automount_mountpoint="/auto" + ).AndReturn(cache_mock2) cache_factory.Create( - mox.IgnoreArg(), mox.IgnoreArg(), - automount_mountpoint='/home').AndReturn(cache_mock1) + mox.IgnoreArg(), mox.IgnoreArg(), automount_mountpoint="/home" + ).AndReturn(cache_mock1) self.mox.ReplayAll() - options = {'name': 'files', 'dir': self.workdir} - updater = files_updater.FileAutomountUpdater(config.MAP_AUTOMOUNT, - self.workdir, options) + options = {"name": "files", "dir": self.workdir} + updater = files_updater.FileAutomountUpdater( + config.MAP_AUTOMOUNT, self.workdir, options + ) updater.UpdateFromSource(source_mock) - self.assertEqual(map_entry1.location, '/etc/auto.home') - self.assertEqual(map_entry2.location, '/etc/auto.auto') + self.assertEqual(map_entry1.location, "/etc/auto.home") + self.assertEqual(map_entry2.location, "/etc/auto.auto") def testUpdateNoMaster(self): """An update skips updating the master map, and approprate sub maps.""" source_entry1 = automount.AutomountMapEntry() source_entry2 = automount.AutomountMapEntry() - source_entry1.key = '/home' - source_entry2.key = '/auto' - source_entry1.location = 'ou=auto.home,ou=automounts' - source_entry2.location = 'ou=auto.auto,ou=automounts' + source_entry1.key = "/home" + source_entry2.key = "/auto" + source_entry1.location = "ou=auto.home,ou=automounts" + source_entry2.location = "ou=auto.auto,ou=automounts" source_master = automount.AutomountMap([source_entry1, source_entry2]) local_entry1 = automount.AutomountMapEntry() local_entry2 = automount.AutomountMapEntry() - local_entry1.key = '/home' - local_entry2.key = '/auto' - local_entry1.location = '/etc/auto.home' - local_entry2.location = '/etc/auto.null' + local_entry1.key = "/home" + local_entry2.key = "/auto" + local_entry1.location = "/etc/auto.home" + local_entry2.location = "/etc/auto.null" local_master = automount.AutomountMap([local_entry1, local_entry2]) source_mock = self.mock() invocation = source_mock.expects(pmock.at_least_once()) @@ -313,37 +316,33 @@ def testUpdateNoMaster(self): cache_mock1 = self.mock() # GetMapLocation() is called, and set to the master map map_entry invocation = cache_mock1.expects(pmock.at_least_once()).GetMapLocation() - invocation.will(pmock.return_value('/etc/auto.home')) + invocation.will(pmock.return_value("/etc/auto.home")) # we should get called inside the DummyUpdater - cache_mock1.expects( - pmock.at_least_once())._CalledUpdateCacheFromSource() + cache_mock1.expects(pmock.at_least_once())._CalledUpdateCacheFromSource() # the auto.auto cache cache_mock2 = self.mock() # GetMapLocation() is called, and set to the master map map_entry invocation = cache_mock2.expects(pmock.at_least_once()).GetMapLocation() - invocation.will(pmock.return_value('/etc/auto.auto')) + invocation.will(pmock.return_value("/etc/auto.auto")) invocation = cache_mock2.expects( - pmock.at_least_once())._CalledUpdateCacheFromSource() + pmock.at_least_once() + )._CalledUpdateCacheFromSource() # the auto.master cache, which should not be written to cache_mock3 = self.mock() invocation = cache_mock3.expects(pmock.once()) - invocation = invocation.method('GetMap') + invocation = invocation.method("GetMap") invocation.will(pmock.return_value(local_master)) invocation = cache_mock3.expects(pmock.once()) - invocation = invocation.method('GetMap') + invocation = invocation.method("GetMap") invocation.will(pmock.return_value(local_master)) - cache_mocks = { - '/home': cache_mock1, - '/auto': cache_mock2, - None: cache_mock3 - } + cache_mocks = {"/home": cache_mock1, "/auto": cache_mock2, None: cache_mock3} # Create needs to return our mock_caches - def DummyCreate(unused_cache_options, - unused_map_name, - automount_mountpoint=None): + def DummyCreate( + unused_cache_options, unused_map_name, automount_mountpoint=None + ): # the order of the master_map iterable is not predictable, so we use the # automount_mountpoint as the key to return the right one. return cache_mocks[automount_mountpoint] @@ -352,9 +351,10 @@ def DummyCreate(unused_cache_options, cache_factory.Create = DummyCreate skip = files_updater.FileAutomountUpdater.OPT_LOCAL_MASTER - options = {skip: 'yes', 'dir': self.workdir} - updater = files_updater.FileAutomountUpdater(config.MAP_AUTOMOUNT, - self.workdir, options) + options = {skip: "yes", "dir": self.workdir} + updater = files_updater.FileAutomountUpdater( + config.MAP_AUTOMOUNT, self.workdir, options + ) updater.UpdateFromSource(source_mock) cache_factory.Create = original_create @@ -371,9 +371,9 @@ def testUpdateCatchesMissingMaster(self): invocation.will(pmock.raise_exception(error.CacheNotFound)) # Create needs to return our mock_cache - def DummyCreate(unused_cache_options, - unused_map_name, - automount_mountpoint=None): + def DummyCreate( + unused_cache_options, unused_map_name, automount_mountpoint=None + ): # the order of the master_map iterable is not predictable, so we use the # automount_mountpoint as the key to return the right one. return cache_mock @@ -382,9 +382,10 @@ def DummyCreate(unused_cache_options, cache_factory.Create = DummyCreate skip = files_updater.FileAutomountUpdater.OPT_LOCAL_MASTER - options = {skip: 'yes', 'dir': self.workdir} - updater = files_updater.FileAutomountUpdater(config.MAP_AUTOMOUNT, - self.workdir, options) + options = {skip: "yes", "dir": self.workdir} + updater = files_updater.FileAutomountUpdater( + config.MAP_AUTOMOUNT, self.workdir, options + ) return_value = updater.UpdateFromSource(source_mock) @@ -393,5 +394,5 @@ def DummyCreate(unused_cache_options, cache_factory.Create = original_create -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/nss_cache/update/map_updater.py b/nss_cache/update/map_updater.py index 6a65b678..43c0f283 100644 --- a/nss_cache/update/map_updater.py +++ b/nss_cache/update/map_updater.py @@ -22,8 +22,7 @@ AutomountMapUpdater: Class used for updating automount map caches. """ -__author__ = ('vasilios@google.com (V Hoffman)', - 'jaq@google.com (Jamie Wilkinson)') +__author__ = ("vasilios@google.com (V Hoffman)", "jaq@google.com (Jamie Wilkinson)") from nss_cache import error from nss_cache.caches import cache_factory @@ -33,12 +32,9 @@ class MapUpdater(updater.Updater): """Updates simple maps like passwd, group, shadow, and netgroup.""" - def UpdateCacheFromSource(self, - cache, - source, - incremental=False, - force_write=False, - location=None): + def UpdateCacheFromSource( + self, cache, source, incremental=False, force_write=False, location=None + ): """Update a single cache, from a given source. Args: @@ -58,19 +54,17 @@ def UpdateCacheFromSource(self, timestamp = self.GetModifyTimestamp() if timestamp is None and incremental is True: - self.log.info( - 'Missing previous timestamp, defaulting to a full sync.') + self.log.info("Missing previous timestamp, defaulting to a full sync.") incremental = False if incremental: - source_map = source.GetMap(self.map_name, - since=timestamp, - location=location) + source_map = source.GetMap( + self.map_name, since=timestamp, location=location + ) try: return_val += self._IncrementalUpdateFromMap(cache, source_map) except (error.CacheNotFound, error.EmptyMap): - self.log.warning( - 'Local cache is invalid, faulting to a full sync.') + self.log.warning("Local cache is invalid, faulting to a full sync.") incremental = False # We don't use an if/else, because we give the incremental a chance to @@ -97,10 +91,10 @@ def _IncrementalUpdateFromMap(self, cache, new_map): return_val = 0 if len(new_map) == 0: - self.log.info('Empty map on incremental update, skipping') + self.log.info("Empty map on incremental update, skipping") return 0 - self.log.debug('loading cache map, may be slow for large maps.') + self.log.debug("loading cache map, may be slow for large maps.") cache_map = cache.GetMap() if len(cache_map) == 0: @@ -112,7 +106,7 @@ def _IncrementalUpdateFromMap(self, cache, new_map): self.WriteModifyTimestamp(new_map.GetModifyTimestamp()) else: self.WriteModifyTimestamp(new_map.GetModifyTimestamp()) - self.log.info('Nothing new merged, returning') + self.log.info("Nothing new merged, returning") # We did an update, even if nothing was written, so write our # update timestamp unless there is an error. @@ -140,8 +134,9 @@ def FullUpdateFromMap(self, cache, new_map, force_write=False): if len(new_map) == 0 and not force_write: raise error.EmptyMap( - 'Source map empty during full update, aborting. ' - 'Use --force-write to override.') + "Source map empty during full update, aborting. " + "Use --force-write to override." + ) return_val = cache.WriteMap(map_data=new_map, force_write=force_write) @@ -165,13 +160,11 @@ class AutomountUpdater(updater.Updater): """ # automount-specific options - OPT_LOCAL_MASTER = 'local_automount_master' + OPT_LOCAL_MASTER = "local_automount_master" - def __init__(self, - map_name, - timestamp_dir, - cache_options, - automount_mountpoint=None): + def __init__( + self, map_name, timestamp_dir, cache_options, automount_mountpoint=None + ): """Initialize automount-specific updater options. Args: @@ -180,12 +173,12 @@ def __init__(self, cache_options: A dict containing the options for any caches we create. automount_mountpoint: An optional string containing automount path info. """ - super(AutomountUpdater, - self).__init__(map_name, timestamp_dir, cache_options, - automount_mountpoint) + super(AutomountUpdater, self).__init__( + map_name, timestamp_dir, cache_options, automount_mountpoint + ) self.local_master = False if self.OPT_LOCAL_MASTER in cache_options: - if cache_options[self.OPT_LOCAL_MASTER] == 'yes': + if cache_options[self.OPT_LOCAL_MASTER] == "yes": self.local_master = True def UpdateFromSource(self, source, incremental=True, force_write=False): @@ -227,33 +220,36 @@ def UpdateFromSource(self, source, incremental=True, force_write=False): """ return_val = 0 - self.log.info('Retrieving automount master map.') + self.log.info("Retrieving automount master map.") master_map = source.GetAutomountMasterMap() if self.local_master: - self.log.info('Using local master map to determine maps to update.') + self.log.info("Using local master map to determine maps to update.") # we need the local map to determine which of the other maps to update - cache = cache_factory.Create(self.cache_options, - self.map_name, - automount_mountpoint=None) + cache = cache_factory.Create( + self.cache_options, self.map_name, automount_mountpoint=None + ) try: local_master = cache.GetMap() except error.CacheNotFound: - self.log.warning('Local master map specified but no map found! ' - 'No maps will update.') + self.log.warning( + "Local master map specified but no map found! " + "No maps will update." + ) return return_val + 1 # update specific maps, e.g. auto.home and auto.auto for map_entry in master_map: source_location = map_entry.location # e.g. ou=auto.auto in ldap automount_mountpoint = map_entry.key # e.g. /auto mountpoint - self.log.debug('looking at %s mount.', automount_mountpoint) + self.log.debug("looking at %s mount.", automount_mountpoint) # create the cache to update cache = cache_factory.Create( self.cache_options, self.map_name, - automount_mountpoint=automount_mountpoint) + automount_mountpoint=automount_mountpoint, + ) # update the master map with the location of the map in the cache # e.g. /etc/auto.auto replaces ou=auto.auto @@ -262,26 +258,29 @@ def UpdateFromSource(self, source, incremental=True, force_write=False): # if configured to use the local master map, skip any not defined there if self.local_master: if map_entry not in local_master: - self.log.debug('skipping %s, not in %s', map_entry, - local_master) + self.log.debug("skipping %s, not in %s", map_entry, local_master) continue - self.log.info('Updating %s mount.', map_entry.key) + self.log.info("Updating %s mount.", map_entry.key) # update this map (e.g. /etc/auto.auto) - update_obj = MapUpdater(self.map_name, - self.timestamp_dir, - self.cache_options, - automount_mountpoint=automount_mountpoint) + update_obj = MapUpdater( + self.map_name, + self.timestamp_dir, + self.cache_options, + automount_mountpoint=automount_mountpoint, + ) return_val += update_obj.UpdateCacheFromSource( - cache, source, incremental, force_write, source_location) + cache, source, incremental, force_write, source_location + ) # with sub-maps updated, write modified master map to disk if configured to if not self.local_master: # automount_mountpoint=None defaults to master - cache = cache_factory.Create(self.cache_options, - self.map_name, - automount_mountpoint=None) - update_obj = MapUpdater(self.map_name, self.timestamp_dir, - self.cache_options) + cache = cache_factory.Create( + self.cache_options, self.map_name, automount_mountpoint=None + ) + update_obj = MapUpdater( + self.map_name, self.timestamp_dir, self.cache_options + ) return_val += update_obj.FullUpdateFromMap(cache, master_map) return return_val diff --git a/nss_cache/update/map_updater_test.py b/nss_cache/update/map_updater_test.py index 5642fe47..039cfa98 100644 --- a/nss_cache/update/map_updater_test.py +++ b/nss_cache/update/map_updater_test.py @@ -15,8 +15,7 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Unit tests for nss_cache/map_updater.py.""" -__author__ = ('vasilios@google.com (V Hoffman)', - 'jaq@google.com (Jamie Wilkinson)') +__author__ = ("vasilios@google.com (V Hoffman)", "jaq@google.com (Jamie Wilkinson)") import os import shutil @@ -58,29 +57,27 @@ def testFullUpdate(self): updater = map_updater.MapUpdater(config.MAP_PASSWORD, self.workdir, {}) updater.WriteModifyTimestamp(original_modify_stamp) - map_entry = passwd.PasswdMapEntry({'name': 'foo', 'uid': 10, 'gid': 10}) + map_entry = passwd.PasswdMapEntry({"name": "foo", "uid": 10, "gid": 10}) password_map = passwd.PasswdMap([map_entry]) password_map.SetModifyTimestamp(new_modify_stamp) cache_mock = self.mox.CreateMock(files.FilesCache) - cache_mock.WriteMap(map_data=password_map, - force_write=False).AndReturn(0) + cache_mock.WriteMap(map_data=password_map, force_write=False).AndReturn(0) source_mock = self.mox.CreateMock(source.Source) - source_mock.GetMap(config.MAP_PASSWORD, - location=None).AndReturn(password_map) + source_mock.GetMap(config.MAP_PASSWORD, location=None).AndReturn(password_map) self.mox.ReplayAll() self.assertEqual( 0, - updater.UpdateCacheFromSource(cache_mock, source_mock, False, False, - None)) + updater.UpdateCacheFromSource(cache_mock, source_mock, False, False, None), + ) self.assertEqual(updater.GetModifyTimestamp(), new_modify_stamp) self.assertNotEqual(updater.GetUpdateTimestamp(), None) def testFullUpdateWithEmptySourceMap(self): - """A full update reads the source, which returns an empty map. + """A full update reads the source, which returns an empty map. Need to provide force write flag to proceed.""" original_modify_stamp = 1 new_modify_stamp = 2 @@ -92,19 +89,16 @@ def testFullUpdateWithEmptySourceMap(self): password_map.SetModifyTimestamp(new_modify_stamp) cache_mock = self.mox.CreateMock(files.FilesCache) - cache_mock.WriteMap(map_data=password_map, - force_write=True).AndReturn(0) + cache_mock.WriteMap(map_data=password_map, force_write=True).AndReturn(0) source_mock = self.mox.CreateMock(source.Source) - source_mock.GetMap(config.MAP_PASSWORD, - location=None).AndReturn(password_map) + source_mock.GetMap(config.MAP_PASSWORD, location=None).AndReturn(password_map) self.mox.ReplayAll() self.assertEqual( - 0, - updater.UpdateCacheFromSource(cache_mock, source_mock, False, True, - None)) + 0, updater.UpdateCacheFromSource(cache_mock, source_mock, False, True, None) + ) self.assertEqual(updater.GetModifyTimestamp(), new_modify_stamp) self.assertNotEqual(updater.GetUpdateTimestamp(), None) @@ -120,16 +114,12 @@ def compare_function(map_object): original_modify_stamp = 1 new_modify_stamp = 2 - updater = map_updater.MapUpdater(config.MAP_PASSWORD, - self.workdir, {}, - can_do_incremental=True) + updater = map_updater.MapUpdater( + config.MAP_PASSWORD, self.workdir, {}, can_do_incremental=True + ) updater.WriteModifyTimestamp(original_modify_stamp) - cache_map_entry = passwd.PasswdMapEntry({ - 'name': 'bar', - 'uid': 20, - 'gid': 20 - }) + cache_map_entry = passwd.PasswdMapEntry({"name": "bar", "uid": 20, "gid": 20}) cache_map = passwd.PasswdMap([cache_map_entry]) cache_map.SetModifyTimestamp(original_modify_stamp) @@ -137,28 +127,27 @@ def compare_function(map_object): cache_mock.GetMap().AndReturn(cache_map) cache_mock.WriteMap(map_data=mox.Func(compare_function)).AndReturn(0) - source_map_entry = passwd.PasswdMapEntry({ - 'name': 'foo', - 'uid': 10, - 'gid': 10 - }) + source_map_entry = passwd.PasswdMapEntry({"name": "foo", "uid": 10, "gid": 10}) source_map = passwd.PasswdMap([source_map_entry]) source_map.SetModifyTimestamp(new_modify_stamp) source_mock = self.mox.CreateMock(source.Source) - source_mock.GetMap(config.MAP_PASSWORD, - location=None, - since=original_modify_stamp).AndReturn(source_map) + source_mock.GetMap( + config.MAP_PASSWORD, location=None, since=original_modify_stamp + ).AndReturn(source_map) self.mox.ReplayAll() self.assertEqual( 0, - updater.UpdateCacheFromSource(cache_mock, - source_mock, - incremental=True, - force_write=False, - location=None)) + updater.UpdateCacheFromSource( + cache_mock, + source_mock, + incremental=True, + force_write=False, + location=None, + ), + ) self.assertEqual(updater.GetModifyTimestamp(), new_modify_stamp) self.assertNotEqual(updater.GetUpdateTimestamp(), None) @@ -171,35 +160,33 @@ def testFullUpdateOnMissingCache(self): source_mock = self.mox.CreateMock(source.Source) # Try incremental first. - source_mock.GetMap(config.MAP_PASSWORD, - location=None, - since=original_modify_stamp).AndReturn('first map') + source_mock.GetMap( + config.MAP_PASSWORD, location=None, since=original_modify_stamp + ).AndReturn("first map") # Try full second. - source_mock.GetMap(config.MAP_PASSWORD, - location=None).AndReturn('second map') + source_mock.GetMap(config.MAP_PASSWORD, location=None).AndReturn("second map") - updater = map_updater.MapUpdater(config.MAP_PASSWORD, - self.workdir, {}, - can_do_incremental=True) - self.mox.StubOutWithMock(updater, 'GetModifyTimestamp') + updater = map_updater.MapUpdater( + config.MAP_PASSWORD, self.workdir, {}, can_do_incremental=True + ) + self.mox.StubOutWithMock(updater, "GetModifyTimestamp") updater.GetModifyTimestamp().AndReturn(original_modify_stamp) - self.mox.StubOutWithMock(updater, '_IncrementalUpdateFromMap') + self.mox.StubOutWithMock(updater, "_IncrementalUpdateFromMap") # force a cache not found on incremental - updater._IncrementalUpdateFromMap('cache', 'first map').AndRaise( - error.CacheNotFound) - self.mox.StubOutWithMock(updater, 'FullUpdateFromMap') - updater.FullUpdateFromMap(mox.IgnoreArg(), 'second map', - False).AndReturn(0) + updater._IncrementalUpdateFromMap("cache", "first map").AndRaise( + error.CacheNotFound + ) + self.mox.StubOutWithMock(updater, "FullUpdateFromMap") + updater.FullUpdateFromMap(mox.IgnoreArg(), "second map", False).AndReturn(0) self.mox.ReplayAll() self.assertEqual( 0, - updater.UpdateCacheFromSource('cache', - source_mock, - incremental=True, - force_write=False, - location=None)) + updater.UpdateCacheFromSource( + "cache", source_mock, incremental=True, force_write=False, location=None + ), + ) def testFullUpdateOnMissingTimestamp(self): """We fault to a full update if our modify timestamp is missing.""" @@ -208,18 +195,15 @@ def testFullUpdateOnMissingTimestamp(self): # We do not call WriteModifyTimestamp() so we force a full sync. source_mock = self.mox.CreateMock(source.Source) - source_mock.GetMap(config.MAP_PASSWORD, - location=None).AndReturn('second map') + source_mock.GetMap(config.MAP_PASSWORD, location=None).AndReturn("second map") updater = map_updater.MapUpdater(config.MAP_PASSWORD, self.workdir, {}) - self.mox.StubOutWithMock(updater, 'FullUpdateFromMap') - updater.FullUpdateFromMap(mox.IgnoreArg(), 'second map', - False).AndReturn(0) + self.mox.StubOutWithMock(updater, "FullUpdateFromMap") + updater.FullUpdateFromMap(mox.IgnoreArg(), "second map", False).AndReturn(0) self.mox.ReplayAll() self.assertEqual( - 0, - updater.UpdateCacheFromSource('cache', source_mock, True, False, - None)) + 0, updater.UpdateCacheFromSource("cache", source_mock, True, False, None) + ) class MapAutomountUpdaterTest(mox.MoxTestBase): @@ -235,28 +219,25 @@ def tearDown(self): def testInit(self): """An automount object correctly sets map-specific attributes.""" - updater = map_updater.AutomountUpdater(config.MAP_AUTOMOUNT, - self.workdir, {}) + updater = map_updater.AutomountUpdater(config.MAP_AUTOMOUNT, self.workdir, {}) self.assertEqual(updater.local_master, False) - conf = {map_updater.AutomountUpdater.OPT_LOCAL_MASTER: 'yes'} - updater = map_updater.AutomountUpdater(config.MAP_AUTOMOUNT, - self.workdir, conf) + conf = {map_updater.AutomountUpdater.OPT_LOCAL_MASTER: "yes"} + updater = map_updater.AutomountUpdater(config.MAP_AUTOMOUNT, self.workdir, conf) self.assertEqual(updater.local_master, True) - conf = {map_updater.AutomountUpdater.OPT_LOCAL_MASTER: 'no'} - updater = map_updater.AutomountUpdater(config.MAP_AUTOMOUNT, - self.workdir, conf) + conf = {map_updater.AutomountUpdater.OPT_LOCAL_MASTER: "no"} + updater = map_updater.AutomountUpdater(config.MAP_AUTOMOUNT, self.workdir, conf) self.assertEqual(updater.local_master, False) def testUpdate(self): """An update gets a master map and updates each entry.""" map_entry1 = automount.AutomountMapEntry() map_entry2 = automount.AutomountMapEntry() - map_entry1.key = '/home' - map_entry2.key = '/auto' - map_entry1.location = 'ou=auto.home,ou=automounts' - map_entry2.location = 'ou=auto.auto,ou=automounts' + map_entry1.key = "/home" + map_entry2.key = "/auto" + map_entry1.location = "ou=auto.home,ou=automounts" + map_entry2.location = "ou=auto.auto,ou=automounts" master_map = automount.AutomountMap([map_entry1, map_entry2]) source_mock = self.mox.CreateMock(source.Source) @@ -266,70 +247,68 @@ def testUpdate(self): # the auto.home cache cache_home = self.mox.CreateMock(caches.Cache) # GetMapLocation() is called, and set to the master map map_entry - cache_home.GetMapLocation().AndReturn('/etc/auto.home') + cache_home.GetMapLocation().AndReturn("/etc/auto.home") # the auto.auto cache cache_auto = self.mox.CreateMock(caches.Cache) # GetMapLocation() is called, and set to the master map map_entry - cache_auto.GetMapLocation().AndReturn('/etc/auto.auto') + cache_auto.GetMapLocation().AndReturn("/etc/auto.auto") # the auto.master cache cache_master = self.mox.CreateMock(caches.Cache) - self.mox.StubOutWithMock(cache_factory, 'Create') - cache_factory.Create(mox.IgnoreArg(), - 'automount', - automount_mountpoint='/home').AndReturn(cache_home) - cache_factory.Create(mox.IgnoreArg(), - 'automount', - automount_mountpoint='/auto').AndReturn(cache_auto) - cache_factory.Create(mox.IgnoreArg(), - 'automount', - automount_mountpoint=None).AndReturn(cache_master) - - updater = map_updater.AutomountUpdater(config.MAP_AUTOMOUNT, - self.workdir, {}) - - self.mox.StubOutClassWithMocks(map_updater, 'MapUpdater') - updater_home = map_updater.MapUpdater(config.MAP_AUTOMOUNT, - self.workdir, {}, - automount_mountpoint='/home') + self.mox.StubOutWithMock(cache_factory, "Create") + cache_factory.Create( + mox.IgnoreArg(), "automount", automount_mountpoint="/home" + ).AndReturn(cache_home) + cache_factory.Create( + mox.IgnoreArg(), "automount", automount_mountpoint="/auto" + ).AndReturn(cache_auto) + cache_factory.Create( + mox.IgnoreArg(), "automount", automount_mountpoint=None + ).AndReturn(cache_master) + + updater = map_updater.AutomountUpdater(config.MAP_AUTOMOUNT, self.workdir, {}) + + self.mox.StubOutClassWithMocks(map_updater, "MapUpdater") + updater_home = map_updater.MapUpdater( + config.MAP_AUTOMOUNT, self.workdir, {}, automount_mountpoint="/home" + ) updater_home.UpdateCacheFromSource( - cache_home, source_mock, True, False, - 'ou=auto.home,ou=automounts').AndReturn(0) - updater_auto = map_updater.MapUpdater(config.MAP_AUTOMOUNT, - self.workdir, {}, - automount_mountpoint='/auto') + cache_home, source_mock, True, False, "ou=auto.home,ou=automounts" + ).AndReturn(0) + updater_auto = map_updater.MapUpdater( + config.MAP_AUTOMOUNT, self.workdir, {}, automount_mountpoint="/auto" + ) updater_auto.UpdateCacheFromSource( - cache_auto, source_mock, True, False, - 'ou=auto.auto,ou=automounts').AndReturn(0) - updater_master = map_updater.MapUpdater(config.MAP_AUTOMOUNT, - self.workdir, {}) + cache_auto, source_mock, True, False, "ou=auto.auto,ou=automounts" + ).AndReturn(0) + updater_master = map_updater.MapUpdater(config.MAP_AUTOMOUNT, self.workdir, {}) updater_master.FullUpdateFromMap(cache_master, master_map).AndReturn(0) self.mox.ReplayAll() updater.UpdateFromSource(source_mock) - self.assertEqual(map_entry1.location, '/etc/auto.home') - self.assertEqual(map_entry2.location, '/etc/auto.auto') + self.assertEqual(map_entry1.location, "/etc/auto.home") + self.assertEqual(map_entry2.location, "/etc/auto.auto") def testUpdateNoMaster(self): """An update skips updating the master map, and approprate sub maps.""" source_entry1 = automount.AutomountMapEntry() source_entry2 = automount.AutomountMapEntry() - source_entry1.key = '/home' - source_entry2.key = '/auto' - source_entry1.location = 'ou=auto.home,ou=automounts' - source_entry2.location = 'ou=auto.auto,ou=automounts' + source_entry1.key = "/home" + source_entry2.key = "/auto" + source_entry1.location = "ou=auto.home,ou=automounts" + source_entry2.location = "ou=auto.auto,ou=automounts" source_master = automount.AutomountMap([source_entry1, source_entry2]) local_entry1 = automount.AutomountMapEntry() local_entry2 = automount.AutomountMapEntry() - local_entry1.key = '/home' - local_entry2.key = '/auto' - local_entry1.location = '/etc/auto.home' - local_entry2.location = '/etc/auto.null' + local_entry1.key = "/home" + local_entry2.key = "/auto" + local_entry1.location = "/etc/auto.home" + local_entry2.location = "/etc/auto.null" local_master = automount.AutomountMap([local_entry1, local_entry2]) source_mock = self.mox.CreateMock(source.Source) @@ -339,40 +318,43 @@ def testUpdateNoMaster(self): # the auto.home cache cache_home = self.mox.CreateMock(caches.Cache) # GetMapLocation() is called, and set to the master map map_entry - cache_home.GetMapLocation().AndReturn('/etc/auto.home') + cache_home.GetMapLocation().AndReturn("/etc/auto.home") # the auto.auto cache cache_auto = self.mox.CreateMock(caches.Cache) # GetMapLocation() is called, and set to the master map map_entry - cache_auto.GetMapLocation().AndReturn('/etc/auto.auto') + cache_auto.GetMapLocation().AndReturn("/etc/auto.auto") # the auto.master cache, which should not be written to cache_master = self.mox.CreateMock(caches.Cache) cache_master.GetMap().AndReturn(local_master) - self.mox.StubOutWithMock(cache_factory, 'Create') - cache_factory.Create(mox.IgnoreArg(), - mox.IgnoreArg(), - automount_mountpoint=None).AndReturn(cache_master) - cache_factory.Create(mox.IgnoreArg(), - mox.IgnoreArg(), - automount_mountpoint='/home').AndReturn(cache_home) - cache_factory.Create(mox.IgnoreArg(), - mox.IgnoreArg(), - automount_mountpoint='/auto').AndReturn(cache_auto) + self.mox.StubOutWithMock(cache_factory, "Create") + cache_factory.Create( + mox.IgnoreArg(), mox.IgnoreArg(), automount_mountpoint=None + ).AndReturn(cache_master) + cache_factory.Create( + mox.IgnoreArg(), mox.IgnoreArg(), automount_mountpoint="/home" + ).AndReturn(cache_home) + cache_factory.Create( + mox.IgnoreArg(), mox.IgnoreArg(), automount_mountpoint="/auto" + ).AndReturn(cache_auto) skip = map_updater.AutomountUpdater.OPT_LOCAL_MASTER - updater = map_updater.AutomountUpdater(config.MAP_AUTOMOUNT, - self.workdir, {skip: 'yes'}) - - self.mox.StubOutClassWithMocks(map_updater, 'MapUpdater') - updater_home = map_updater.MapUpdater(config.MAP_AUTOMOUNT, - self.workdir, - {'local_automount_master': 'yes'}, - automount_mountpoint='/home') + updater = map_updater.AutomountUpdater( + config.MAP_AUTOMOUNT, self.workdir, {skip: "yes"} + ) + + self.mox.StubOutClassWithMocks(map_updater, "MapUpdater") + updater_home = map_updater.MapUpdater( + config.MAP_AUTOMOUNT, + self.workdir, + {"local_automount_master": "yes"}, + automount_mountpoint="/home", + ) updater_home.UpdateCacheFromSource( - cache_home, source_mock, True, False, - 'ou=auto.home,ou=automounts').AndReturn(0) + cache_home, source_mock, True, False, "ou=auto.home,ou=automounts" + ).AndReturn(0) self.mox.ReplayAll() @@ -380,7 +362,6 @@ def testUpdateNoMaster(self): class AutomountUpdaterMoxTest(mox.MoxTestBase): - def setUp(self): super(AutomountUpdaterMoxTest, self).setUp() self.workdir = tempfile.mkdtemp() @@ -403,22 +384,23 @@ def testUpdateCatchesMissingMaster(self): cache_mock.GetMap().AndRaise(error.CacheNotFound) skip = map_updater.AutomountUpdater.OPT_LOCAL_MASTER - cache_options = {skip: 'yes'} + cache_options = {skip: "yes"} - self.mox.StubOutWithMock(cache_factory, 'Create') - cache_factory.Create(cache_options, - 'automount', - automount_mountpoint=None).AndReturn(cache_mock) + self.mox.StubOutWithMock(cache_factory, "Create") + cache_factory.Create( + cache_options, "automount", automount_mountpoint=None + ).AndReturn(cache_mock) self.mox.ReplayAll() - updater = map_updater.AutomountUpdater(config.MAP_AUTOMOUNT, - self.workdir, cache_options) + updater = map_updater.AutomountUpdater( + config.MAP_AUTOMOUNT, self.workdir, cache_options + ) return_value = updater.UpdateFromSource(source_mock) self.assertEqual(return_value, 1) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/nss_cache/update/updater.py b/nss_cache/update/updater.py index f02fd889..6f2e3f95 100644 --- a/nss_cache/update/updater.py +++ b/nss_cache/update/updater.py @@ -24,8 +24,7 @@ """ import errno -__author__ = ('vasilios@google.com (V Hoffman)', - 'jaq@google.com (Jamie Wilkinson)') +__author__ = ("vasilios@google.com (V Hoffman)", "jaq@google.com (Jamie Wilkinson)") import calendar import logging @@ -53,12 +52,14 @@ class Updater(object): update_file: A string with our last updated timestamp filename. """ - def __init__(self, - map_name, - timestamp_dir, - cache_options, - automount_mountpoint=None, - can_do_incremental=False): + def __init__( + self, + map_name, + timestamp_dir, + cache_options, + automount_mountpoint=None, + can_do_incremental=False, + ): """Construct an updater object. Args: @@ -82,15 +83,18 @@ def __init__(self, # Calculate our timestamp files if automount_mountpoint is None: - timestamp_prefix = '%s/timestamp-%s' % (timestamp_dir, map_name) + timestamp_prefix = "%s/timestamp-%s" % (timestamp_dir, map_name) else: # turn /auto into auto.auto, and /usr/local into /auto.usr_local - automount_mountpoint = automount_mountpoint.lstrip('/') - automount_mountpoint = automount_mountpoint.replace('/', '_') - timestamp_prefix = '%s/timestamp-%s-%s' % (timestamp_dir, map_name, - automount_mountpoint) - self.modify_file = '%s-modify' % timestamp_prefix - self.update_file = '%s-update' % timestamp_prefix + automount_mountpoint = automount_mountpoint.lstrip("/") + automount_mountpoint = automount_mountpoint.replace("/", "_") + timestamp_prefix = "%s/timestamp-%s-%s" % ( + timestamp_dir, + map_name, + automount_mountpoint, + ) + self.modify_file = "%s-modify" % timestamp_prefix + self.update_file = "%s-update" % timestamp_prefix # Timestamp info is cached here self.modify_time = None @@ -120,37 +124,42 @@ def _ReadTimestamp(self, filename): return None try: - timestamp_file = open(filename, 'r') + timestamp_file = open(filename, "r") timestamp_string = timestamp_file.read().strip() except IOError as e: - self.log.warning('error opening timestamp file: %s', e) + self.log.warning("error opening timestamp file: %s", e) timestamp_string = None else: timestamp_file.close() - self.log.debug('read timestamp %s from file %r', timestamp_string, - filename) + self.log.debug("read timestamp %s from file %r", timestamp_string, filename) if timestamp_string is not None: try: # Append UTC to force the timezone to parse the string in. timestamp = int( calendar.timegm( - time.strptime(timestamp_string + ' UTC', - '%Y-%m-%dT%H:%M:%SZ %Z'))) + time.strptime( + timestamp_string + " UTC", "%Y-%m-%dT%H:%M:%SZ %Z" + ) + ) + ) except ValueError as e: - self.log.error('cannot parse timestamp file %r: %s', filename, - e) + self.log.error("cannot parse timestamp file %r: %s", filename, e) timestamp = None else: timestamp = None now = self._GetCurrentTime() if timestamp and timestamp > now: - self.log.warning('timestamp %r from %r is in the future, now is %r', - timestamp_string, filename, now) + self.log.warning( + "timestamp %r from %r is in the future, now is %r", + timestamp_string, + filename, + now, + ) if timestamp - now >= 60 * 60: - self.log.info('Resetting timestamp to now.') + self.log.info("Resetting timestamp to now.") timestamp = now return timestamp @@ -179,24 +188,25 @@ def _WriteTimestamp(self, timestamp, filename): else: raise - (filedesc, temp_filename) = tempfile.mkstemp(prefix='nsscache-update-', - dir=self.timestamp_dir) - time_string = time.strftime('%Y-%m-%dT%H:%M:%SZ', - time.gmtime(timestamp)) + (filedesc, temp_filename) = tempfile.mkstemp( + prefix="nsscache-update-", dir=self.timestamp_dir + ) + time_string = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(timestamp)) try: - os.write(filedesc, b'%s\n' % time_string.encode()) + os.write(filedesc, b"%s\n" % time_string.encode()) os.fsync(filedesc) os.close(filedesc) except OSError: os.unlink(temp_filename) - self.log.warning('writing timestamp failed!') + self.log.warning("writing timestamp failed!") return False - os.chmod(temp_filename, - stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IROTH) + os.chmod( + temp_filename, stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IROTH + ) os.rename(temp_filename, filename) - self.log.debug('wrote timestamp %s to file %r', time_string, filename) + self.log.debug("wrote timestamp %s to file %r", time_string, filename) return True def GetUpdateTimestamp(self): @@ -278,8 +288,6 @@ def UpdateFromSource(self, source, incremental=True, force_write=False): # Create the single cache we write to cache = cache_factory.Create(self.cache_options, self.map_name) - return self.UpdateCacheFromSource(cache, - source, - incremental, - force_write, - location=None) + return self.UpdateCacheFromSource( + cache, source, incremental, force_write, location=None + ) diff --git a/nss_cache/update/updater_test.py b/nss_cache/update/updater_test.py index fd0edf70..2164642e 100644 --- a/nss_cache/update/updater_test.py +++ b/nss_cache/update/updater_test.py @@ -15,21 +15,20 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Unit tests for nss_cache/update/base.py.""" -__author__ = ('vasilios@google.com (V Hoffman)', - 'jaq@google.com (Jamie Wilkinson)') +__author__ = ("vasilios@google.com (V Hoffman)", "jaq@google.com (Jamie Wilkinson)") import os import shutil import tempfile import time import unittest -from mox3 import mox +from unittest import mock from nss_cache import config from nss_cache.update import updater -class TestUpdater(mox.MoxTestBase): +class TestUpdater(unittest.TestCase): """Unit tests for the Updater class.""" def setUp(self): @@ -56,13 +55,19 @@ def testTimestampDir(self): self.assertEqual( update_time, update_stamp, - msg=('retrieved a different update time than we stored: ' - 'Expected: %r, observed: %r' % (update_time, update_stamp))) + msg=( + "retrieved a different update time than we stored: " + "Expected: %r, observed: %r" % (update_time, update_stamp) + ), + ) self.assertEqual( modify_time, modify_stamp, - msg=('retrieved a different modify time than we stored: ' - 'Expected %r, observed: %r' % (modify_time, modify_stamp))) + msg=( + "retrieved a different modify time than we stored: " + "Expected %r, observed: %r" % (modify_time, modify_stamp) + ), + ) def testWriteWhenTimestampIsNone(self): update_obj = updater.Updater(config.MAP_PASSWORD, self.workdir, {}) @@ -76,16 +81,12 @@ def testTimestampDefaultsToNone(self): update_stamp = update_obj.GetUpdateTimestamp() modify_stamp = update_obj.GetModifyTimestamp() - self.assertEqual(None, - update_stamp, - msg='update time did not default to None') - self.assertEqual(None, - modify_stamp, - msg='modify time did not default to None') + self.assertEqual(None, update_stamp, msg="update time did not default to None") + self.assertEqual(None, modify_stamp, msg="modify time did not default to None") # touch a file, make it unreadable - update_file = open(update_obj.update_file, 'w') - modify_file = open(update_obj.modify_file, 'w') + update_file = open(update_obj.update_file, "w") + modify_file = open(update_obj.modify_file, "w") update_file.close() modify_file.close() os.chmod(update_obj.update_file, 0000) @@ -94,26 +95,27 @@ def testTimestampDefaultsToNone(self): update_stamp = update_obj.GetUpdateTimestamp() modify_stamp = update_obj.GetModifyTimestamp() - self.assertEqual(None, - update_stamp, - msg='unreadable update time did not default to None') - self.assertEqual(None, - modify_stamp, - msg='unreadable modify time did not default to None') + self.assertEqual( + None, update_stamp, msg="unreadable update time did not default to None" + ) + self.assertEqual( + None, modify_stamp, msg="unreadable modify time did not default to None" + ) def testTimestampInTheFuture(self): """Timestamps in the future are turned into now.""" update_obj = updater.Updater(config.MAP_PASSWORD, self.workdir, {}) expected_time = 1 update_time = 3601 - update_file = open(update_obj.update_file, 'w') + update_file = open(update_obj.update_file, "w") update_obj.WriteUpdateTimestamp(update_time) update_file.close() - self.mox.StubOutWithMock(update_obj, '_GetCurrentTime') - update_obj._GetCurrentTime().AndReturn(expected_time) - self.mox.ReplayAll() - self.assertEqual(expected_time, update_obj.GetUpdateTimestamp()) + + with mock.patch.object( + update_obj, "_GetCurrentTime", return_value=expected_time + ) as ct: + self.assertEqual(expected_time, update_obj.GetUpdateTimestamp()) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/nss_cache/util/curl.py b/nss_cache/util/curl.py index 059f3241..e17dc60a 100644 --- a/nss_cache/util/curl.py +++ b/nss_cache/util/curl.py @@ -15,7 +15,7 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Minor curl methods.""" -__author__ = 'blaedd@google.com (David MacKinnon)' +__author__ = "blaedd@google.com (David MacKinnon)" import logging import pycurl @@ -42,8 +42,7 @@ def CurlFetch(url, conn=None, logger=None): HandleCurlError(e, logger) raise error.Error(e) resp_code = conn.getinfo(pycurl.RESPONSE_CODE) - return (resp_code, conn.headers.getvalue().decode('utf-8'), - conn.body.getvalue()) + return (resp_code, conn.headers.getvalue().decode("utf-8"), conn.body.getvalue()) def HandleCurlError(e, logger=None): @@ -68,18 +67,29 @@ def HandleCurlError(e, logger=None): msg = e.args[1] # Config errors - if code in (pycurl.E_UNSUPPORTED_PROTOCOL, pycurl.E_URL_MALFORMAT, - pycurl.E_SSL_ENGINE_NOTFOUND, pycurl.E_SSL_ENGINE_SETFAILED, - pycurl.E_SSL_CACERT_BADFILE): + if code in ( + pycurl.E_UNSUPPORTED_PROTOCOL, + pycurl.E_URL_MALFORMAT, + pycurl.E_SSL_ENGINE_NOTFOUND, + pycurl.E_SSL_ENGINE_SETFAILED, + pycurl.E_SSL_CACERT_BADFILE, + ): raise error.ConfigurationError(msg) # Possibly transient errors, try again - if code in (pycurl.E_FAILED_INIT, pycurl.E_COULDNT_CONNECT, - pycurl.E_PARTIAL_FILE, pycurl.E_WRITE_ERROR, - pycurl.E_READ_ERROR, pycurl.E_OPERATION_TIMEOUTED, - pycurl.E_SSL_CONNECT_ERROR, pycurl.E_COULDNT_RESOLVE_PROXY, - pycurl.E_COULDNT_RESOLVE_HOST, pycurl.E_GOT_NOTHING): - logger.debug('Possibly transient error: %s', msg) + if code in ( + pycurl.E_FAILED_INIT, + pycurl.E_COULDNT_CONNECT, + pycurl.E_PARTIAL_FILE, + pycurl.E_WRITE_ERROR, + pycurl.E_READ_ERROR, + pycurl.E_OPERATION_TIMEOUTED, + pycurl.E_SSL_CONNECT_ERROR, + pycurl.E_COULDNT_RESOLVE_PROXY, + pycurl.E_COULDNT_RESOLVE_HOST, + pycurl.E_GOT_NOTHING, + ): + logger.debug("Possibly transient error: %s", msg) return # SSL issues diff --git a/nss_cache/util/file_formats.py b/nss_cache/util/file_formats.py index 618dd056..df2e32d6 100644 --- a/nss_cache/util/file_formats.py +++ b/nss_cache/util/file_formats.py @@ -15,8 +15,10 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Parsing methods for file cache types.""" -__author__ = ('jaq@google.com (Jamie Wilkinson)', - 'vasilios@google.com (Vasilios Hoffman)') +__author__ = ( + "jaq@google.com (Jamie Wilkinson)", + "vasilios@google.com (Vasilios Hoffman)", +) import logging @@ -31,6 +33,7 @@ SetType = set except NameError: import sets + SetType = sets.Set @@ -50,19 +53,19 @@ def GetMap(self, cache_info, data): A child of Map containing the cache data. """ for line in cache_info: - line = line.rstrip('\n') - if not line or line[0] == '#': + line = line.rstrip("\n") + if not line or line[0] == "#": continue entry = self._ReadEntry(line) if entry is None: self.log.warning( - 'Could not create entry from line %r in cache, skipping', - line) + "Could not create entry from line %r in cache, skipping", line + ) continue if not data.Add(entry): self.log.warning( - 'Could not add entry %r read from line %r in cache', entry, - line) + "Could not add entry %r read from line %r in cache", entry, line + ) return data @@ -71,7 +74,7 @@ class FilesSshkeyMapParser(FilesMapParser): def _ReadEntry(self, entry): """Return a SshkeyMapEntry from a record in the target cache.""" - entry = entry.split(':') + entry = entry.split(":") map_entry = sshkey.SshkeyMapEntry() # maps expect strict typing, so convert to int as appropriate. map_entry.name = entry[0] @@ -84,7 +87,7 @@ class FilesPasswdMapParser(FilesMapParser): def _ReadEntry(self, entry): """Return a PasswdMapEntry from a record in the target cache.""" - entry = entry.split(':') + entry = entry.split(":") map_entry = passwd.PasswdMapEntry() # maps expect strict typing, so convert to int as appropriate. map_entry.name = entry[0] @@ -102,13 +105,13 @@ class FilesGroupMapParser(FilesMapParser): def _ReadEntry(self, line): """Return a GroupMapEntry from a record in the target cache.""" - line = line.split(':') + line = line.split(":") map_entry = group.GroupMapEntry() # map entries expect strict typing, so convert as appropriate map_entry.name = line[0] map_entry.passwd = line[1] map_entry.gid = int(line[2]) - map_entry.members = line[3].split(',') + map_entry.members = line[3].split(",") return map_entry @@ -117,7 +120,7 @@ class FilesShadowMapParser(FilesMapParser): def _ReadEntry(self, line): """Return a ShadowMapEntry from a record in the target cache.""" - line = line.split(':') + line = line.split(":") map_entry = shadow.ShadowMapEntry() # map entries expect strict typing, so convert as appropriate map_entry.name = line[0] @@ -148,20 +151,20 @@ def _ReadEntry(self, line): # the first word is our name, but since the whole line is space delimited # avoid .split(' ') since groups can have thousands of members. - index = line.find(' ') + index = line.find(" ") if index == -1: if line: # empty group is OK, as long as the line isn't blank map_entry.name = line return map_entry - raise RuntimeError('Failed to parse entry: %s' % line) + raise RuntimeError("Failed to parse entry: %s" % line) map_entry.name = line[0:index] # the rest is our entries, and for better or for worse this preserves extra # leading spaces - map_entry.entries = line[index + 1:] + map_entry.entries = line[index + 1 :] return map_entry diff --git a/nss_cache/util/file_formats_test.py b/nss_cache/util/file_formats_test.py index 1d5c28e8..efeb876b 100644 --- a/nss_cache/util/file_formats_test.py +++ b/nss_cache/util/file_formats_test.py @@ -15,8 +15,10 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Unit tests for nss_cache/util/file_formats.py.""" -__author__ = ('jaq@google.com (Jamie Wilkinson)', - 'vasilios@google.com (Vasilios Hoffman)') +__author__ = ( + "jaq@google.com (Jamie Wilkinson)", + "vasilios@google.com (Vasilios Hoffman)", +) import unittest @@ -24,40 +26,39 @@ class TestFilesUtils(unittest.TestCase): - def testReadPasswdEntry(self): """We correctly parse a typical entry in /etc/passwd format.""" parser = file_formats.FilesPasswdMapParser() - file_entry = 'root:x:0:0:Rootsy:/root:/bin/bash' + file_entry = "root:x:0:0:Rootsy:/root:/bin/bash" map_entry = parser._ReadEntry(file_entry) - self.assertEqual(map_entry.name, 'root') - self.assertEqual(map_entry.passwd, 'x') + self.assertEqual(map_entry.name, "root") + self.assertEqual(map_entry.passwd, "x") self.assertEqual(map_entry.uid, 0) self.assertEqual(map_entry.gid, 0) - self.assertEqual(map_entry.gecos, 'Rootsy') - self.assertEqual(map_entry.dir, '/root') - self.assertEqual(map_entry.shell, '/bin/bash') + self.assertEqual(map_entry.gecos, "Rootsy") + self.assertEqual(map_entry.dir, "/root") + self.assertEqual(map_entry.shell, "/bin/bash") def testReadGroupEntry(self): """We correctly parse a typical entry in /etc/group format.""" parser = file_formats.FilesGroupMapParser() - file_entry = 'root:x:0:zero_cool,acid_burn' + file_entry = "root:x:0:zero_cool,acid_burn" map_entry = parser._ReadEntry(file_entry) - self.assertEqual(map_entry.name, 'root') - self.assertEqual(map_entry.passwd, 'x') + self.assertEqual(map_entry.name, "root") + self.assertEqual(map_entry.passwd, "x") self.assertEqual(map_entry.gid, 0) - self.assertEqual(map_entry.members, ['zero_cool', 'acid_burn']) + self.assertEqual(map_entry.members, ["zero_cool", "acid_burn"]) def testReadShadowEntry(self): """We correctly parse a typical entry in /etc/shadow format.""" parser = file_formats.FilesShadowMapParser() - file_entry = 'root:$1$zomgmd5support:::::::' + file_entry = "root:$1$zomgmd5support:::::::" map_entry = parser._ReadEntry(file_entry) - self.assertEqual(map_entry.name, 'root') - self.assertEqual(map_entry.passwd, '$1$zomgmd5support') + self.assertEqual(map_entry.name, "root") + self.assertEqual(map_entry.passwd, "$1$zomgmd5support") self.assertEqual(map_entry.lstchg, None) self.assertEqual(map_entry.min, None) self.assertEqual(map_entry.max, None) @@ -69,49 +70,48 @@ def testReadShadowEntry(self): def testReadNetgroupEntry(self): """We correctly parse a typical entry in /etc/netgroup format.""" parser = file_formats.FilesNetgroupMapParser() - file_entry = 'administrators unix_admins noc_monkeys (-,zero_cool,)' + file_entry = "administrators unix_admins noc_monkeys (-,zero_cool,)" map_entry = parser._ReadEntry(file_entry) - self.assertEqual(map_entry.name, 'administrators') - self.assertEqual(map_entry.entries, - 'unix_admins noc_monkeys (-,zero_cool,)') + self.assertEqual(map_entry.name, "administrators") + self.assertEqual(map_entry.entries, "unix_admins noc_monkeys (-,zero_cool,)") def testReadEmptyNetgroupEntry(self): """We correctly parse a memberless netgroup entry.""" parser = file_formats.FilesNetgroupMapParser() - file_entry = 'administrators' + file_entry = "administrators" map_entry = parser._ReadEntry(file_entry) - self.assertEqual(map_entry.name, 'administrators') - self.assertEqual(map_entry.entries, '') + self.assertEqual(map_entry.name, "administrators") + self.assertEqual(map_entry.entries, "") def testReadAutomountEntry(self): """We correctly parse a typical entry in /etc/auto.* format.""" parser = file_formats.FilesAutomountMapParser() - file_entry = 'scratch -tcp,rw,intr,bg fileserver:/scratch' + file_entry = "scratch -tcp,rw,intr,bg fileserver:/scratch" map_entry = parser._ReadEntry(file_entry) - self.assertEqual(map_entry.key, 'scratch') - self.assertEqual(map_entry.options, '-tcp,rw,intr,bg') - self.assertEqual(map_entry.location, 'fileserver:/scratch') + self.assertEqual(map_entry.key, "scratch") + self.assertEqual(map_entry.options, "-tcp,rw,intr,bg") + self.assertEqual(map_entry.location, "fileserver:/scratch") def testReadAutmountEntryWithExtraWhitespace(self): """Extra whitespace doesn't break the parsing.""" parser = file_formats.FilesAutomountMapParser() - file_entry = 'scratch fileserver:/scratch' + file_entry = "scratch fileserver:/scratch" map_entry = parser._ReadEntry(file_entry) - self.assertEqual(map_entry.key, 'scratch') + self.assertEqual(map_entry.key, "scratch") self.assertEqual(map_entry.options, None) - self.assertEqual(map_entry.location, 'fileserver:/scratch') + self.assertEqual(map_entry.location, "fileserver:/scratch") def testReadBadAutomountEntry(self): """Cope with empty data.""" parser = file_formats.FilesAutomountMapParser() - file_entry = '' + file_entry = "" map_entry = parser._ReadEntry(file_entry) self.assertEqual(None, map_entry) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/nss_cache/util/timestamps.py b/nss_cache/util/timestamps.py index a76e0814..f1f0bac7 100644 --- a/nss_cache/util/timestamps.py +++ b/nss_cache/util/timestamps.py @@ -15,7 +15,7 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Timestamp handling routines.""" -__author__ = 'jaq@google.com (Jamie Wilkinson)' +__author__ = "jaq@google.com (Jamie Wilkinson)" import logging import os.path @@ -45,36 +45,41 @@ def ReadTimestamp(filename): return None try: - timestamp_file = open(filename, 'r') + timestamp_file = open(filename, "r") timestamp_string = timestamp_file.read().strip() except IOError as e: - logging.warning('error opening timestamp file: %s', e) + logging.warning("error opening timestamp file: %s", e) timestamp_string = None else: timestamp_file.close() - logging.debug('read timestamp %s from file %r', timestamp_string, filename) + logging.debug("read timestamp %s from file %r", timestamp_string, filename) if timestamp_string is not None: try: # Append UTC to force the timezone to parse the string in. - timestamp = time.strptime(timestamp_string + ' UTC', - '%Y-%m-%dT%H:%M:%SZ %Z') + timestamp = time.strptime( + timestamp_string + " UTC", "%Y-%m-%dT%H:%M:%SZ %Z" + ) except ValueError as e: - logging.error('cannot parse timestamp file %r: %s', filename, e) + logging.error("cannot parse timestamp file %r: %s", filename, e) timestamp = None else: timestamp = None - logging.debug('Timestamp is: %r', timestamp) + logging.debug("Timestamp is: %r", timestamp) now = time.gmtime() - logging.debug(' Now is: %r', now) + logging.debug(" Now is: %r", now) if timestamp > now: - logging.warning('timestamp %r (%r) from %r is in the future, now is %r', - timestamp_string, time.mktime(timestamp), filename, - time.mktime(now)) + logging.warning( + "timestamp %r (%r) from %r is in the future, now is %r", + timestamp_string, + time.mktime(timestamp), + filename, + time.mktime(now), + ) if time.mktime(timestamp) - time.mktime(now) >= 60 * 60: - logging.info('Resetting timestamp to now.') + logging.info("Resetting timestamp to now.") timestamp = now return timestamp @@ -101,24 +106,24 @@ def WriteTimestamp(timestamp, filename): timestamp_dir = os.path.dirname(filename) - (filedesc, temp_filename) = tempfile.mkstemp(prefix='nsscache-update-', - dir=timestamp_dir) + (filedesc, temp_filename) = tempfile.mkstemp( + prefix="nsscache-update-", dir=timestamp_dir + ) - time_string = time.strftime('%Y-%m-%dT%H:%M:%SZ', timestamp) + time_string = time.strftime("%Y-%m-%dT%H:%M:%SZ", timestamp) try: - os.write(filedesc, b'%s\n' % time_string.encode()) + os.write(filedesc, b"%s\n" % time_string.encode()) os.fsync(filedesc) os.close(filedesc) except OSError: os.unlink(temp_filename) - logging.warning('writing timestamp failed!') + logging.warning("writing timestamp failed!") return False - os.chmod(temp_filename, - stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IROTH) + os.chmod(temp_filename, stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IROTH) os.rename(temp_filename, filename) - logging.debug('wrote timestamp %s to file %r', time_string, filename) + logging.debug("wrote timestamp %s to file %r", time_string, filename) return True diff --git a/nss_cache/util/timestamps_test.py b/nss_cache/util/timestamps_test.py index 12a63728..a951f6c2 100644 --- a/nss_cache/util/timestamps_test.py +++ b/nss_cache/util/timestamps_test.py @@ -15,7 +15,7 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Unit tests for nss_cache/util/timestamps.py.""" -__author__ = 'jaq@google.com (Jamie Wilkinson)' +__author__ = "jaq@google.com (Jamie Wilkinson)" import datetime from datetime import timezone @@ -30,7 +30,6 @@ class TestTimestamps(unittest.TestCase): - def setUp(self): super(TestTimestamps, self).setUp() self.workdir = tempfile.mkdtemp() @@ -40,9 +39,9 @@ def tearDown(self): shutil.rmtree(self.workdir) def testReadTimestamp(self): - ts_filename = os.path.join(self.workdir, 'tsr') - ts_file = open(ts_filename, 'w') - ts_file.write('1970-01-01T00:00:01Z\n') + ts_filename = os.path.join(self.workdir, "tsr") + ts_file = open(ts_filename, "w") + ts_file.write("1970-01-01T00:00:01Z\n") ts_file.close() ts = timestamps.ReadTimestamp(ts_filename) @@ -51,49 +50,52 @@ def testReadTimestamp(self): def testReadTimestamp(self): # TZ=UTC date -d @1306428781 # Thu May 26 16:53:01 UTC 2011 - ts_filename = os.path.join(self.workdir, 'tsr') - ts_file = open(ts_filename, 'w') - ts_file.write('2011-05-26T16:53:01Z\n') + ts_filename = os.path.join(self.workdir, "tsr") + ts_file = open(ts_filename, "w") + ts_file.write("2011-05-26T16:53:01Z\n") ts_file.close() ts = timestamps.ReadTimestamp(ts_filename) self.assertEqual(time.gmtime(1306428781), ts) def testReadTimestampInFuture(self): - ts_filename = os.path.join(self.workdir, 'tsr') - ts_file = open(ts_filename, 'w') - ts_file.write('2011-05-26T16:02:00Z') + ts_filename = os.path.join(self.workdir, "tsr") + ts_file = open(ts_filename, "w") + ts_file.write("2011-05-26T16:02:00Z") ts_file.close() now = time.gmtime(1) - with mock.patch('time.gmtime') as gmtime: + with mock.patch("time.gmtime") as gmtime: gmtime.return_value = now ts = timestamps.ReadTimestamp(ts_filename) self.assertEqual(now, ts) def testWriteTimestamp(self): - ts_filename = os.path.join(self.workdir, 'tsw') + ts_filename = os.path.join(self.workdir, "tsw") good_ts = time.gmtime(1) timestamps.WriteTimestamp(good_ts, ts_filename) self.assertEqual(good_ts, timestamps.ReadTimestamp(ts_filename)) - ts_file = open(ts_filename, 'r') - self.assertEqual('1970-01-01T00:00:01Z\n', ts_file.read()) + ts_file = open(ts_filename, "r") + self.assertEqual("1970-01-01T00:00:01Z\n", ts_file.read()) ts_file.close() def testTimestampToDateTime(self): now = datetime.datetime.now(timezone.utc) - self.assertEqual(timestamps.FromTimestampToDateTime(now.timestamp()), - now.replace(tzinfo=None)) + self.assertEqual( + timestamps.FromTimestampToDateTime(now.timestamp()), + now.replace(tzinfo=None), + ) def testDateTimeToTimestamp(self): now = datetime.datetime.now(timezone.utc) self.assertEqual( now.replace(microsecond=0).timestamp(), - timestamps.FromDateTimeToTimestamp(now)) + timestamps.FromDateTimeToTimestamp(now), + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/setup.py b/setup.py index 57891bec..307c0697 100755 --- a/setup.py +++ b/setup.py @@ -17,44 +17,50 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Distutils setup for nsscache tool and nss_cache package.""" -__author__ = 'jaq@google.com (Jamie Wilkinson)' +__author__ = "jaq@google.com (Jamie Wilkinson)" from setuptools import setup, find_packages import nss_cache setup( - name='nsscache', + name="nsscache", version=nss_cache.__version__, - author='Jamie Wilkinson', - author_email='jaq@google.com', - url='https://github.com/google/nsscache', - description='nsscache tool and library', - license='GPL', - long_description= - """nsscache is a Python library and a commandline frontend to that library + author="Jamie Wilkinson", + author_email="jaq@google.com", + url="https://github.com/google/nsscache", + description="nsscache tool and library", + license="GPL", + long_description="""nsscache is a Python library and a commandline frontend to that library that synchronises a local NSS cache against a remote directory service, such as LDAP.""", classifiers=[ - 'Development Status :: 4 - Beta', 'Environment :: Console', - 'Indended Audience :: System Administrators', - 'License :: OSI Approved :: GPL', 'Operating System :: POSIX', - 'Programming Language :: Python', 'Topic :: System' + "Development Status :: 4 - Beta", + "Environment :: Console", + "Indended Audience :: System Administrators", + "License :: OSI Approved :: GPL", + "Operating System :: POSIX", + "Programming Language :: Python", + "Topic :: System", ], packages=[ - 'nss_cache', 'nss_cache.caches', 'nss_cache.maps', 'nss_cache.util', - 'nss_cache.update', 'nss_cache.sources' + "nss_cache", + "nss_cache.caches", + "nss_cache.maps", + "nss_cache.util", + "nss_cache.update", + "nss_cache.sources", ], - scripts=['nsscache'], - data_files=[('config', ['nsscache.conf'])], - python_requires='~=3.4', - setup_requires=['pytest-runner'], - tests_require=['pytest', 'mox3', 'pytest-cov', 'python-coveralls'], + scripts=["nsscache"], + data_files=[("config", ["nsscache.conf"])], + python_requires="~=3.4", + setup_requires=["pytest-runner"], + tests_require=["pytest", "mox3", "pytest-cov", "python-coveralls"], extras_require={ - 'ldap': ['python3-ldap', 'python-ldap'], - 'http': ['pycurl'], - 's3': ['boto3'], - 'consul': ['pycurl'], - 'gcs': ['google-cloud-storage'], + "ldap": ["python3-ldap", "python-ldap"], + "http": ["pycurl"], + "s3": ["boto3"], + "consul": ["pycurl"], + "gcs": ["google-cloud-storage"], }, )