From f3fc74141048dd03867fc943d3736930f68e083d Mon Sep 17 00:00:00 2001 From: Nuclearist Date: Fri, 2 Feb 2024 21:14:46 +0300 Subject: [PATCH] Implemented thread-based downloading --- src/CDNClient.cs | 682 ++++++++++++------------------------- src/Manifest/DepotDelta.cs | 2 +- 2 files changed, 227 insertions(+), 457 deletions(-) diff --git a/src/CDNClient.cs b/src/CDNClient.cs index bfa2696..777ddd2 100644 --- a/src/CDNClient.cs +++ b/src/CDNClient.cs @@ -27,24 +27,17 @@ public class CDNClient /// CM client used to get CDN server list and manifest request codes. public required CM.CMClient CmClient { get; init; } /// - /// Number of servers that clients will simultaneously use when downloading depot content. The default value is . - /// The product of and is the number of simultaneous download tasks, scale it - /// accordingly to your network bandwidth and CPU capabilities. + /// Number of servers that clients will simultaneously use when downloading depot content. The default value is Math.Max(Environment.ProcessorCount * 6, 50). + /// It is also the number of download threads, scale it accordingly to your network bandwidth and CPU capabilities. /// - public static int NumDownloadServers { get; set; } = Environment.ProcessorCount; - /// - /// Number of simultaneous download tasks created per server. The default value is 4. - /// The product of and is the number of simultaneous download tasks, scale it - /// accordingly to your network bandwidth and CPU capabilities. - /// - public static int NumRequestsPerServer { get; set; } = 4; + public static int NumDownloadServers { get; set; } = Math.Max(Environment.ProcessorCount * 6, 50); /// Gets CDN server list if necessary. private void CheckServerList() { - if (_servers.Length >= NumDownloadServers) + if (_servers.Length > NumDownloadServers) return; - var servers = new List(NumDownloadServers); - while (servers.Count < NumDownloadServers) + var servers = new List(NumDownloadServers + 1); + while (servers.Count < NumDownloadServers + 1) servers.AddRange(Array.FindAll(CmClient.GetCDNServers(), s => s.Type is "SteamCache" or "CDN" && s.HttpsSupport is "mandatory" or "optional")); servers.Sort((left, right) => { @@ -61,290 +54,130 @@ private void CheckServerList() for (int i = 0; i < servers.Count; i++) _servers[i] = new(string.Concat("https://", servers[i].Host)); } - /// Downloads, decrypts, decompresses and writes chunk specified in the context. - /// An object. - private static async Task AcquireChunk(object? arg) + private static void DownloadThreadProcedure(object? arg) { - var context = (AcquisitionTaskContext)arg!; - byte[] buffer = context.Buffer; - int compressedSize = context.CompressedSize; - int uncompressedSize = context.UncompressedSize; - var cancellationToken = context.CancellationToken; - var aes = context.Aes; - Exception? exception = null; - var progress = context.Progress; + var context = (DownloadThreadContext)arg!; + byte[] buffer = GC.AllocateUninitializedArray(0x400000); + using var aes = Aes.Create(); + aes.Key = DepotDecryptionKeys[context.SharedContext.DepotId]; + var lzmaDecoder = new Utils.LZMA.Decoder(); + var httpClient = context.SharedContext.HttpClients[context.Index]; + var requestUri = new Uri($"depot/{context.SharedContext.DepotId}/chunk/0000000000000000000000000000000000000000", UriKind.Relative); + ref byte uriGid = ref Unsafe.As(ref MemoryMarshal.GetReference(requestUri.ToString().AsSpan()[^40..])); + var token = context.Cts.Token; var downloadBuffer = new Memory(buffer, 0, 0x200000); - for (int i = 0; i < 5; i++) //5 attempts, after which task fails + var bufferSpan = new Span(buffer); + var ivSpan = (ReadOnlySpan)bufferSpan[..16]; + var decryptedIvSpan = bufferSpan.Slice(0x3FFFF0, 16); + var decryptedDataSpan = new Span(buffer, 0x200000, 0x1FFFF0); + bool resumedContext = context.ChunkContext.FilePath is not null; + int fallbackServerIndex = NumDownloadServers; + for (;;) { - cancellationToken.ThrowIfCancellationRequested(); - try - { - //Download encrypted chunk data - var request = new HttpRequestMessage(HttpMethod.Get, context.RequestUri) { Version = HttpVersion.Version20 }; - using var response = await context.HttpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false); - using var content = response.EnsureSuccessStatusCode().Content; - using var stream = await content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); - int bytesRead; - using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - cts.CancelAfter(60000); - try { bytesRead = await stream.ReadAtLeastAsync(downloadBuffer, compressedSize, false, cts.Token).ConfigureAwait(false); } - catch (OperationCanceledException oce) - { - if (oce.CancellationToken == cancellationToken) - throw; - throw new TimeoutException(); - } - catch (AggregateException ae) when (ae.InnerException is OperationCanceledException oce) - { - if (oce.CancellationToken == cancellationToken) - throw; - throw new TimeoutException(); - } - if (bytesRead != compressedSize) - { - exception = new InvalidDataException($"Downloaded chunk data size doesn't match expected [URL: {context.HttpClient.BaseAddress}/{request.RequestUri}]"); - continue; - } - //Decrypt the data - aes.DecryptEcb(new ReadOnlySpan(buffer, 0, 16), new Span(buffer, 0x3FFFF0, 16), PaddingMode.None); - int decryptedDataSize = aes.DecryptCbc(new ReadOnlySpan(buffer, 16, compressedSize - 16), new ReadOnlySpan(buffer, 0x3FFFF0, 16), new Span(buffer, 0x200000, 0x1FFFF0)); - //Decompress the data - if (!context.LzmaDecoder.Decode(new ReadOnlySpan(buffer, 0x200000, decryptedDataSize), new Span(buffer, 0, uncompressedSize))) - { - exception = new InvalidDataException("LZMA decoding failed"); - continue; - } - if (Adler.ComputeChecksum(new ReadOnlySpan(buffer, 0, uncompressedSize)) != context.Checksum) - { - exception = new InvalidDataException("Adler checksum mismatch"); - continue; - } - exception = null; - } - catch (OperationCanceledException) { throw; } - catch (AggregateException ae) when (ae.InnerException is OperationCanceledException) { throw ae.InnerException; } - catch (Exception e) - { - exception = e; - continue; - } - } - if (exception is not null) - throw exception; - //Write acquired chunk data to the file - var handle = context.FileHandle; - await RandomAccess.WriteAsync(handle.Handle, new ReadOnlyMemory(buffer, 0, uncompressedSize), context.FileOffset, cancellationToken).ConfigureAwait(false); - handle.Release(); - progress.SubmitChunk(compressedSize); - } - /// Downloads depot content chunks. - /// State of the item. - /// The target manifest. - /// Delta object that lists data to be downloaded. - /// Token to monitor for cancellation requests. - internal void DownloadContent(ItemState state, DepotManifest manifest, DepotDelta delta, CancellationToken cancellationToken) - { - using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - var contexts = new AcquisitionTaskContext[NumDownloadServers * NumRequestsPerServer]; - var tasks = new Task?[contexts.Length]; - int numResumedContexts = 0; - Exception? exception = null; - string? chunkBufferFilePath = null; - LimitedUseFileHandle? chunkBufferFileHandle = null; - string baseRequestUrl = $"depot/{state.Id.DepotId}/chunk/"; - void downloadDir(in DirectoryEntry.AcquisitionEntry dir, string path, int recursionLevel) - { - int index; - if (state.ProgressIndexStack.Count > recursionLevel) - index = state.ProgressIndexStack[recursionLevel]; + if (resumedContext) + resumedContext = false; else { - state.ProgressIndexStack.Add(0); - index = 0; + context.ChunkContext = context.SharedContext.GetNextChunk(context.ChunkContext.CompressedSize); + if (context.ChunkContext.FilePath is null) + return; } - for (; index < dir.Files.Count; index++) + Unsafe.CopyBlockUnaligned(ref uriGid, ref Unsafe.As(ref MemoryMarshal.GetReference(context.ChunkContext.Gid.ToString().AsSpan())), 80); + int compressedSize = context.ChunkContext.CompressedSize; + int uncompressedSize = context.ChunkContext.UncompressedSize; + Exception? exception = null; + var dataSpan = (ReadOnlySpan)bufferSpan[16..compressedSize]; + var uncompressedDataSpan = bufferSpan[..uncompressedSize]; + for (int i = 0; i < 5; i++) //5 attempts, after which the thread fails { - var acquisitonFile = dir.Files[index]; - var file = manifest.FileBuffer[acquisitonFile.Index]; - if (file.Size is 0) - continue; - if (linkedCts.IsCancellationRequested) - { - state.ProgressIndexStack[recursionLevel] = index; + if (token.IsCancellationRequested) return; - } - int chunkRecLevel = recursionLevel + 1; - int chunkIndex; - if (state.ProgressIndexStack.Count > chunkRecLevel) - chunkIndex = state.ProgressIndexStack[chunkRecLevel]; - else - { - state.ProgressIndexStack.Add(0); - chunkIndex = 0; - } - if (acquisitonFile.Chunks.Count is 0) + try { - string filePath = Path.Join(path, file.Name); - var chunks = file.Chunks; - LimitedUseFileHandle? handle; - if (numResumedContexts > 0) + //Download encrypted chunk data + var request = new HttpRequestMessage(HttpMethod.Get, requestUri) { Version = HttpVersion.Version20 }; + using var response = httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, token).GetAwaiter().GetResult(); + using var content = response.EnsureSuccessStatusCode().Content; + using var stream = content.ReadAsStream(token); + int bytesRead; + using var cts = CancellationTokenSource.CreateLinkedTokenSource(token); + cts.CancelAfter(60000); + try { - handle = null; - for (int i = 0; i < numResumedContexts; i++) - if (contexts[i].FilePath == filePath) - { - handle = contexts[i].FileHandle; - break; - } - numResumedContexts = 0; - handle ??= new(File.OpenHandle(filePath, FileMode.OpenOrCreate, FileAccess.Write, options: FileOptions.RandomAccess | FileOptions.Asynchronous), chunks.Count); + var task = stream.ReadAtLeastAsync(downloadBuffer, compressedSize, false, cts.Token); + if (task.IsCompletedSuccessfully) + bytesRead = task.GetAwaiter().GetResult(); + else + bytesRead = task.AsTask().GetAwaiter().GetResult(); } - else - handle = new(File.OpenHandle(filePath, FileMode.OpenOrCreate, FileAccess.Write, options: FileOptions.RandomAccess | FileOptions.Asynchronous), chunks.Count); - for (; chunkIndex < chunks.Count; chunkIndex++) + catch (OperationCanceledException oce) { - if (linkedCts.IsCancellationRequested) - { - state.ProgressIndexStack[chunkRecLevel] = chunkIndex; - state.ProgressIndexStack[recursionLevel] = index; + if (oce.CancellationToken == token) return; - } - var chunk = chunks[chunkIndex]; - int contextIndex = -1; - for (int i = 0; i < contexts.Length; i++) - { - var task = tasks[i]; - if (task is null) - { - contextIndex = i; - break; - } - if (task.IsCompleted) - { - if (task.IsFaulted) - { - exception = task.Exception; - linkedCts.Cancel(); - state.ProgressIndexStack[chunkRecLevel] = chunkIndex; - state.ProgressIndexStack[recursionLevel] = index; - return; - } - contextIndex = i; - break; - } - } - if (contextIndex < 0) - { - try { contextIndex = Task.WaitAny(tasks!, linkedCts.Token); } - catch (OperationCanceledException) - { - state.ProgressIndexStack[chunkRecLevel] = chunkIndex; - state.ProgressIndexStack[recursionLevel] = index; - return; - } - var task = tasks[contextIndex]!; - if (task.IsFaulted) - { - exception = task.Exception; - linkedCts.Cancel(); - state.ProgressIndexStack[chunkRecLevel] = chunkIndex; - state.ProgressIndexStack[recursionLevel] = index; - return; - } - } - var context = contexts[contextIndex]; - context.CompressedSize = chunk.CompressedSize; - context.UncompressedSize = chunk.UncompressedSize; - context.Checksum = chunk.Checksum; - context.FileOffset = chunk.Offset; - context.FilePath = filePath; - context.RequestUri = new(string.Concat(baseRequestUrl, chunk.Gid.ToString()), UriKind.Relative); - context.FileHandle = handle; - tasks[contextIndex] = Task.Factory.StartNew(AcquireChunk, context, TaskCreationOptions.DenyChildAttach).Result; + exception = new TimeoutException(); + continue; } + if (bytesRead != compressedSize) + { + exception = new InvalidDataException($"Downloaded chunk data size doesn't match expected [URL: {httpClient.BaseAddress}/{request.RequestUri}]"); + continue; + } + //Decrypt the data + aes.DecryptEcb(ivSpan, decryptedIvSpan, PaddingMode.None); + int decryptedDataSize = aes.DecryptCbc(dataSpan, decryptedIvSpan, decryptedDataSpan); + //Decompress the data + if (!lzmaDecoder.Decode(decryptedDataSpan[..decryptedDataSize], uncompressedDataSpan)) + { + exception = new InvalidDataException("LZMA decoding failed"); + continue; + } + if (Adler.ComputeChecksum(uncompressedDataSpan) != context.ChunkContext.Checksum) + { + exception = new InvalidDataException("Adler checksum mismatch"); + continue; + } + exception = null; } - else + catch (OperationCanceledException) { return; } + catch (HttpRequestException hre) when (hre.StatusCode > HttpStatusCode.InternalServerError) { - var acquisitionChunks = acquisitonFile.Chunks; - for (; chunkIndex < acquisitionChunks.Count; chunkIndex++) + if (fallbackServerIndex is 0) { - if (linkedCts.IsCancellationRequested) - { - state.ProgressIndexStack[chunkRecLevel] = chunkIndex; - state.ProgressIndexStack[recursionLevel] = index; - return; - } - var acquisitionChunk = acquisitionChunks[chunkIndex]; - var chunk = manifest.ChunkBuffer[acquisitionChunk.Index]; - int contextIndex = -1; - for (int i = 0; i < contexts.Length; i++) - { - var task = tasks[i]; - if (task is null) - { - contextIndex = i; - break; - } - if (task.IsCompleted) - { - if (task.IsFaulted) - { - exception = task.Exception; - linkedCts.Cancel(); - state.ProgressIndexStack[chunkRecLevel] = chunkIndex; - state.ProgressIndexStack[recursionLevel] = index; - return; - } - contextIndex = i; - break; - } - } - if (contextIndex < 0) - { - try - { contextIndex = Task.WaitAny(tasks!, linkedCts.Token); } - catch (OperationCanceledException) - { - state.ProgressIndexStack[chunkRecLevel] = chunkIndex; - state.ProgressIndexStack[recursionLevel] = index; - return; - } - var task = tasks[contextIndex]!; - if (task.IsFaulted) - { - exception = task.Exception; - linkedCts.Cancel(); - state.ProgressIndexStack[chunkRecLevel] = chunkIndex; - state.ProgressIndexStack[recursionLevel] = index; - return; - } - } - var context = contexts[contextIndex]; - context.CompressedSize = chunk.CompressedSize; - context.UncompressedSize = chunk.UncompressedSize; - context.Checksum = chunk.Checksum; - context.FileOffset = acquisitionChunk.Offset; - context.FilePath = chunkBufferFilePath!; - context.RequestUri = new(string.Concat(baseRequestUrl, chunk.Gid.ToString()), UriKind.Relative); - context.FileHandle = chunkBufferFileHandle!; - tasks[contextIndex] = Task.Factory.StartNew(AcquireChunk, context, TaskCreationOptions.DenyChildAttach).Result; + httpClient = context.SharedContext.HttpClients[context.Index]; + fallbackServerIndex = NumDownloadServers; + } + else + { + httpClient = context.SharedContext.HttpClients[fallbackServerIndex]; + if (++fallbackServerIndex == context.SharedContext.HttpClients.Length) + fallbackServerIndex = 0; } } - state.ProgressIndexStack.RemoveAt(chunkRecLevel); - } - index -= dir.Files.Count; - for (; index < dir.Subdirectories.Count; index++) - { - var subdir = dir.Subdirectories[index]; - downloadDir(in subdir, Path.Join(path, manifest.DirectoryBuffer[subdir.Index].Name), recursionLevel + 1); - if (linkedCts.IsCancellationRequested) + catch (Exception e) { - state.ProgressIndexStack[recursionLevel] = dir.Files.Count + index; - return; + exception = e; + continue; } } - state.ProgressIndexStack.RemoveAt(recursionLevel); + if (exception is not null) + { + context.SharedContext.Exception = exception; + context.Cts.Cancel(); + return; + } + //Write acquired chunk data to the file + var handle = context.ChunkContext.FileHandle; + RandomAccess.Write(handle.Handle, uncompressedDataSpan, context.ChunkContext.FileOffset); + handle.Release(); } + } + /// Downloads depot content chunks. + /// State of the item. + /// The target manifest. + /// Delta object that lists data to be downloaded. + /// Token to monitor for cancellation requests. + internal void DownloadContent(ItemState state, DepotManifest manifest, DepotDelta delta, CancellationToken cancellationToken) + { if (state.Status is not ItemState.ItemStatus.Downloading) { state.Status = ItemState.ItemStatus.Downloading; @@ -352,102 +185,72 @@ void downloadDir(in DirectoryEntry.AcquisitionEntry dir, string path, int recurs state.DisplayProgress = 0; } StatusUpdated?.Invoke(Status.Downloading); - if (!DepotDecryptionKeys.TryGetValue(state.Id.DepotId, out var decryptionKey)) + if (!DepotDecryptionKeys.ContainsKey(state.Id.DepotId)) throw new SteamException(SteamException.ErrorType.DepotDecryptionKeyMissing); CheckServerList(); - var threadSafeProgress = new ThreadSafeProgress(ProgressUpdated, state); - var httpClients = new HttpClient[NumDownloadServers]; - for (int i = 0; i < httpClients.Length; i++) - httpClients[i] = new() + var sharedContext = new DownloadContext(state, manifest, delta, DownloadsDirectory!, ProgressUpdated) { HttpClients = new HttpClient[_servers.Length] }; + for (int i = 0; i < _servers.Length; i++) + sharedContext.HttpClients[i] = new() { BaseAddress = _servers[i], DefaultRequestVersion = HttpVersion.Version20, Timeout = TimeSpan.FromSeconds(10) }; + var contexts = new DownloadThreadContext[NumDownloadServers]; + using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); for (int i = 0; i < contexts.Length; i++) - { - contexts[i] = new(threadSafeProgress, httpClients[i % httpClients.Length], linkedCts.Token); - contexts[i].Aes.Key = decryptionKey; - } - string dwContextsFilePath = Path.Join(DownloadsDirectory!, $"{state.Id}.scdwcontexts"); - ProgressInitiated?.Invoke(ProgressType.Binary, delta.DownloadSize, state.DisplayProgress); - if (delta.ChunkBufferFileSize > 0) - { - chunkBufferFilePath = Path.Join(DownloadsDirectory!, $"{state.Id}.scchunkbuffer"); - chunkBufferFileHandle = new(File.OpenHandle(chunkBufferFilePath, FileMode.OpenOrCreate, FileAccess.Write, options: FileOptions.RandomAccess | FileOptions.Asynchronous), int.MaxValue); - } - if (File.Exists(dwContextsFilePath)) + contexts[i] = new() + { + Index = i, + Cts = linkedCts, + SharedContext = sharedContext + }; + string chunkContextsFilePath = Path.Join(DownloadsDirectory!, $"{state.Id}.scchcontexts"); + if (File.Exists(chunkContextsFilePath)) { Span buffer; - using (var fileHandle = File.OpenHandle(dwContextsFilePath)) + using (var fileHandle = File.OpenHandle(chunkContextsFilePath)) { buffer = GC.AllocateUninitializedArray((int)RandomAccess.GetLength(fileHandle)); RandomAccess.Read(fileHandle, buffer, 0); } - numResumedContexts = Unsafe.As(ref MemoryMarshal.GetReference(buffer)); - nint offset = 8; - int stringOffset = 8 + numResumedContexts * 32; - for (int i = 0; i < numResumedContexts; i++) - { - int numChunks = contexts[i].LoadFromBuffer(buffer, ref offset, ref stringOffset); - bool fileCreated = false; - if (chunkBufferFilePath is not null && contexts[i].FilePath == chunkBufferFilePath) - { - contexts[i].FileHandle = chunkBufferFileHandle!; - fileCreated = true; - } - else - for (int j = 0; j < i; j++) - if (contexts[j].FilePath == contexts[i].FilePath) - { - contexts[i].FileHandle = contexts[j].FileHandle; - fileCreated = true; - break; - } - if (!fileCreated) - contexts[i].FileHandle = new(File.OpenHandle(contexts[i].FilePath, FileMode.OpenOrCreate, FileAccess.Write, options: FileOptions.RandomAccess | FileOptions.Asynchronous), numChunks); - } - for (int i = 0; i < numResumedContexts; i++) - tasks[i] = Task.Factory.StartNew(AcquireChunk, contexts[i], TaskCreationOptions.DenyChildAttach); - } - downloadDir(in delta.AcquisitionTree, Path.Join(DownloadsDirectory!, state.Id.ToString()), 0); - foreach (var task in tasks) - { - if (task is null) - continue; - if (!task.IsCompleted) - Task.WaitAny([ task ], CancellationToken.None); - if (task.IsFaulted) - exception = task.Exception; + nint offset = 0; + int pathOffset = contexts.Length * 48; + for (int i = 0; i < contexts.Length; i++) + contexts[i].ChunkContext = ChunkContext.LoadFromBuffer(buffer, ref offset, ref pathOffset); } - chunkBufferFileHandle?.Handle?.Dispose(); - foreach (var context in contexts) - context.Dispose(); + ProgressInitiated?.Invoke(ProgressType.Binary, delta.DownloadSize, state.DisplayProgress); + var threads = new Thread[contexts.Length]; + for (int i = 0; i < threads.Length; i++) + threads[i] = new(DownloadThreadProcedure); + for (int i = 0; i < threads.Length; i++) + threads[i].UnsafeStart(contexts[i]); + foreach (var thread in threads) + thread.Join(); + sharedContext.CurrentFileHandle?.Handle.Close(); + sharedContext.ChunkBufferFileHandle?.Handle.Close(); + foreach (var client in sharedContext.HttpClients) + client.Dispose(); if (linkedCts.IsCancellationRequested) { state.SaveToFile(); - int numContextsToSave = 0; int contextsFileSize = 0; - for (int i = 0; i < contexts.Length; i++) - if (!(tasks[i]?.IsCompletedSuccessfully ?? true)) - { - var context = contexts[i]; - numContextsToSave++; - contextsFileSize += 32 + Encoding.UTF8.GetByteCount(context.FilePath) + Encoding.UTF8.GetByteCount(context.FilePath); - } - if (numContextsToSave > 0) + foreach (var context in contexts) + contextsFileSize += 48 + Encoding.UTF8.GetByteCount(context.ChunkContext.FilePath); + Span buffer = new byte[contextsFileSize]; + nint offset = 0; + int pathOffset = contexts.Length * 48; + foreach (var context in contexts) + context.ChunkContext.WriteToBuffer(buffer, ref offset, ref pathOffset); + using var fileHandle = File.OpenHandle(chunkContextsFilePath, FileMode.Create, FileAccess.Write, preallocationSize: buffer.Length); + RandomAccess.Write(fileHandle, buffer, 0); + var exception = sharedContext.Exception; + throw exception switch { - Span buffer = new byte[contextsFileSize + 8]; - Unsafe.As(ref MemoryMarshal.GetReference(buffer)) = numContextsToSave; - nint offset = 8; - int stringOffset = 8 + numContextsToSave * 32; - for (int i = 0; i < contexts.Length; i++) - if (!(tasks[i]?.IsCompletedSuccessfully ?? true)) - contexts[i].WriteToBuffer(buffer, ref offset, ref stringOffset); - using var fileHandle = File.OpenHandle(dwContextsFilePath, FileMode.Create, FileAccess.Write, preallocationSize: buffer.Length); - RandomAccess.Write(fileHandle, buffer, 0); - } - throw exception is null ? new OperationCanceledException(linkedCts.Token) : exception is SteamException ? exception : new SteamException(SteamException.ErrorType.DownloadFailed, exception); + null => new OperationCanceledException(linkedCts.Token), + SteamException => exception, + _ => new SteamException(SteamException.ErrorType.DownloadFailed, exception) + }; } } /// Preallocates all files for the download on the disk. @@ -661,106 +464,71 @@ public DepotPatch GetPatch(uint appId, ItemIdentifier item, DepotManifest source public event ProgressUpdatedHandler? ProgressUpdated; /// Called when client status is updated. public event StatusUpdatedHandler? StatusUpdated; - /// Persistent context for chunk acquisitions tasks. - /// Progress wrapper. - /// HTTP client with base address set to server to download from. - /// Token to monitor for cancellation requests. - private class AcquisitionTaskContext(ThreadSafeProgress progress, HttpClient httpClient, CancellationToken cancellationToken) : IDisposable + /// Context containing all the data needed to download a chunk. + private readonly struct ChunkContext { - /// Buffer for storing downloaded data and intermediate decrypted and decompressed data. - public byte[] Buffer { get; } = GC.AllocateUninitializedArray(0x400000); /// Size of LZMA-compressed chunk data. - public int CompressedSize { get; internal set; } - /// Size of uncompressed chunk data. If -1, chunk won't be decompressed. - public int UncompressedSize { get; internal set; } + public required int CompressedSize { get; init; } + /// Size of uncompressed chunk data. + public required int UncompressedSize { get; init; } /// Adler checksum of chunk data. - public uint Checksum { get; internal set; } + public required uint Checksum { get; init; } /// Offset of chunk data from the beginning of containing file. - public long FileOffset { get; internal set; } - /// Path to the file to write chunk to. - public string FilePath { get; internal set; } = string.Empty; - /// AES decryptor. - public Aes Aes { get; } = Aes.Create(); - /// Token to monitor for cancellation requests. - public CancellationToken CancellationToken { get; } = cancellationToken; - /// LZMA decoder. - public Utils.LZMA.Decoder LzmaDecoder { get; } = new(); - /// HTTP client used to download chunk data. - public HttpClient HttpClient { get; } = httpClient; - /// Handle of the file to write chunk to. - public LimitedUseFileHandle FileHandle { get; internal set; } = null!; - /// Progress wrapper. - public ThreadSafeProgress Progress { get; } = progress; - /// Relative chunk URL. - public Uri RequestUri { get; internal set; } = null!; - public void Dispose() - { - Aes.Dispose(); - HttpClient.Dispose(); - FileHandle?.Handle?.Dispose(); - } + public required long FileOffset { get; init; } + /// Path to the file to download chunk to. + public required string FilePath { get; init; } + /// Handle for the file to download chunk to. + public required LimitedUseFileHandle FileHandle { get; init; } + /// GID of the chunk. + public required SHA1Hash Gid { get; init; } /// Writes context data to a buffer. /// Buffer to write data to. /// Offset into to write context data to. - /// Offset into to write UTF-8 encoded strings to. - public void WriteToBuffer(Span buffer, ref nint offset, ref int stringOffset) + /// Offset into to write file path to. + public void WriteToBuffer(Span buffer, ref nint offset, ref int pathOffset) { ref byte bufferRef = ref MemoryMarshal.GetReference(buffer); nint entryOffset = offset; - Unsafe.As(ref Unsafe.AddByteOffset(ref bufferRef, entryOffset)) = CompressedSize; - Unsafe.As(ref Unsafe.AddByteOffset(ref bufferRef, entryOffset + 4)) = UncompressedSize; - Unsafe.As(ref Unsafe.AddByteOffset(ref bufferRef, entryOffset + 8)) = Checksum; - Unsafe.As(ref Unsafe.AddByteOffset(ref bufferRef, entryOffset + 12)) = FileHandle.ChunksLeft; - Unsafe.As(ref Unsafe.AddByteOffset(ref bufferRef, entryOffset + 16)) = FileOffset; - int stringLength = Encoding.UTF8.GetBytes(FilePath, buffer[stringOffset..]); - Unsafe.As(ref Unsafe.AddByteOffset(ref bufferRef, entryOffset + 24)) = stringLength; - stringOffset += stringLength; - stringLength = Encoding.UTF8.GetBytes(RequestUri.ToString(), buffer[stringOffset..]); - Unsafe.As(ref Unsafe.AddByteOffset(ref bufferRef, entryOffset + 28)) = stringLength; - stringOffset += stringLength; - offset += 32; + Unsafe.As(ref Unsafe.AddByteOffset(ref bufferRef, entryOffset)) = Gid; + Unsafe.As(ref Unsafe.AddByteOffset(ref bufferRef, entryOffset + 24)) = CompressedSize; + Unsafe.As(ref Unsafe.AddByteOffset(ref bufferRef, entryOffset + 28)) = UncompressedSize; + Unsafe.As(ref Unsafe.AddByteOffset(ref bufferRef, entryOffset + 32)) = Checksum; + int pathLength = Encoding.UTF8.GetBytes(FilePath, buffer[pathOffset..]); + Unsafe.As(ref Unsafe.AddByteOffset(ref bufferRef, entryOffset + 36)) = pathLength; + pathOffset += pathLength; + Unsafe.As(ref Unsafe.AddByteOffset(ref bufferRef, entryOffset + 40)) = FileOffset; + offset += 48; } /// Loads context data from a buffer. /// Buffer to read data from. /// Offset into to read context data from. - /// Offset into to read UTF-8 encoded strings from. - public int LoadFromBuffer(ReadOnlySpan buffer, ref nint offset, ref int stringOffset) + /// Offset into to read file path from. + /// Loaded chunk context. + public static ChunkContext LoadFromBuffer(ReadOnlySpan buffer, ref nint offset, ref int pathOffset) { ref byte bufferRef = ref MemoryMarshal.GetReference(buffer); nint entryOffset = offset; - CompressedSize = Unsafe.As(ref Unsafe.AddByteOffset(ref bufferRef, entryOffset)); - UncompressedSize = Unsafe.As(ref Unsafe.AddByteOffset(ref bufferRef, entryOffset + 4)); - Checksum = Unsafe.As(ref Unsafe.AddByteOffset(ref bufferRef, entryOffset + 8)); - int chunksLeft = Unsafe.As(ref Unsafe.AddByteOffset(ref bufferRef, entryOffset + 12)); - FileOffset = Unsafe.As(ref Unsafe.AddByteOffset(ref bufferRef, entryOffset + 16)); - int stringLength = Unsafe.As(ref Unsafe.AddByteOffset(ref bufferRef, entryOffset + 24)); - FilePath = Encoding.UTF8.GetString(buffer.Slice(stringOffset, stringLength)); - stringOffset += stringLength; - stringLength = Unsafe.As(ref Unsafe.AddByteOffset(ref bufferRef, entryOffset + 28)); - RequestUri = new(Encoding.UTF8.GetString(buffer.Slice(stringOffset, stringLength)), UriKind.Relative); - stringOffset += stringLength; - offset += 32; - return chunksLeft; + var gid = new SHA1Hash(buffer.Slice(offset.ToInt32(), 20)); + int compressedSize = Unsafe.As(ref Unsafe.AddByteOffset(ref bufferRef, entryOffset + 24)); + int uncompressedSize = Unsafe.As(ref Unsafe.AddByteOffset(ref bufferRef, entryOffset + 28)); + uint checksum = Unsafe.As(ref Unsafe.AddByteOffset(ref bufferRef, entryOffset + 32)); + int pathLength = Unsafe.As(ref Unsafe.AddByteOffset(ref bufferRef, entryOffset + 36)); + string filePath = Encoding.UTF8.GetString(buffer.Slice(pathOffset, pathLength)); + pathOffset += pathLength; + long fileOffset = Unsafe.As(ref Unsafe.AddByteOffset(ref bufferRef, entryOffset + 40)); + offset += 48; + return new() + { + CompressedSize = compressedSize, + UncompressedSize = uncompressedSize, + Checksum = checksum, + FileOffset = fileOffset, + FilePath = filePath, + FileHandle = new(File.OpenHandle(filePath, FileMode.OpenOrCreate, FileAccess.Write, FileShare.ReadWrite, FileOptions.RandomAccess | FileOptions.Asynchronous), 1), + Gid = gid + }; } } - /// Context containing all the data needed to download a chunk. - private readonly struct ChunkContext - { - /// Size of LZMA-compressed chunk data. - public required int CompressedSize { get; init; } - /// Size of uncompressed chunk data. - public required int UncompressedSize { get; init; } - /// Adler checksum of chunk data. - public required uint Checksum { get; init; } - /// Offset of chunk data from the beginning of containing file. - public required long FileOffset { get; init; } - /// Path to the file to download chunk to. - public required string FilePath { get; init; } - /// Handle for the file to download chunk to. - public required LimitedUseFileHandle FileHandle { get; init; } - /// GID of the chunk. - public required SHA1Hash Gid { get; init; } - } /// File handle wrapper that releases the handle after the last chunk has been written to the file. /// File handle. /// The number of chunks that will be written to the file. @@ -798,7 +566,7 @@ static int getDirTreeDepth(in DirectoryEntry.AcquisitionEntry dir) } int dirTreeDepth = getDirTreeDepth(in delta.AcquisitionTree); _pathTree = new string[dirTreeDepth + 1]; - _pathTree[0] = basePath; + _pathTree[0] = Path.Join(basePath, state.Id.ToString()); _currentDirTree = new DirectoryEntry.AcquisitionEntry[dirTreeDepth]; _currentDirTree[0] = delta.AcquisitionTree; var indexStack = state.ProgressIndexStack; @@ -841,20 +609,20 @@ bool findFirstChunk(in DirectoryEntry.AcquisitionEntry dir) if (delta.ChunkBufferFileSize > 0) { _chunkBufferFilePath = Path.Join(basePath, $"{state.Id}.scchunkbuffer"); - _chunkBufferFileHandle = new(File.OpenHandle(_chunkBufferFilePath, FileMode.OpenOrCreate, FileAccess.Write, FileShare.ReadWrite, options: FileOptions.RandomAccess | FileOptions.Asynchronous), int.MaxValue); + ChunkBufferFileHandle = new(File.OpenHandle(_chunkBufferFilePath, FileMode.OpenOrCreate, FileAccess.Write, FileShare.ReadWrite, options: FileOptions.RandomAccess | FileOptions.Asynchronous), int.MaxValue); } if (state.ProgressIndexStack[^1] > 0) { if (_currentFile.Chunks.Count is 0) { _currentFilePath = Path.Join(_pathTree); - _currentFileHandle = new(File.OpenHandle(_currentFilePath, FileMode.OpenOrCreate, FileAccess.Write, FileShare.ReadWrite, FileOptions.RandomAccess | FileOptions.Asynchronous), + CurrentFileHandle = new(File.OpenHandle(_currentFilePath, FileMode.OpenOrCreate, FileAccess.Write, FileShare.ReadWrite, FileOptions.RandomAccess | FileOptions.Asynchronous), (_currentFile.Chunks.Count is 0 ? manifest.FileBuffer[_currentFile.Index].Chunks.Count : _currentFile.Chunks.Count) - state.ProgressIndexStack[^1]); } else { _currentFilePath = _chunkBufferFilePath!; - _currentFileHandle = _chunkBufferFileHandle!; + CurrentFileHandle = ChunkBufferFileHandle!; } } _progressUpdatedHandler = handler; @@ -863,8 +631,6 @@ bool findFirstChunk(in DirectoryEntry.AcquisitionEntry dir) private string? _currentFilePath; /// Entry for the currently selected file. private FileEntry.AcquisitionEntry _currentFile; - /// Handle for the currently selected file. - private LimitedUseFileHandle? _currentFileHandle; /// Path to the chunk buffer file. private readonly string? _chunkBufferFilePath; /// Array of directory and file names to compose path from. @@ -875,10 +641,18 @@ bool findFirstChunk(in DirectoryEntry.AcquisitionEntry dir) private readonly DirectoryEntry.AcquisitionEntry[] _currentDirTree; /// Item state. private readonly ItemState _state; - /// Handle for the chunk buffer file. - private readonly LimitedUseFileHandle? _chunkBufferFileHandle; /// Called when progress value is updated. private readonly ProgressUpdatedHandler? _progressUpdatedHandler; + /// Gets item depot ID. + public uint DepotId => _state.Id.DepotId; + /// Exception thrown by one of the download threads. + public Exception? Exception { get; set; } + /// HTTP clients for all CDN servers. + public required HttpClient[] HttpClients { get; init; } + /// Handle for the chunk buffer file. + public LimitedUseFileHandle? ChunkBufferFileHandle { get; } + /// Handle for the currently selected file. + public LimitedUseFileHandle? CurrentFileHandle { get; private set; } /// Submits progress for the previous chunk, gets context for the next chunk or if the are no more chunks and moves index stack to the next chunk. public ChunkContext GetNextChunk(long previousChunkSize) { @@ -898,12 +672,12 @@ public ChunkContext GetNextChunk(long previousChunkSize) if (_currentFile.Chunks.Count is 0) { _currentFilePath = Path.Join(_pathTree); - _currentFileHandle = new(File.OpenHandle(_currentFilePath, FileMode.OpenOrCreate, FileAccess.Write, FileShare.ReadWrite, FileOptions.RandomAccess | FileOptions.Asynchronous), _manifest.FileBuffer[_currentFile.Index].Chunks.Count); + CurrentFileHandle = new(File.OpenHandle(_currentFilePath, FileMode.OpenOrCreate, FileAccess.Write, FileShare.ReadWrite, FileOptions.RandomAccess | FileOptions.Asynchronous), _manifest.FileBuffer[_currentFile.Index].Chunks.Count); } else { _currentFilePath = _chunkBufferFilePath!; - _currentFileHandle = _chunkBufferFileHandle!; + CurrentFileHandle = ChunkBufferFileHandle!; } } ChunkEntry chunk; @@ -989,26 +763,22 @@ bool findNextChunk(in DirectoryEntry.AcquisitionEntry dir, int recursionLevel) Checksum = chunk.Checksum, FileOffset = chunkBufferOffset >= 0 ? chunkBufferOffset : chunk.Offset, FilePath = _currentFilePath!, - FileHandle = _currentFileHandle!, + FileHandle = CurrentFileHandle!, Gid = chunk.Gid }; } } } - /// Thread-safe wrapper for updating progress value. - /// Event handler called when progress is updated. - /// Depot state object that holds progress value. - private class ThreadSafeProgress(ProgressUpdatedHandler? handler, ItemState state) + /// Individual context for download threads. + private class DownloadThreadContext { - /// Updates progress value by adding chunk size to it. - /// Size of LZMA-compressed chunk data. - public void SubmitChunk(int chunkSize) - { - lock (this) - { - state.DisplayProgress += chunkSize; - handler?.Invoke(state.DisplayProgress); - } - } + /// Thread index, used to select download server. + public required int Index { get; init; } + /// Cancellation token source for all download threads. + public required CancellationTokenSource Cts { get; init; } + /// Current chunk context. + public ChunkContext ChunkContext { get; set; } + /// Shared download context. + public required DownloadContext SharedContext { get; init; } } } \ No newline at end of file diff --git a/src/Manifest/DepotDelta.cs b/src/Manifest/DepotDelta.cs index de88731..69a5642 100644 --- a/src/Manifest/DepotDelta.cs +++ b/src/Manifest/DepotDelta.cs @@ -47,7 +47,7 @@ void countAcq(in DirectoryEntry dir, DirectoryEntry.AcquisitionStaging acquisiti } } foreach (var subdir in acquisitionDir.Subdirectories) - countAcq(dir.Subdirectories[subdir.Index], subdir); + countAcq(in manifest.DirectoryBuffer[subdir.Index], subdir); } countAcq(in manifest.Root, acquisitionTree); ChunkBufferFileSize = chunkBufferFileSize;