Skip to content

Commit

Permalink
Add SimpleCellCipher to remove session key manager
Browse files Browse the repository at this point in the history
Refactor common functions to serve both designs
  • Loading branch information
IsaacYangSLA committed Sep 1, 2023
1 parent 0c28a26 commit eee330a
Showing 1 changed file with 118 additions and 40 deletions.
158 changes: 118 additions & 40 deletions nvflare/fuel/f3/cellnet/cell_cipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@
import os

from cryptography.exceptions import InvalidKey, InvalidSignature
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import padding as sym_padding
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives import asymmetric, ciphers, hashes, padding
from cryptography.x509 import Certificate

HASH_LENGTH = 4 # Adjustable to avoid collision
NONCE_LENGTH = 16 # For AES, this is 128 bits (i.e. block size)
Expand All @@ -27,6 +25,7 @@
PADDING_LENGTH = NONCE_LENGTH * 8 # in bits
KEY_ENC_LENGTH = 256
SIGNATURE_LENGTH = 256
SIMPLE_HEADER_LENGTH = NONCE_LENGTH + KEY_ENC_LENGTH + SIGNATURE_LENGTH


def get_hash(value):
Expand All @@ -43,6 +42,63 @@ class InvalidCertChain(Exception):
pass


def _asym_enc(k, m):
return k.encrypt(
m,
asymmetric.padding.OAEP(
mgf=asymmetric.padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None
),
)


def _asym_dec(k, m):
return k.decrypt(
m,
asymmetric.padding.OAEP(
mgf=asymmetric.padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None
),
)


def _sign(k, m):
return k.sign(
data=m,
padding=asymmetric.padding.PSS(
mgf=asymmetric.padding.MGF1(hashes.SHA256()),
salt_length=asymmetric.padding.PSS.MAX_LENGTH,
),
algorithm=hashes.SHA256(),
)


def _verify(k, m, s):
k.verify(
s,
m,
asymmetric.padding.PSS(
mgf=asymmetric.padding.MGF1(hashes.SHA256()), salt_length=asymmetric.padding.PSS.MAX_LENGTH
),
hashes.SHA256(),
)


def _sym_enc(k, n, m):
cipher = ciphers.Cipher(ciphers.algorithms.AES(k), ciphers.modes.CBC(n))
encryptor = cipher.encryptor()
padder = padding.PKCS7(PADDING_LENGTH).padder()
padded_data = padder.update(m) + padder.finalize()
return encryptor.update(padded_data) + encryptor.finalize()


def _sym_dec(k, n, m):
cipher = ciphers.Cipher(ciphers.algorithms.AES(k), ciphers.modes.CBC(n))
decryptor = cipher.decryptor()
plain_text = decryptor.update(m)
plain_text = plain_text + decryptor.finalize()
unpadder = padding.PKCS7(PADDING_LENGTH).unpadder()
return unpadder.update(plain_text) + unpadder.finalize()


class SessionKeyManager:
def __init__(self, root_ca):
self.key_hash_dict = dict()
Expand All @@ -51,48 +107,30 @@ def __init__(self, root_ca):

def validate_cert_chain(self, cert):
self.root_ca_pub_key.verify(
cert.signature, cert.tbs_certificate_bytes, padding.PKCS1v15(), cert.signature_hash_algorithm
cert.signature, cert.tbs_certificate_bytes, asymmetric.padding.PKCS1v15(), cert.signature_hash_algorithm
)

def key_request(self, remote_cert, local_cert, local_pri_key):
session_key = os.urandom(KEY_LENGTH)
signature = local_pri_key.sign(
data=session_key,
padding=padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH,
),
algorithm=hashes.SHA256(),
)
signature = _sign(local_pri_key, session_key)
try:
self.validate_cert_chain(remote_cert)
except InvalidSignature:
return False

remote_pub_key = remote_cert.public_key()
key_enc = remote_pub_key.encrypt(
session_key,
padding.OAEP(mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None),
)
key_enc = _asym_enc(remote_pub_key, session_key)
self.key_hash_dict[get_hash(session_key)[-HASH_LENGTH:]] = session_key
key_response = key_enc + signature
return key_response

def process_key_response(self, remote_cert, local_cert, local_pri_key, key_response):
key_enc, signature = key_response[:KEY_ENC_LENGTH], key_response[KEY_ENC_LENGTH:]
try:
session_key = local_pri_key.decrypt(
key_enc,
padding.OAEP(mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None),
)
session_key = _asym_dec(local_pri_key, key_enc)
self.validate_cert_chain(remote_cert)
public_key = remote_cert.public_key()
public_key.verify(
signature,
session_key,
padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH),
hashes.SHA256(),
)
_verify(public_key, session_key, signature)
self.key_hash_dict[get_hash(session_key)[-HASH_LENGTH:]] = session_key
except (InvalidKey, InvalidSignature):
return False
Expand Down Expand Up @@ -120,12 +158,7 @@ def encrypt(self, message):
key = self.session_key_manager.get_latest_key()
key_hash = get_hash(key)
nonce = os.urandom(NONCE_LENGTH)
cipher = Cipher(algorithms.AES(key), modes.CBC(nonce))
encryptor = cipher.encryptor()
padder = sym_padding.PKCS7(PADDING_LENGTH).padder()
padded_data = padder.update(message) + padder.finalize()
ct = nonce + key_hash[-HASH_LENGTH:] + encryptor.update(padded_data) + encryptor.finalize()
return ct
return nonce + key_hash[-HASH_LENGTH:] + _sym_enc(key, nonce, message)

def decrypt(self, message):
nonce, key_hash, message = (
Expand All @@ -136,10 +169,55 @@ def decrypt(self, message):
key = self.session_key_manager.get_key(key_hash)
if key is None:
raise SessionKeyUnavailable("No session key found for received message")
cipher = Cipher(algorithms.AES(key), modes.CBC(nonce))
decryptor = cipher.decryptor()
plain_text = decryptor.update(message)
plain_text = plain_text + decryptor.finalize()
unpadder = sym_padding.PKCS7(PADDING_LENGTH).unpadder()
data = unpadder.update(plain_text) + unpadder.finalize()
return data
return _sym_dec(key, nonce, message)


class SimpleCellCipher:
def __init__(self, root_ca: Certificate, pri_key: asymmetric.rsa.RSAPrivateKey, cert: Certificate):
self._root_ca = root_ca
self._root_ca_pub_key = root_ca.public_key()
self._pri_key = pri_key
self._cert = cert
self._pub_key = cert.public_key()
self._validate_cert_chain(self._cert)
self._cached_enc = dict()
self._cached_dec = dict()

def _validate_cert_chain(self, cert: Certificate):
self._root_ca_pub_key.verify(
cert.signature, cert.tbs_certificate_bytes, asymmetric.padding.PKCS1v15(), cert.signature_hash_algorithm
)

def encrypt(self, message: bytes, target_cert: Certificate):
cert_hash = hash(target_cert)
secret = self._cached_enc.get(cert_hash)
if secret is None:
self._validate_cert_chain(target_cert)
key = os.urandom(KEY_LENGTH)
remote_pub_key = target_cert.public_key()
key_enc = _asym_enc(remote_pub_key, key)
signature = _sign(self._pri_key, key_enc)
self._cached_enc[cert_hash] = (key, key_enc, signature)
else:
(key, key_enc, signature) = secret
nonce = os.urandom(NONCE_LENGTH)
ct = nonce + key_enc + signature + _sym_enc(key, nonce, message)
return ct

def decrypt(self, message: bytes, origin_cert: Certificate):
nonce, key_enc, signature = (
message[:NONCE_LENGTH],
message[NONCE_LENGTH : NONCE_LENGTH + KEY_ENC_LENGTH],
message[NONCE_LENGTH + KEY_ENC_LENGTH : SIMPLE_HEADER_LENGTH],
)
key_hash = hash(key_enc)
dec = self._cached_dec.get(key_hash)
if dec is None:
self._validate_cert_chain(origin_cert)
public_key = origin_cert.public_key()
_verify(public_key, key_enc, signature)
key = _asym_dec(self._pri_key, key_enc)
self._cached_dec[key_hash] = key
else:
key = dec
return _sym_dec(key, nonce, message[SIMPLE_HEADER_LENGTH:])

0 comments on commit eee330a

Please sign in to comment.