diff --git a/nvflare/fuel/f3/cellnet/cell_cipher.py b/nvflare/fuel/f3/cellnet/cell_cipher.py index f8cc640019..39621fb9d2 100644 --- a/nvflare/fuel/f3/cellnet/cell_cipher.py +++ b/nvflare/fuel/f3/cellnet/cell_cipher.py @@ -15,10 +15,8 @@ import os from cryptography.exceptions import InvalidKey, InvalidSignature -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives import padding as sym_padding -from cryptography.hazmat.primitives.asymmetric import padding -from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from cryptography.hazmat.primitives import asymmetric, ciphers, hashes, padding +from cryptography.x509 import Certificate HASH_LENGTH = 4 # Adjustable to avoid collision NONCE_LENGTH = 16 # For AES, this is 128 bits (i.e. block size) @@ -27,6 +25,7 @@ PADDING_LENGTH = NONCE_LENGTH * 8 # in bits KEY_ENC_LENGTH = 256 SIGNATURE_LENGTH = 256 +SIMPLE_HEADER_LENGTH = NONCE_LENGTH + KEY_ENC_LENGTH + SIGNATURE_LENGTH def get_hash(value): @@ -43,6 +42,63 @@ class InvalidCertChain(Exception): pass +def _asym_enc(k, m): + return k.encrypt( + m, + asymmetric.padding.OAEP( + mgf=asymmetric.padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None + ), + ) + + +def _asym_dec(k, m): + return k.decrypt( + m, + asymmetric.padding.OAEP( + mgf=asymmetric.padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None + ), + ) + + +def _sign(k, m): + return k.sign( + data=m, + padding=asymmetric.padding.PSS( + mgf=asymmetric.padding.MGF1(hashes.SHA256()), + salt_length=asymmetric.padding.PSS.MAX_LENGTH, + ), + algorithm=hashes.SHA256(), + ) + + +def _verify(k, m, s): + k.verify( + s, + m, + asymmetric.padding.PSS( + mgf=asymmetric.padding.MGF1(hashes.SHA256()), salt_length=asymmetric.padding.PSS.MAX_LENGTH + ), + hashes.SHA256(), + ) + + +def _sym_enc(k, n, m): + cipher = ciphers.Cipher(ciphers.algorithms.AES(k), ciphers.modes.CBC(n)) + encryptor = cipher.encryptor() + padder = padding.PKCS7(PADDING_LENGTH).padder() + padded_data = padder.update(m) + padder.finalize() + return encryptor.update(padded_data) + encryptor.finalize() + + +def _sym_dec(k, n, m): + cipher = ciphers.Cipher(ciphers.algorithms.AES(k), ciphers.modes.CBC(n)) + decryptor = cipher.decryptor() + plain_text = decryptor.update(m) + plain_text = plain_text + decryptor.finalize() + unpadder = padding.PKCS7(PADDING_LENGTH).unpadder() + return unpadder.update(plain_text) + unpadder.finalize() + + class SessionKeyManager: def __init__(self, root_ca): self.key_hash_dict = dict() @@ -51,29 +107,19 @@ def __init__(self, root_ca): def validate_cert_chain(self, cert): self.root_ca_pub_key.verify( - cert.signature, cert.tbs_certificate_bytes, padding.PKCS1v15(), cert.signature_hash_algorithm + cert.signature, cert.tbs_certificate_bytes, asymmetric.padding.PKCS1v15(), cert.signature_hash_algorithm ) def key_request(self, remote_cert, local_cert, local_pri_key): session_key = os.urandom(KEY_LENGTH) - signature = local_pri_key.sign( - data=session_key, - padding=padding.PSS( - mgf=padding.MGF1(hashes.SHA256()), - salt_length=padding.PSS.MAX_LENGTH, - ), - algorithm=hashes.SHA256(), - ) + signature = _sign(local_pri_key, session_key) try: self.validate_cert_chain(remote_cert) except InvalidSignature: return False remote_pub_key = remote_cert.public_key() - key_enc = remote_pub_key.encrypt( - session_key, - padding.OAEP(mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None), - ) + key_enc = _asym_enc(remote_pub_key, session_key) self.key_hash_dict[get_hash(session_key)[-HASH_LENGTH:]] = session_key key_response = key_enc + signature return key_response @@ -81,18 +127,10 @@ def key_request(self, remote_cert, local_cert, local_pri_key): def process_key_response(self, remote_cert, local_cert, local_pri_key, key_response): key_enc, signature = key_response[:KEY_ENC_LENGTH], key_response[KEY_ENC_LENGTH:] try: - session_key = local_pri_key.decrypt( - key_enc, - padding.OAEP(mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None), - ) + session_key = _asym_dec(local_pri_key, key_enc) self.validate_cert_chain(remote_cert) public_key = remote_cert.public_key() - public_key.verify( - signature, - session_key, - padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH), - hashes.SHA256(), - ) + _verify(public_key, session_key, signature) self.key_hash_dict[get_hash(session_key)[-HASH_LENGTH:]] = session_key except (InvalidKey, InvalidSignature): return False @@ -120,12 +158,7 @@ def encrypt(self, message): key = self.session_key_manager.get_latest_key() key_hash = get_hash(key) nonce = os.urandom(NONCE_LENGTH) - cipher = Cipher(algorithms.AES(key), modes.CBC(nonce)) - encryptor = cipher.encryptor() - padder = sym_padding.PKCS7(PADDING_LENGTH).padder() - padded_data = padder.update(message) + padder.finalize() - ct = nonce + key_hash[-HASH_LENGTH:] + encryptor.update(padded_data) + encryptor.finalize() - return ct + return nonce + key_hash[-HASH_LENGTH:] + _sym_enc(key, nonce, message) def decrypt(self, message): nonce, key_hash, message = ( @@ -136,10 +169,55 @@ def decrypt(self, message): key = self.session_key_manager.get_key(key_hash) if key is None: raise SessionKeyUnavailable("No session key found for received message") - cipher = Cipher(algorithms.AES(key), modes.CBC(nonce)) - decryptor = cipher.decryptor() - plain_text = decryptor.update(message) - plain_text = plain_text + decryptor.finalize() - unpadder = sym_padding.PKCS7(PADDING_LENGTH).unpadder() - data = unpadder.update(plain_text) + unpadder.finalize() - return data + return _sym_dec(key, nonce, message) + + +class SimpleCellCipher: + def __init__(self, root_ca: Certificate, pri_key: asymmetric.rsa.RSAPrivateKey, cert: Certificate): + self._root_ca = root_ca + self._root_ca_pub_key = root_ca.public_key() + self._pri_key = pri_key + self._cert = cert + self._pub_key = cert.public_key() + self._validate_cert_chain(self._cert) + self._cached_enc = dict() + self._cached_dec = dict() + + def _validate_cert_chain(self, cert: Certificate): + self._root_ca_pub_key.verify( + cert.signature, cert.tbs_certificate_bytes, asymmetric.padding.PKCS1v15(), cert.signature_hash_algorithm + ) + + def encrypt(self, message: bytes, target_cert: Certificate): + cert_hash = hash(target_cert) + secret = self._cached_enc.get(cert_hash) + if secret is None: + self._validate_cert_chain(target_cert) + key = os.urandom(KEY_LENGTH) + remote_pub_key = target_cert.public_key() + key_enc = _asym_enc(remote_pub_key, key) + signature = _sign(self._pri_key, key_enc) + self._cached_enc[cert_hash] = (key, key_enc, signature) + else: + (key, key_enc, signature) = secret + nonce = os.urandom(NONCE_LENGTH) + ct = nonce + key_enc + signature + _sym_enc(key, nonce, message) + return ct + + def decrypt(self, message: bytes, origin_cert: Certificate): + nonce, key_enc, signature = ( + message[:NONCE_LENGTH], + message[NONCE_LENGTH : NONCE_LENGTH + KEY_ENC_LENGTH], + message[NONCE_LENGTH + KEY_ENC_LENGTH : SIMPLE_HEADER_LENGTH], + ) + key_hash = hash(key_enc) + dec = self._cached_dec.get(key_hash) + if dec is None: + self._validate_cert_chain(origin_cert) + public_key = origin_cert.public_key() + _verify(public_key, key_enc, signature) + key = _asym_dec(self._pri_key, key_enc) + self._cached_dec[key_hash] = key + else: + key = dec + return _sym_dec(key, nonce, message[SIMPLE_HEADER_LENGTH:])