From 37c1540ff492595b4dfa0a74147c3cf87b7b4fc1 Mon Sep 17 00:00:00 2001 From: Ajpantuso Date: Mon, 3 May 2021 15:10:48 -0400 Subject: [PATCH] New module_utils openssh (#213) * Adding openssh utils and unit tests * Adding changelog fragment and correcting RSA default size * Adding changelog fragment * Added passphrase update, test cases, and check for SSH private key loader * corrected ecdsa type when loading * Resolving inital review comments * Fixed import in unit tests * Cleaning up validation functions * Separating private/public key related errors; Adding verify method * Expressed generate/load functions as classmethods and cleaned up method comments * Added support for loading asymmetric key pairs of PEM and DER formats * Refactored loading/generation for Asym keypairs into classmethods * Rescoped helper functions and classmethods for OpenSSH Keypair * Corrected docstring for OpenSSH_Keypair.generate() * Fixed import errors for sanity tests * Improvements to comparison, key verification, and password validation * Added comparison tests, simplified password validation, fixed Ed25519 load bug * Adding additional equivalence tests with passphrases --- .../213-cryptography-openssh-module-utils.yml | 2 + .../openssh/cryptography_openssh.py | 648 ++++++++++++++++++ .../openssh/test_cryptography_openssh.py | 400 +++++++++++ 3 files changed, 1050 insertions(+) create mode 100644 changelogs/fragments/213-cryptography-openssh-module-utils.yml create mode 100644 plugins/module_utils/openssh/cryptography_openssh.py create mode 100644 tests/unit/plugins/module_utils/openssh/test_cryptography_openssh.py diff --git a/changelogs/fragments/213-cryptography-openssh-module-utils.yml b/changelogs/fragments/213-cryptography-openssh-module-utils.yml new file mode 100644 index 000000000..05abf1d66 --- /dev/null +++ b/changelogs/fragments/213-cryptography-openssh-module-utils.yml @@ -0,0 +1,2 @@ +minor_changes: + - cryptography_openssh module utils - new module_utils for managing asymmetric keypairs and OpenSSH formatted/encoded asymmetric keypairs (https://github.com/ansible-collections/community.crypto/pull/213). diff --git a/plugins/module_utils/openssh/cryptography_openssh.py b/plugins/module_utils/openssh/cryptography_openssh.py new file mode 100644 index 000000000..ebd1cbdff --- /dev/null +++ b/plugins/module_utils/openssh/cryptography_openssh.py @@ -0,0 +1,648 @@ +# -*- coding: utf-8 -*- +# +# Copyright: (c) 2021, Andrew Pantuso (@ajpantuso) +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +from base64 import b64encode, b64decode +from getpass import getuser +from socket import gethostname + +from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import ( + HAS_CRYPTOGRAPHY, + CRYPTOGRAPHY_HAS_ED25519, +) + +try: + from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm + from cryptography.hazmat.backends.openssl import backend + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import dsa, ec, rsa, padding + from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey, Ed25519PublicKey +except ImportError: + pass + +if HAS_CRYPTOGRAPHY and CRYPTOGRAPHY_HAS_ED25519: + HAS_OPENSSH_SUPPORT = True + + _ALGORITHM_PARAMETERS = { + 'rsa': { + 'default_size': 2048, + 'valid_sizes': range(1024, 16384), + 'signer_params': { + 'padding': padding.PSS( + mgf=padding.MGF1(hashes.SHA256()), + salt_length=padding.PSS.MAX_LENGTH, + ), + 'algorithm': hashes.SHA256(), + }, + }, + 'dsa': { + 'default_size': 1024, + 'valid_sizes': [1024], + 'signer_params': { + 'algorithm': hashes.SHA256(), + }, + }, + 'ed25519': { + 'default_size': 256, + 'valid_sizes': [256], + 'signer_params': {}, + }, + 'ecdsa': { + 'default_size': 256, + 'valid_sizes': [256, 384, 521], + 'signer_params': { + 'signature_algorithm': ec.ECDSA(hashes.SHA256()), + }, + 'curves': { + 256: ec.SECP256R1(), + 384: ec.SECP384R1(), + 521: ec.SECP521R1(), + } + } + } +else: + HAS_OPENSSH_SUPPORT = False + _ALGORITHM_PARAMETERS = {} + +_TEXT_ENCODING = 'UTF-8' + + +class OpenSSHError(Exception): + pass + + +class InvalidAlgorithmError(OpenSSHError): + pass + + +class InvalidCommentError(OpenSSHError): + pass + + +class InvalidDataError(OpenSSHError): + pass + + +class InvalidPrivateKeyFileError(OpenSSHError): + pass + + +class InvalidPublicKeyFileError(OpenSSHError): + pass + + +class InvalidKeyFormatError(OpenSSHError): + pass + + +class InvalidKeySizeError(OpenSSHError): + pass + + +class InvalidKeyTypeError(OpenSSHError): + pass + + +class InvalidPassphraseError(OpenSSHError): + pass + + +class InvalidSignatureError(OpenSSHError): + pass + + +class Asymmetric_Keypair(object): + """Container for newly generated asymmetric key pairs or those loaded from existing files""" + + @classmethod + def generate(cls, keytype='rsa', size=None, passphrase=None): + """Returns an Asymmetric_Keypair object generated with the supplied parameters + or defaults to an unencrypted RSA-2048 key + + :keytype: One of rsa, dsa, ecdsa, ed25519 + :size: The key length for newly generated keys + :passphrase: Secret of type Bytes used to encrypt the private key being generated + """ + + if keytype not in _ALGORITHM_PARAMETERS.keys(): + raise InvalidKeyTypeError( + "%s is not a valid keytype. Valid keytypes are %s" % ( + keytype, ", ".join(_ALGORITHM_PARAMETERS.keys()) + ) + ) + + if not size: + size = _ALGORITHM_PARAMETERS[keytype]['default_size'] + else: + if size not in _ALGORITHM_PARAMETERS[keytype]['valid_sizes']: + raise InvalidKeySizeError( + "%s is not a valid key size for %s keys" % (size, keytype) + ) + + if passphrase: + encryption_algorithm = get_encryption_algorithm(passphrase) + else: + encryption_algorithm = serialization.NoEncryption() + + if keytype == 'rsa': + privatekey = rsa.generate_private_key( + # Public exponent should always be 65537 to prevent issues + # if improper padding is used during signing + public_exponent=65537, + key_size=size, + backend=backend, + ) + elif keytype == 'dsa': + privatekey = dsa.generate_private_key( + key_size=size, + backend=backend, + ) + elif keytype == 'ed25519': + privatekey = Ed25519PrivateKey.generate() + elif keytype == 'ecdsa': + privatekey = ec.generate_private_key( + _ALGORITHM_PARAMETERS['ecdsa']['curves'][size], + backend=backend, + ) + + publickey = privatekey.public_key() + + return cls( + keytype=keytype, + size=size, + privatekey=privatekey, + publickey=publickey, + encryption_algorithm=encryption_algorithm + ) + + @classmethod + def load(cls, path, passphrase=None, key_format='PEM'): + """Returns an Asymmetric_Keypair object loaded from the supplied file path + + :path: A path to an existing private key to be loaded + :passphrase: Secret of type bytes used to decrypt the private key being loaded + :key_format: Format of key files to be loaded + """ + + if passphrase: + encryption_algorithm = get_encryption_algorithm(passphrase) + else: + encryption_algorithm = serialization.NoEncryption() + + privatekey = load_privatekey(path, passphrase, key_format) + publickey = load_publickey(path + '.pub', key_format) + + # Ed25519 keys are always of size 256 and do not have a key_size attribute + if isinstance(privatekey, Ed25519PrivateKey): + size = _ALGORITHM_PARAMETERS['ed25519']['default_size'] + else: + size = privatekey.key_size + + if isinstance(privatekey, rsa.RSAPrivateKey): + keytype = 'rsa' + elif isinstance(privatekey, dsa.DSAPrivateKey): + keytype = 'dsa' + elif isinstance(privatekey, ec.EllipticCurvePrivateKey): + keytype = 'ecdsa' + elif isinstance(privatekey, Ed25519PrivateKey): + keytype = 'ed25519' + else: + raise InvalidKeyTypeError("Key type '%s' is not supported" % type(privatekey)) + + return cls( + keytype=keytype, + size=size, + privatekey=privatekey, + publickey=publickey, + encryption_algorithm=encryption_algorithm + ) + + def __init__(self, keytype, size, privatekey, publickey, encryption_algorithm): + """ + :keytype: One of rsa, dsa, ecdsa, ed25519 + :size: The key length for the private key of this key pair + :privatekey: Private key object of this key pair + :publickey: Public key object of this key pair + :encryption_algorithm: Hashed secret used to encrypt the private key of this key pair + """ + + self.__size = size + self.__keytype = keytype + self.__privatekey = privatekey + self.__publickey = publickey + self.__encryption_algorithm = encryption_algorithm + + try: + self.verify(self.sign(b'message'), b'message') + except InvalidSignatureError: + raise InvalidPublicKeyFileError( + "The private key and public key of this keypair do not match" + ) + + def __eq__(self, other): + if not isinstance(other, Asymmetric_Keypair): + return NotImplemented + + return (compare_publickeys(self.public_key, other.public_key) and + compare_encryption_algorithms(self.encryption_algorithm, other.encryption_algorithm)) + + def __ne__(self, other): + return not self == other + + @property + def private_key(self): + """Returns the private key of this key pair""" + + return self.__privatekey + + @property + def public_key(self): + """Returns the public key of this key pair""" + + return self.__publickey + + @property + def size(self): + """Returns the size of the private key of this key pair""" + + return self.__size + + @property + def key_type(self): + """Returns the key type of this key pair""" + + return self.__keytype + + @property + def encryption_algorithm(self): + """Returns the key encryption algorithm of this key pair""" + + return self.__encryption_algorithm + + def sign(self, data): + """Returns signature of data signed with the private key of this key pair + + :data: byteslike data to sign + """ + + try: + signature = self.__privatekey.sign( + data, + **_ALGORITHM_PARAMETERS[self.__keytype]['signer_params'] + ) + except TypeError as e: + raise InvalidDataError(e) + + return signature + + def verify(self, signature, data): + """Verifies that the signature associated with the provided data was signed + by the private key of this key pair. + + :signature: signature to verify + :data: byteslike data signed by the provided signature + """ + try: + return self.__publickey.verify( + signature, + data, + **_ALGORITHM_PARAMETERS[self.__keytype]['signer_params'] + ) + except InvalidSignature: + raise InvalidSignatureError + + def update_passphrase(self, passphrase=None): + """Updates the encryption algorithm of this key pair + + :passphrase: Byte secret used to encrypt this key pair + """ + + if passphrase: + self.__encryption_algorithm = get_encryption_algorithm(passphrase) + else: + self.__encryption_algorithm = serialization.NoEncryption() + + +class OpenSSH_Keypair(object): + """Container for OpenSSH encoded asymmetric key pairs""" + + @classmethod + def generate(cls, keytype='rsa', size=None, passphrase=None, comment=None): + """Returns an Openssh_Keypair object generated using the supplied parameters or defaults to a RSA-2048 key + + :keytype: One of rsa, dsa, ecdsa, ed25519 + :size: The key length for newly generated keys + :passphrase: Secret of type Bytes used to encrypt the newly generated private key + :comment: Comment for a newly generated OpenSSH public key + """ + + if not comment: + comment = "%s@%s" % (getuser(), gethostname()) + + asym_keypair = Asymmetric_Keypair.generate(keytype, size, passphrase) + openssh_privatekey = cls.encode_openssh_privatekey(asym_keypair) + openssh_publickey = cls.encode_openssh_publickey(asym_keypair, comment) + fingerprint = calculate_fingerprint(openssh_publickey) + + return cls( + asym_keypair=asym_keypair, + openssh_privatekey=openssh_privatekey, + openssh_publickey=openssh_publickey, + fingerprint=fingerprint, + comment=comment + ) + + @classmethod + def load(cls, path, passphrase=None): + """Returns an Openssh_Keypair object loaded from the supplied file path + + :path: A path to an existing private key to be loaded + :passphrase: Secret used to decrypt the private key being loaded + """ + + comment = extract_comment(path + '.pub') + asym_keypair = Asymmetric_Keypair.load(path, passphrase, 'SSH') + openssh_privatekey = cls.encode_openssh_privatekey(asym_keypair) + openssh_publickey = cls.encode_openssh_publickey(asym_keypair, comment) + fingerprint = calculate_fingerprint(openssh_publickey) + + return cls( + asym_keypair=asym_keypair, + openssh_privatekey=openssh_privatekey, + openssh_publickey=openssh_publickey, + fingerprint=fingerprint, + comment=comment + ) + + @staticmethod + def encode_openssh_privatekey(asym_keypair): + """Returns an OpenSSH encoded private key for a given keypair + + :asym_keypair: Asymmetric_Keypair from the private key is extracted + """ + + # OpenSSH formatted private keys are not available in Cryptography <3.0 + try: + privatekey_format = serialization.PrivateFormat.OpenSSH + except AttributeError: + privatekey_format = serialization.PrivateFormat.PKCS8 + + encoded_privatekey = asym_keypair.private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=privatekey_format, + encryption_algorithm=asym_keypair.encryption_algorithm + ) + + return encoded_privatekey + + @staticmethod + def encode_openssh_publickey(asym_keypair, comment): + """Returns an OpenSSH encoded public key for a given keypair + + :asym_keypair: Asymmetric_Keypair from the public key is extracted + :comment: Comment to apply to the end of the returned OpenSSH encoded public key + """ + encoded_publickey = asym_keypair.public_key.public_bytes( + encoding=serialization.Encoding.OpenSSH, + format=serialization.PublicFormat.OpenSSH, + ) + + validate_comment(comment) + + encoded_publickey += (" %s" % comment).encode(encoding=_TEXT_ENCODING) + + return encoded_publickey + + def __init__(self, asym_keypair, openssh_privatekey, openssh_publickey, fingerprint, comment): + """ + :asym_keypair: An Asymmetric_Keypair object from which the OpenSSH encoded keypair is derived + :openssh_privatekey: An OpenSSH encoded private key + :openssh_privatekey: An OpenSSH encoded public key + :fingerprint: The fingerprint of the OpenSSH encoded public key of this keypair + :comment: Comment applied to the OpenSSH public key of this keypair + """ + + self.__asym_keypair = asym_keypair + self.__openssh_privatekey = openssh_privatekey + self.__openssh_publickey = openssh_publickey + self.__fingerprint = fingerprint + self.__comment = comment + + def __eq__(self, other): + if not isinstance(other, OpenSSH_Keypair): + return NotImplemented + + return self.asymmetric_keypair == other.asymmetric_keypair and self.comment == other.comment + + @property + def asymmetric_keypair(self): + """Returns the underlying asymmetric key pair of this OpenSSH encoded key pair""" + + return self.__asym_keypair + + @property + def private_key(self): + """Returns the OpenSSH formatted private key of this key pair""" + + return self.__openssh_privatekey + + @property + def public_key(self): + """Returns the OpenSSH formatted public key of this key pair""" + + return self.__openssh_publickey + + @property + def size(self): + """Returns the size of the private key of this key pair""" + + return self.__asym_keypair.size + + @property + def key_type(self): + """Returns the key type of this key pair""" + + return self.__asym_keypair.key_type + + @property + def fingerprint(self): + """Returns the fingerprint (SHA256 Hash) of the public key of this key pair""" + + return self.__fingerprint + + @property + def comment(self): + """Returns the comment applied to the OpenSSH formatted public key of this key pair""" + + return self.__comment + + @comment.setter + def comment(self, comment): + """Updates the comment applied to the OpenSSH formatted public key of this key pair + + :comment: Text to update the OpenSSH public key comment + """ + + validate_comment(comment) + + self.__comment = comment + encoded_comment = (" %s" % self.__comment).encode(encoding=_TEXT_ENCODING) + self.__openssh_publickey = b' '.join(self.__openssh_publickey.split(b' ', 2)[:2]) + encoded_comment + return self.__openssh_publickey + + def update_passphrase(self, passphrase): + """Updates the passphrase used to encrypt the private key of this keypair + + :passphrase: Text secret used for encryption + """ + + self.__asym_keypair.update_passphrase(passphrase) + self.__openssh_privatekey = OpenSSH_Keypair.encode_openssh_privatekey(self.__asym_keypair) + + +def load_privatekey(path, passphrase, key_format): + privatekey_loaders = { + 'PEM': serialization.load_pem_private_key, + 'DER': serialization.load_der_private_key, + } + + # OpenSSH formatted private keys are not available in Cryptography <3.0 + if hasattr(serialization, 'load_ssh_private_key'): + privatekey_loaders['SSH'] = serialization.load_ssh_private_key + else: + privatekey_loaders['SSH'] = serialization.load_pem_private_key + + try: + privatekey_loader = privatekey_loaders[key_format] + except KeyError: + raise InvalidKeyFormatError( + "%s is not a valid key format (%s)" % ( + key_format, + ','.join(privatekey_loaders.keys()) + ) + ) + + try: + with open(path, 'rb') as f: + content = f.read() + + privatekey = privatekey_loader( + data=content, + password=passphrase, + backend=backend, + ) + + except ValueError as e: + raise InvalidPrivateKeyFileError(e) + except TypeError as e: + raise InvalidPassphraseError(e) + except UnsupportedAlgorithm as e: + raise InvalidAlgorithmError(e) + + return privatekey + + +def load_publickey(path, key_format): + publickey_loaders = { + 'PEM': serialization.load_pem_public_key, + 'DER': serialization.load_der_public_key, + 'SSH': serialization.load_ssh_public_key, + } + + try: + publickey_loader = publickey_loaders[key_format] + except KeyError: + raise InvalidKeyFormatError( + "%s is not a valid key format (%s)" % ( + key_format, + ','.join(publickey_loaders.keys()) + ) + ) + + try: + with open(path, 'rb') as f: + content = f.read() + + publickey = publickey_loader( + data=content, + backend=backend, + ) + except ValueError as e: + raise InvalidPublicKeyFileError(e) + except UnsupportedAlgorithm as e: + raise InvalidAlgorithmError(e) + + return publickey + + +def compare_publickeys(pk1, pk2): + a = isinstance(pk1, Ed25519PublicKey) + b = isinstance(pk2, Ed25519PublicKey) + if a or b: + if not a or not b: + return False + a = pk1.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) + b = pk2.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) + return a == b + else: + return pk1.public_numbers() == pk2.public_numbers() + + +def compare_encryption_algorithms(ea1, ea2): + if isinstance(ea1, serialization.NoEncryption) and isinstance(ea2, serialization.NoEncryption): + return True + elif (isinstance(ea1, serialization.BestAvailableEncryption) and + isinstance(ea2, serialization.BestAvailableEncryption)): + return ea1.password == ea2.password + else: + return False + + +def get_encryption_algorithm(passphrase): + try: + return serialization.BestAvailableEncryption(passphrase) + except ValueError as e: + raise InvalidPassphraseError(e) + + +def validate_comment(comment): + if not hasattr(comment, 'encode'): + raise InvalidCommentError("%s cannot be encoded to text" % comment) + + +def extract_comment(path): + try: + with open(path, 'rb') as f: + fields = f.read().split(b' ', 2) + if len(fields) == 3: + comment = fields[2].decode(_TEXT_ENCODING) + else: + comment = "" + except OSError as e: + raise InvalidPublicKeyFileError(e) + + return comment + + +def calculate_fingerprint(openssh_publickey): + digest = hashes.Hash(hashes.SHA256(), backend=backend) + decoded_pubkey = b64decode(openssh_publickey.split(b' ')[1]) + digest.update(decoded_pubkey) + + return b64encode(digest.finalize()).decode(encoding=_TEXT_ENCODING).rstrip('=') diff --git a/tests/unit/plugins/module_utils/openssh/test_cryptography_openssh.py b/tests/unit/plugins/module_utils/openssh/test_cryptography_openssh.py new file mode 100644 index 000000000..b77b7db01 --- /dev/null +++ b/tests/unit/plugins/module_utils/openssh/test_cryptography_openssh.py @@ -0,0 +1,400 @@ +# -*- coding: utf-8 -*- + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import pytest + +import os.path +from getpass import getuser +from os import remove, rmdir +from socket import gethostname +from tempfile import mkdtemp + +from ansible_collections.community.crypto.plugins.module_utils.openssh.cryptography_openssh import ( + Asymmetric_Keypair, + HAS_OPENSSH_SUPPORT, + InvalidCommentError, + InvalidPrivateKeyFileError, + InvalidPublicKeyFileError, + InvalidKeySizeError, + InvalidKeyTypeError, + InvalidPassphraseError, + OpenSSH_Keypair +) + +DEFAULT_KEY_PARAMS = [ + ( + 'rsa', + None, + None, + None, + ), + ( + 'dsa', + None, + None, + None, + ), + ( + 'ecdsa', + None, + None, + None, + ), + ( + 'ed25519', + None, + None, + None, + ), +] + +VALID_USER_KEY_PARAMS = [ + ( + 'rsa', + 8192, + 'change_me'.encode('UTF-8'), + 'comment', + ), + ( + 'dsa', + 1024, + 'change_me'.encode('UTF-8'), + 'comment', + ), + ( + 'ecdsa', + 521, + 'change_me'.encode('UTF-8'), + 'comment', + ), + ( + 'ed25519', + 256, + 'change_me'.encode('UTF-8'), + 'comment', + ), +] + +INVALID_USER_KEY_PARAMS = [ + ( + 'dne', + None, + None, + None, + ), + ( + 'rsa', + None, + [1, 2, 3], + 'comment', + ), + ( + 'ecdsa', + None, + None, + [1, 2, 3], + ), +] + +INVALID_KEY_SIZES = [ + ( + 'rsa', + 1023, + None, + None, + ), + ( + 'rsa', + 16385, + None, + None, + ), + ( + 'dsa', + 256, + None, + None, + ), + ( + 'ecdsa', + 1024, + None, + None, + ), + ( + 'ed25519', + 1024, + None, + None, + ), +] + + +@pytest.mark.parametrize("keytype,size,passphrase,comment", DEFAULT_KEY_PARAMS) +@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") +def test_default_key_params(keytype, size, passphrase, comment): + result = True + + default_sizes = { + 'rsa': 2048, + 'dsa': 1024, + 'ecdsa': 256, + 'ed25519': 256, + } + + default_comment = "%s@%s" % (getuser(), gethostname()) + pair = OpenSSH_Keypair.generate(keytype=keytype, size=size, passphrase=passphrase, comment=comment) + try: + pair = OpenSSH_Keypair.generate(keytype=keytype, size=size, passphrase=passphrase, comment=comment) + if pair.size != default_sizes[pair.key_type] or pair.comment != default_comment: + result = False + except Exception as e: + print(e) + result = False + + assert result + + +@pytest.mark.parametrize("keytype,size,passphrase,comment", VALID_USER_KEY_PARAMS) +@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") +def test_valid_user_key_params(keytype, size, passphrase, comment): + result = True + + try: + pair = OpenSSH_Keypair.generate(keytype=keytype, size=size, passphrase=passphrase, comment=comment) + if pair.key_type != keytype or pair.size != size or pair.comment != comment: + result = False + except Exception as e: + print(e) + result = False + + assert result + + +@pytest.mark.parametrize("keytype,size,passphrase,comment", INVALID_USER_KEY_PARAMS) +@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") +def test_invalid_user_key_params(keytype, size, passphrase, comment): + result = False + + try: + OpenSSH_Keypair.generate(keytype=keytype, size=size, passphrase=passphrase, comment=comment) + except (InvalidCommentError, InvalidKeyTypeError, InvalidPassphraseError): + result = True + except Exception as e: + print(e) + pass + + assert result + + +@pytest.mark.parametrize("keytype,size,passphrase,comment", INVALID_KEY_SIZES) +@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") +def test_invalid_key_sizes(keytype, size, passphrase, comment): + result = False + + try: + OpenSSH_Keypair.generate(keytype=keytype, size=size, passphrase=passphrase, comment=comment) + except InvalidKeySizeError: + result = True + except Exception as e: + print(e) + pass + + assert result + + +@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") +def test_valid_comment_update(): + + pair = OpenSSH_Keypair.generate() + new_comment = "comment" + try: + pair.comment = new_comment + except Exception as e: + print(e) + pass + + assert pair.comment == new_comment and pair.public_key.split(b' ', 2)[2].decode() == new_comment + + +@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") +def test_invalid_comment_update(): + result = False + + pair = OpenSSH_Keypair.generate() + new_comment = [1, 2, 3] + try: + pair.comment = new_comment + except InvalidCommentError: + result = True + + assert result + + +@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") +def test_valid_passphrase_update(): + result = False + + passphrase = "change_me".encode('UTF-8') + + try: + tmpdir = mkdtemp() + keyfilename = os.path.join(tmpdir, "id_rsa") + + pair1 = OpenSSH_Keypair.generate() + pair1.update_passphrase(passphrase) + + with open(keyfilename, "w+b") as keyfile: + keyfile.write(pair1.private_key) + + with open(keyfilename + '.pub', "w+b") as pubkeyfile: + pubkeyfile.write(pair1.public_key) + + pair2 = OpenSSH_Keypair.load(path=keyfilename, passphrase=passphrase) + + if pair1 == pair2: + result = True + finally: + if os.path.exists(keyfilename): + remove(keyfilename) + if os.path.exists(keyfilename + '.pub'): + remove(keyfilename + '.pub') + if os.path.exists(tmpdir): + rmdir(tmpdir) + + assert result + + +@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") +def test_invalid_passphrase_update(): + result = False + + passphrase = [1, 2, 3] + pair = OpenSSH_Keypair.generate() + try: + pair.update_passphrase(passphrase) + except InvalidPassphraseError: + result = True + + assert result + + +@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") +def test_invalid_privatekey(): + result = False + + try: + tmpdir = mkdtemp() + keyfilename = os.path.join(tmpdir, "id_rsa") + + pair = OpenSSH_Keypair.generate() + + with open(keyfilename, "w+b") as keyfile: + keyfile.write(pair.private_key[1:]) + + with open(keyfilename + '.pub', "w+b") as pubkeyfile: + pubkeyfile.write(pair.public_key) + + OpenSSH_Keypair.load(path=keyfilename) + except InvalidPrivateKeyFileError: + result = True + finally: + if os.path.exists(keyfilename): + remove(keyfilename) + if os.path.exists(keyfilename + '.pub'): + remove(keyfilename + '.pub') + if os.path.exists(tmpdir): + rmdir(tmpdir) + + assert result + + +@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") +def test_mismatched_keypair(): + result = False + + try: + tmpdir = mkdtemp() + keyfilename = os.path.join(tmpdir, "id_rsa") + + pair1 = OpenSSH_Keypair.generate() + pair2 = OpenSSH_Keypair.generate() + + with open(keyfilename, "w+b") as keyfile: + keyfile.write(pair1.private_key) + + with open(keyfilename + '.pub', "w+b") as pubkeyfile: + pubkeyfile.write(pair2.public_key) + + OpenSSH_Keypair.load(path=keyfilename) + except InvalidPublicKeyFileError: + result = True + finally: + if os.path.exists(keyfilename): + remove(keyfilename) + if os.path.exists(keyfilename + '.pub'): + remove(keyfilename + '.pub') + if os.path.exists(tmpdir): + rmdir(tmpdir) + + assert result + + +@pytest.mark.skipif(not HAS_OPENSSH_SUPPORT, reason="requires cryptography") +def test_keypair_comparison(): + assert OpenSSH_Keypair.generate() != OpenSSH_Keypair.generate() + assert OpenSSH_Keypair.generate() != OpenSSH_Keypair.generate(keytype='dsa') + assert OpenSSH_Keypair.generate() != OpenSSH_Keypair.generate(keytype='ed25519') + assert OpenSSH_Keypair.generate(keytype='ed25519') != OpenSSH_Keypair.generate(keytype='ed25519') + try: + tmpdir = mkdtemp() + + keys = { + 'rsa': { + 'pair': OpenSSH_Keypair.generate(), + 'filename': os.path.join(tmpdir, "id_rsa"), + }, + 'dsa': { + 'pair': OpenSSH_Keypair.generate(keytype='dsa', passphrase='change_me'.encode('UTF-8')), + 'filename': os.path.join(tmpdir, "id_dsa"), + }, + 'ed25519': { + 'pair': OpenSSH_Keypair.generate(keytype='ed25519'), + 'filename': os.path.join(tmpdir, "id_ed25519"), + } + } + + for v in keys.values(): + with open(v['filename'], "w+b") as keyfile: + keyfile.write(v['pair'].private_key) + with open(v['filename'] + '.pub', "w+b") as pubkeyfile: + pubkeyfile.write(v['pair'].public_key) + + assert keys['rsa']['pair'] == OpenSSH_Keypair.load(path=keys['rsa']['filename']) + + loaded_dsa_key = OpenSSH_Keypair.load(path=keys['dsa']['filename'], passphrase='change_me'.encode('UTF-8')) + assert keys['dsa']['pair'] == loaded_dsa_key + + loaded_dsa_key.update_passphrase('change_me_again'.encode('UTF-8')) + assert keys['dsa']['pair'] != loaded_dsa_key + + loaded_dsa_key.update_passphrase('change_me'.encode('UTF-8')) + assert keys['dsa']['pair'] == loaded_dsa_key + + loaded_dsa_key.comment = "comment" + assert keys['dsa']['pair'] != loaded_dsa_key + + assert keys['ed25519']['pair'] == OpenSSH_Keypair.load(path=keys['ed25519']['filename']) + finally: + for v in keys.values(): + if os.path.exists(v['filename']): + remove(v['filename']) + if os.path.exists(v['filename'] + '.pub'): + remove(v['filename'] + '.pub') + if os.path.exists(tmpdir): + rmdir(tmpdir) + assert OpenSSH_Keypair.generate() != []