Skip to content

Commit

Permalink
PYTHON-1341 Impl of client-side column-level encryption/decryption (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
absurdfarce authored and weideng1 committed May 3, 2023
1 parent b370bc4 commit 7e9b6f8
Show file tree
Hide file tree
Showing 17 changed files with 573 additions and 48 deletions.
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ language: python
python:
- "3.7"
- "3.8"
- "pypy3.5"

env:
- CASS_DRIVER_NO_CYTHON=1
Expand Down
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ matrices = [
"SMOKE": [
"SERVER": ['3.11', '4.0', 'dse-6.8.30'],
"RUNTIME": ['3.7.7', '3.8.3'],
"CYTHON": ["False"]
"CYTHON": ["True", "False"]
]
]

Expand Down
20 changes: 18 additions & 2 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,12 @@ def default_retry_policy(self, policy):
load the configuration and certificates.
"""

column_encryption_policy = None
"""
An instance of :class:`cassandra.policies.ColumnEncryptionPolicy` specifying encryption materials to be
used for columns in this cluster.
"""

@property
def schema_metadata_enabled(self):
"""
Expand Down Expand Up @@ -1104,7 +1110,8 @@ def __init__(self,
monitor_reporting_enabled=True,
monitor_reporting_interval=30,
client_id=None,
cloud=None):
cloud=None,
column_encryption_policy=None):
"""
``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as
extablishing connection pools or refreshing metadata.
Expand Down Expand Up @@ -1152,6 +1159,9 @@ def __init__(self,

self.port = port

if column_encryption_policy is not None:
self.column_encryption_policy = column_encryption_policy

self.endpoint_factory = endpoint_factory or DefaultEndPointFactory(port=self.port)
self.endpoint_factory.configure(self)

Expand Down Expand Up @@ -2535,6 +2545,12 @@ def __init__(self, cluster, hosts, keyspace=None):

self.encoder = Encoder()

if self.cluster.column_encryption_policy is not None:
try:
self.client_protocol_handler.column_encryption_policy = self.cluster.column_encryption_policy
except AttributeError:
log.info("Unable to set column encryption policy for session")

# create connection pools in parallel
self._initial_connect_futures = set()
for host in hosts:
Expand Down Expand Up @@ -3074,7 +3090,7 @@ def prepare(self, query, custom_payload=None, keyspace=None):
prepared_keyspace = keyspace if keyspace else None
prepared_statement = PreparedStatement.from_message(
response.query_id, response.bind_metadata, response.pk_indexes, self.cluster.metadata, query, prepared_keyspace,
self._protocol_version, response.column_metadata, response.result_metadata_id)
self._protocol_version, response.column_metadata, response.result_metadata_id, self.cluster.column_encryption_policy)
prepared_statement.custom_payload = future.custom_payload

self.cluster.add_prepared(response.query_id, prepared_statement)
Expand Down
16 changes: 15 additions & 1 deletion cassandra/obj_parser.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ include "ioutils.pyx"
from cassandra import DriverException
from cassandra.bytesio cimport BytesIOReader
from cassandra.deserializers cimport Deserializer, from_binary
from cassandra.deserializers import find_deserializer
from cassandra.parsing cimport ParseDesc, ColumnParser, RowParser
from cassandra.tuple cimport tuple_new, tuple_set

from cpython.bytes cimport PyBytes_AsStringAndSize


cdef class ListParser(ColumnParser):
"""Decode a ResultMessage into a list of tuples (or other objects)"""
Expand Down Expand Up @@ -58,18 +61,29 @@ cdef class TupleRowParser(RowParser):
assert desc.rowsize >= 0

cdef Buffer buf
cdef Buffer newbuf
cdef Py_ssize_t i, rowsize = desc.rowsize
cdef Deserializer deserializer
cdef tuple res = tuple_new(desc.rowsize)

ce_policy = desc.column_encryption_policy
for i in range(rowsize):
# Read the next few bytes
get_buf(reader, &buf)

# Deserialize bytes to python object
deserializer = desc.deserializers[i]
coldesc = desc.coldescs[i]
uses_ce = ce_policy and ce_policy.contains_column(coldesc)
try:
val = from_binary(deserializer, &buf, desc.protocol_version)
if uses_ce:
col_type = ce_policy.column_type(coldesc)
decrypted_bytes = ce_policy.decrypt(coldesc, to_bytes(&buf))
PyBytes_AsStringAndSize(decrypted_bytes, &newbuf.ptr, &newbuf.size)
deserializer = find_deserializer(ce_policy.column_type(coldesc))
val = from_binary(deserializer, &newbuf, desc.protocol_version)
else:
val = from_binary(deserializer, &buf, desc.protocol_version)
except Exception as e:
raise DriverException('Failed decoding result column "%s" of type %s: %s' % (desc.colnames[i],
desc.coltypes[i].cql_parameterized_type(),
Expand Down
2 changes: 2 additions & 0 deletions cassandra/parsing.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ from cassandra.deserializers cimport Deserializer
cdef class ParseDesc:
cdef public object colnames
cdef public object coltypes
cdef public object column_encryption_policy
cdef public list coldescs
cdef Deserializer[::1] deserializers
cdef public int protocol_version
cdef Py_ssize_t rowsize
Expand Down
4 changes: 3 additions & 1 deletion cassandra/parsing.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ Module containing the definitions and declarations (parsing.pxd) for parsers.
cdef class ParseDesc:
"""Description of what structure to parse"""

def __init__(self, colnames, coltypes, deserializers, protocol_version):
def __init__(self, colnames, coltypes, column_encryption_policy, coldescs, deserializers, protocol_version):
self.colnames = colnames
self.coltypes = coltypes
self.column_encryption_policy = column_encryption_policy
self.coldescs = coldescs
self.deserializers = deserializers
self.protocol_version = protocol_version
self.rowsize = len(colnames)
Expand Down
181 changes: 173 additions & 8 deletions cassandra/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import namedtuple
from functools import lru_cache
from itertools import islice, cycle, groupby, repeat
import logging
import os
from random import randint, shuffle
from threading import Lock
import socket
import warnings

from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes

from cassandra import WriteType as WT
from cassandra.cqltypes import _cqltypes


# This is done this way because WriteType was originally
Expand Down Expand Up @@ -455,7 +463,7 @@ class HostFilterPolicy(LoadBalancingPolicy):
A :class:`.LoadBalancingPolicy` subclass configured with a child policy,
and a single-argument predicate. This policy defers to the child policy for
hosts where ``predicate(host)`` is truthy. Hosts for which
``predicate(host)`` is falsey will be considered :attr:`.IGNORED`, and will
``predicate(host)`` is falsy will be considered :attr:`.IGNORED`, and will
not be used in a query plan.
This can be used in the cases where you need a whitelist or blacklist
Expand Down Expand Up @@ -491,7 +499,7 @@ def __init__(self, child_policy, predicate):
:param child_policy: an instantiated :class:`.LoadBalancingPolicy`
that this one will defer to.
:param predicate: a one-parameter function that takes a :class:`.Host`.
If it returns a falsey value, the :class:`.Host` will
If it returns a falsy value, the :class:`.Host` will
be :attr:`.IGNORED` and not returned in query plans.
"""
super(HostFilterPolicy, self).__init__()
Expand Down Expand Up @@ -527,7 +535,7 @@ def predicate(self):
def distance(self, host):
"""
Checks if ``predicate(host)``, then returns
:attr:`~HostDistance.IGNORED` if falsey, and defers to the child policy
:attr:`~HostDistance.IGNORED` if falsy, and defers to the child policy
otherwise.
"""
if self.predicate(host):
Expand Down Expand Up @@ -616,7 +624,7 @@ class ReconnectionPolicy(object):
def new_schedule(self):
"""
This should return a finite or infinite iterable of delays (each as a
floating point number of seconds) inbetween each failed reconnection
floating point number of seconds) in-between each failed reconnection
attempt. Note that if the iterable is finite, reconnection attempts
will cease once the iterable is exhausted.
"""
Expand All @@ -626,12 +634,12 @@ def new_schedule(self):
class ConstantReconnectionPolicy(ReconnectionPolicy):
"""
A :class:`.ReconnectionPolicy` subclass which sleeps for a fixed delay
inbetween each reconnection attempt.
in-between each reconnection attempt.
"""

def __init__(self, delay, max_attempts=64):
"""
`delay` should be a floating point number of seconds to wait inbetween
`delay` should be a floating point number of seconds to wait in-between
each attempt.
`max_attempts` should be a total number of attempts to be made before
Expand All @@ -655,7 +663,7 @@ def new_schedule(self):
class ExponentialReconnectionPolicy(ReconnectionPolicy):
"""
A :class:`.ReconnectionPolicy` subclass which exponentially increases
the length of the delay inbetween each reconnection attempt up to
the length of the delay in-between each reconnection attempt up to
a set maximum delay.
A random amount of jitter (+/- 15%) will be added to the pure exponential
Expand Down Expand Up @@ -715,7 +723,7 @@ class RetryPolicy(object):
timeout and unavailable failures. These are failures reported from the
server side. Timeouts are configured by
`settings in cassandra.yaml <https://github.com/apache/cassandra/blob/cassandra-2.1.4/conf/cassandra.yaml#L568-L584>`_.
Unavailable failures occur when the coordinator cannot acheive the consistency
Unavailable failures occur when the coordinator cannot achieve the consistency
level for a request. For further information see the method descriptions
below.
Expand Down Expand Up @@ -1181,3 +1189,160 @@ def _rethrow(self, *args, **kwargs):
on_read_timeout = _rethrow
on_write_timeout = _rethrow
on_unavailable = _rethrow


ColDesc = namedtuple('ColDesc', ['ks', 'table', 'col'])
ColData = namedtuple('ColData', ['key','type'])

class ColumnEncryptionPolicy(object):
"""
A policy enabling (mostly) transparent encryption and decryption of data before it is
sent to the cluster.
Key materials and other configurations are specified on a per-column basis. This policy can
then be used by driver structures which are aware of the underlying columns involved in their
work. In practice this includes the following cases:
* Prepared statements - data for columns specified by the cluster's policy will be transparently
encrypted before they are sent
* Rows returned from any query - data for columns specified by the cluster's policy will be
transparently decrypted before they are returned to the user
To enable this functionality, create an instance of this class (or more likely a subclass)
before creating a cluster. This policy should then be configured and supplied to the Cluster
at creation time via the :attr:`.Cluster.column_encryption_policy` attribute.
"""

def encrypt(self, coldesc, obj_bytes):
"""
Encrypt the specified bytes using the cryptography materials for the specified column.
Largely used internally, although this could also be used to encrypt values supplied
to non-prepared statements in a way that is consistent with this policy.
"""
raise NotImplementedError()

def decrypt(self, coldesc, encrypted_bytes):
"""
Decrypt the specified (encrypted) bytes using the cryptography materials for the
specified column. Used internally; could be used externally as well but there's
not currently an obvious use case.
"""
raise NotImplementedError()

def add_column(self, coldesc, key):
"""
Provide cryptography materials to be used when encrypted and/or decrypting data
for the specified column.
"""
raise NotImplementedError()

def contains_column(self, coldesc):
"""
Predicate to determine if a specific column is supported by this policy.
Currently only used internally.
"""
raise NotImplementedError()

def encode_and_encrypt(self, coldesc, obj):
"""
Helper function to enable use of this policy on simple (i.e. non-prepared)
statements.
"""
raise NotImplementedError()

AES256_BLOCK_SIZE = 128
AES256_BLOCK_SIZE_BYTES = int(AES256_BLOCK_SIZE / 8)
AES256_KEY_SIZE = 256
AES256_KEY_SIZE_BYTES = int(AES256_KEY_SIZE / 8)

class AES256ColumnEncryptionPolicy(ColumnEncryptionPolicy):

# CBC uses an IV that's the same size as the block size
#
# TODO: Need to find some way to expose mode options
# (CBC etc.) without leaking classes from the underlying
# impl here
def __init__(self, mode = modes.CBC, iv = os.urandom(AES256_BLOCK_SIZE_BYTES)):

self.mode = mode
self.iv = iv

# ColData for a given ColDesc is always preserved. We only create a Cipher
# when there's an actual need to for a given ColDesc
self.coldata = {}
self.ciphers = {}

def encrypt(self, coldesc, obj_bytes):

# AES256 has a 128-bit block size so if the input bytes don't align perfectly on
# those blocks we have to pad them. There's plenty of room for optimization here:
#
# * Instances of the PKCS7 padder should be managed in a bounded pool
# * It would be nice if we could get a flag from encrypted data to indicate
# whether it was padded or not
# * Might be able to make this happen with a leading block of flags in encrypted data
padder = padding.PKCS7(AES256_BLOCK_SIZE).padder()
padded_bytes = padder.update(obj_bytes) + padder.finalize()

cipher = self._get_cipher(coldesc)
encryptor = cipher.encryptor()
return encryptor.update(padded_bytes) + encryptor.finalize()

def decrypt(self, coldesc, encrypted_bytes):

cipher = self._get_cipher(coldesc)
decryptor = cipher.decryptor()
padded_bytes = decryptor.update(encrypted_bytes) + decryptor.finalize()

unpadder = padding.PKCS7(AES256_BLOCK_SIZE).unpadder()
return unpadder.update(padded_bytes) + unpadder.finalize()

def add_column(self, coldesc, key, type):

if not coldesc:
raise ValueError("ColDesc supplied to add_column cannot be None")
if not key:
raise ValueError("Key supplied to add_column cannot be None")
if not type:
raise ValueError("Type supplied to add_column cannot be None")
if type not in _cqltypes.keys():
raise ValueError("Type %s is not a supported type".format(type))
if not len(key) == AES256_KEY_SIZE_BYTES:
raise ValueError("AES256 column encryption policy expects a 256-bit encryption key")
self.coldata[coldesc] = ColData(key, _cqltypes[type])

def contains_column(self, coldesc):
return coldesc in self.coldata

def encode_and_encrypt(self, coldesc, obj):
if not coldesc:
raise ValueError("ColDesc supplied to encode_and_encrypt cannot be None")
if not obj:
raise ValueError("Object supplied to encode_and_encrypt cannot be None")
coldata = self.coldata.get(coldesc)
if not coldata:
raise ValueError("Could not find ColData for ColDesc %s".format(coldesc))
return self.encrypt(coldesc, coldata.type.serialize(obj, None))

def cache_info(self):
return AES256ColumnEncryptionPolicy._build_cipher.cache_info()

def column_type(self, coldesc):
return self.coldata[coldesc].type

def _get_cipher(self, coldesc):
"""
Access relevant state from this instance necessary to create a Cipher and then get one,
hopefully returning a cached instance if we've already done so (and it hasn't been evicted)
"""

try:
coldata = self.coldata[coldesc]
return AES256ColumnEncryptionPolicy._build_cipher(coldata.key, self.mode, self.iv)
except KeyError:
raise ValueError("Could not find column {}".format(coldesc))

# Explicitly use a class method here to avoid caching self
@lru_cache(maxsize=128)
def _build_cipher(key, mode, iv):
return Cipher(algorithms.AES256(key), mode(iv))
Loading

0 comments on commit 7e9b6f8

Please sign in to comment.