From 5e64ee3ee91449860232799ea0587d7fe26a3b4c Mon Sep 17 00:00:00 2001 From: AN Long Date: Tue, 14 Nov 2023 18:02:04 +0800 Subject: [PATCH] add 'strict_decode' to cybin protocol --- tests/test_protocol_cybinary.py | 8 ++++++ thriftpy2/protocol/cybin/cybin.pyx | 41 +++++++++++++++++++----------- 2 files changed, 34 insertions(+), 15 deletions(-) diff --git a/tests/test_protocol_cybinary.py b/tests/test_protocol_cybinary.py index d96e23d4..482ea703 100644 --- a/tests/test_protocol_cybinary.py +++ b/tests/test_protocol_cybinary.py @@ -147,6 +147,14 @@ def test_read_binary(): b, TType.STRING, decode_response=False) +def test_strict_decode(): + bs = TCyMemoryBuffer(b"\x00\x00\x00\x0c\x00" # there is a redundant '\x00' + b"\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c") + with pytest.raises(UnicodeDecodeError): + proto.read_val(bs, TType.STRING, decode_response=True, + strict_decode=True) + + def test_write_message_begin(): trans = TCyMemoryBuffer() b = proto.TCyBinaryProtocol(trans) diff --git a/thriftpy2/protocol/cybin/cybin.pyx b/thriftpy2/protocol/cybin/cybin.pyx index 14a38392..8ecf722c 100644 --- a/thriftpy2/protocol/cybin/cybin.pyx +++ b/thriftpy2/protocol/cybin/cybin.pyx @@ -170,7 +170,8 @@ cdef inline write_dict(CyTransportBase buf, object val, spec): c_write_val(buf, v_type, v, v_spec) -cdef inline read_struct(CyTransportBase buf, obj, decode_response=True): +cdef inline read_struct(CyTransportBase buf, obj, decode_response=True, + strict_decode=False): cdef dict field_specs = obj.thrift_spec cdef int fid cdef TType field_type, ttype @@ -199,7 +200,8 @@ cdef inline read_struct(CyTransportBase buf, obj, decode_response=True): else: spec = field_spec[2] - setattr(obj, name, c_read_val(buf, ttype, spec, decode_response)) + setattr(obj, name, c_read_val(buf, ttype, spec, decode_response, + strict_decode)) return obj @@ -251,16 +253,19 @@ cdef inline c_read_binary(CyTransportBase buf, int32_t size): return py_data -cdef inline c_read_string(CyTransportBase buf, int32_t size): +cdef inline c_read_string(CyTransportBase buf, int32_t size, + strict_decode=False): py_data = c_read_binary(buf, size) try: return (py_data)[:size].decode("utf-8") except: # noqa + if strict_decode: + raise return py_data cdef c_read_val(CyTransportBase buf, TType ttype, spec=None, - decode_response=True): + decode_response=True, strict_decode=False): cdef int size cdef int64_t n cdef TType v_type, k_type, orig_type, orig_key_type @@ -291,7 +296,7 @@ cdef c_read_val(CyTransportBase buf, TType ttype, spec=None, elif ttype == T_STRING: size = read_i32(buf) if decode_response: - return c_read_string(buf, size) + return c_read_string(buf, size, strict_decode) else: return c_read_binary(buf, size) @@ -311,7 +316,7 @@ cdef c_read_val(CyTransportBase buf, TType ttype, spec=None, skip(buf, orig_type) return [] - return [c_read_val(buf, v_type, v_spec, decode_response) + return [c_read_val(buf, v_type, v_spec, decode_response, strict_decode) for _ in range(size)] elif ttype == T_MAP: @@ -345,13 +350,13 @@ cdef c_read_val(CyTransportBase buf, TType ttype, spec=None, return {} return { - c_read_val(buf, k_type, k_spec, decode_response): - c_read_val(buf, v_type, v_spec, decode_response) + c_read_val(buf, k_type, k_spec, decode_response, strict_decode): + c_read_val(buf, v_type, v_spec, decode_response, strict_decode) for _ in range(size) } elif ttype == T_STRUCT: - return read_struct(buf, spec(), decode_response) + return read_struct(buf, spec(), decode_response, strict_decode) cdef c_write_val(CyTransportBase buf, TType ttype, val, spec=None): @@ -432,8 +437,9 @@ cpdef skip(CyTransportBase buf, TType ttype): skip(buf, f_type) -def read_val(CyTransportBase buf, TType ttype, decode_response=True): - return c_read_val(buf, ttype, None, decode_response) +def read_val(CyTransportBase buf, TType ttype, decode_response=True, + strict_decode=False): + return c_read_val(buf, ttype, None, decode_response, strict_decode) def write_val(CyTransportBase buf, TType ttype, val, spec=None): @@ -445,13 +451,15 @@ cdef class TCyBinaryProtocol(object): cdef public bool strict_read cdef public bool strict_write cdef public bool decode_response + cdef public bool strict_decode def __init__(self, trans, strict_read=True, strict_write=True, - decode_response=True): + decode_response=True, strict_decode=False): self.trans = trans self.strict_read = strict_read self.strict_write = strict_write self.decode_response = decode_response + self.strict_decode = strict_decode def skip(self, ttype): skip(self.trans, (ttype)) @@ -498,7 +506,8 @@ cdef class TCyBinaryProtocol(object): def read_struct(self, obj): try: - return read_struct(self.trans, obj, self.decode_response) + return read_struct(self.trans, obj, self.decode_response, + self.strict_decode) except Exception: self.trans.clean() raise @@ -513,11 +522,13 @@ cdef class TCyBinaryProtocol(object): class TCyBinaryProtocolFactory(object): def __init__(self, strict_read=True, strict_write=True, - decode_response=True): + decode_response=True, strict_decode=False): self.strict_read = strict_read self.strict_write = strict_write self.decode_response = decode_response + self.strict_decode = strict_decode def get_protocol(self, trans): return TCyBinaryProtocol( - trans, self.strict_read, self.strict_write, self.decode_response) + trans, self.strict_read, self.strict_write, self.decode_response, + self.strict_decode)