diff --git a/bifromq-dist/bifromq-dist-client/src/main/java/com/baidu/bifromq/dist/client/DistClient.java b/bifromq-dist/bifromq-dist-client/src/main/java/com/baidu/bifromq/dist/client/DistClient.java index 0d06296bc..de953b121 100644 --- a/bifromq-dist/bifromq-dist-client/src/main/java/com/baidu/bifromq/dist/client/DistClient.java +++ b/bifromq-dist/bifromq-dist-client/src/main/java/com/baidu/bifromq/dist/client/DistClient.java @@ -14,6 +14,7 @@ package com.baidu.bifromq.dist.client; import com.baidu.bifromq.baserpc.IRPCClient; +import com.baidu.bifromq.basescheduler.exception.BackPressureException; import com.baidu.bifromq.dist.RPCBluePrint; import com.baidu.bifromq.dist.client.scheduler.DistServerCall; import com.baidu.bifromq.dist.client.scheduler.DistServerCallScheduler; @@ -55,7 +56,11 @@ public CompletableFuture pub(long reqId, String topic, Message messa return reqScheduler.schedule(new DistServerCall(publisher, topic, message)) .exceptionally(e -> { log.debug("Failed to pub", e); - return DistResult.ERROR; + if (e instanceof BackPressureException || e.getCause() instanceof BackPressureException) { + return DistResult.BACK_PRESSURE_REJECTED; + } else { + return DistResult.ERROR; + } }); } diff --git a/bifromq-dist/bifromq-dist-coproc-proto/src/main/java/com/baidu/bifromq/dist/util/MessageUtil.java b/bifromq-dist/bifromq-dist-coproc-proto/src/main/java/com/baidu/bifromq/dist/util/MessageUtil.java deleted file mode 100644 index 0a3dd34a5..000000000 --- a/bifromq-dist/bifromq-dist-coproc-proto/src/main/java/com/baidu/bifromq/dist/util/MessageUtil.java +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Copyright (c) 2023. The BifroMQ Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * http://www.apache.org/licenses/LICENSE-2.0 - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and limitations under the License. - */ - -package com.baidu.bifromq.dist.util; - -import com.baidu.bifromq.dist.rpc.proto.BatchDistRequest; -import com.baidu.bifromq.dist.rpc.proto.DistServiceROCoProcInput; - -public class MessageUtil { - - public static DistServiceROCoProcInput buildBatchDistRequest(BatchDistRequest request) { - return DistServiceROCoProcInput.newBuilder() - .setBatchDist(request) - .build(); - } -} diff --git a/bifromq-dist/bifromq-dist-coproc-proto/src/main/proto/distservice/DistCoProc.proto b/bifromq-dist/bifromq-dist-coproc-proto/src/main/proto/distservice/DistCoProc.proto index b279e8491..a6b51b3fd 100644 --- a/bifromq-dist/bifromq-dist-coproc-proto/src/main/proto/distservice/DistCoProc.proto +++ b/bifromq-dist/bifromq-dist-coproc-proto/src/main/proto/distservice/DistCoProc.proto @@ -50,17 +50,40 @@ message DistPack{ repeated commontype.TopicMessagePack msgPack = 2; // topic messages packs shares same tenantId } +// deprecate since 3.1.0, will be removed in 5.0.0 message BatchDistRequest { uint64 reqId = 1; repeated DistPack distPack = 2; // sorted by tenantId and topic string orderKey = 3; } +// deprecate since 3.1.0, will be removed in 5.0.0 message BatchDistReply { uint64 reqId = 1; map result = 2; } +message TenantDistRequest{ + uint64 reqId = 1; + string tenantId = 2; + repeated commontype.TopicMessagePack msgPack = 3; + string orderKey = 4; +} + +message TenantDistReply { + enum Code{ + OK = 0; + ERROR = 1; + } + + message Result{ + Code code = 1; + uint32 fanout = 2; + } + uint64 reqId = 1; + map results = 2; // key: topic +} + message DistServiceRWCoProcInput{ oneof type{ BatchMatchRequest batchMatch = 1; @@ -78,12 +101,14 @@ message DistServiceRWCoProcOutput{ message DistServiceROCoProcInput{ oneof Input{ BatchDistRequest batchDist = 1; + TenantDistRequest tenantDist = 2; } } message DistServiceROCoProcOutput{ oneof Output{ BatchDistReply batchDist = 1; + TenantDistReply tenantDist = 2; } } diff --git a/bifromq-dist/bifromq-dist-server/src/main/java/com/baidu/bifromq/dist/server/DistResponsePipeline.java b/bifromq-dist/bifromq-dist-server/src/main/java/com/baidu/bifromq/dist/server/DistResponsePipeline.java index c4fe7e449..5db20f311 100644 --- a/bifromq-dist/bifromq-dist-server/src/main/java/com/baidu/bifromq/dist/server/DistResponsePipeline.java +++ b/bifromq-dist/bifromq-dist-server/src/main/java/com/baidu/bifromq/dist/server/DistResponsePipeline.java @@ -56,7 +56,7 @@ class DistResponsePipeline extends ResponsePipeline { protected CompletableFuture handleRequest(String tenantId, DistRequest request) { return distCallScheduler.schedule(new DistWorkerCall(tenantId, request.getMessagesList(), callQueueIdx, tenantFanouts.get(tenantId).estimate())) - .handle((v, e) -> { + .handle((fanOutByTopic, e) -> { DistReply.Builder replyBuilder = DistReply.newBuilder().setReqId(request.getReqId()); if (e != null) { if (e instanceof BackPressureException || e.getCause() instanceof BackPressureException) { @@ -85,20 +85,22 @@ protected CompletableFuture handleRequest(String tenantId, DistReques .code(RPC_FAILURE)); } } else { - tenantFanouts.get(tenantId).log(v.values().stream().reduce(0, Integer::sum) / v.size()); + // TODO: exclude fanout = -1? + tenantFanouts.get(tenantId).log(fanOutByTopic.values().stream().reduce(0, Integer::sum) / + fanOutByTopic.size()); for (PublisherMessagePack publisherMsgPack : request.getMessagesList()) { DistReply.Result.Builder resultBuilder = DistReply.Result.newBuilder(); for (PublisherMessagePack.TopicPack topicPack : publisherMsgPack.getMessagePackList()) { - int fanout = v.get(topicPack.getTopic()); - resultBuilder.putTopic(topicPack.getTopic(), - fanout > 0 ? DistReply.Code.OK : DistReply.Code.NO_MATCH); + int fanout = fanOutByTopic.get(topicPack.getTopic()); + resultBuilder.putTopic(topicPack.getTopic(), fanout > 0 ? DistReply.Code.OK : + (fanout == 0 ? DistReply.Code.NO_MATCH : DistReply.Code.ERROR)); } replyBuilder.addResults(resultBuilder.build()); } eventCollector.report(getLocal(Disted.class) .reqId(request.getReqId()) .messages(request.getMessagesList()) - .fanout(v.values().stream().reduce(0, Integer::sum))); + .fanout(fanOutByTopic.values().stream().reduce(0, Integer::sum))); } return replyBuilder.build(); }); diff --git a/bifromq-dist/bifromq-dist-server/src/main/java/com/baidu/bifromq/dist/server/scheduler/DistCallScheduler.java b/bifromq-dist/bifromq-dist-server/src/main/java/com/baidu/bifromq/dist/server/scheduler/DistCallScheduler.java index de70b19a6..9a9e6c4a7 100644 --- a/bifromq-dist/bifromq-dist-server/src/main/java/com/baidu/bifromq/dist/server/scheduler/DistCallScheduler.java +++ b/bifromq-dist/bifromq-dist-server/src/main/java/com/baidu/bifromq/dist/server/scheduler/DistCallScheduler.java @@ -15,7 +15,6 @@ import static com.baidu.bifromq.dist.entity.EntityUtil.matchRecordKeyPrefix; import static com.baidu.bifromq.dist.entity.EntityUtil.tenantUpperBound; -import static com.baidu.bifromq.dist.util.MessageUtil.buildBatchDistRequest; import static com.baidu.bifromq.sysprops.BifroMQSysProp.DATA_PLANE_BURST_LATENCY_MS; import static com.baidu.bifromq.sysprops.BifroMQSysProp.DATA_PLANE_TOLERABLE_LATENCY_MS; import static com.baidu.bifromq.sysprops.BifroMQSysProp.DIST_WORKER_FANOUT_SPLIT_THRESHOLD; @@ -32,9 +31,9 @@ import com.baidu.bifromq.basescheduler.CallTask; import com.baidu.bifromq.basescheduler.IBatchCall; import com.baidu.bifromq.basescheduler.ICallScheduler; -import com.baidu.bifromq.dist.rpc.proto.BatchDistReply; -import com.baidu.bifromq.dist.rpc.proto.BatchDistRequest; -import com.baidu.bifromq.dist.rpc.proto.DistPack; +import com.baidu.bifromq.dist.rpc.proto.DistServiceROCoProcInput; +import com.baidu.bifromq.dist.rpc.proto.TenantDistReply; +import com.baidu.bifromq.dist.rpc.proto.TenantDistRequest; import com.baidu.bifromq.type.ClientInfo; import com.baidu.bifromq.type.Message; import com.baidu.bifromq.type.TopicMessagePack; @@ -53,8 +52,12 @@ import lombok.extern.slf4j.Slf4j; @Slf4j -public class DistCallScheduler extends BatchCallScheduler, Integer> +public class DistCallScheduler + extends BatchCallScheduler, DistCallScheduler.BatcherKey> implements IDistCallScheduler { + public record BatcherKey(String tenantId, Integer callQueueIdx) { + } + private final IBaseKVStoreClient distWorkerClient; private final Function tenantFanoutGetter; private final int fanoutSplitThreshold = DIST_WORKER_FANOUT_SPLIT_THRESHOLD.get(); @@ -69,10 +72,10 @@ public DistCallScheduler(ICallScheduler reqScheduler, } @Override - protected Batcher, Integer> newBatcher(String name, - long tolerableLatencyNanos, - long burstLatencyNanos, - Integer batchKey) { + protected Batcher, BatcherKey> newBatcher(String name, + long tolerableLatencyNanos, + long burstLatencyNanos, + BatcherKey batchKey) { return new DistWorkerCallBatcher(batchKey, name, tolerableLatencyNanos, burstLatencyNanos, fanoutSplitThreshold, distWorkerClient, @@ -80,17 +83,19 @@ protected Batcher, Integer> newBatcher(Stri } @Override - protected Optional find(DistWorkerCall request) { - return Optional.of(request.callQueueIdx); + protected Optional find(DistWorkerCall request) { + return Optional.of(new BatcherKey(request.tenantId, request.callQueueIdx)); } - private static class DistWorkerCallBatcher extends Batcher, Integer> { + private static class DistWorkerCallBatcher + extends Batcher, DistCallScheduler.BatcherKey> { private final IBaseKVStoreClient distWorkerClient; private final String orderKey = UUID.randomUUID().toString(); private final Function tenantFanoutGetter; private final int fanoutSplitThreshold; - protected DistWorkerCallBatcher(Integer batcherKey, String name, + protected DistWorkerCallBatcher(DistCallScheduler.BatcherKey batcherKey, + String name, long tolerableLatencyNanos, long burstLatencyNanos, int fanoutSplitThreshold, @@ -103,13 +108,14 @@ protected DistWorkerCallBatcher(Integer batcherKey, String name, } @Override - protected IBatchCall, Integer> newBatch() { + protected IBatchCall, DistCallScheduler.BatcherKey> newBatch() { return new BatchDistCall(); } - private class BatchDistCall implements IBatchCall, Integer> { - private final Queue, Integer>> tasks = new ArrayDeque<>(128); - private Map>>> batch = new HashMap<>(128); + private class BatchDistCall implements IBatchCall, BatcherKey> { + private final Queue, BatcherKey>> tasks = + new ArrayDeque<>(128); + private Map>> batch = new HashMap<>(128); @Override public void reset() { @@ -117,12 +123,10 @@ public void reset() { } @Override - public void add(CallTask, Integer> callTask) { - Map>> clientMsgsByTopic = - batch.computeIfAbsent(callTask.call.tenantId, k -> new HashMap<>()); + public void add(CallTask, BatcherKey> callTask) { callTask.call.publisherMsgPacks.forEach(senderMsgPack -> senderMsgPack.getMessagePackList().forEach(topicMsgs -> - clientMsgsByTopic.computeIfAbsent(topicMsgs.getTopic(), k -> new HashMap<>()) + batch.computeIfAbsent(topicMsgs.getTopic(), k -> new HashMap<>()) .compute(senderMsgPack.getPublisher(), (k, v) -> { if (v == null) { v = topicMsgs.getMessageList(); @@ -136,44 +140,42 @@ public void add(CallTask, Integer> callTask @Override public CompletableFuture execute() { - Map> distPacksByRangeReplica = new HashMap<>(); - batch.forEach((tenantId, topicMap) -> { - DistPack.Builder distPackBuilder = DistPack.newBuilder().setTenantId(tenantId); - topicMap.forEach((topic, senderMap) -> { - TopicMessagePack.Builder topicMsgPackBuilder = TopicMessagePack.newBuilder().setTopic(topic); - senderMap.forEach((sender, msgs) -> - topicMsgPackBuilder.addMessage(TopicMessagePack.PublisherPack - .newBuilder() - .setPublisher(sender) - .addAllMessage(msgs) - .build())); - distPackBuilder.addMsgPack(topicMsgPackBuilder.build()); - }); - DistPack distPack = distPackBuilder.build(); + String tenantId = batcherKey.tenantId(); + Map> msgPacksByRangeReplica = new HashMap<>(); + batch.forEach((topic, senderMap) -> { + TopicMessagePack.Builder topicMsgPackBuilder = TopicMessagePack.newBuilder().setTopic(topic); + senderMap.forEach((sender, msgs) -> + topicMsgPackBuilder.addMessage(TopicMessagePack.PublisherPack + .newBuilder() + .setPublisher(sender) + .addAllMessage(msgs) + .build())); + TopicMessagePack topicMsgPack = topicMsgPackBuilder.build(); int fanoutScale = tenantFanoutGetter.apply(tenantId); List ranges = distWorkerClient.findByBoundary(Boundary.newBuilder() .setStartKey(matchRecordKeyPrefix(tenantId)) .setEndKey(tenantUpperBound(tenantId)) .build()); if (fanoutScale > fanoutSplitThreshold) { - ranges.forEach(range -> distPacksByRangeReplica.computeIfAbsent( + ranges.forEach(range -> msgPacksByRangeReplica.computeIfAbsent( new KVRangeReplica(range.id, range.ver, range.leader), - k -> new LinkedList<>()).add(distPack)); + k -> new LinkedList<>()).add(topicMsgPack)); } else { - ranges.forEach(range -> distPacksByRangeReplica.computeIfAbsent( + ranges.forEach(range -> msgPacksByRangeReplica.computeIfAbsent( new KVRangeReplica(range.id, range.ver, range.randomReplica()), - k -> new LinkedList<>()).add(distPack)); + k -> new LinkedList<>()).add(topicMsgPack)); } }); long reqId = System.nanoTime(); @SuppressWarnings("unchecked") - CompletableFuture[] distReplyFutures = distPacksByRangeReplica.entrySet().stream() + CompletableFuture[] distReplyFutures = msgPacksByRangeReplica.entrySet().stream() .map(entry -> { KVRangeReplica rangeReplica = entry.getKey(); - BatchDistRequest batchDist = BatchDistRequest.newBuilder() + TenantDistRequest tenantDistRequest = TenantDistRequest.newBuilder() .setReqId(reqId) - .addAllDistPack(entry.getValue()) + .setTenantId(tenantId) + .addAllMsgPack(entry.getValue()) .setOrderKey(orderKey) .build(); return distWorkerClient.query(rangeReplica.storeId, KVRangeRORequest.newBuilder() @@ -181,16 +183,18 @@ public CompletableFuture execute() { .setVer(rangeReplica.ver) .setKvRangeId(rangeReplica.id) .setRoCoProc(ROCoProcInput.newBuilder() - .setDistService(buildBatchDistRequest(batchDist)) + .setDistService(DistServiceROCoProcInput.newBuilder() + .setTenantDist(tenantDistRequest) + .build()) .build()) - .build(), batchDist.getOrderKey()) + .build(), tenantDistRequest.getOrderKey()) .thenApply(v -> { if (v.getCode() == ReplyCode.Ok) { - BatchDistReply batchDistReply = v.getRoCoProcResult() + TenantDistReply tenantDistReply = v.getRoCoProcResult() .getDistService() - .getBatchDist(); - assert batchDistReply.getReqId() == reqId; - return batchDistReply; + .getTenantDist(); + assert tenantDistReply.getReqId() == reqId; + return tenantDistReply; } log.warn("Failed to exec ro co-proc[code={}]", v.getCode()); throw new RuntimeException("Failed to exec rw co-proc"); @@ -199,32 +203,34 @@ public CompletableFuture execute() { .toArray(CompletableFuture[]::new); return CompletableFuture.allOf(distReplyFutures) .handle((v, e) -> { - CallTask, Integer> task; + CallTask, DistCallScheduler.BatcherKey> task; if (e != null) { while ((task = tasks.poll()) != null) { task.callResult.completeExceptionally(e); } } else { // aggregate fanout from each reply - Map> topicFanoutByTenant = new HashMap<>(); - for (CompletableFuture replyFuture : distReplyFutures) { - BatchDistReply reply = replyFuture.join(); - reply.getResultMap().forEach((tenantId, topicFanout) -> { - topicFanoutByTenant.computeIfAbsent(tenantId, k -> new HashMap<>()); - topicFanout.getFanoutMap() - .forEach((topic, fanout) -> topicFanoutByTenant.get(tenantId) - .compute(topic, (k, val) -> { - if (val == null) { - val = 0; + Map allTopicFanouts = new HashMap<>(); + for (CompletableFuture replyFuture : distReplyFutures) { + TenantDistReply reply = replyFuture.join(); + reply.getResultsMap() + .forEach((topic, result) -> allTopicFanouts.compute(topic, (k, f) -> { + if (f == null) { + f = 0; + } + switch (result.getCode()) { + case OK -> { + if (f >= 0) { + f += result.getFanout(); } - val += fanout; - return val; - })); - }); + } + // -1 stands for dist error + case ERROR -> f = -1; + } + return f; + })); } while ((task = tasks.poll()) != null) { - Map allTopicFanouts = - topicFanoutByTenant.get(task.call.tenantId); Map topicFanouts = new HashMap<>(); task.call.publisherMsgPacks.forEach(clientMessagePack -> clientMessagePack.getMessagePackList().forEach(topicMessagePack -> diff --git a/bifromq-dist/bifromq-dist-worker/src/main/java/com/baidu/bifromq/dist/worker/DeliverExecutor.java b/bifromq-dist/bifromq-dist-worker/src/main/java/com/baidu/bifromq/dist/worker/DeliverExecutor.java index ca22f5600..26a63c387 100644 --- a/bifromq-dist/bifromq-dist-worker/src/main/java/com/baidu/bifromq/dist/worker/DeliverExecutor.java +++ b/bifromq-dist/bifromq-dist-worker/src/main/java/com/baidu/bifromq/dist/worker/DeliverExecutor.java @@ -14,6 +14,7 @@ package com.baidu.bifromq.dist.worker; import static com.baidu.bifromq.plugin.eventcollector.ThreadLocalEventPool.getLocal; +import static com.baidu.bifromq.plugin.subbroker.DeliveryResult.Code.ERROR; import com.baidu.bifromq.baseenv.EnvProvider; import com.baidu.bifromq.deliverer.DeliveryCall; @@ -28,6 +29,7 @@ import com.baidu.bifromq.type.TopicMessagePack; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ExecutorService; import java.util.concurrent.LinkedTransferQueue; @@ -58,13 +60,16 @@ public DeliverExecutor(int id, EnvProvider.INSTANCE.newThreadFactory("deliver-executor-" + id)), "deliver-executor-" + id); } - public void submit(NormalMatching route, TopicMessagePack msgPack, boolean inline) { + public CompletableFuture submit(NormalMatching route, TopicMessagePack msgPack, + boolean inline) { + CompletableFuture onDone = new CompletableFuture<>(); if (inline) { - send(route, msgPack); + send(route, msgPack, onDone); } else { - tasks.add(new SendTask(route, msgPack)); + tasks.add(new SendTask(route, msgPack, onDone)); scheduleSend(); } + return onDone; } public void shutdown() { @@ -80,7 +85,7 @@ private void scheduleSend() { private void sendAll() { SendTask task; while ((task = tasks.poll()) != null) { - send(task.route, task.msgPack); + send(task.route, task.msgPack, task.onDone); } sending.set(false); if (!tasks.isEmpty()) { @@ -88,7 +93,7 @@ private void sendAll() { } } - private void send(NormalMatching matched, TopicMessagePack msgPack) { + private void send(NormalMatching matched, TopicMessagePack msgPack, CompletableFuture onDone) { int subBrokerId = matched.subBrokerId; String delivererKey = matched.delivererKey; MatchInfo sub = matched.matchInfo; @@ -101,6 +106,7 @@ private void send(NormalMatching matched, TopicMessagePack msgPack) { .delivererKey(delivererKey) .subInfo(sub) .messages(msgPack)); + onDone.complete(false); } else { switch (result) { case OK -> eventCollector.report(getLocal(Delivered.class) @@ -129,11 +135,11 @@ private void send(NormalMatching matched, TopicMessagePack msgPack) { .subInfo(sub) .messages(msgPack)); } + onDone.complete(result != ERROR); } }); - } - private record SendTask(NormalMatching route, TopicMessagePack msgPack) { + private record SendTask(NormalMatching route, TopicMessagePack msgPack, CompletableFuture onDone) { } } diff --git a/bifromq-dist/bifromq-dist-worker/src/main/java/com/baidu/bifromq/dist/worker/DeliverExecutorGroup.java b/bifromq-dist/bifromq-dist-worker/src/main/java/com/baidu/bifromq/dist/worker/DeliverExecutorGroup.java index 5ad8753f3..f36d58a2b 100644 --- a/bifromq-dist/bifromq-dist-worker/src/main/java/com/baidu/bifromq/dist/worker/DeliverExecutorGroup.java +++ b/bifromq-dist/bifromq-dist-worker/src/main/java/com/baidu/bifromq/dist/worker/DeliverExecutorGroup.java @@ -39,10 +39,12 @@ import com.github.benmanes.caffeine.cache.LoadingCache; import com.github.benmanes.caffeine.cache.RemovalListener; import com.github.benmanes.caffeine.cache.Scheduler; +import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import lombok.extern.slf4j.Slf4j; @@ -97,15 +99,19 @@ public void shutdown() { orderedSharedMatching.invalidateAll(); } - public void submit(List matchedRoutes, TopicMessagePack msgPack) { + public CompletableFuture submit(List matchedRoutes, TopicMessagePack msgPack) { int msgPackSize = SizeUtil.estSizeOf(msgPack); + if (matchedRoutes.isEmpty()) { + return CompletableFuture.completedFuture(true); + } if (matchedRoutes.size() == 1) { Matching matching = matchedRoutes.get(0); - prepareSend(matching, msgPack, true); + CompletableFuture onDone = prepareSend(matching, msgPack, true); if (isSendToInbox(matching)) { ITenantMeter.get(matching.tenantId).recordSummary(MqttPersistentFanOutBytes, msgPackSize); } - } else if (matchedRoutes.size() > 1) { + return onDone; + } else { String tenantId = matchedRoutes.get(0).tenantId; boolean inline = matchedRoutes.size() > inlineFanOutThreshold; boolean hasTFanOutBandwidth = @@ -116,11 +122,12 @@ public void submit(List matchedRoutes, TopicMessagePack msgPack) { boolean hasPFannedOutUnderThrottled = false; // we meter persistent fanout bytes here, since for transient fanout is actually happened in the broker long pFanoutBytes = 0; + List> onDones = new ArrayList<>(matchedRoutes.size()); for (Matching matching : matchedRoutes) { if (isSendToInbox(matching)) { if (hasPFanOutBandwidth || !hasPFannedOutUnderThrottled) { pFanoutBytes += msgPackSize; - prepareSend(matching, msgPack, inline); + onDones.add(prepareSend(matching, msgPack, inline)); if (!hasPFanOutBandwidth) { hasPFannedOutUnderThrottled = true; for (TopicMessagePack.PublisherPack publisherPack : msgPack.getMessageList()) { @@ -132,7 +139,7 @@ public void submit(List matchedRoutes, TopicMessagePack msgPack) { } } } else if (hasTFanOutBandwidth || !hasTFannedOutUnderThrottled) { - prepareSend(matching, msgPack, inline); + onDones.add(prepareSend(matching, msgPack, inline)); if (!hasTFanOutBandwidth) { hasTFannedOutUnderThrottled = true; for (TopicMessagePack.PublisherPack publisherPack : msgPack.getMessageList()) { @@ -148,6 +155,8 @@ public void submit(List matchedRoutes, TopicMessagePack msgPack) { } } ITenantMeter.get(tenantId).recordSummary(MqttPersistentFanOutBytes, pFanoutBytes); + return CompletableFuture.allOf(onDones.toArray(CompletableFuture[]::new)) + .thenApply(v -> onDones.stream().allMatch(CompletableFuture::join)); } } @@ -159,14 +168,14 @@ public void invalidate(ScopedTopic scopedTopic) { orderedSharedMatching.invalidate(new OrderedSharedMatchingKey(scopedTopic.tenantId, escape(scopedTopic.topic))); } - private void prepareSend(Matching matching, TopicMessagePack msgPack, boolean inline) { - switch (matching.type()) { + private CompletableFuture prepareSend(Matching matching, TopicMessagePack msgPack, boolean inline) { + return switch (matching.type()) { case Normal -> send((NormalMatching) matching, msgPack, inline); case Group -> { GroupMatching groupMatching = (GroupMatching) matching; if (!groupMatching.ordered) { // pick one route randomly - send(groupMatching.receiverList.get( + yield send(groupMatching.receiverList.get( ThreadLocalRandom.current().nextInt(groupMatching.receiverList.size())), msgPack, inline); } else { // ordered shared subscription @@ -189,17 +198,21 @@ private void prepareSend(Matching matching, TopicMessagePack msgPack, boolean in .setTopic(msgPack.getTopic()) .addMessage(publisherPack); } - orderedRoutes.forEach((route, msgPackBuilder) -> send(route, msgPackBuilder.build(), inline)); + List> onDones = new ArrayList<>(orderedRoutes.size()); + orderedRoutes.forEach( + (route, msgPackBuilder) -> onDones.add(send(route, msgPackBuilder.build(), inline))); + yield CompletableFuture.allOf(onDones.toArray(CompletableFuture[]::new)) + .thenApply(v -> onDones.stream().allMatch(CompletableFuture::join)); } } - } + }; } - private void send(NormalMatching route, TopicMessagePack msgPack, boolean inline) { + private CompletableFuture send(NormalMatching route, TopicMessagePack msgPack, boolean inline) { int idx = route.hashCode() % fanoutExecutors.length; if (idx < 0) { idx += fanoutExecutors.length; } - fanoutExecutors[idx].submit(route, msgPack, inline); + return fanoutExecutors[idx].submit(route, msgPack, inline); } } diff --git a/bifromq-dist/bifromq-dist-worker/src/main/java/com/baidu/bifromq/dist/worker/DistWorkerCoProc.java b/bifromq-dist/bifromq-dist-worker/src/main/java/com/baidu/bifromq/dist/worker/DistWorkerCoProc.java index d2d27ae47..ae363774c 100644 --- a/bifromq-dist/bifromq-dist-worker/src/main/java/com/baidu/bifromq/dist/worker/DistWorkerCoProc.java +++ b/bifromq-dist/bifromq-dist-worker/src/main/java/com/baidu/bifromq/dist/worker/DistWorkerCoProc.java @@ -60,6 +60,8 @@ import com.baidu.bifromq.dist.rpc.proto.DistServiceRWCoProcInput; import com.baidu.bifromq.dist.rpc.proto.DistServiceRWCoProcOutput; import com.baidu.bifromq.dist.rpc.proto.GroupMatchRecord; +import com.baidu.bifromq.dist.rpc.proto.TenantDistReply; +import com.baidu.bifromq.dist.rpc.proto.TenantDistRequest; import com.baidu.bifromq.dist.rpc.proto.TopicFanout; import com.baidu.bifromq.plugin.eventcollector.IEventCollector; import com.baidu.bifromq.plugin.subbroker.ISubBrokerManager; @@ -131,6 +133,12 @@ public CompletableFuture query(ROCoProcInput input, IKVReader re v -> ROCoProcOutput.newBuilder().setDistService(DistServiceROCoProcOutput.newBuilder() .setBatchDist(v).build()).build()); } + case TENANTDIST -> { + return tenantDist(coProcInput.getTenantDist(), reader) + .thenApply( + v -> ROCoProcOutput.newBuilder().setDistService(DistServiceROCoProcOutput.newBuilder() + .setTenantDist(v).build()).build()); + } default -> { log.error("Unknown co proc type {}", coProcInput.getInputCase()); CompletableFuture f = new CompletableFuture<>(); @@ -423,6 +431,57 @@ private CompletableFuture batchDist(BatchDistRequest request, IK }); } + private CompletableFuture tenantDist(TenantDistRequest request, IKVReader reader) { + List msgPackList = request.getMsgPackList(); + if (msgPackList.isEmpty()) { + return CompletableFuture.completedFuture(TenantDistReply.newBuilder() + .setReqId(request.getReqId()) + .build()); + } + String tenantId = request.getTenantId(); + Boundary boundary = intersect(Boundary.newBuilder() + .setStartKey(matchRecordKeyPrefix(tenantId)) + .setEndKey(tenantUpperBound(tenantId)) + .build(), reader.boundary()); + if (isEmptyRange(boundary)) { + TenantDistReply.Builder replyBuilder = TenantDistReply.newBuilder().setReqId(request.getReqId()); + for (TopicMessagePack topicMessagePack : msgPackList) { + replyBuilder.putResults(topicMessagePack.getTopic(), TenantDistReply.Result.newBuilder() + .setCode(TenantDistReply.Code.OK) + .build()); + } + return CompletableFuture.completedFuture(replyBuilder.build()); + } + Map> fanOutByTopics = new HashMap<>(); + for (TopicMessagePack topicMsgPack : msgPackList) { + String topic = topicMsgPack.getTopic(); + ScopedTopic scopedTopic = ScopedTopic.builder() + .tenantId(tenantId) + .topic(topic) + .boundary(reader.boundary()) + .build(); + fanOutByTopics.put(topic, routeCache.get(scopedTopic) + .thenCompose(matchResult -> fanoutExecutorGroup.submit(matchResult.routes, topicMsgPack) + .thenApply(success -> { + if (success) { + return TenantDistReply.Result.newBuilder() + .setCode(TenantDistReply.Code.OK) + .setFanout(matchResult.routes.size()) + .build(); + } else { + return TenantDistReply.Result.newBuilder() + .setCode(TenantDistReply.Code.ERROR) + .build(); + } + }))); + } + return CompletableFuture.allOf(fanOutByTopics.values().toArray(CompletableFuture[]::new)) + .thenApply(v -> TenantDistReply.newBuilder() + .setReqId(request.getReqId()) + .putAllResults(Maps.transformValues(fanOutByTopics, CompletableFuture::join)) + .build()); + } + private void load() { IKVReader reader = readerProvider.get(); IKVIterator itr = reader.iterator(); diff --git a/bifromq-dist/bifromq-dist-worker/src/test/java/com/baidu/bifromq/dist/worker/BatchDistTest.java b/bifromq-dist/bifromq-dist-worker/src/test/java/com/baidu/bifromq/dist/worker/BatchDistTest.java index c2c268398..4b46d2779 100644 --- a/bifromq-dist/bifromq-dist-worker/src/test/java/com/baidu/bifromq/dist/worker/BatchDistTest.java +++ b/bifromq-dist/bifromq-dist-worker/src/test/java/com/baidu/bifromq/dist/worker/BatchDistTest.java @@ -13,7 +13,6 @@ package com.baidu.bifromq.dist.worker; -import static com.baidu.bifromq.plugin.subbroker.TypeUtil.to; import static com.baidu.bifromq.type.MQTTClientInfoConstants.MQTT_CLIENT_ADDRESS_KEY; import static com.baidu.bifromq.type.MQTTClientInfoConstants.MQTT_CLIENT_ID_KEY; import static com.baidu.bifromq.type.MQTTClientInfoConstants.MQTT_PROTOCOL_VER_3_1_1_VALUE; @@ -26,25 +25,16 @@ import static org.mockito.Mockito.when; import static org.testng.Assert.assertEquals; -import com.baidu.bifromq.dist.rpc.proto.BatchDistReply; -import com.baidu.bifromq.plugin.subbroker.DeliveryPack; -import com.baidu.bifromq.plugin.subbroker.DeliveryPackage; -import com.baidu.bifromq.plugin.subbroker.DeliveryReply; -import com.baidu.bifromq.plugin.subbroker.DeliveryRequest; +import com.baidu.bifromq.dist.rpc.proto.TenantDistReply; import com.baidu.bifromq.plugin.subbroker.DeliveryResult; import com.baidu.bifromq.type.ClientInfo; -import com.baidu.bifromq.type.MatchInfo; import com.baidu.bifromq.type.Message; import com.baidu.bifromq.type.QoS; import com.baidu.bifromq.type.TopicMessagePack; import com.google.protobuf.ByteString; -import java.util.HashMap; import java.util.List; -import java.util.Map; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ThreadLocalRandom; import lombok.extern.slf4j.Slf4j; -import org.mockito.stubbing.Answer; import org.testng.annotations.Test; @Slf4j @@ -52,10 +42,9 @@ public class BatchDistTest extends DistWorkerTest { @Test(groups = "integration") public void batchDistWithNoSub() { - String topic = "/a/b/c"; ByteString payload = copyFromUtf8("hello"); - BatchDistReply reply = dist(tenantA, + TenantDistReply reply = tenantDist(tenantA, List.of(TopicMessagePack.newBuilder() .setTopic("a") .addMessage(toMsg(tenantA, AT_MOST_ONCE, payload)) @@ -68,7 +57,8 @@ public void batchDistWithNoSub() { .setTopic("a/b") .addMessage(toMsg(tenantA, AT_MOST_ONCE, payload)) .build()), "orderKey1"); - assertEquals(reply.getResultMap().get(tenantA).getFanoutMap().getOrDefault(topic, 0).intValue(), 0); + assertEquals(reply.getResultsMap().size(), 3); + reply.getResultsMap().forEach((k, v) -> assertEquals(v.getFanout(), 0)); } @Test(groups = "integration") @@ -86,7 +76,7 @@ public void batchDist() { match(tenantA, "/a/3", InboxService, "inbox2", "batch2"); match(tenantA, "/a/4", InboxService, "inbox2", "batch2"); - BatchDistReply reply = dist(tenantA, + TenantDistReply reply = tenantDist(tenantA, List.of( TopicMessagePack.newBuilder() .setTopic("/a/1") @@ -105,10 +95,10 @@ public void batchDist() { .addMessage(toMsg(tenantA, AT_MOST_ONCE, copyFromUtf8("Hello"))) .build()), "orderKey1"); - assertEquals(reply.getResultMap().get(tenantA).getFanoutMap().get("/a/1").intValue(), 1); - assertEquals(reply.getResultMap().get(tenantA).getFanoutMap().get("/a/2").intValue(), 2); - assertEquals(reply.getResultMap().get(tenantA).getFanoutMap().get("/a/3").intValue(), 1); - assertEquals(reply.getResultMap().get(tenantA).getFanoutMap().get("/a/4").intValue(), 1); + assertEquals(reply.getResultsMap().get("/a/1").getFanout(), 1); + assertEquals(reply.getResultsMap().get("/a/2").getFanout(), 2); + assertEquals(reply.getResultsMap().get("/a/3").getFanout(), 1); + assertEquals(reply.getResultsMap().get("/a/4").getFanout(), 1); unmatch(tenantA, "/a/1", MqttBroker, "inbox1", "batch1"); unmatch(tenantA, "/a/2", MqttBroker, "inbox1", "batch1"); diff --git a/bifromq-dist/bifromq-dist-worker/src/test/java/com/baidu/bifromq/dist/worker/DistQoS0Test.java b/bifromq-dist/bifromq-dist-worker/src/test/java/com/baidu/bifromq/dist/worker/DistQoS0Test.java index 92dd459bf..9c1ea041b 100644 --- a/bifromq-dist/bifromq-dist-worker/src/test/java/com/baidu/bifromq/dist/worker/DistQoS0Test.java +++ b/bifromq-dist/bifromq-dist-worker/src/test/java/com/baidu/bifromq/dist/worker/DistQoS0Test.java @@ -24,7 +24,7 @@ import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; -import com.baidu.bifromq.dist.rpc.proto.BatchDistReply; +import com.baidu.bifromq.dist.rpc.proto.TenantDistReply; import com.baidu.bifromq.plugin.subbroker.DeliveryPack; import com.baidu.bifromq.plugin.subbroker.DeliveryPackage; import com.baidu.bifromq.plugin.subbroker.DeliveryRequest; @@ -67,8 +67,8 @@ public void succeedWithNoSub() { String topic = "/a/b/c"; ByteString payload = copyFromUtf8("hello"); - BatchDistReply reply = dist(tenantA, AT_MOST_ONCE, topic, payload, "orderKey1"); - assertEquals(reply.getResultMap().get(tenantA).getFanoutMap().getOrDefault(topic, 0).intValue(), 0); + TenantDistReply reply = tenantDist(tenantA, AT_MOST_ONCE, topic, payload, "orderKey1"); + assertEquals(reply.getResultsMap().get(topic).getFanout(), 0); } @Test(groups = "integration") @@ -81,8 +81,8 @@ public void testDistCase1() { match(tenantA, "TopicA/#", MqttBroker, "inbox1", "batch1"); - BatchDistReply reply = dist(tenantA, AT_MOST_ONCE, "TopicB", copyFromUtf8("Hello"), "orderKey1"); - assertEquals(reply.getResultMap().get(tenantA).getFanoutMap().getOrDefault("TopicB", 0).intValue(), 0); + TenantDistReply reply = tenantDist(tenantA, AT_MOST_ONCE, "TopicB", copyFromUtf8("Hello"), "orderKey1"); + assertEquals(reply.getResultsMap().get("TopicB").getFanout(), 0); unmatch(tenantA, "TopicA/#", InboxService, "inbox1", "batch1"); } @@ -100,8 +100,8 @@ public void testDistCase2() { match(tenantA, "/#", MqttBroker, "inbox1", "batch1"); match(tenantA, "/#", InboxService, "inbox2", "batch2"); - BatchDistReply reply = dist(tenantA, AT_MOST_ONCE, "/你好/hello/😄", copyFromUtf8("Hello"), "orderKey1"); - assertEquals(reply.getResultMap().get(tenantA).getFanoutMap().get("/你好/hello/😄").intValue(), 3); + TenantDistReply reply = tenantDist(tenantA, AT_MOST_ONCE, "/你好/hello/😄", copyFromUtf8("Hello"), "orderKey1"); + assertEquals(reply.getResultsMap().get("/你好/hello/😄").getFanout(), 3); ArgumentCaptor msgCap = ArgumentCaptor.forClass(DeliveryRequest.class); verify(writer1, after(1000).atMost(2)).deliver(msgCap.capture()); @@ -158,8 +158,8 @@ public void testDistCase3() { match(tenantA, "/a/b/c", MqttBroker, "inbox1", "batch1"); match(tenantA, "/a/b/c", MqttBroker, "inbox2", "batch1"); - BatchDistReply reply = dist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); - assertEquals(reply.getResultMap().get(tenantA).getFanoutMap().get("/a/b/c").intValue(), 2); + TenantDistReply reply = tenantDist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); + assertEquals(reply.getResultsMap().get("/a/b/c").getFanout(), 2); ArgumentCaptor list1 = ArgumentCaptor.forClass(DeliveryRequest.class); verify(writer1, after(1000).atMost(2)).deliver(list1.capture()); log.info("Case3: verify writer1, list size is {}", list1.getAllValues().size()); @@ -199,8 +199,8 @@ public void testDistCase4() { match(tenantA, "$share/group//a/b/c", MqttBroker, "inbox1", "batch1"); match(tenantA, "$share/group//a/b/c", MqttBroker, "inbox2", "batch2"); for (int i = 0; i < 10; i++) { - BatchDistReply reply = dist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); - assertEquals(reply.getResultMap().get(tenantA).getFanoutMap().get("/a/b/c"), 1); + TenantDistReply reply = tenantDist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); + assertEquals(reply.getResultsMap().get("/a/b/c").getFanout(), 1); } ArgumentCaptor list1 = ArgumentCaptor.forClass(DeliveryRequest.class); @@ -255,8 +255,8 @@ public void testDistCase5() { match(tenantA, "$oshare/group//a/b/c", MqttBroker, "inbox2", "batch2"); for (int i = 0; i < 10; i++) { - BatchDistReply reply = dist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); - assertEquals(reply.getResultMap().get(tenantA).getFanoutMap().get("/a/b/c").intValue(), 1); + TenantDistReply reply = tenantDist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); + assertEquals(reply.getResultsMap().get("/a/b/c").getFanout(), 1); } ArgumentCaptor list1 = ArgumentCaptor.forClass(DeliveryRequest.class); @@ -303,8 +303,8 @@ public void testDistCase6() { match(tenantA, "$share/group//a/b/c", MqttBroker, "inbox2", "batch2"); match(tenantA, "$oshare/group//a/b/c", MqttBroker, "inbox3", "batch3"); for (int i = 0; i < 1; ++i) { - BatchDistReply reply = dist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); - assertEquals(reply.getResultMap().get(tenantA).getFanoutMap().get("/a/b/c").intValue(), 3); + TenantDistReply reply = tenantDist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); + assertEquals(reply.getResultsMap().get("/a/b/c").getFanout(), 3); } verify(writer1, timeout(1000).times(1)).deliver(any()); @@ -326,8 +326,8 @@ public void testDistCase7() { match(tenantA, "/a/b/c", MqttBroker, "inbox1", "batch1"); match(tenantB, "#", MqttBroker, "inbox1", "batch1"); - BatchDistReply reply = dist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); - assertEquals(reply.getResultMap().get(tenantA).getFanoutMap().get("/a/b/c").intValue(), 1); + TenantDistReply reply = tenantDist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); + assertEquals(reply.getResultsMap().get("/a/b/c").getFanout(), 1); ArgumentCaptor list1 = ArgumentCaptor.forClass(DeliveryRequest.class); verify(writer1, timeout(1000).times(1)).deliver(list1.capture()); @@ -366,7 +366,7 @@ public void testRouteRefresh() { // sub: inbox1 -> [(/a/b/c, qos0)] // expected behavior: inbox1 gets 1 message match(tenantA, "/a/b/c", MqttBroker, "inbox1", "batch1"); - dist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); + tenantDist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); verify(writer1, timeout(1000).times(1)).deliver(any()); clearInvocations(writer1, writer2, writer3); @@ -376,7 +376,7 @@ public void testRouteRefresh() { // sub: no sub // expected behavior: inbox1 gets no messages unmatch(tenantA, "/a/b/c", MqttBroker, "inbox1", "batch1"); - dist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); + tenantDist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); verify(writer1, timeout(1000).times(0)).deliver(any()); clearInvocations(writer1, writer2, writer3); @@ -386,7 +386,7 @@ public void testRouteRefresh() { // sub: inbox2 -> [($share/group/a/b/c, qos0)] // expected behavior: inbox2 gets 1 message match(tenantA, "$share/group//a/b/c", MqttBroker, "inbox2", "batch2"); - dist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); + tenantDist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); verify(writer2, timeout(1000).times(1)).deliver(any()); clearInvocations(writer1, writer2, writer3); @@ -397,7 +397,7 @@ public void testRouteRefresh() { // expected behavior: inbox2 gets no messages and inbox3 gets 1 unmatch(tenantA, "$share/group//a/b/c", MqttBroker, "inbox2", "batch2"); match(tenantA, "$share/group//a/b/c", MqttBroker, "inbox3", "batch3"); - dist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); + tenantDist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); verify(writer2, timeout(1000).times(0)).deliver(any()); verify(writer3, timeout(1000).times(1)).deliver(any()); clearInvocations(writer1, writer2, writer3); @@ -409,7 +409,7 @@ public void testRouteRefresh() { // expected behavior: inbox2 gets 1 message and inbox3 gets none match(tenantA, "$oshare/group//a/b/c", MqttBroker, "inbox2", "batch2"); unmatch(tenantA, "$share/group//a/b/c", MqttBroker, "inbox3", "batch3"); - dist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); + tenantDist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); verify(writer2, timeout(1000).times(1)).deliver(any()); verify(writer3, timeout(1000).times(0)).deliver(any()); clearInvocations(writer1, writer2, writer3); @@ -421,7 +421,7 @@ public void testRouteRefresh() { // expected behavior: inbox2 gets no messages and inbox3 gets 1 unmatch(tenantA, "$oshare/group//a/b/c", MqttBroker, "inbox2", "batch2"); match(tenantA, "$oshare/group//a/b/c", MqttBroker, "inbox3", "batch3"); - dist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); + tenantDist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); verify(writer2, timeout(1000).times(0)).deliver(any()); verify(writer3, timeout(1000).times(1)).deliver(any()); @@ -439,7 +439,7 @@ public void testRouteRefreshWithWildcardTopic() throws InterruptedException { when(mqttBroker.open("batch2")).thenReturn(writer2); when(mqttBroker.open("batch3")).thenReturn(writer3); match(tenantA, "/a/b/c", MqttBroker, "inbox1", "batch1"); - dist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); + tenantDist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); verify(writer1, timeout(1000).times(1)).deliver(any()); clearInvocations(writer1, writer2, writer3); @@ -450,7 +450,7 @@ public void testRouteRefreshWithWildcardTopic() throws InterruptedException { // expected behavior: inbox1 gets 1 message and inbox2 gets 1 either match(tenantA, "/#", MqttBroker, "inbox2", "batch2"); Thread.sleep(1100); - dist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); + tenantDist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); verify(writer1, timeout(1000).times(1)).deliver(any()); verify(writer2, timeout(1000).times(1)).deliver(any()); clearInvocations(writer1, writer2, writer3); @@ -466,7 +466,7 @@ public void testRouteRefreshWithWildcardTopic() throws InterruptedException { match(tenantA, "$oshare/group/#", MqttBroker, "inbox3", "batch3"); // wait for cache refresh after writing Thread.sleep(1100); - dist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); + tenantDist(tenantA, AT_MOST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); verify(writer1, timeout(1000).times(1)).deliver(any()); verify(writer2, timeout(1000).times(1)).deliver(any()); verify(writer3, timeout(1000).atLeastOnce()).deliver(any()); diff --git a/bifromq-dist/bifromq-dist-worker/src/test/java/com/baidu/bifromq/dist/worker/DistQoS1Test.java b/bifromq-dist/bifromq-dist-worker/src/test/java/com/baidu/bifromq/dist/worker/DistQoS1Test.java index 244048d77..cd4d1b6e8 100644 --- a/bifromq-dist/bifromq-dist-worker/src/test/java/com/baidu/bifromq/dist/worker/DistQoS1Test.java +++ b/bifromq-dist/bifromq-dist-worker/src/test/java/com/baidu/bifromq/dist/worker/DistQoS1Test.java @@ -25,7 +25,7 @@ import static org.mockito.Mockito.when; import static org.testng.Assert.assertEquals; -import com.baidu.bifromq.dist.rpc.proto.BatchDistReply; +import com.baidu.bifromq.dist.rpc.proto.TenantDistReply; import com.baidu.bifromq.plugin.eventcollector.EventType; import com.baidu.bifromq.plugin.subbroker.DeliveryPack; import com.baidu.bifromq.plugin.subbroker.DeliveryRequest; @@ -53,8 +53,8 @@ public void succeedWithNoSub() { String topic = "/a/b/c"; ByteString payload = copyFromUtf8("hello"); - BatchDistReply reply = dist(tenantA, AT_LEAST_ONCE, topic, payload, "orderKey1"); - assertEquals(reply.getResultMap().get(tenantA).getFanoutMap().getOrDefault(topic, 0).intValue(), 0); + TenantDistReply reply = tenantDist(tenantA, AT_LEAST_ONCE, topic, payload, "orderKey1"); + assertEquals(reply.getResultsMap().get(topic).getFanout(), 0); } @Test(groups = "integration") @@ -76,8 +76,8 @@ public void testDistCase9() { match(tenantA, "/a/b/c", MqttBroker, "inbox1", "server1"); for (int i = 0; i < 10; i++) { - BatchDistReply reply = dist(tenantA, AT_LEAST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); - assertEquals(reply.getResultMap().get(tenantA).getFanoutMap().get("/a/b/c").intValue(), 1); + TenantDistReply reply = tenantDist(tenantA, AT_LEAST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); + assertEquals(reply.getResultsMap().get("/a/b/c").getFanout(), 1); } ArgumentCaptor messageListCap = ArgumentCaptor.forClass(DeliveryRequest.class); @@ -121,8 +121,8 @@ public void testDistCase10() { match(tenantA, "$share/group//a/b/c", MqttBroker, "inbox1", "server1"); for (int i = 0; i < 10; i++) { - BatchDistReply reply = dist(tenantA, AT_LEAST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); - assertEquals(reply.getResultMap().get(tenantA).getFanoutMap().get("/a/b/c").intValue(), 1); + TenantDistReply reply = tenantDist(tenantA, AT_LEAST_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); + assertEquals(reply.getResultsMap().get("/a/b/c").getFanout(), 1); } diff --git a/bifromq-dist/bifromq-dist-worker/src/test/java/com/baidu/bifromq/dist/worker/DistQoS2Test.java b/bifromq-dist/bifromq-dist-worker/src/test/java/com/baidu/bifromq/dist/worker/DistQoS2Test.java index ba73a1485..00d569fdd 100644 --- a/bifromq-dist/bifromq-dist-worker/src/test/java/com/baidu/bifromq/dist/worker/DistQoS2Test.java +++ b/bifromq-dist/bifromq-dist-worker/src/test/java/com/baidu/bifromq/dist/worker/DistQoS2Test.java @@ -21,7 +21,7 @@ import static org.mockito.Mockito.when; import static org.testng.Assert.assertEquals; -import com.baidu.bifromq.dist.rpc.proto.BatchDistReply; +import com.baidu.bifromq.dist.rpc.proto.TenantDistReply; import com.baidu.bifromq.plugin.subbroker.DeliveryPack; import com.baidu.bifromq.plugin.subbroker.DeliveryRequest; import com.baidu.bifromq.plugin.subbroker.DeliveryResult; @@ -48,8 +48,8 @@ public void succeedWithNoSub() { String topic = "/a/b/c"; ByteString payload = copyFromUtf8("hello"); - BatchDistReply reply = dist(tenantA, EXACTLY_ONCE, topic, payload, "orderKey1"); - assertEquals(reply.getResultMap().get(tenantA).getFanoutMap().getOrDefault(topic, 0).intValue(), 0); + TenantDistReply reply = tenantDist(tenantA, EXACTLY_ONCE, topic, payload, "orderKey1"); + assertEquals(reply.getResultsMap().get(topic).getFanout(), 0); } @Test(groups = "integration") @@ -68,8 +68,8 @@ public void distQoS2ToVariousSubQoS() { match(tenantA, "/a/b/c", MqttBroker, "inbox1", "server1"); match(tenantA, "/#", MqttBroker, "inbox1", "server1"); match(tenantA, "/#", MqttBroker, "inbox2", "server2"); - BatchDistReply reply = dist(tenantA, EXACTLY_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); - assertEquals(reply.getResultMap().get(tenantA).getFanoutMap().get("/a/b/c").intValue(), 3); + TenantDistReply reply = tenantDist(tenantA, EXACTLY_ONCE, "/a/b/c", copyFromUtf8("Hello"), "orderKey1"); + assertEquals(reply.getResultsMap().get("/a/b/c").getFanout(), 3); ArgumentCaptor msgCap = ArgumentCaptor.forClass(DeliveryRequest.class); diff --git a/bifromq-dist/bifromq-dist-worker/src/test/java/com/baidu/bifromq/dist/worker/DistWorkerTest.java b/bifromq-dist/bifromq-dist-worker/src/test/java/com/baidu/bifromq/dist/worker/DistWorkerTest.java index 4582fac31..13e473c6e 100644 --- a/bifromq-dist/bifromq-dist-worker/src/test/java/com/baidu/bifromq/dist/worker/DistWorkerTest.java +++ b/bifromq-dist/bifromq-dist-worker/src/test/java/com/baidu/bifromq/dist/worker/DistWorkerTest.java @@ -50,16 +50,16 @@ import com.baidu.bifromq.dist.client.IDistClient; import com.baidu.bifromq.dist.entity.EntityUtil; import com.baidu.bifromq.dist.rpc.proto.BatchDistReply; -import com.baidu.bifromq.dist.rpc.proto.BatchDistRequest; import com.baidu.bifromq.dist.rpc.proto.BatchMatchReply; import com.baidu.bifromq.dist.rpc.proto.BatchMatchRequest; import com.baidu.bifromq.dist.rpc.proto.BatchUnmatchReply; import com.baidu.bifromq.dist.rpc.proto.BatchUnmatchRequest; -import com.baidu.bifromq.dist.rpc.proto.DistPack; +import com.baidu.bifromq.dist.rpc.proto.DistServiceROCoProcInput; import com.baidu.bifromq.dist.rpc.proto.DistServiceROCoProcOutput; import com.baidu.bifromq.dist.rpc.proto.DistServiceRWCoProcInput; +import com.baidu.bifromq.dist.rpc.proto.TenantDistReply; +import com.baidu.bifromq.dist.rpc.proto.TenantDistRequest; import com.baidu.bifromq.dist.rpc.proto.TenantOption; -import com.baidu.bifromq.dist.util.MessageUtil; import com.baidu.bifromq.plugin.eventcollector.IEventCollector; import com.baidu.bifromq.plugin.subbroker.DeliveryPack; import com.baidu.bifromq.plugin.subbroker.DeliveryPackage; @@ -329,19 +329,20 @@ protected BatchUnmatchReply.Result unmatch(String tenantId, String topicFilter, return batchUnmatchReply.getResultsMap().get(scopedTopicFilter); } - protected BatchDistReply dist(String tenantId, List msgs, String orderKey) { + protected TenantDistReply tenantDist(String tenantId, List msgs, String orderKey) { long reqId = ThreadLocalRandom.current().nextInt(); KVRangeSetting s = storeClient.findByKey(EntityUtil.matchRecordKeyPrefix(tenantId)).get(); - BatchDistRequest request = BatchDistRequest.newBuilder() + TenantDistRequest request = TenantDistRequest.newBuilder() .setReqId(reqId) - .addDistPack(DistPack.newBuilder() - .setTenantId(tenantId) - .addAllMsgPack(msgs) - .build()) + .setTenantId(tenantId) + .addAllMsgPack(msgs) .setOrderKey(orderKey) .build(); + ROCoProcInput input = ROCoProcInput.newBuilder() - .setDistService(MessageUtil.buildBatchDistRequest(request)) + .setDistService(DistServiceROCoProcInput.newBuilder() + .setTenantDist(request) + .build()) .build(); KVRangeROReply reply = storeClient.query(s.leader, KVRangeRORequest.newBuilder() .setReqId(reqId) @@ -352,13 +353,13 @@ protected BatchDistReply dist(String tenantId, List msgs, Stri assertEquals(reply.getReqId(), reqId); assertEquals(reply.getCode(), ReplyCode.Ok); DistServiceROCoProcOutput output = reply.getRoCoProcResult().getDistService(); - assertTrue(output.hasBatchDist()); - assertEquals(output.getBatchDist().getReqId(), reqId); - return output.getBatchDist(); + assertTrue(output.hasTenantDist()); + assertEquals(output.getTenantDist().getReqId(), reqId); + return output.getTenantDist(); } - protected BatchDistReply dist(String tenantId, QoS qos, String topic, ByteString payload, String orderKey) { - return dist(tenantId, List.of(TopicMessagePack.newBuilder() + protected TenantDistReply tenantDist(String tenantId, QoS qos, String topic, ByteString payload, String orderKey) { + return tenantDist(tenantId, List.of(TopicMessagePack.newBuilder() .setTopic(topic) .addMessage(TopicMessagePack.PublisherPack.newBuilder() .setPublisher(ClientInfo.newBuilder() diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/MQTTSessionHandler.java b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/MQTTSessionHandler.java index 03d3b3ae3..682c2f5e7 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/MQTTSessionHandler.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/MQTTSessionHandler.java @@ -1274,6 +1274,8 @@ private CompletableFuture doPub(long reqId, } } default -> { + // TODO: support limit retry + msgIdGenerator.markDrain(topic, message.getMessageId()); switch (message.getPubQoS()) { case AT_MOST_ONCE -> eventCollector.report(getLocal(QoS0DistError.class) .reqId(reqId) diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/TopicMessageIdGenerator.java b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/TopicMessageIdGenerator.java index 3ce23b408..a107c2605 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/TopicMessageIdGenerator.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/TopicMessageIdGenerator.java @@ -42,16 +42,23 @@ public TopicMessageIdGenerator(Duration syncWindowInterval, int maxActiveTopics, public long nextMessageId(String topic, long nowMillis) { checker.updateNowMillis(nowMillis); - return topicMessageIdCache.computeIfAbsent(topic, k -> new AtomicLong(0xFFFFFFFE00000000L)) - .updateAndGet(msgId -> { - long currentSWS = syncWindowSequence(nowMillis, syncWindowIntervalMillis); - long lastSWS = syncWindowSequence(msgId); - if (currentSWS == lastSWS || currentSWS == lastSWS + 1) { - return messageId(currentSWS, messageSequence(msgId) + 1); - } else { - return messageId(currentSWS, 0); - } - }); + return topicMessageIdCache.computeIfAbsent(topic, k -> new MessageIdGenerator(syncWindowIntervalMillis)) + .next(nowMillis); + } + + public void markDrain(String topic, long failedMessageId) { + MessageIdGenerator idGenerator = topicMessageIdCache.get(topic); + if (idGenerator != null) { + idGenerator.markDrain(failedMessageId); + } + } + + public long drainMessageId(String topic) { + MessageIdGenerator idGenerator = topicMessageIdCache.get(topic); + if (idGenerator != null) { + return idGenerator.drainMessageId; + } + return 0; } private static class PrematureEvictionChecker implements Predicate { @@ -72,7 +79,49 @@ public boolean test(Long messageId) { } } - private static class TopicMessageIdCache extends LinkedHashMap { + private static class MessageIdGenerator { + private final long syncWindowIntervalMillis; + private final AtomicLong messageId; + private long drainMessageId = 0; // the messages before this id has been drained + private boolean drainMarked = false; // whether the drain flag should be set for next message id + + private MessageIdGenerator(long syncWindowIntervalMillis) { + this.syncWindowIntervalMillis = syncWindowIntervalMillis; + this.messageId = new AtomicLong(0xFFFFFFFE00000000L); + } + + public void markDrain(long failedMessageId) { + if (failedMessageId > drainMessageId) { + drainMarked = true; + } + } + + public long next(long nowMillis) { + return messageId.updateAndGet(msgId -> { + long currentSWS = syncWindowSequence(nowMillis, syncWindowIntervalMillis); + long lastSWS = syncWindowSequence(msgId); + if (currentSWS == lastSWS || currentSWS == lastSWS + 1) { + if (drainMarked) { + drainMessageId = messageId(currentSWS, messageSequence(msgId) + 1, true); + drainMarked = false; + return drainMessageId; + } else { + return messageId(currentSWS, messageSequence(msgId) + 1); + } + } else { + if (drainMarked) { + drainMessageId = messageId(currentSWS, 0, true); + drainMarked = false; + return drainMessageId; + } else { + return messageId(currentSWS, 0); + } + } + }); + } + } + + private static class TopicMessageIdCache extends LinkedHashMap { private final ITenantMeter meter; private final Predicate isPrematureEviction; private final int maxSize; @@ -87,9 +136,9 @@ private TopicMessageIdCache(ITenantMeter meter, } @Override - protected boolean removeEldestEntry(Map.Entry eldest) { + protected boolean removeEldestEntry(Map.Entry eldest) { if (size() > maxSize) { - if (isPrematureEviction.test(eldest.getValue().get())) { + if (isPrematureEviction.test(eldest.getValue().messageId.get())) { meter.recordCount(TenantMetric.MqttTopicSeqAbortCount); } return true; diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/TopicMessageOrderingSender.java b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/TopicMessageOrderingSender.java index 69c6a312f..0e3fde99e 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/TopicMessageOrderingSender.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/TopicMessageOrderingSender.java @@ -16,6 +16,7 @@ import static com.baidu.bifromq.metrics.TenantMetric.MqttOutOfOrderSendBytes; import static com.baidu.bifromq.metrics.TenantMetric.MqttReorderBytes; import static com.baidu.bifromq.metrics.TenantMetric.MqttTopicSorterAbortCount; +import static com.baidu.bifromq.mqtt.utils.MessageIdUtil.isDrainFlagSet; import static com.baidu.bifromq.mqtt.utils.MessageIdUtil.isSuccessive; import static com.baidu.bifromq.mqtt.utils.MessageIdUtil.previousMessageId; @@ -129,14 +130,21 @@ boolean submit(long inboxSeqNo, MQTTSessionHandler.SubMessage subMessage) { tailMsgId = msgId; meter.recordSummary(MqttReorderBytes, subMessage.estBytes()); sortingBuffer.put(msgId, new SortingMessage(inboxSeqNo, subMessage)); + if (isDrainFlagSet(msgId)) { + drain(); + } } } else if (msgId > tailMsgId) { // out of order happens meter.recordSummary(MqttReorderBytes, subMessage.estBytes()); sortingBuffer.put(msgId, new SortingMessage(inboxSeqNo, subMessage)); tailMsgId = msgId; - if (timeout == null) { - timeout = executor.schedule(this::drain, syncWindowIntervalMillis, TimeUnit.MILLISECONDS); + if (isDrainFlagSet(msgId)) { + drain(); + } else { + if (timeout == null) { + timeout = executor.schedule(this::drain, syncWindowIntervalMillis, TimeUnit.MILLISECONDS); + } } } else { // tailMsgSeq <= msgSeq diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/v3/MQTT3ProtocolHelper.java b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/v3/MQTT3ProtocolHelper.java index 22c8ee189..482f645e7 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/v3/MQTT3ProtocolHelper.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/v3/MQTT3ProtocolHelper.java @@ -358,7 +358,7 @@ public ProtocolResponse onQoS0DistDenied(String topic, Message distMessage, Chec public ProtocolResponse onQoS0PubHandled(PubResult result, MqttPublishMessage message, UserProperties userProps) { if (result.distResult() == DistResult.BACK_PRESSURE_REJECTED || result.retainResult() == RetainReply.Result.BACK_PRESSURE_REJECTED) { - return goAway(getLocal(ServerBusy.class) + return responseNothing(getLocal(ServerBusy.class) .reason("Too many qos0 publish") .clientInfo(clientInfo)); } else { @@ -379,7 +379,7 @@ public ProtocolResponse onQoS1DistDenied(String topic, int packetId, Message dis public ProtocolResponse onQoS1PubHandled(PubResult result, MqttPublishMessage message, UserProperties userProps) { if (result.distResult() == DistResult.BACK_PRESSURE_REJECTED || result.retainResult() == RetainReply.Result.BACK_PRESSURE_REJECTED) { - return goAway(getLocal(ServerBusy.class) + return responseNothing(getLocal(ServerBusy.class) .reason("Too many qos1 publish") .clientInfo(clientInfo)); } else { @@ -419,7 +419,7 @@ public ProtocolResponse onQoS2DistDenied(String topic, int packetId, Message dis public ProtocolResponse onQoS2PubHandled(PubResult result, MqttPublishMessage message, UserProperties userProps) { if (result.distResult() == DistResult.BACK_PRESSURE_REJECTED || result.retainResult() == RetainReply.Result.BACK_PRESSURE_REJECTED) { - return goAway(getLocal(ServerBusy.class) + return responseNothing(getLocal(ServerBusy.class) .reason("Too many qos2 publish") .clientInfo(clientInfo)); } else { diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/v5/MQTT5ProtocolHelper.java b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/v5/MQTT5ProtocolHelper.java index 10fd598f2..7d267f270 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/v5/MQTT5ProtocolHelper.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/v5/MQTT5ProtocolHelper.java @@ -16,6 +16,7 @@ import static com.baidu.bifromq.mqtt.handler.record.ProtocolResponse.farewell; import static com.baidu.bifromq.mqtt.handler.record.ProtocolResponse.farewellNow; import static com.baidu.bifromq.mqtt.handler.record.ProtocolResponse.response; +import static com.baidu.bifromq.mqtt.handler.record.ProtocolResponse.responseNothing; import static com.baidu.bifromq.mqtt.handler.v5.MQTT5MessageUtils.isUTF8Payload; import static com.baidu.bifromq.mqtt.handler.v5.MQTT5MessageUtils.messageExpiryInterval; import static com.baidu.bifromq.mqtt.handler.v5.MQTT5MessageUtils.receiveMaximum; @@ -756,16 +757,11 @@ public ProtocolResponse onQoS0DistDenied(String topic, Message distMessage, Chec public ProtocolResponse onQoS0PubHandled(PubResult result, MqttPublishMessage message, UserProperties userProps) { if (result.distResult() == DistResult.BACK_PRESSURE_REJECTED || result.retainResult() == RetainReply.Result.BACK_PRESSURE_REJECTED) { - return farewell(MQTT5MessageBuilders.disconnect() - .reasonCode(MQTT5DisconnectReasonCode.ServerBusy) - .reasonString("Too many QoS0 publish") - .userProps(userProps) - .build(), - getLocal(ServerBusy.class) - .reason("Too many QoS0 publish") - .clientInfo(clientInfo)); + return responseNothing(getLocal(ServerBusy.class) + .reason("Too many QoS0 publish") + .clientInfo(clientInfo)); } else { - return ProtocolResponse.responseNothing(); + return responseNothing(); } } @@ -796,10 +792,7 @@ public ProtocolResponse onQoS1DistDenied(String topic, int packetId, Message dis public ProtocolResponse onQoS1PubHandled(PubResult result, MqttPublishMessage message, UserProperties userProps) { if (result.distResult() == DistResult.BACK_PRESSURE_REJECTED || result.retainResult() == RetainReply.Result.BACK_PRESSURE_REJECTED) { - return farewell(MQTT5MessageBuilders.disconnect() - .reasonCode(MQTT5DisconnectReasonCode.ServerBusy) - .reasonString("Too many QoS1 publish") - .build(), + return responseNothing( getLocal(ServerBusy.class) .reason("Too many QoS1 publish") .clientInfo(clientInfo)); @@ -878,13 +871,9 @@ public ProtocolResponse onQoS2DistDenied(String topic, int packetId, Message dis public ProtocolResponse onQoS2PubHandled(PubResult result, MqttPublishMessage message, UserProperties userProps) { if (result.distResult() == DistResult.BACK_PRESSURE_REJECTED || result.retainResult() == RetainReply.Result.BACK_PRESSURE_REJECTED) { - return farewell(MQTT5MessageBuilders.disconnect() - .reasonCode(MQTT5DisconnectReasonCode.ServerBusy) - .reasonString("Too many QoS2 publish") - .build(), - getLocal(ServerBusy.class) - .reason("Too many QoS2 publish") - .clientInfo(clientInfo)); + return responseNothing(getLocal(ServerBusy.class) + .reason("Too many QoS2 publish") + .clientInfo(clientInfo)); } int packetId = message.variableHeader().packetId(); Event[] debugEvents; diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/utils/MessageIdUtil.java b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/utils/MessageIdUtil.java index 946f030e2..95ca0fa85 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/utils/MessageIdUtil.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/utils/MessageIdUtil.java @@ -15,7 +15,8 @@ public class MessageIdUtil { - private static final long MESSAGE_SEQ_MASK = 0xFFFFFFFFL; + private static final long DRAIN_FLAG_MASK = 0x80000000L; + private static final long MESSAGE_SEQ_MASK = 0x7FFFFFFFL; private static final int INTEGER_BITS = 32; public static long syncWindowSequence(long nowMillis, long syncWindowIntervalMillis) { @@ -26,6 +27,14 @@ public static long messageId(long syncWindowSequence, long messageSequence) { return (syncWindowSequence << INTEGER_BITS) | messageSequence; } + public static long messageId(long syncWindowSequence, long messageSequence, boolean drainFlag) { + return (syncWindowSequence << INTEGER_BITS) | messageSequence | (drainFlag ? DRAIN_FLAG_MASK : 0); + } + + public static boolean isDrainFlagSet(long messageId) { + return (messageId & DRAIN_FLAG_MASK) != 0; + } + public static long syncWindowSequence(long messageId) { return messageId >> INTEGER_BITS; } @@ -45,7 +54,7 @@ public static boolean isSuccessive(long messageId, long successorMessageId) { } if (syncWindowSequence == successorSyncWindowSequence || syncWindowSequence + 1 == successorSyncWindowSequence) { - return messageSequence(messageId) + 1 == messageSequence(successorMessageId); + return (messageSequence(messageId) + 1) % DRAIN_FLAG_MASK == messageSequence(successorMessageId); } return messageSequence(successorMessageId) == 0; } diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v3/MQTT3TransientSessionHandlerTest.java b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v3/MQTT3TransientSessionHandlerTest.java index 4b54e61fd..15406ab2a 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v3/MQTT3TransientSessionHandlerTest.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v3/MQTT3TransientSessionHandlerTest.java @@ -369,6 +369,22 @@ public void handleQoS0PubDistError() { verifyEvent(QOS0_DIST_ERROR); } + @Test + public void handleQoS0PubDistBackPressured() { + mockCheckPermission(true); + mockDistBackPressure(); + assertTrue(channel.isOpen()); + + MqttPublishMessage message = MQTTMessageUtils.publishQoS0Message(topic, 123); + channel.writeInbound(message); + channel.advanceTimeBy(6, TimeUnit.SECONDS); + channel.runScheduledPendingTasks(); + // the channel is still open, but the message is dropped + assertTrue(channel.isOpen()); + verifyEvent(QOS0_DIST_ERROR, SERVER_BUSY); + } + + @Test public void handleQoS1Pub() { mockCheckPermission(true); @@ -397,7 +413,7 @@ public void handleQoS1PubDistError() { } @Test - public void handleQoS1PubDistRejected() { + public void handleQoS1PubDistBackPressured() { mockCheckPermission(true); mockDistBackPressure(); assertTrue(channel.isOpen()); @@ -406,7 +422,8 @@ public void handleQoS1PubDistRejected() { channel.writeInbound(message); channel.advanceTimeBy(6, TimeUnit.SECONDS); channel.runScheduledPendingTasks(); - assertFalse(channel.isOpen()); + // the channel is still open, but the message is dropped + assertTrue(channel.isOpen()); verifyEvent(QOS1_DIST_ERROR, SERVER_BUSY); } @@ -452,7 +469,7 @@ public void handleQoS2PubDistError() { } @Test - public void handleQoS2PubDistRejected() { + public void handleQoS2PubDistBackPressured() { mockCheckPermission(true); mockDistBackPressure(); assertTrue(channel.isOpen()); @@ -461,7 +478,8 @@ public void handleQoS2PubDistRejected() { channel.writeInbound(message); channel.advanceTimeBy(6, TimeUnit.SECONDS); channel.runScheduledPendingTasks(); - assertFalse(channel.isOpen()); + // the channel is still open, but the message is dropped + assertTrue(channel.isOpen()); verifyEvent(QOS2_DIST_ERROR, SERVER_BUSY); } diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v5/MQTT5TransientSessionHandlerTest.java b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v5/MQTT5TransientSessionHandlerTest.java index b29ab84dd..3ae209c33 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v5/MQTT5TransientSessionHandlerTest.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v5/MQTT5TransientSessionHandlerTest.java @@ -63,9 +63,9 @@ import static org.testng.Assert.assertTrue; import com.baidu.bifromq.dist.client.DistResult; +import com.baidu.bifromq.mqtt.handler.BaseSessionHandlerTest; import com.baidu.bifromq.mqtt.handler.ChannelAttrs; import com.baidu.bifromq.mqtt.handler.TenantSettings; -import com.baidu.bifromq.mqtt.handler.BaseSessionHandlerTest; import com.baidu.bifromq.mqtt.handler.v3.MQTT3TransientSessionHandler; import com.baidu.bifromq.mqtt.handler.v5.reason.MQTT5SubAckReasonCode; import com.baidu.bifromq.mqtt.session.MQTTSessionContext; @@ -133,7 +133,7 @@ public void setup(Method method) { // common mocks mockSettings(); MqttProperties mqttProperties = new MqttProperties(); - mqttProperties.add(new MqttProperties.IntegerProperty(TOPIC_ALIAS_MAXIMUM.value(), 10));; + mqttProperties.add(new MqttProperties.IntegerProperty(TOPIC_ALIAS_MAXIMUM.value(), 10)); ChannelDuplexHandler sessionHandlerAdder = new ChannelDuplexHandler() { @Override public void channelActive(ChannelHandlerContext ctx) throws Exception { @@ -331,6 +331,22 @@ public void handleQoS0PubDistError() { verifyEvent(QOS0_DIST_ERROR); } + @Test + public void handleQoS0PubDistBackPressured() { + mockCheckPermission(true); + mockDistBackPressure(); + assertTrue(channel.isOpen()); + + MqttPublishMessage message = MQTTMessageUtils.publishQoS0Message(topic, 123); + channel.writeInbound(message); + channel.advanceTimeBy(6, TimeUnit.SECONDS); + channel.runScheduledPendingTasks(); + channel.runPendingTasks(); + assertTrue(channel.isOpen()); + verifyEvent(QOS0_DIST_ERROR, SERVER_BUSY); + } + + @Test public void handleQoS1Pub() { mockCheckPermission(true); @@ -359,7 +375,7 @@ public void handleQoS1PubDistError() { } @Test - public void handleQoS1PubDistRejected() { + public void handleQoS1PubDistBackPressured() { mockCheckPermission(true); mockDistBackPressure(); assertTrue(channel.isOpen()); @@ -369,7 +385,8 @@ public void handleQoS1PubDistRejected() { channel.advanceTimeBy(6, TimeUnit.SECONDS); channel.runScheduledPendingTasks(); channel.runPendingTasks(); - assertFalse(channel.isOpen()); + // the channel is still open, but message is dropped + assertTrue(channel.isOpen()); verifyEvent(QOS1_DIST_ERROR, SERVER_BUSY); } @@ -415,7 +432,7 @@ public void handleQoS2PubDistError() { } @Test - public void handleQoS2PubDistRejected() { + public void handleQoS2PubDistBackPressured() { mockCheckPermission(true); mockDistBackPressure(); assertTrue(channel.isOpen()); @@ -425,7 +442,7 @@ public void handleQoS2PubDistRejected() { channel.advanceTimeBy(6, TimeUnit.SECONDS); channel.runScheduledPendingTasks(); channel.runPendingTasks(); - assertFalse(channel.isOpen()); + assertTrue(channel.isOpen()); verifyEvent(QOS2_DIST_ERROR, SERVER_BUSY); } @@ -591,7 +608,8 @@ public void qos1PubAndAck() { channel.runPendingTasks(); int messageCount = 3; - transientSessionHandler.publish(matchInfo(topicFilter), s2cMQTT5MessageList(topic, messageCount, QoS.AT_LEAST_ONCE)); + transientSessionHandler.publish(matchInfo(topicFilter), + s2cMQTT5MessageList(topic, messageCount, QoS.AT_LEAST_ONCE)); channel.runPendingTasks(); // s2c pub received and ack for (int i = 0; i < messageCount; i++) {