diff --git a/happybase/connection.py b/happybase/connection.py index e52d256..091958e 100644 --- a/happybase/connection.py +++ b/happybase/connection.py @@ -9,11 +9,14 @@ from thrift.transport.TSocket import TSocket from thrift.transport.TTransport import TBufferedTransport, TFramedTransport from thrift.protocol import TBinaryProtocol, TCompactProtocol +import sasl +from os import path from .hbase import Hbase from .hbase.ttypes import ColumnDescriptor from .table import Table from .util import pep8_to_camel_case +from .thrift_sasl import TSaslClientTransport logger = logging.getLogger(__name__) @@ -81,6 +84,11 @@ class Connection(object): process as well. ``TBinaryAccelerated`` is the default protocol that happybase uses. + The optional `use_kerberos` argument allows you to establish a + secure connection to HBase. This argument requires a buffered + `transport` protocol. You must first authorize yourself with + your KDC by using kinit (e.g. kinit -kt my.keytab user@REALM) + .. versionadded:: 0.9 `protocol` argument @@ -101,11 +109,14 @@ class Connection(object): :param str table_prefix_separator: Separator used for `table_prefix` :param str compat: Compatibility mode (optional) :param str transport: Thrift transport mode (optional) + :param bool use_kerberos: Connect to HBase via a secure connection (default: False) + :param str sasl_service: The name of the SASL service (default: hbase) """ def __init__(self, host=DEFAULT_HOST, port=DEFAULT_PORT, timeout=None, autoconnect=True, table_prefix=None, table_prefix_separator='_', compat=DEFAULT_COMPAT, - transport=DEFAULT_TRANSPORT, protocol=DEFAULT_PROTOCOL): + transport=DEFAULT_TRANSPORT, protocol=DEFAULT_PROTOCOL, + use_kerberos=False, sasl_service="hbase"): if transport not in THRIFT_TRANSPORTS: raise ValueError("'transport' must be one of %s" @@ -135,6 +146,8 @@ def __init__(self, host=DEFAULT_HOST, port=DEFAULT_PORT, timeout=None, self.table_prefix_separator = table_prefix_separator self.compat = compat + self._use_kerberos = use_kerberos + self._sasl_service = sasl_service self._transport_class = THRIFT_TRANSPORTS[transport] self._protocol_class = THRIFT_PROTOCOLS[protocol] self._refresh_thrift_client() @@ -150,7 +163,20 @@ def _refresh_thrift_client(self): if self.timeout is not None: socket.setTimeout(self.timeout) - self.transport = self._transport_class(socket) + if not self._use_kerberos: + self.transport = self._transport_class(socket) + else: + # Check for required arguments for kerberos + if self._transport_class is not TBufferedTransport: + raise ValueError("Must use a buffered transport " + " when use_kerberos is enabled") + + saslc = sasl.Client() + saslc.setAttr("host", self.host) + saslc.setAttr("service", self._sasl_service) + saslc.init() + self.transport = TSaslClientTransport(saslc, "GSSAPI", socket) + protocol = self._protocol_class(self.transport) self.client = Hbase.Client(protocol) diff --git a/happybase/thrift_sasl.py b/happybase/thrift_sasl.py new file mode 100644 index 0000000..8af8f4e --- /dev/null +++ b/happybase/thrift_sasl.py @@ -0,0 +1,169 @@ +""" SASL transports for Thrift. """ + +from thrift.transport.TTransport import CReadableTransport, TTransportBase, TTransportException, StringIO +import struct + +class TSaslClientTransport(TTransportBase, CReadableTransport): + START = 1 + OK = 2 + BAD = 3 + ERROR = 4 + COMPLETE = 5 + + def __init__(self, sasl_client_factory, mechanism, trans): + """ + @param sasl_client_factory: a callable that returns a new sasl.Client object + @param mechanism: the SASL mechanism (e.g. "GSSAPI", "PLAIN") + @param trans: the underlying transport over which to communicate. + """ + self._trans = trans + self.sasl_client_factory = sasl_client_factory + self.sasl = None + self.mechanism = mechanism + self.__wbuf = StringIO() + self.__rbuf = StringIO() + self.opened = False + self.encode = None + + def isOpen(self): + return self._trans.isOpen() + + def open(self): + if not self._trans.isOpen(): + self._trans.open() + + if self.sasl is not None: + raise TTransportException( + type=TTransportException.NOT_OPEN, + message="Already open!") + self.sasl = self.sasl_client_factory + + ret, chosen_mech, initial_response = self.sasl.start(self.mechanism) + if not ret: + raise TTransportException(type=TTransportException.NOT_OPEN, + message=("Could not start SASL: %s" % self.sasl.getError())) + + # Send initial response + self._send_message(self.START, chosen_mech) + self._send_message(self.OK, initial_response) + + # SASL negotiation loop + while True: + status, payload = self._recv_sasl_message() + if status not in (self.OK, self.COMPLETE): + raise TTransportException(type=TTransportException.NOT_OPEN, + message=("Bad status: %d (%s)" % (status, payload))) + if status == self.COMPLETE: + break + ret, response = self.sasl.step(payload) + if not ret: + raise TTransportException(type=TTransportException.NOT_OPEN, + message=("Bad SASL result: %s" % (self.sasl.getError()))) + self._send_message(self.OK, response) + + def _send_message(self, status, body): + header = struct.pack(">BI", status, len(body)) + self._trans.write(header + body) + self._trans.flush() + + def _recv_sasl_message(self): + header = self._trans.readAll(5) + status, length = struct.unpack(">BI", header) + if length > 0: + payload = self._trans.readAll(length) + else: + payload = "" + return status, payload + + def write(self, data): + self.__wbuf.write(data) + + def flush(self): + buffer = self.__wbuf.getvalue() + # The first time we flush data, we send it to sasl.encode() + # If the length doesn't change, then we must be using a QOP + # of auth and we should no longer call sasl.encode(), otherwise + # we encode every time. + if self.encode == None: + success, encoded = self.sasl.encode(buffer) + if not success: + raise TTransportException(type=TTransportException.UNKNOWN, + message=self.sasl.getError()) + if (len(encoded)==len(buffer)): + self.encode = False + self._flushPlain(buffer) + else: + self.encode = True + self._trans.write(encoded) + elif self.encode: + self._flushEncoded(buffer) + else: + self._flushPlain(buffer) + + self._trans.flush() + self.__wbuf = StringIO() + + def _flushEncoded(self, buffer): + # sasl.ecnode() does the encoding and adds the length header, so nothing + # to do but call it and write the result. + success, encoded = self.sasl.encode(buffer) + if not success: + raise TTransportException(type=TTransportException.UNKNOWN, + message=self.sasl.getError()) + self._trans.write(encoded) + + def _flushPlain(self, buffer): + # When we have QOP of auth, sasl.encode() will pass the input to the output + # but won't put a length header, so we have to do that. + + # Note stolen from TFramedTransport: + # N.B.: Doing this string concatenation is WAY cheaper than making + # two separate calls to the underlying socket object. Socket writes in + # Python turn out to be REALLY expensive, but it seems to do a pretty + # good job of managing string buffer operations without excessive copies + self._trans.write(struct.pack(">I", len(buffer)) + buffer) + + def read(self, sz): + ret = self.__rbuf.read(sz) + if len(ret) != 0: + return ret + + self._read_frame() + return self.__rbuf.read(sz) + + def _read_frame(self): + header = self._trans.readAll(4) + (length,) = struct.unpack(">I", header) + if self.encode: + # If the frames are encoded (i.e. you're using a QOP of auth-int or + # auth-conf), then make sure to include the header in the bytes you send to + # sasl.decode() + encoded = header + self._trans.readAll(length) + success, decoded = self.sasl.decode(encoded) + if not success: + raise TTransportException(type=TTransportException.UNKNOWN, + message=self.sasl.getError()) + else: + # If the frames are not encoded, just pass it through + decoded = self._trans.readAll(length) + self.__rbuf = StringIO(decoded) + + def close(self): + self._trans.close() + self.sasl = None + + # Implement the CReadableTransport interface. + # Stolen shamelessly from TFramedTransport + @property + def cstringio_buf(self): + return self.__rbuf + + def cstringio_refill(self, prefix, reqlen): + # self.__rbuf will already be empty here because fastbinary doesn't + # ask for a refill until the previous buffer is empty. Therefore, + # we can start reading new frames immediately. + while len(prefix) < reqlen: + self._read_frame() + prefix += self.__rbuf.getvalue() + self.__rbuf = StringIO(prefix) + return self.__rbuf \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 85cebac..b5bbab9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ thrift>=0.8.0 +sasl