Skip to content

Commit

Permalink
[feat] Support schema field type promotion (#159)
Browse files Browse the repository at this point in the history
## Motivation

The client is not correctly following [Avro's type promotion rules](https://avro.apache.org/docs/1.11.1/specification/#schema-resolution), leading to a potential problem with data serialization and deserialization.

The expected behavior is that the Python client should correctly follow Avro's type promotion rules and perform type conversion when necessary, ensuring compatibility. However the actual behavior is that the Python client's schema deserialization is too strict, and type promotion is not happening as expected.

## Modification

- Support schema field type promotion when validating the python type
- Convert the field value to the desired compatible python type
  • Loading branch information
RobertIndie authored Oct 30, 2023
1 parent 995e491 commit dfd163a
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 11 deletions.
25 changes: 16 additions & 9 deletions pulsar/schema/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def validate_type(self, name, val):
if val is None and not self._required:
return self.default()

if type(val) != self.python_type():
if not isinstance(val, self.python_type()):
raise TypeError("Invalid type '%s' for field '%s'. Expected: %s" % (type(val), name, _string_representation(self.python_type())))
return val

Expand Down Expand Up @@ -309,7 +309,7 @@ def type(self):
return 'float'

def python_type(self):
return float
return float, int

def default(self):
if self._default is not None:
Expand All @@ -323,7 +323,7 @@ def type(self):
return 'double'

def python_type(self):
return float
return float, int

def default(self):
if self._default is not None:
Expand All @@ -337,30 +337,37 @@ def type(self):
return 'bytes'

def python_type(self):
return bytes
return bytes, str

def default(self):
if self._default is not None:
return self._default
else:
return None

def validate_type(self, name, val):
if isinstance(val, str):
return val.encode()
return val


class String(Field):
def type(self):
return 'string'

def python_type(self):
return str
return str, bytes

def validate_type(self, name, val):
t = type(val)

if val is None and not self._required:
return self.default()

if not (t is str or t.__name__ == 'unicode'):
if not (isinstance(val, (str, bytes)) or t.__name__ == 'unicode'):
raise TypeError("Invalid type '%s' for field '%s'. Expected a string" % (t, name))
if isinstance(val, bytes):
return val.decode()
return val

def default(self):
Expand Down Expand Up @@ -406,7 +413,7 @@ def validate_type(self, name, val):
else:
raise TypeError(
"Invalid enum value '%s' for field '%s'. Expected: %s" % (val, name, self.values.keys()))
elif type(val) != self.python_type():
elif not isinstance(val, self.python_type()):
raise TypeError("Invalid type '%s' for field '%s'. Expected: %s" % (type(val), name, _string_representation(self.python_type())))
else:
return val
Expand Down Expand Up @@ -450,7 +457,7 @@ def validate_type(self, name, val):
super(Array, self).validate_type(name, val)

for x in val:
if type(x) != self.array_type.python_type():
if not isinstance(x, self.array_type.python_type()):
raise TypeError('Array field ' + name + ' items should all be of type ' +
_string_representation(self.array_type.type()))
return val
Expand Down Expand Up @@ -493,7 +500,7 @@ def validate_type(self, name, val):
for k, v in val.items():
if type(k) != str and not is_unicode(k):
raise TypeError('Map keys for field ' + name + ' should all be strings')
if type(v) != self.value_type.python_type():
if not isinstance(v, self.value_type.python_type()):
raise TypeError('Map values for field ' + name + ' should all be of type '
+ _string_representation(self.value_type.python_type()))

Expand Down
2 changes: 2 additions & 0 deletions pulsar/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def __init__(self, record_cls):
def _get_serialized_value(self, o):
if isinstance(o, enum.Enum):
return o.value
elif isinstance(o, bytes):
return o.decode()
else:
data = o.__dict__.copy()
remove_reserved_key(data)
Expand Down
67 changes: 65 additions & 2 deletions tests/schema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@
format='%(asctime)s %(levelname)-5s %(message)s')


class ExampleRecord(Record):
str_field = String()
int_field = Integer()
float_field = Float()
bytes_field = Bytes()

class SchemaTest(TestCase):

serviceUrl = 'pulsar://localhost:6650'
Expand Down Expand Up @@ -87,6 +93,31 @@ class Example(Record):
]
})

def test_type_promotion(self):
test_cases = [
(20, int, 20), # No promotion necessary: int => int
(20, float, 20.0), # Promotion: int => float
(20.0, float, 20.0), # No Promotion necessary: float => float
("Test text1", bytes, b"Test text1"), # Promotion: str => bytes
(b"Test text1", str, "Test text1"), # Promotion: bytes => str
]

for value_from, type_to, value_to in test_cases:
if type_to == int:
fieldType = Integer()
elif type_to == float:
fieldType = Double()
elif type_to == str:
fieldType = String()
elif type_to == bytes:
fieldType = Bytes()
else:
fieldType = String()

field_value = fieldType.validate_type("test_field", value_from)
self.assertEqual(value_to, field_value)


def test_complex(self):
class Color(Enum):
red = 1
Expand Down Expand Up @@ -229,7 +260,7 @@ class E3(Record):
a = Float()

E3(a=1.0) # Ok
self._expectTypeError(lambda: E3(a=1))
E3(a=1) # Ok Type promotion: int -> float

class E4(Record):
a = Null()
Expand Down Expand Up @@ -259,7 +290,7 @@ class E7(Record):
a = Double()

E7(a=1.0) # Ok
self._expectTypeError(lambda: E3(a=1))
E7(a=1) # Ok Type promotion: int -> double

class Color(Enum):
red = 1
Expand Down Expand Up @@ -1346,5 +1377,37 @@ def verify_messages(msgs: List[pulsar.Message]):

client.close()

def test_schema_type_promotion(self):
client = pulsar.Client(self.serviceUrl)

schemas = [("avro", AvroSchema(ExampleRecord)), ("json", JsonSchema(ExampleRecord))]

for schema_name, schema in schemas:
topic = f'test_schema_type_promotion_{schema_name}'

consumer = client.subscribe(
topic=topic,
subscription_name=f'my-sub-{schema_name}',
schema=schema
)
producer = client.create_producer(
topic=topic,
schema=schema
)
sendValue = ExampleRecord(str_field=b'test', int_field=1, float_field=3, bytes_field='str')

producer.send(sendValue)

msg = consumer.receive()
msg_value = msg.value()
self.assertEqual(msg_value.str_field, sendValue.str_field)
self.assertEqual(msg_value.int_field, sendValue.int_field)
self.assertEqual(msg_value.float_field, sendValue.float_field)
self.assertEqual(msg_value.bytes_field, sendValue.bytes_field)
consumer.acknowledge(msg)

client.close()


if __name__ == '__main__':
main()

0 comments on commit dfd163a

Please sign in to comment.