From d553069d7455e98ad5e4909bffb57b0589e7e371 Mon Sep 17 00:00:00 2001 From: JT Date: Thu, 16 Nov 2023 20:34:40 -0800 Subject: [PATCH] always upgrade torch+cu118 to torch+cu121 for invoke --- .../Models/Packages/InvokeAI.cs | 83 ++++++------------- 1 file changed, 26 insertions(+), 57 deletions(-) diff --git a/StabilityMatrix.Core/Models/Packages/InvokeAI.cs b/StabilityMatrix.Core/Models/Packages/InvokeAI.cs index 8840f8b92..f54d7e68a 100644 --- a/StabilityMatrix.Core/Models/Packages/InvokeAI.cs +++ b/StabilityMatrix.Core/Models/Packages/InvokeAI.cs @@ -188,73 +188,42 @@ public override async Task InstallPackage( var pipCommandArgs = "-e . --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu"; - var installTorch21 = versionOptions.IsLatest; - - if (!string.IsNullOrWhiteSpace(versionOptions.VersionTag) && !versionOptions.IsLatest) - { - if ( - Version.TryParse(versionOptions.VersionTag, out var version) - && version >= new Version(3, 4) - ) - { - installTorch21 = true; - } - } - switch (torchVersion) { // If has Nvidia Gpu, install CUDA version case TorchVersion.Cuda: - if (installTorch21) + progress?.Report( + new ProgressReport(-1f, "Installing PyTorch for CUDA", isIndeterminate: true) + ); + + var args = new List(); + if (exists) { - progress?.Report( - new ProgressReport( - -1f, - "Installing PyTorch for CUDA", - isIndeterminate: true - ) + var pipPackages = await venvRunner.PipList().ConfigureAwait(false); + var hasCuda121 = pipPackages.Any( + p => p.Name == "torch" && p.Version.Contains("cu121") ); - - var args = new List(); - if (exists) + if (!hasCuda121) { - var pipPackages = await venvRunner.PipList().ConfigureAwait(false); - var hasCuda121 = pipPackages.Any( - p => p.Name == "torch" && p.Version.Contains("cu121") - ); - if (!hasCuda121) - { - args.Add("--upgrade"); - args.Add("--force-reinstall"); - } + args.Add("--upgrade"); + args.Add("--force-reinstall"); } - - await venvRunner - .PipInstall( - new PipInstallArgs( - args.Any() ? args.ToArray() : Array.Empty() - ) - .WithTorch("==2.1.0") - .WithTorchVision("==0.16.0") - .WithXFormers("==0.0.22post7") - .WithTorchExtraIndex("cu121"), - onConsoleOutput - ) - .ConfigureAwait(false); - - Logger.Info("Starting InvokeAI install (CUDA)..."); - pipCommandArgs = - "-e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121"; - } - else - { - await InstallCudaTorch(venvRunner, progress, onConsoleOutput) - .ConfigureAwait(false); - Logger.Info("Starting InvokeAI install (CUDA)..."); - pipCommandArgs = - "-e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118"; } + await venvRunner + .PipInstall( + new PipInstallArgs(args.Any() ? args.ToArray() : Array.Empty()) + .WithTorch("==2.1.0") + .WithTorchVision("==0.16.0") + .WithXFormers("==0.0.22post7") + .WithTorchExtraIndex("cu121"), + onConsoleOutput + ) + .ConfigureAwait(false); + + Logger.Info("Starting InvokeAI install (CUDA)..."); + pipCommandArgs = + "-e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121"; break; // For AMD, Install ROCm version case TorchVersion.Rocm: