diff --git a/src/main/java/ee/carlrobert/codegpt/completions/HuggingFaceModel.java b/src/main/java/ee/carlrobert/codegpt/completions/HuggingFaceModel.java index bd1157b25..3734fd94a 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/HuggingFaceModel.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/HuggingFaceModel.java @@ -4,6 +4,8 @@ import java.net.MalformedURLException; import java.net.URL; +import java.util.List; +import java.util.stream.IntStream; public enum HuggingFaceModel { @@ -89,21 +91,25 @@ public String getCode() { return name(); } - public String getFileName() { + public List getFileNames() { if ("TheBloke".equals(user)) { - return modelName.toLowerCase().replace("-gguf", format(".Q%d_K_M.gguf", quantization)); + return List.of(modelName.toLowerCase().replace("-gguf", format(".Q%d_K_M.gguf", quantization))); } - // TODO: Download all 10 files ;( - return modelName.toLowerCase().replace("-gguf", "-00001-of-00010.gguf"); + if ("phymbert".equals(user)) { + return IntStream.range(1, 11).mapToObj(i -> modelName + .replace("-gguf", "-000%02d-of-00010.gguf".formatted(i))).toList(); + } + return List.of(modelName); } - public URL getFileURL() { - try { - return new URL( - "https://huggingface.co/%s/%s/resolve/main/%s".formatted(user, getDirectory(), getFileName())); - } catch (MalformedURLException ex) { - throw new RuntimeException(ex); - } + public List getFileURLs() { + return getFileNames().stream().map(file -> { + try { + return new URL("https://huggingface.co/%s/%s/resolve/main/%s".formatted(user, getDirectory(), file)); + } catch (MalformedURLException ex) { + throw new RuntimeException(ex); + } + }).toList(); } public URL getHuggingFaceURL() { diff --git a/src/main/java/ee/carlrobert/codegpt/settings/service/llama/form/DownloadModelAction.java b/src/main/java/ee/carlrobert/codegpt/settings/service/llama/form/DownloadModelAction.java deleted file mode 100644 index 87f6c2361..000000000 --- a/src/main/java/ee/carlrobert/codegpt/settings/service/llama/form/DownloadModelAction.java +++ /dev/null @@ -1,106 +0,0 @@ -package ee.carlrobert.codegpt.settings.service.llama.form; - -import static java.lang.String.format; - -import com.intellij.openapi.actionSystem.AnAction; -import com.intellij.openapi.actionSystem.AnActionEvent; -import com.intellij.openapi.diagnostic.Logger; -import com.intellij.openapi.progress.ProgressIndicator; -import com.intellij.openapi.progress.ProgressManager; -import com.intellij.openapi.progress.Task; -import com.intellij.openapi.project.Project; -import ee.carlrobert.codegpt.CodeGPTBundle; -import ee.carlrobert.codegpt.completions.HuggingFaceModel; -import ee.carlrobert.codegpt.util.DownloadingUtil; -import ee.carlrobert.codegpt.util.file.FileUtil; -import java.io.IOException; -import java.net.URL; -import java.util.concurrent.Executors; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.TimeUnit; -import java.util.function.Consumer; -import javax.swing.DefaultComboBoxModel; -import org.jetbrains.annotations.NotNull; - -public class DownloadModelAction extends AnAction { - - private static final Logger LOG = Logger.getInstance(DownloadModelAction.class); - - private final Consumer onDownload; - private final Runnable onDownloaded; - private final Consumer onFailed; - private final Consumer onUpdateProgress; - private final DefaultComboBoxModel comboBoxModel; - - public DownloadModelAction( - Consumer onDownload, - Runnable onDownloaded, - Consumer onFailed, - Consumer onUpdateProgress, - DefaultComboBoxModel comboBoxModel) { - this.onDownload = onDownload; - this.onDownloaded = onDownloaded; - this.onFailed = onFailed; - this.onUpdateProgress = onUpdateProgress; - this.comboBoxModel = comboBoxModel; - } - - @Override - public void actionPerformed(@NotNull AnActionEvent e) { - ProgressManager.getInstance().run(new DownloadBackgroundTask(e.getProject())); - } - - class DownloadBackgroundTask extends Task.Backgroundable { - - DownloadBackgroundTask(Project project) { - super( - project, - CodeGPTBundle.get("settingsConfigurable.service.llama.progress.downloadingModel.title"), - true); - } - - @Override - public void run(@NotNull ProgressIndicator indicator) { - var model = (HuggingFaceModel) comboBoxModel.getSelectedItem(); - URL url = model.getFileURL(); - ScheduledExecutorService executorService = Executors.newSingleThreadScheduledExecutor(); - ScheduledFuture progressUpdateScheduler = null; - - try { - onDownload.accept(indicator); - - indicator.setIndeterminate(false); - indicator.setText(format( - CodeGPTBundle.get( - "settingsConfigurable.service.llama.progress.downloadingModelIndicator.text"), - model.getFileName())); - - long fileSize = url.openConnection().getContentLengthLong(); - long[] bytesRead = {0}; - long startTime = System.currentTimeMillis(); - - progressUpdateScheduler = executorService.scheduleAtFixedRate(() -> - onUpdateProgress.accept(DownloadingUtil.getFormattedDownloadProgress( - startTime, - fileSize, - bytesRead[0])), - 0, 1, TimeUnit.SECONDS); - FileUtil.copyFileWithProgress(model.getFileName(), url, bytesRead, fileSize, indicator); - } catch (IOException ex) { - LOG.error("Unable to open connection", ex); - onFailed.accept(ex); - } finally { - if (progressUpdateScheduler != null) { - progressUpdateScheduler.cancel(true); - } - executorService.shutdown(); - } - } - - @Override - public void onSuccess() { - onDownloaded.run(); - } - } -} \ No newline at end of file diff --git a/src/main/java/ee/carlrobert/codegpt/settings/service/llama/form/LlamaModelPreferencesForm.java b/src/main/java/ee/carlrobert/codegpt/settings/service/llama/form/LlamaModelPreferencesForm.java index 424286730..4b859c269 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/service/llama/form/LlamaModelPreferencesForm.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/service/llama/form/LlamaModelPreferencesForm.java @@ -195,7 +195,8 @@ public InfillPromptTemplate getInfillPromptTemplate() { public String getActualModelPath() { return isUseCustomLlamaModel() ? getCustomLlamaModelPath() - : CodeGPTPlugin.getLlamaModelsPath() + File.separator + getSelectedModel().getFileName(); + : CodeGPTPlugin.getLlamaModelsPath() + File.separator + + getSelectedModel().getFileNames().get(0); } private JPanel createFormPanelCards() { @@ -394,8 +395,9 @@ private TextFieldWithBrowseButton createBrowsableCustomModelTextField(boolean en } private boolean isModelExists(HuggingFaceModel model) { - return FileUtil.exists( - CodeGPTPlugin.getLlamaModelsPath() + File.separator + model.getFileName()); + return model.getFileNames().stream().allMatch(filename -> + FileUtil.exists(CodeGPTPlugin.getLlamaModelsPath() + File.separator + filename) + ); } private AnActionLink createCancelDownloadLink( diff --git a/src/main/java/ee/carlrobert/codegpt/settings/service/llama/form/LlamaServerPreferencesForm.java b/src/main/java/ee/carlrobert/codegpt/settings/service/llama/form/LlamaServerPreferencesForm.java index cbd2bae33..3fafd5e5a 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/service/llama/form/LlamaServerPreferencesForm.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/service/llama/form/LlamaServerPreferencesForm.java @@ -290,7 +290,7 @@ private boolean validateSelectedModel() { private boolean isModelExists(HuggingFaceModel model) { return FileUtil.exists( - CodeGPTPlugin.getLlamaModelsPath() + File.separator + model.getFileName()); + CodeGPTPlugin.getLlamaModelsPath() + File.separator + model.getFileNames()); } private void enableForm(JButton serverButton, ServerProgressPanel progressPanel) { diff --git a/src/main/java/ee/carlrobert/codegpt/util/DownloadingUtil.java b/src/main/java/ee/carlrobert/codegpt/util/DownloadingUtil.java deleted file mode 100644 index f7c1fee6b..000000000 --- a/src/main/java/ee/carlrobert/codegpt/util/DownloadingUtil.java +++ /dev/null @@ -1,40 +0,0 @@ -package ee.carlrobert.codegpt.util; - -import static java.lang.String.format; - -import ee.carlrobert.codegpt.util.file.FileUtil; - -public class DownloadingUtil { - - private DownloadingUtil() { - } - - private static final int BYTES_IN_MB = 1024 * 1024; - - public static String getFormattedDownloadProgress(long startTime, long fileSize, long bytesRead) { - long timeElapsed = System.currentTimeMillis() - startTime; - - double speed = ((double) bytesRead / timeElapsed) * 1000 / BYTES_IN_MB; - double percent = (double) bytesRead / fileSize * 100; - double downloadedMB = (double) bytesRead / BYTES_IN_MB; - double totalMB = (double) fileSize / BYTES_IN_MB; - double remainingMB = totalMB - downloadedMB; - - return format( - "%s of %s (%.2f%%), Speed: %.2f MB/sec, Time left: %s", - FileUtil.convertFileSize((long) downloadedMB * BYTES_IN_MB), - FileUtil.convertFileSize((long) totalMB * BYTES_IN_MB), - percent, - speed, - getTimeLeftFormattedString(speed, remainingMB)); - } - - private static String getTimeLeftFormattedString(double speed, double remainingMB) { - double timeLeftSec = speed > 0 ? remainingMB / speed : 0; - long hours = (long) (timeLeftSec / 3600); - long minutes = (long) ((timeLeftSec % 3600) / 60); - long seconds = (long) (timeLeftSec % 60); - - return format("%02d:%02d:%02d", hours, minutes, seconds); - } -} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/settings/service/llama/form/DownloadModelAction.kt b/src/main/kotlin/ee/carlrobert/codegpt/settings/service/llama/form/DownloadModelAction.kt new file mode 100644 index 000000000..5fd824125 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/settings/service/llama/form/DownloadModelAction.kt @@ -0,0 +1,100 @@ +package ee.carlrobert.codegpt.settings.service.llama.form + +import com.intellij.openapi.actionSystem.AnAction +import com.intellij.openapi.actionSystem.AnActionEvent +import com.intellij.openapi.diagnostic.Logger +import com.intellij.openapi.progress.ProgressIndicator +import com.intellij.openapi.progress.ProgressManager +import com.intellij.openapi.progress.Task +import com.intellij.openapi.project.Project +import ee.carlrobert.codegpt.CodeGPTBundle +import ee.carlrobert.codegpt.completions.HuggingFaceModel +import ee.carlrobert.codegpt.util.DownloadingUtil +import ee.carlrobert.codegpt.util.file.FileUtil.copyFileWithProgress +import java.io.IOException +import java.util.concurrent.Executors +import java.util.concurrent.ScheduledFuture +import java.util.concurrent.TimeUnit +import java.util.function.Consumer +import javax.swing.DefaultComboBoxModel + +class DownloadModelAction( + private val onDownload: Consumer, + private val onDownloaded: Runnable, + private val onFailed: Consumer, + private val onUpdateProgress: Consumer, + private val comboBoxModel: DefaultComboBoxModel +) : AnAction() { + + override fun actionPerformed(e: AnActionEvent) { + ProgressManager.getInstance().run(DownloadBackgroundTask(e.project)) + } + + internal inner class DownloadBackgroundTask(project: Project?) : Task.Backgroundable( + project, + CodeGPTBundle.get("settingsConfigurable.service.llama.progress.downloadingModel.title"), + true + ) { + override fun run(indicator: ProgressIndicator) { + val model = comboBoxModel.selectedItem as HuggingFaceModel + val urls = model.fileURLs + val numberOfFiles = urls.size + var errorOccured = false + for (i in 1..numberOfFiles + 1) { + if (errorOccured || indicator.isCanceled) { + break + } + val executorService = Executors.newSingleThreadScheduledExecutor() + var progressUpdateScheduler: ScheduledFuture<*>? = null + val url = urls[i - 1] + + try { + onDownload.accept(indicator) + + indicator.isIndeterminate = false + indicator.text = String.format( + CodeGPTBundle.get( + "settingsConfigurable.service.llama.progress.downloadingModelIndicator.text" + ), + model.fileNames[i - 1] + ) + + val fileSize = url.openConnection().contentLengthLong + val bytesRead = longArrayOf(0) + val startTime = System.currentTimeMillis() + + progressUpdateScheduler = executorService.scheduleAtFixedRate( + { + onUpdateProgress.accept( + DownloadingUtil.getFormattedDownloadProgress( + i, + numberOfFiles, + startTime, + fileSize, + bytesRead[0] + ) + ) + }, + 0, 1, TimeUnit.SECONDS + ) + copyFileWithProgress(model.fileNames[i - 1], url, bytesRead, fileSize, indicator) + } catch (ex: IOException) { + LOG.error("Unable to download", ex, url.toString()) + onFailed.accept(ex) + errorOccured = true + } finally { + progressUpdateScheduler?.cancel(true) + executorService.shutdown() + } + } + } + + override fun onSuccess() { + onDownloaded.run() + } + } + + companion object { + private val LOG = Logger.getInstance(DownloadModelAction::class.java) + } +} diff --git a/src/main/kotlin/ee/carlrobert/codegpt/util/DownloadingUtil.kt b/src/main/kotlin/ee/carlrobert/codegpt/util/DownloadingUtil.kt new file mode 100644 index 000000000..efcee2f30 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/util/DownloadingUtil.kt @@ -0,0 +1,40 @@ +package ee.carlrobert.codegpt.util + +import ee.carlrobert.codegpt.util.file.FileUtil.convertFileSize + +object DownloadingUtil { + private const val BYTES_IN_MB = 1024 * 1024 + + fun getFormattedDownloadProgress( + fileNumber: Int, fileCount: Int, startTime: Long, + fileSize: Long, bytesRead: Long + ): String { + val timeElapsed = System.currentTimeMillis() - startTime + + val speed = (bytesRead.toDouble() / timeElapsed) * 1000 / BYTES_IN_MB + val percent = bytesRead.toDouble() / fileSize * 100 + val downloadedMB = bytesRead.toDouble() / BYTES_IN_MB + val totalMB = fileSize.toDouble() / BYTES_IN_MB + val remainingMB = totalMB - downloadedMB + + return String.format( + "File %d/%d: %s of %s (%.2f%%), Speed: %.2f MB/sec, Time left: %s", + fileNumber, + fileCount, + convertFileSize(downloadedMB.toLong() * BYTES_IN_MB), + convertFileSize(totalMB.toLong() * BYTES_IN_MB), + percent, + speed, + getTimeLeftFormattedString(speed, remainingMB) + ) + } + + private fun getTimeLeftFormattedString(speed: Double, remainingMB: Double): String { + val timeLeftSec = if (speed > 0) remainingMB / speed else 0.0 + val hours = (timeLeftSec / 3600).toLong() + val minutes = ((timeLeftSec % 3600) / 60).toLong() + val seconds = (timeLeftSec % 60).toLong() + + return String.format("%02d:%02d:%02d", hours, minutes, seconds) + } +}