Skip to content

Commit

Permalink
PYTHON-1352 Add vector type, codec + support for parsing CQL type (da…
Browse files Browse the repository at this point in the history
  • Loading branch information
absurdfarce authored and dkropachev committed Nov 18, 2024
1 parent 3a9ac29 commit 6b46906
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 5 deletions.
2 changes: 1 addition & 1 deletion cassandra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def emit(self, record):

logging.getLogger('cassandra').addHandler(NullHandler())

__version_info__ = (3, 27, 0)
__version_info__ = (3, 28, 0b1)
__version__ = '.'.join(map(str, __version_info__))


Expand Down
36 changes: 33 additions & 3 deletions cassandra/cqltypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,15 @@ def parse_casstype_args(typestring):
else:
names.append(None)

ctype = lookup_casstype_simple(tok)
try:
ctype = int(tok)
except ValueError:
ctype = lookup_casstype_simple(tok)
types.append(ctype)

# return the first (outer) type, which will have all parameters applied
return args[0][0][0]


def lookup_casstype(casstype):
"""
Given a Cassandra type as a string (possibly including parameters), hand
Expand Down Expand Up @@ -286,7 +288,7 @@ class _CassandraType(object, metaclass=CassandraTypeType):
"""

def __repr__(self):
return '<%s( %r )>' % (self.cql_parameterized_type(), self.val)
return '<%s>' % (self.cql_parameterized_type())

@classmethod
def from_binary(cls, byts, protocol_version):
Expand Down Expand Up @@ -1402,3 +1404,31 @@ def serialize(cls, v, protocol_version):
buf.write(int8_pack(cls._encode_precision(bound.precision)))

return buf.getvalue()

class VectorType(_CassandraType):
typename = 'org.apache.cassandra.db.marshal.VectorType'
vector_size = 0
subtype = None

@classmethod
def apply_parameters(cls, params, names):
assert len(params) == 2
subtype = lookup_casstype(params[0])
vsize = params[1]
return type('%s(%s)' % (cls.cass_parameterized_type_with([]), vsize), (cls,), {'vector_size': vsize, 'subtype': subtype})

@classmethod
def deserialize(cls, byts, protocol_version):
indexes = (4 * x for x in range(0, cls.vector_size))
return [cls.subtype.deserialize(byts[idx:idx + 4], protocol_version) for idx in indexes]

@classmethod
def serialize(cls, v, protocol_version):
buf = io.BytesIO()
for item in v:
buf.write(cls.subtype.serialize(item, protocol_version))
return buf.getvalue()

@classmethod
def cql_parameterized_type(cls):
return "%s<%s, %s>" % (cls.typename, cls.subtype.typename, cls.vector_size)
22 changes: 21 additions & 1 deletion tests/unit/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
EmptyValue, LongType, SetType, UTF8Type,
cql_typename, int8_pack, int64_pack, lookup_casstype,
lookup_casstype_simple, parse_casstype_args,
int32_pack, Int32Type, ListType, MapType
int32_pack, Int32Type, ListType, MapType, VectorType,
FloatType
)
from cassandra.encoder import cql_quote
from cassandra.pool import Host
Expand Down Expand Up @@ -188,6 +189,12 @@ class BarType(FooType):
self.assertEqual(UTF8Type, ctype.subtypes[2])
self.assertEqual([b'city', None, b'zip'], ctype.names)

def test_parse_casstype_vector(self):
ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 3)")
self.assertTrue(issubclass(ctype, VectorType))
self.assertEqual(3, ctype.vector_size)
self.assertEqual(FloatType, ctype.subtype)

def test_empty_value(self):
self.assertEqual(str(EmptyValue()), 'EMPTY')

Expand Down Expand Up @@ -301,6 +308,19 @@ def test_cql_quote(self):
self.assertEqual(cql_quote('test'), "'test'")
self.assertEqual(cql_quote(0), '0')

def test_vector_round_trip(self):
base = [3.4, 2.9, 41.6, 12.0]
ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)")
base_bytes = ctype.serialize(base, 0)
self.assertEqual(16, len(base_bytes))
result = ctype.deserialize(base_bytes, 0)
self.assertEqual(len(base), len(result))
for idx in range(0,len(base)):
self.assertAlmostEqual(base[idx], result[idx], places=5)

def test_vector_cql_parameterized_type(self):
ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)")
self.assertEqual(ctype.cql_parameterized_type(), "org.apache.cassandra.db.marshal.VectorType<float, 4>")

ZERO = datetime.timedelta(0)

Expand Down

0 comments on commit 6b46906

Please sign in to comment.