Skip to content

Commit

Permalink
Make MessageFactory static
Browse files Browse the repository at this point in the history
This change makes MessageFactory into a static class similar to ClassConfigurator. This greatly simplifies adding custom message types, and also allows custom types to be used in Batch/CompositeMessag which otherwise only allows the default types.
This comes with the side effect that it's no longer possible to change how the default message types are generated, but this must be a much less common use case than registering custom types. And as a workaround, a new custom type can be registered instead of altering the default types.
  • Loading branch information
cfredri4 authored and belaban committed Nov 14, 2024
1 parent 32998fc commit 84a36ca
Show file tree
Hide file tree
Showing 24 changed files with 86 additions and 195 deletions.
35 changes: 4 additions & 31 deletions doc/manual/api.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -269,48 +269,22 @@ to better performance.

JGroups 5.0 comes with a number of message types (see the next sections). If none of them are a fit for the application's
requirements, new message types can be defined and registered. To do this, the new message type needs to implement
`Message` (typically by subclassing `BaseMessage`) and registering it with the `MessageFactory` in the transport:
`Message` (typically by subclassing `BaseMessage`) and registering it with the `MessageFactory`:

[source,java]
----
CustomMessage msg=new CustomMessage(...);
JChannel ch;
TP transport=ch.getProtocolStack().getTransport();
MessageFactory mf=transport.getMessageFactory();
mf.register((short)12345, CustomMessage::new)
MessageFactory.register((short)12345, CustomMessage::new)
----

A (unique) ID has to be assigned with the message type, and then it has to be registered with the message factory
in the transport. This has to be done before sending an instance of the new message type.
If the ID has already been registered before, or is taken, an exception will be thrown.
Note that the default implementation of `MessageFactory` requires all IDs to be greater than 32, so that there's room
for adding built-in message types.
`MessageFactory` requires all IDs to be greater than 32, so that there's room for adding built-in message types.

NOTE: It is recommended to register all custom message types _before_ connecting the channel, so that potential errors
are detected early.

[[CustomMessageFactory]]
==== Custom `MessageFactory`
`MessageFactory` is a simple interface:

[source,java]
----
public interface MessageFactory {
<T extends Message> T create(short id);
void register(short type, Supplier<? extends Message> generator);
}
----
We saw the that the `register()` method is used to associate new message types with IDs <<MessageFactory,above>>.

There is a `DefaultMessageFactory` which is set in the transport (`TP`). If more control over the creation of custom
messages is desired, a custom implementation of `MessageFactory` can be written and registered in the transport, using
`TP.setMessageFactory(MessageFactory mf)`.

An example for why we might want to provide our own `MessageFactory` is that we have control over the creation of
messages; e.g. to create an `NioMessage` with a *direct* `ByteBuffer`, we may want to use a _pool_ of off-heap memory
rather than calling `ByteBuffer.allocateDirect()` for each message, which is slow.


[[BytesMessage]]
==== BytesMessage
This is the equivalent to the 4.x `Message`, and contains a byte array, offset and length. There are methods to get and
Expand Down Expand Up @@ -365,8 +339,7 @@ The methods of `NioMessage` are:
|==========================

NOTE: The envisioned use case for `useDirectMemory()` is when we send an `NioMessage` with a direct `ByteBuffer`, but
don't need the `ByteBuffer` to be created in off-heap memory at the receiver, when on-heap will do. +
The alternative is to provide a custom <<MessageFactory,`MessageFactory`>>.
don't need the `ByteBuffer` to be created in off-heap memory at the receiver, when on-heap will do.



Expand Down
5 changes: 1 addition & 4 deletions src/org/jgroups/BatchMessage.java
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

package org.jgroups;


Expand Down Expand Up @@ -32,8 +31,6 @@ public class BatchMessage extends BaseMessage implements Iterable<Message> {
protected Address orig_src;


protected static final MessageFactory mf=new DefaultMessageFactory();

public BatchMessage() {
}

Expand Down Expand Up @@ -155,7 +152,7 @@ public void readPayload(DataInput in) throws IOException, ClassNotFoundException
msgs=new Message[index]; // a bit of additional space should we add byte arrays
for(int i=0; i < index; i++) {
short type=in.readShort();
msgs[i]=mf.create(type).setDest(dest()).setSrc(orig_src);
msgs[i]=MessageFactory.create(type).setDest(dest()).setSrc(orig_src);
msgs[i].readFrom(in);
}
}
Expand Down
4 changes: 1 addition & 3 deletions src/org/jgroups/CompositeMessage.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ public class CompositeMessage extends BaseMessage implements Iterable<Message> {
protected boolean collapse; // send as a BytesMessage when true


protected static final MessageFactory mf=new DefaultMessageFactory();

public CompositeMessage() {
}

Expand Down Expand Up @@ -147,7 +145,7 @@ public void readPayload(DataInput in) throws IOException, ClassNotFoundException
msgs=new Message[index]; // a bit of additional space should we add byte arrays
for(int i=0; i < index; i++) {
short type=in.readShort();
msgs[i]=mf.create(type);
msgs[i]=MessageFactory.create(type);
msgs[i].readFrom(in);
}
}
Expand Down
46 changes: 0 additions & 46 deletions src/org/jgroups/DefaultMessageFactory.java

This file was deleted.

41 changes: 35 additions & 6 deletions src/org/jgroups/MessageFactory.java
Original file line number Diff line number Diff line change
@@ -1,27 +1,56 @@
package org.jgroups;

import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.function.Supplier;

/**
* Factory to create messages. Uses an array for message IDs less then 32, and a hashmap for
* types above 32
* @author Bela Ban
* @since 5.0
*/
public interface MessageFactory {


public class MessageFactory {
protected static final byte MIN_TYPE=32;
protected static final Supplier<? extends Message>[] creators=new Supplier[MIN_TYPE];
protected static Map<Short,Supplier<? extends Message>> map=new HashMap<>();
static {
creators[Message.BYTES_MSG]=BytesMessage::new;
creators[Message.NIO_MSG]=NioMessage::new;
creators[Message.EMPTY_MSG]=EmptyMessage::new;
creators[Message.OBJ_MSG]=ObjectMessage::new;
creators[Message.LONG_MSG]=LongMessage::new;
creators[Message.COMPOSITE_MSG]=CompositeMessage::new;
creators[Message.FRAG_MSG]=FragmentedMessage::new;
creators[Message.EARLYBATCH_MSG]=BatchMessage::new;
}

/**
* Creates a message based on the given ID
* @param id The ID
* @param type The ID
* @param <T> The type of the message
* @return A message
*/
<T extends Message> T create(short id);
public static <T extends Message> T create(short type) {
Supplier<? extends Message> creator=type < MIN_TYPE? creators[type] : map.get(type);
if(creator == null)
throw new IllegalArgumentException("no creator found for type " + type);
return (T)creator.get();
}

/**
* Registers a new creator of messages
* @param type The type associated with the new payload. Needs to be the same in all nodes of the same cluster, and
* needs to be available (ie., not taken by JGroups or other applications).
* @param generator The creator of the payload associated with the given type
*/
<M extends MessageFactory> M register(short type, Supplier<? extends Message> generator);
public static void register(short type, Supplier<? extends Message> generator) {
Objects.requireNonNull(generator, "the creator must be non-null");
if(type < MIN_TYPE)
throw new IllegalArgumentException(String.format("type (%d) must be >= 32", type));
if(map.containsKey(type))
throw new IllegalArgumentException(String.format("type %d is already taken", type));
map.put(type, generator);
}
}
8 changes: 3 additions & 5 deletions src/org/jgroups/protocols/COMPRESS.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ public class COMPRESS extends Protocol {

protected BlockingQueue<Deflater> deflater_pool;
protected BlockingQueue<Inflater> inflater_pool;
protected MessageFactory msg_factory;
protected final LongAdder num_compressions=new LongAdder(), num_decompressions=new LongAdder();


Expand Down Expand Up @@ -77,7 +76,6 @@ public void init() throws Exception {
inflater_pool=new ArrayBlockingQueue<>(pool_size);
for(int i=0; i < pool_size; i++)
inflater_pool.add(new Inflater());
msg_factory=getTransport().getMessageFactory();
}

public void destroy() {
Expand Down Expand Up @@ -192,7 +190,7 @@ protected Message uncompress(Message msg, int original_size, boolean needs_deser
inflater.inflate(uncompressed_payload);
// we need to copy: https://issues.redhat.com/browse/JGRP-867
if(needs_deserialization) {
return messageFromByteArray(uncompressed_payload, msg_factory);
return messageFromByteArray(uncompressed_payload);
}
else
return msg.copy(false, true).setArray(uncompressed_payload, 0, uncompressed_payload.length);
Expand Down Expand Up @@ -221,9 +219,9 @@ protected static ByteArray messageToByteArray(Message msg) {
}
}

protected static Message messageFromByteArray(byte[] uncompressed_payload, MessageFactory msg_factory) {
protected static Message messageFromByteArray(byte[] uncompressed_payload) {
try {
return Util.messageFromBuffer(uncompressed_payload, 0, uncompressed_payload.length, msg_factory);
return Util.messageFromBuffer(uncompressed_payload, 0, uncompressed_payload.length);
}
catch(Exception ex) {
throw new RuntimeException("failed unmarshalling message", ex);
Expand Down
7 changes: 1 addition & 6 deletions src/org/jgroups/protocols/Encrypt.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ public abstract class Encrypt<E extends KeyStore.Entry> extends Protocol {
// SecureRandom instance for generating IV's
protected SecureRandom secure_random = new SecureRandom();

protected MessageFactory msg_factory;


/**
* Sets the key store entry used to configure this protocol.
Expand All @@ -96,7 +94,6 @@ public abstract class Encrypt<E extends KeyStore.Entry> extends Protocol {
public SecureRandom secureRandom() {return this.secure_random;}
/** Allows callers to replace secure_random with impl of their choice, e.g. for performance reasons. */
public <T extends Encrypt<E>> T secureRandom(SecureRandom sr) {this.secure_random = sr; return (T)this;}
public <T extends Encrypt<E>> T msgFactory(MessageFactory f) {this.msg_factory=f; return (T)this;}
@ManagedAttribute public String version() {return Util.byteArrayToHexString(sym_version);}


Expand All @@ -116,8 +113,6 @@ public void init() throws Exception {
key_map=new BoundedHashMap<>(key_map_max_size);
initSymCiphers(sym_algorithm, secret_key);
TP transport=getTransport();
if(transport != null)
msg_factory=transport.getMessageFactory();
}


Expand Down Expand Up @@ -322,7 +317,7 @@ protected Message _decrypt(final Cipher cipher, Key key, Message msg, EncryptHea
decrypted_msg=cipher.doFinal(msg.getArray(), msg.getOffset(), msg.getLength());
}
if(hdr.needsDeserialization())
return Util.messageFromBuffer(decrypted_msg, 0, decrypted_msg.length, msg_factory);
return Util.messageFromBuffer(decrypted_msg, 0, decrypted_msg.length);
else
return msg.setArray(decrypted_msg, 0, decrypted_msg.length);
}
Expand Down
4 changes: 1 addition & 3 deletions src/org/jgroups/protocols/FRAG.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ public class FRAG extends Fragmentation {
protected final FragmentationList fragment_list=new FragmentationList();
protected final AtomicInteger curr_id=new AtomicInteger(1);
protected final List<Address> members=new ArrayList<>(11);
protected MessageFactory msg_factory;
protected final Predicate<Message> HAS_FRAG_HEADER=msg -> msg.getHeader(id) != null;


Expand All @@ -58,7 +57,6 @@ public class FRAG extends Fragmentation {

public void init() throws Exception {
super.init();
msg_factory=getTransport().getMessageFactory();
Map<String,Object> info=new HashMap<>(1);
info.put("frag_size", frag_size);
down_prot.down(new Event(Event.CONFIG, info));
Expand Down Expand Up @@ -229,7 +227,7 @@ private Message unfragment(Message msg, FragHeader hdr) {
return null;

try {
Message assembled_msg=Util.messageFromBuffer(buf, 0, buf.length, msg_factory);
Message assembled_msg=Util.messageFromBuffer(buf, 0, buf.length);
assembled_msg.setSrc(sender); // needed ? YES, because fragments have a null src !!
if(log.isTraceEnabled()) log.trace("assembled_msg is " + assembled_msg);
num_received_msgs++;
Expand Down
10 changes: 3 additions & 7 deletions src/org/jgroups/protocols/FRAG2.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ public class FRAG2 extends Fragmentation {
protected final AtomicLong curr_id=new AtomicLong(1);

protected final List<Address> members=new ArrayList<>(11);
protected MessageFactory msg_factory;

protected final AverageMinMax avg_size_down=new AverageMinMax();
protected final AverageMinMax avg_size_up=new AverageMinMax();
Expand Down Expand Up @@ -75,7 +74,6 @@ public void init() throws Exception {
throw new IllegalArgumentException("frag_size (" + frag_size + ") has to be < TP.max_bundle_size (" +
max_bundle_size + ")");
}
msg_factory=transport.getMessageFactory();
Map<String,Object> info=new HashMap<>(1);
info.put("frag_size", frag_size);
down_prot.down(new Event(Event.CONFIG, info));
Expand Down Expand Up @@ -261,7 +259,7 @@ protected Message unfragment(Message msg, FragHeader hdr) {

FragEntry entry=frag_table.get(hdr.id);
if(entry == null) {
entry=new FragEntry(hdr.num_frags, hdr.needs_deserialization, msg_factory);
entry=new FragEntry(hdr.num_frags, hdr.needs_deserialization);
FragEntry tmp=frag_table.putIfAbsent(hdr.id, entry);
if(tmp != null)
entry=tmp;
Expand Down Expand Up @@ -314,7 +312,7 @@ protected Message assembleMessage(Message[] fragments, boolean needs_deserializa
index+=length;
}
if(needs_deserialization)
retval=Util.messageFromBuffer(combined_buffer, 0, combined_buffer.length, msg_factory);
retval=Util.messageFromBuffer(combined_buffer, 0, combined_buffer.length);
else
retval.setArray(combined_buffer, 0, combined_buffer.length);
return retval;
Expand All @@ -332,18 +330,16 @@ protected static class FragEntry {
protected final Message[] fragments;
protected int number_of_frags_recvd;
protected final boolean needs_deserialization;
protected final MessageFactory msg_factory;
protected final Lock lock=new ReentrantLock();


/**
* Creates a new entry
* @param tot_frags the number of fragments to expect for this message
*/
protected FragEntry(int tot_frags, boolean needs_deserialization, MessageFactory mf) {
protected FragEntry(int tot_frags, boolean needs_deserialization) {
fragments=new Message[tot_frags];
this.needs_deserialization=needs_deserialization;
this.msg_factory=mf;
}


Expand Down
5 changes: 1 addition & 4 deletions src/org/jgroups/protocols/FRAG3.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ public class FRAG3 extends Fragmentation {

protected final List<Address> members=new ArrayList<>(11);

protected MessageFactory msg_factory;

protected final AverageMinMax avg_size_down=new AverageMinMax();
protected final AverageMinMax avg_size_up=new AverageMinMax();

Expand All @@ -70,7 +68,6 @@ public void init() throws Exception {
if(frag_size >= max_bundle_size)
throw new IllegalArgumentException("frag_size (" + frag_size + ") has to be < TP.max_bundle_size (" +
max_bundle_size + ")");
msg_factory=transport.getMessageFactory();
Map<String,Object> info=new HashMap<>(1);
info.put("frag_size", frag_size);
down_prot.down(new Event(Event.CONFIG, info));
Expand Down Expand Up @@ -351,7 +348,7 @@ protected boolean isComplete() {
* @return the complete message in one buffer
*/
protected Message assembleMessage() throws Exception {
return needs_deserialization? Util.messageFromBuffer(buffer, 0, buffer.length, msg_factory)
return needs_deserialization? Util.messageFromBuffer(buffer, 0, buffer.length)
: msg.setArray(buffer, 0, buffer.length);
}

Expand Down
Loading

0 comments on commit 84a36ca

Please sign in to comment.