Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Streaming Replication 7th] Replication Mode Messages Handling #58

Merged
merged 8 commits into from
Sep 10, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/postgres.dart
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ library postgres;
export 'src/connection.dart';
export 'src/execution_context.dart';
export 'src/models.dart';
export 'src/replication.dart' show ReplicationMode;
export 'src/substituter.dart';
export 'src/types.dart';
93 changes: 70 additions & 23 deletions lib/src/logical_replication_messages.dart
Original file line number Diff line number Diff line change
Expand Up @@ -15,63 +15,110 @@ import 'types.dart';
abstract class LogicalReplicationMessage
implements ReplicationMessage, ServerMessage {}

class XLogDataLogicalMessage extends XLogDataMessage {
late final LogicalReplicationMessage _message;
class XLogDataLogicalMessage implements XLogDataMessage {
@override
// this and others are `late` to comply with super class
late final Uint8List bytes;

@override
LogicalReplicationMessage get data => _message;
late final DateTime time;

XLogDataLogicalMessage(Uint8List bytes) : super(bytes) {
_message = parseLogicalReplicationMessage(this.bytes);
}
@override
late final LSN walEnd;

@override
late final LSN walStart;

@override
final LSN walDataLength;

late final LogicalReplicationMessage message;

@override
LogicalReplicationMessage get data => message;

XLogDataLogicalMessage({
required this.message,
required this.bytes,
required this.time,
required this.walEnd,
required this.walStart,
required this.walDataLength,
});

@override
String toString() => super.toString();
}

LogicalReplicationMessage parseLogicalReplicationMessage(Uint8List bytesList) {
/// Tries to check if the [XLogDataMessage.bytes] is a [LogicalReplicationMessage]
/// If so, it'll return [XLogDataLogicalMessage], otherwise it reutnrs [message]
XLogDataMessage tryParseLogicalReplicationMessage(XLogDataMessage message) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my experience tryX usually returns T?: either T if successful, or null if not. This method seems to be better places on the XLogDataMessage (either directly or through an extension method) and named cast or castToLogicalReplication, especially if those message have a LogicalReplicationMessage interface or superclass. wdyt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch and I totally agree with you.

I added XLogDataMessage.parse and made tryParseLogicalReplicationMessage returns LogicalReplicationMessage?. Let me know if it looks better.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, thank you!

// take the message bytes to check if it's a replication message
final bytesList = message.bytes;
// the first byte is the msg type
final firstByte = bytesList.first;
final msgType = LogicalReplicationMessageTypes.fromByte(firstByte);
// remaining bytes are the data
// remaining bytes are the data
LogicalReplicationMessage? logicalMessage;
final bytes = bytesList.sublist(1);

switch (msgType) {
case LogicalReplicationMessageTypes.Begin:
return BeginMessage(bytes);
logicalMessage = BeginMessage(bytes);
break;
case LogicalReplicationMessageTypes.Commit:
return CommitMessage(bytes);
logicalMessage = CommitMessage(bytes);
break;
case LogicalReplicationMessageTypes.Origin:
return OriginMessage(bytes);
logicalMessage = OriginMessage(bytes);
break;
case LogicalReplicationMessageTypes.Relation:
return RelationMessage(bytes);
logicalMessage = RelationMessage(bytes);
break;
case LogicalReplicationMessageTypes.Type:
return TypeMessage(bytes);
logicalMessage = TypeMessage(bytes);
break;
case LogicalReplicationMessageTypes.Insert:
return InsertMessage(bytes);
logicalMessage = InsertMessage(bytes);
break;
case LogicalReplicationMessageTypes.Update:
return UpdateMessage(bytes);
logicalMessage = UpdateMessage(bytes);
break;
case LogicalReplicationMessageTypes.Delete:
return DeleteMessage(bytes);
logicalMessage = DeleteMessage(bytes);
break;
case LogicalReplicationMessageTypes.Truncate:
return TruncateMessage(bytes);
logicalMessage = TruncateMessage(bytes);
break;
case LogicalReplicationMessageTypes.Unsupported:
default:
return _parseJsonMessageOrReturnUnknownMessage(bytes);
// note this needs the full set of bytes unlike other cases
logicalMessage = _tryParseJsonMessage(bytesList);
break;
}
if (logicalMessage != null) {
return XLogDataLogicalMessage(
message: logicalMessage,
bytes: message.bytes,
time: message.time,
walEnd: message.walEnd,
walStart: message.walStart,
walDataLength: message.walDataLength,
);
} else {
return message;
}
}

LogicalReplicationMessage _parseJsonMessageOrReturnUnknownMessage(
Uint8List bytes) {
LogicalReplicationMessage? _tryParseJsonMessage(Uint8List bytes) {
// wal2json messages starts with `{` as the first byte
if (bytes.first == '{'.codeUnits.first) {
try {
return JsonMessage(utf8.decode(bytes));
} catch (e) {
return UnknownLogicalReplicationMessage(bytes);
return null;
}
}
return UnknownLogicalReplicationMessage(bytes);
return null;
}

enum LogicalReplicationMessageTypes {
Expand Down
48 changes: 43 additions & 5 deletions lib/src/message_window.dart
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ import 'dart:typed_data';

import 'package:buffer/buffer.dart';

import 'logical_replication_messages.dart';
import 'server_messages.dart';
import 'shared_messages.dart';

const int _headerByteSize = 5;
final _emptyData = Uint8List(0);
Expand All @@ -21,7 +23,9 @@ Map<int, _ServerMessageFn> _messageTypeMap = {
82: (d) => AuthenticationMessage(d),
83: (d) => ParameterStatusMessage(d),
84: (d) => RowDescriptionMessage(d),
87: (d) => CopyBothResponseMessage(d),
90: (d) => ReadyForQueryMessage(d),
100: (d) => CopyDataMessage(d),
110: (d) => NoDataMessage(),
116: (d) => ParameterDescriptionMessage(d),
};
Expand Down Expand Up @@ -51,20 +55,54 @@ class MessageFramer {
_expectedLength = _reader.readUint32() - 4;
}

if (_hasReadHeader && _isComplete) {
// special case
if (_type == SharedMessages.copyDoneIdentifier) {
// unlike other messages, CopyDoneMessage only takes the length as an
// argument (must be the full length including the length bytes)
final msg = CopyDoneMessage(_expectedLength + 4);
_addMsg(msg);
evaluateNextMessage = true;
} else if (_hasReadHeader && _isComplete) {
final data =
_expectedLength == 0 ? _emptyData : _reader.read(_expectedLength);
final msgMaker = _messageTypeMap[_type];
final msg =
var msg =
msgMaker == null ? UnknownMessage(_type, data) : msgMaker(data);
messageQueue.add(msg);
_type = null;
_expectedLength = 0;

// Copy Data message is a wrapper around data stream messages
// such as replication messages.
if (msg is CopyDataMessage) {
// checks if it's a replication message, otherwise returns given msg
msg = _extractReplicationMessageIfAny(msg);
}

_addMsg(msg);
evaluateNextMessage = true;
}
}
}

void _addMsg(ServerMessage msg) {
messageQueue.add(msg);
_type = null;
_expectedLength = 0;
}

/// Returns a [ReplicationMessage] if the [CopyDataMessage] contains such message.
/// Otherwise, it'll just return the provided [copyData].
ServerMessage _extractReplicationMessageIfAny(CopyDataMessage copyData) {
final bytes = copyData.bytes;
final code = bytes.first;
final data = bytes.sublist(1);
if (code == ReplicationMessage.primaryKeepAliveIdentifier) {
return PrimaryKeepAliveMessage(data);
} else if (code == ReplicationMessage.xLogDataIdentifier) {
return tryParseLogicalReplicationMessage(XLogDataMessage(data));
} else {
return copyData;
}
}

bool get hasMessage => messageQueue.isNotEmpty;

ServerMessage popMessage() {
Expand Down
135 changes: 135 additions & 0 deletions test/framer_test.dart
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import 'dart:typed_data';

import 'package:buffer/buffer.dart';
import 'package:postgres/src/logical_replication_messages.dart';
import 'package:postgres/src/message_window.dart';
import 'package:postgres/src/server_messages.dart';
import 'package:postgres/src/shared_messages.dart';
import 'package:test/test.dart';

void main() {
Expand Down Expand Up @@ -173,6 +175,139 @@ void main() {
final messages = framer.messageQueue.toList();
expect(messages, [UnknownMessage(10, Uint8List(0))]);
});

test('Identify CopyDoneMessage with length equals size length (min)', () {
// min length
final length = [0, 0, 0, 4]; // min length (4 bytes) as 32-bit
final bytes = Uint8List.fromList([
SharedMessages.copyDoneIdentifier,
...length,
]);
framer.addBytes(bytes);

final message = framer.messageQueue.toList().first;
expect(message, isA<CopyDoneMessage>());
expect((message as CopyDoneMessage).length, 4);
});

test('Identify CopyDoneMessage when length larger than size length', () {
final length = (ByteData(4)..setUint32(0, 42)).buffer.asUint8List();
final bytes = Uint8List.fromList([
SharedMessages.copyDoneIdentifier,
...length,
]);
framer.addBytes(bytes);

final message = framer.messageQueue.toList().first;
expect(message, isA<CopyDoneMessage>());
expect((message as CopyDoneMessage).length, 42);
});

test('Adds XLogDataMessage to queue', () {
final bits64 = (ByteData(8)..setUint64(0, 42)).buffer.asUint8List();
// random data bytes
final dataBytes = [1, 2, 3, 4, 5, 6, 7, 8];

/// This represent a raw [XLogDataMessage]
final xlogDataMessage = <int>[
ReplicationMessage.xLogDataIdentifier,
...bits64, // walStart (64bit)
...bits64, // walEnd (64bit)
...bits64, // time (64bit)
...dataBytes // bytes (any)
];
final length = ByteData(4)..setUint32(0, xlogDataMessage.length + 4);

// this represents the [CopyDataMessage] which is a wrapper for [XLogDataMessage]
// and such
final copyDataBytes = <int>[
SharedMessages.copyDataIdentifier,
...length.buffer.asUint8List(),
...xlogDataMessage,
];

framer.addBytes(Uint8List.fromList(copyDataBytes));
final message = framer.messageQueue.toList().first;
expect(message, isA<XLogDataMessage>());
expect(message, isNot(isA<XLogDataLogicalMessage>()));
});

test('Adds XLogDataLogicalMessage with JsonMessage to queue', () {
final bits64 = (ByteData(8)..setUint64(0, 42)).buffer.asUint8List();

/// represent an empty json object so we should get a XLogDataLogicalMessage
/// with JsonMessage as its message.
final dataBytes = '{}'.codeUnits;

/// This represent a raw [XLogDataMessage]
final xlogDataMessage = <int>[
ReplicationMessage.xLogDataIdentifier,
...bits64, // walStart (64bit)
...bits64, // walEnd (64bit)
...bits64, // time (64bit)
...dataBytes, // bytes (any)
];

final length = ByteData(4)..setUint32(0, xlogDataMessage.length + 4);

/// this represents the [CopyDataMessage] in which [XLogDataMessage]
/// is delivered per protocol
final copyDataMessage = <int>[
SharedMessages.copyDataIdentifier,
...length.buffer.asUint8List(),
...xlogDataMessage,
];

framer.addBytes(Uint8List.fromList(copyDataMessage));
final message = framer.messageQueue.toList().first;
expect(message, isA<XLogDataLogicalMessage>());
expect((message as XLogDataLogicalMessage).message, isA<JsonMessage>());
});

test('Adds PrimaryKeepAliveMessage to queue', () {
final bits64 = (ByteData(8)..setUint64(0, 42)).buffer.asUint8List();

/// This represent a raw [PrimaryKeepAliveMessage]
final xlogDataMessage = <int>[
ReplicationMessage.primaryKeepAliveIdentifier,
...bits64, // walEnd (64bits)
...bits64, // time (64bits)
0, // mustReply (1bit)
];
final length = ByteData(4)..setUint32(0, xlogDataMessage.length + 4);

/// This represents the [CopyDataMessage] in which [PrimaryKeepAliveMessage]
/// is delivered per protocol
final copyDataMessage = <int>[
SharedMessages.copyDataIdentifier,
...length.buffer.asUint8List(),
...xlogDataMessage,
];

framer.addBytes(Uint8List.fromList(copyDataMessage));
final message = framer.messageQueue.toList().first;
expect(message, isA<PrimaryKeepAliveMessage>());
});

test('Adds raw CopyDataMessage for unknown stream message', () {
final xlogDataBytes = <int>[
-1, // unknown id
];

final length = ByteData(4)..setUint32(0, xlogDataBytes.length + 4);

/// This represents the [CopyDataMessage] in which data is delivered per protocol
/// typically contains [XLogData] and such but this tests unknown content
final copyDataMessage = <int>[
SharedMessages.copyDataIdentifier,
...length.buffer.asUint8List(),
...xlogDataBytes,
];

framer.addBytes(Uint8List.fromList(copyDataMessage));
final message = framer.messageQueue.toList().first;
expect(message, isA<CopyDataMessage>());
});
}

List<int> messageWithBytes(List<int> bytes, int messageID) {
Expand Down
Loading