diff --git a/CHANGELOG_DEV.md b/CHANGELOG_DEV.md index 0b821fd54..54093cf89 100644 --- a/CHANGELOG_DEV.md +++ b/CHANGELOG_DEV.md @@ -1,5 +1,6 @@ # 0.93.0 - 2022-05-23 - Add support for mysql `_binary` charset. +- Handle properly null values in MySQL # 0.93.0 - 2022-05-20 - Fix normalization of integers during an insertion. diff --git a/crypto/envelope_detector.go b/crypto/envelope_detector.go index 473116e44..dc1100dd7 100644 --- a/crypto/envelope_detector.go +++ b/crypto/envelope_detector.go @@ -196,6 +196,9 @@ func (wrapper *OldContainerDetectorWrapper) OnCryptoEnvelope(ctx context.Context // OnColumn callback which finds serializedContainer or AcraStruct/AcraBlock for backward compatibility func (wrapper *OldContainerDetectorWrapper) OnColumn(ctx context.Context, inBuffer []byte) (context.Context, []byte, error) { + if inBuffer == nil { + return ctx, inBuffer, nil + } // we should track that if incoming data contains any signs of new container and if it is we return data as is // otherwise try to search for AcraBlock or AcraStruct to save backward compatibility // so before any OnColumn we should reset hasMatchedEnvelope flag to track if its changed during EnvelopeDetector OnColumn via BackWrapper.OnCryptoEnvelope callback diff --git a/decryptor/mysql/response_proxy.go b/decryptor/mysql/response_proxy.go index 034608a15..57a3a07b6 100644 --- a/decryptor/mysql/response_proxy.go +++ b/decryptor/mysql/response_proxy.go @@ -513,8 +513,14 @@ func (handler *Handler) processTextDataRow(ctx context.Context, rowData []byte, if err != nil { return nil, err } - - decrCtx, value, err := handler.onColumnDecryption(ctx, i, value, false, fields[i]) + var decrCtx context.Context + // skip processing if value is NULL/nil + if value == nil { + output = append(output, rowData[pos:pos+n]...) + pos += n + continue + } + decrCtx, value, err = handler.onColumnDecryption(ctx, i, value, false, fields[i]) if err != nil { fieldLogger.WithField(logging.FieldKeyEventCode, logging.EventCodeErrorGeneral). WithError(err).Errorln("Failed to process column data") diff --git a/tests/test.py b/tests/test.py index c2f14f9bd..62ad241a7 100644 --- a/tests/test.py +++ b/tests/test.py @@ -5718,13 +5718,8 @@ def testEmptyValues(self): # check null values result = self.engine1.execute(sa.select([self.temp_table]).where(self.temp_table.c.id == null_value_id)) row = result.fetchone() - if TEST_MYSQL: - # PyMySQL returns empty strings for NULL values - self.assertEqual(row['text'], '') - self.assertEqual(row['binary'], b'') - else: - self.assertIsNone(row['text']) - self.assertIsNone(row['binary']) + self.assertIsNone(row['text']) + self.assertIsNone(row['binary']) # check empty values result = self.engine1.execute(sa.select([self.temp_table]).where(self.temp_table.c.id == empty_value_id)) @@ -8696,12 +8691,14 @@ def testClientIDRead(self): 'value_int64': 64, 'value_bytes': b'value_bytes', 'value_str': 'value_str', - 'value_empty_str': '' + 'value_empty_str': '', + 'value_null_str': None, + 'value_null_int32': None, } self.schema_table.create(bind=self.engine_raw, checkfirst=True) - columns = ('value_bytes', 'value_int32', 'value_int64', 'value_empty_str', 'value_str') - null_columns = ('value_null_str', 'value_null_int32') + columns = ('value_bytes', 'value_int32', 'value_int64', 'value_empty_str', 'value_str', 'value_null_str', + 'value_null_int32') self.engine1.execute(self.test_table.insert(), data) @@ -8714,18 +8711,12 @@ def testClientIDRead(self): self.assertEqual(data[column], row[column]) self.assertIsInstance(row[column], type(data[column])) - # mysql.connector represent null value as empty string - for column in null_columns: - self.assertEqual(row[column], '') row = self.executor2.execute(query)[0] for column in columns: self.assertEqual(row[column], default_expected_values[column]) self.assertIsInstance(row[column], type(default_expected_values[column])) - for column in null_columns: - self.assertEqual(row[column], '') - row = self.engine_raw.execute(sa.select([self.test_table]) .where(self.test_table.c.id == data['id'])).fetchone() for column in columns: @@ -8993,8 +8984,8 @@ def testClientIDRead(self): } self.schema_table.create(bind=self.engine_raw, checkfirst=True) self.engine1.execute(self.test_table.insert(), data) - columns = ('value_str', 'value_bytes', 'value_int32', 'value_int64', 'value_empty_str') - null_columns = ('value_null_str', 'value_null_int32') + columns = ('value_str', 'value_bytes', 'value_int32', 'value_int64', 'value_empty_str', 'value_null_str', + 'value_null_int32') compile_kwargs = {"literal_binds": True} query = sa.select([self.test_table]).where(self.test_table.c.id == data['id']) @@ -9005,10 +8996,6 @@ def testClientIDRead(self): self.assertEqual(data[column], row[column]) self.assertIsInstance(row[column], type(data[column])) - # mysql.connector represent null value as empty string - for column in null_columns: - self.assertEqual(row[column], '') - # field types should be rollbacked in case of invalid encoding row = self.executor2.execute(query)[0] @@ -9516,8 +9503,8 @@ def testClientIDRead(self): } self.schema_table.create(bind=self.engine_raw, checkfirst=True) self.engine1.execute(self.test_table.insert(), data) - columns = ('value_str', 'value_bytes', 'value_int32', 'value_int64', 'value_empty_str') - null_columns = ('value_null_str', 'value_null_int32') + columns = ('value_str', 'value_bytes', 'value_int32', 'value_int64', 'value_empty_str', 'value_null_str', + 'value_null_int32') compile_kwargs = {"literal_binds": True} query = sa.select([self.test_table]).where(self.test_table.c.id == data['id']) @@ -9528,10 +9515,6 @@ def testClientIDRead(self): self.assertEqual(data[column], row[column]) self.assertIsInstance(row[column], type(data[column])) - # mysql.connector represent null value as empty string - for column in null_columns: - self.assertEqual(row[column], '') - # field types should be rollbacked in case of invalid encoding row = self.executor2.execute(query)[0] @@ -9653,8 +9636,8 @@ def testClientIDRead(self): } self.schema_table.create(bind=self.engine_raw, checkfirst=True) self.engine1.execute(self.test_table.insert(), data) - columns = ('value_str', 'value_bytes', 'value_int32', 'value_int64', 'value_empty_str') - null_columns = ('value_null_str', 'value_null_int32') + columns = ('value_str', 'value_bytes', 'value_int32', 'value_int64', 'value_empty_str', 'value_null_str', + 'value_null_int32') compile_kwargs = {"literal_binds": True} query = sa.select([self.test_table]).where(self.test_table.c.id == data['id']) @@ -9665,10 +9648,6 @@ def testClientIDRead(self): self.assertEqual(data[column], row[column]) self.assertIsInstance(row[column], type(data[column])) - # mysql.connector represent null value as empty string - for column in null_columns: - self.assertEqual(row[column], '') - # we expect an exception because of decryption error with self.assertRaises(mysql.connector.errors.DatabaseError) as ex: self.executor2.execute(query)[0]