diff --git a/src/neo/Network/P2P/TaskManager.cs b/src/neo/Network/P2P/TaskManager.cs index 0fab0e446e..c1590657a4 100644 --- a/src/neo/Network/P2P/TaskManager.cs +++ b/src/neo/Network/P2P/TaskManager.cs @@ -35,7 +35,7 @@ private class Timer { } /// private readonly HashSetCache knownHashes; private readonly Dictionary globalInvTasks = new Dictionary(); - private readonly Dictionary globalIndexTasks = new Dictionary(); + private readonly Dictionary> globalIndexTasks = new Dictionary>(); private readonly Dictionary sessions = new Dictionary(); private readonly ICancelable timer = Context.System.Scheduler.ScheduleTellRepeatedlyCancelable(TimerInterval, TimerInterval, Context.Self, new Timer(), ActorRefs.NoSender); @@ -225,14 +225,14 @@ private void DecrementGlobalTask(UInt256 hash) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void DecrementGlobalTask(uint index) + private void DecrementGlobalTask(uint index, IActorRef actor) { if (globalIndexTasks.TryGetValue(index, out var value)) { - if (value == 1) - globalIndexTasks.Remove(index); - else - globalIndexTasks[index] = value - 1; + if (value.Remove(actor)) + { + if (value.Count == 0) globalIndexTasks.Remove(index); + } } } @@ -252,17 +252,19 @@ private bool IncrementGlobalTask(UInt256 hash) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private bool IncrementGlobalTask(uint index) + private bool IncrementGlobalTask(uint index, IActorRef actor, uint count) { if (!globalIndexTasks.TryGetValue(index, out var value)) { - globalIndexTasks[index] = 1; + value = new Dictionary(); + value[actor] = count; + globalIndexTasks[index] = value; return true; } - if (value >= MaxConncurrentTasks) + if (value.Count >= MaxConncurrentTasks) return false; - globalIndexTasks[index] = value + 1; + globalIndexTasks[index][actor] = count; return true; } @@ -273,13 +275,13 @@ private void OnTerminated(IActorRef actor) foreach (UInt256 hash in session.InvTasks.Keys) DecrementGlobalTask(hash); foreach (uint index in session.IndexTasks.Keys) - DecrementGlobalTask(index); + DecrementGlobalTask(index, actor); sessions.Remove(actor); } private void OnTimer() { - foreach (TaskSession session in sessions.Values) + foreach (var (actor, session) in sessions) { foreach (var (hash, time) in session.InvTasks.ToArray()) if (TimeProvider.Current.UtcNow - time > TaskTimeout) @@ -291,7 +293,7 @@ private void OnTimer() if (TimeProvider.Current.UtcNow - time > TaskTimeout) { if (session.IndexTasks.Remove(index)) - DecrementGlobalTask(index); + DecrementGlobalTask(index, actor); } } foreach (var (actor, session) in sessions) @@ -311,8 +313,6 @@ public static Props Props(NeoSystem system) private void RequestTasks(IActorRef remoteNode, TaskSession session) { - if (session.HasTask) return; - DataCache snapshot = Blockchain.Singleton.View; // If there are pending tasks of InventoryType.Block we should process them @@ -351,7 +351,14 @@ private void RequestTasks(IActorRef remoteNode, TaskSession session) else if (currentHeight < session.LastBlockIndex) { uint startHeight = currentHeight; - while (globalIndexTasks.ContainsKey(++startHeight)) { } + foreach (var (index, pair) in globalIndexTasks) + { + foreach (uint _count in pair.Values) + { + uint maxRequiredHeight = index + _count; + startHeight = startHeight < maxRequiredHeight ? maxRequiredHeight : startHeight; + } + } if (startHeight > session.LastBlockIndex) return; uint endHeight = startHeight; while (!globalIndexTasks.ContainsKey(++endHeight) && endHeight <= session.LastBlockIndex) { } @@ -359,7 +366,7 @@ private void RequestTasks(IActorRef remoteNode, TaskSession session) for (uint i = 0; i < count; i++) { session.IndexTasks[startHeight + i] = TimeProvider.Current.UtcNow; - IncrementGlobalTask(startHeight + i); + IncrementGlobalTask(startHeight + i, remoteNode, count); } remoteNode.Tell(Message.Create(MessageCommand.GetBlockByIndex, GetBlockByIndexPayload.Create(startHeight, (short)count))); }