Skip to content

Commit

Permalink
always upgrade torch+cu118 to torch+cu121 for invoke
Browse files Browse the repository at this point in the history
  • Loading branch information
mohnjiles committed Nov 17, 2023
1 parent d751efc commit d553069
Showing 1 changed file with 26 additions and 57 deletions.
83 changes: 26 additions & 57 deletions StabilityMatrix.Core/Models/Packages/InvokeAI.cs
Original file line number Diff line number Diff line change
Expand Up @@ -188,73 +188,42 @@ public override async Task InstallPackage(
var pipCommandArgs =
"-e . --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu";

var installTorch21 = versionOptions.IsLatest;

if (!string.IsNullOrWhiteSpace(versionOptions.VersionTag) && !versionOptions.IsLatest)
{
if (
Version.TryParse(versionOptions.VersionTag, out var version)
&& version >= new Version(3, 4)
)
{
installTorch21 = true;
}
}

switch (torchVersion)
{
// If has Nvidia Gpu, install CUDA version
case TorchVersion.Cuda:
if (installTorch21)
progress?.Report(
new ProgressReport(-1f, "Installing PyTorch for CUDA", isIndeterminate: true)
);

var args = new List<Argument>();
if (exists)
{
progress?.Report(
new ProgressReport(
-1f,
"Installing PyTorch for CUDA",
isIndeterminate: true
)
var pipPackages = await venvRunner.PipList().ConfigureAwait(false);
var hasCuda121 = pipPackages.Any(
p => p.Name == "torch" && p.Version.Contains("cu121")
);

var args = new List<Argument>();
if (exists)
if (!hasCuda121)
{
var pipPackages = await venvRunner.PipList().ConfigureAwait(false);
var hasCuda121 = pipPackages.Any(
p => p.Name == "torch" && p.Version.Contains("cu121")
);
if (!hasCuda121)
{
args.Add("--upgrade");
args.Add("--force-reinstall");
}
args.Add("--upgrade");
args.Add("--force-reinstall");
}

await venvRunner
.PipInstall(
new PipInstallArgs(
args.Any() ? args.ToArray() : Array.Empty<Argument>()
)
.WithTorch("==2.1.0")
.WithTorchVision("==0.16.0")
.WithXFormers("==0.0.22post7")
.WithTorchExtraIndex("cu121"),
onConsoleOutput
)
.ConfigureAwait(false);

Logger.Info("Starting InvokeAI install (CUDA)...");
pipCommandArgs =
"-e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121";
}
else
{
await InstallCudaTorch(venvRunner, progress, onConsoleOutput)
.ConfigureAwait(false);
Logger.Info("Starting InvokeAI install (CUDA)...");
pipCommandArgs =
"-e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118";
}

await venvRunner
.PipInstall(
new PipInstallArgs(args.Any() ? args.ToArray() : Array.Empty<Argument>())
.WithTorch("==2.1.0")
.WithTorchVision("==0.16.0")
.WithXFormers("==0.0.22post7")
.WithTorchExtraIndex("cu121"),
onConsoleOutput
)
.ConfigureAwait(false);

Logger.Info("Starting InvokeAI install (CUDA)...");
pipCommandArgs =
"-e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121";
break;
// For AMD, Install ROCm version
case TorchVersion.Rocm:
Expand Down

0 comments on commit d553069

Please sign in to comment.