Skip to content

Commit

Permalink
Use enum value to store selected service (#265)
Browse files Browse the repository at this point in the history
  • Loading branch information
carlrobertoh authored Nov 8, 2023
1 parent ff60d1e commit cfa5ff7
Show file tree
Hide file tree
Showing 12 changed files with 166 additions and 233 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import ee.carlrobert.codegpt.completions.you.YouUserManager;
import ee.carlrobert.codegpt.conversations.Conversation;
import ee.carlrobert.codegpt.conversations.message.Message;
import ee.carlrobert.codegpt.settings.service.ServiceType;
import ee.carlrobert.codegpt.settings.state.AzureSettingsState;
import ee.carlrobert.codegpt.settings.state.OpenAISettingsState;
import ee.carlrobert.codegpt.settings.state.SettingsState;
Expand Down Expand Up @@ -80,12 +81,12 @@ private EventSource startCall(
var requestProvider = new CompletionRequestProvider(conversation);

try {
if (settings.isUseLlamaService()) {
if (settings.getSelectedService() == ServiceType.LLAMA_CPP) {
return CompletionClientProvider.getLlamaClient()
.getChatCompletion(requestProvider.buildLlamaCompletionRequest(message), eventListener);
}

if (settings.isUseYouService()) {
if (settings.getSelectedService() == ServiceType.YOU) {
var sessionId = "";
var accessToken = "";
var youUserManager = YouUserManager.getInstance();
Expand All @@ -103,7 +104,7 @@ private EventSource startCall(
.getChatCompletion(request, eventListener);
}

if (settings.isUseAzureService()) {
if (settings.getSelectedService() == ServiceType.AZURE) {
var azureSettings = AzureSettingsState.getInstance();
return CompletionClientProvider.getAzureClient().getChatCompletion(
requestProvider.buildOpenAIChatCompletionRequest(
Expand Down Expand Up @@ -151,7 +152,7 @@ protected Void doInBackground() {
conversation,
message,
isRetry,
settings.isUseYouService() ?
settings.getSelectedService() == ServiceType.YOU ?
new YouRequestCompletionEventListener() :
new BaseCompletionEventListener());
} catch (TotalUsageExceededException e) {
Expand Down Expand Up @@ -212,20 +213,10 @@ public void onSerpResults(List<YouSerpResult> results) {
}

private void sendInfo(SettingsState settings) {
var service = "openai";
if (settings.isUseAzureService()) {
service = "azure";
}
if (settings.isUseYouService()) {
service = "you";
}
if (settings.isUseLlamaService()) {
service = "llama";
}
TelemetryAction.COMPLETION.createActionMessage()
.property("conversationId", conversation.getId().toString())
.property("model", conversation.getModel())
.property("service", service)
.property("service", settings.getSelectedService().getCode().toLowerCase())
.send();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import ee.carlrobert.codegpt.conversations.ConversationsState;
import ee.carlrobert.codegpt.conversations.message.Message;
import ee.carlrobert.codegpt.settings.configuration.ConfigurationState;
import ee.carlrobert.codegpt.settings.service.ServiceType;
import ee.carlrobert.codegpt.settings.state.LlamaSettingsState;
import ee.carlrobert.codegpt.settings.state.SettingsState;
import ee.carlrobert.codegpt.settings.state.YouSettingsState;
Expand Down Expand Up @@ -149,7 +150,7 @@ private List<OpenAIChatCompletionMessage> buildMessages(
messages.add(new OpenAIChatCompletionMessage("user", message.getPrompt()));
}

if (SettingsState.getInstance().isUseYouService()) {
if (SettingsState.getInstance().getSelectedService() == ServiceType.YOU) {
return messages;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.components.Service;
import ee.carlrobert.codegpt.conversations.message.Message;
import ee.carlrobert.codegpt.settings.service.ServiceType;
import ee.carlrobert.codegpt.settings.state.AzureSettingsState;
import ee.carlrobert.codegpt.settings.state.LlamaSettingsState;
import ee.carlrobert.codegpt.settings.state.OpenAISettingsState;
Expand Down Expand Up @@ -43,17 +44,9 @@ public Conversation createConversation(String clientCode) {
var conversation = new Conversation();
conversation.setId(UUID.randomUUID());
conversation.setClientCode(clientCode);
if (settings.isUseYouService()) {
conversation.setModel("YouCode");
} else if (settings.isUseAzureService()) {
conversation.setModel(AzureSettingsState.getInstance().getModel());
} else if (settings.isUseOpenAIService()) {
conversation.setModel(OpenAISettingsState.getInstance().getModel());
} else {
conversation.setModel(LlamaSettingsState.getInstance().getHuggingFaceModel().getCode());
}
conversation.setCreatedOn(LocalDateTime.now());
conversation.setUpdatedOn(LocalDateTime.now());
conversation.setModel(getModelForSelectedService(settings.getSelectedService()));
return conversation;
}

Expand Down Expand Up @@ -121,22 +114,9 @@ public void saveConversation(Conversation conversation) {
conversationState.setCurrentConversation(conversation);
}

private String getClientCode() {
var settings = SettingsState.getInstance();
if (settings.isUseOpenAIService()) {
return "chat.completion";
}
if (settings.isUseAzureService()) {
return "azure.chat.completion";
}
if (settings.isUseLlamaService()) {
return "llama.chat.completion";
}
return "you.chat.completion";
}

public Conversation startConversation() {
var conversation = createConversation(getClientCode());
var completionCode = SettingsState.getInstance().getSelectedService().getCompletionCode();
var conversation = createConversation(completionCode);
conversationState.setCurrentConversation(conversation);
addConversation(conversation);
return conversation;
Expand All @@ -147,32 +127,6 @@ public void clearAll() {
conversationState.setCurrentConversation(null);
}

public Optional<Conversation> getPreviousConversation() {
return tryGetNextOrPreviousConversation(true);
}

public Optional<Conversation> getNextConversation() {
return tryGetNextOrPreviousConversation(false);
}

private Optional<Conversation> tryGetNextOrPreviousConversation(boolean isPrevious) {
var currentConversation = ConversationsState.getCurrentConversation();
if (currentConversation != null) {
var sortedConversations = getSortedConversations();
for (int i = 0; i < sortedConversations.size(); i++) {
var conversation = sortedConversations.get(i);
if (conversation != null && conversation.getId().equals(currentConversation.getId())) {
// higher index indicates older conversation
var previousIndex = isPrevious ? i + 1 : i - 1;
if (isPrevious ? previousIndex < sortedConversations.size() : previousIndex != -1) {
return Optional.of(sortedConversations.get(previousIndex));
}
}
}
}
return Optional.empty();
}

public void deleteConversation(Conversation conversation) {
var iterator = conversationState.getConversationsMapping()
.get(conversation.getClientCode())
Expand Down Expand Up @@ -205,4 +159,48 @@ public void discardTokenLimits(Conversation conversation) {
conversation.discardTokenLimits();
saveConversation(conversation);
}

public Optional<Conversation> getPreviousConversation() {
return tryGetNextOrPreviousConversation(true);
}

public Optional<Conversation> getNextConversation() {
return tryGetNextOrPreviousConversation(false);
}

private Optional<Conversation> tryGetNextOrPreviousConversation(boolean isPrevious) {
var currentConversation = ConversationsState.getCurrentConversation();
if (currentConversation != null) {
var sortedConversations = getSortedConversations();
for (int i = 0; i < sortedConversations.size(); i++) {
var conversation = sortedConversations.get(i);
if (conversation != null && conversation.getId().equals(currentConversation.getId())) {
// higher index indicates older conversation
var previousIndex = isPrevious ? i + 1 : i - 1;
if (isPrevious ? previousIndex < sortedConversations.size() : previousIndex != -1) {
return Optional.of(sortedConversations.get(previousIndex));
}
}
}
}
return Optional.empty();
}

private static String getModelForSelectedService(ServiceType serviceType) {
switch (serviceType) {
case OPENAI:
return OpenAISettingsState.getInstance().getModel();
case AZURE:
return AzureSettingsState.getInstance().getModel();
case YOU:
return "YouCode";
case LLAMA_CPP:
var llamaSettings = LlamaSettingsState.getInstance();
return llamaSettings.isUseCustomModel() ?
llamaSettings.getCustomLlamaModelPath() :
llamaSettings.getHuggingFaceModel().getCode();
default:
throw new RuntimeException("Could not find corresponding service mapping");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public SettingsComponent(Disposable parentDisposable, SettingsState settings) {
cards.add(serviceSelectionForm.getLlamaServiceSectionPanel(), ServiceType.LLAMA_CPP.getCode());
var serviceComboBoxModel = new DefaultComboBoxModel<ServiceType>();
serviceComboBoxModel.addAll(Arrays.stream(ServiceType.values())
.filter(it -> !"LLAMA_CPP".equals(it.getCode()) || SystemInfoRt.isUnix)
.filter(it -> ServiceType.LLAMA_CPP != it || SystemInfoRt.isUnix)
.collect(toList()));
serviceComboBox = new ComboBox<>(serviceComboBoxModel);
serviceComboBox.setSelectedItem(ServiceType.OPENAI);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@
package ee.carlrobert.codegpt.settings;

import static ee.carlrobert.codegpt.settings.service.ServiceType.AZURE;
import static ee.carlrobert.codegpt.settings.service.ServiceType.LLAMA_CPP;
import static ee.carlrobert.codegpt.settings.service.ServiceType.OPENAI;
import static ee.carlrobert.codegpt.settings.service.ServiceType.YOU;

import com.intellij.openapi.Disposable;
import com.intellij.openapi.options.Configurable;
import com.intellij.openapi.util.Disposer;
import ee.carlrobert.codegpt.CodeGPTBundle;
import ee.carlrobert.codegpt.conversations.ConversationsState;
import ee.carlrobert.codegpt.credentials.AzureCredentialsManager;
import ee.carlrobert.codegpt.credentials.OpenAICredentialsManager;
import ee.carlrobert.codegpt.settings.service.ServiceType;
import ee.carlrobert.codegpt.settings.state.AzureSettingsState;
import ee.carlrobert.codegpt.settings.state.LlamaSettingsState;
import ee.carlrobert.codegpt.settings.state.OpenAISettingsState;
Expand Down Expand Up @@ -98,13 +92,7 @@ public void apply() {
.setAzureActiveDirectoryToken(serviceSelectionForm.getAzureActiveDirectoryToken());

settings.setDisplayName(settingsComponent.getDisplayName());
// TODO: Store as single enum value
settings.setUseOpenAIService(settingsComponent.getSelectedService() == OPENAI);
settings.setUseAzureService(settingsComponent.getSelectedService() == ServiceType.AZURE);
settings.setUseYouService(settingsComponent.getSelectedService() == ServiceType.YOU);
YouSettingsState.getInstance()
.setDisplayWebSearchResults(serviceSelectionForm.isDisplayWebSearchResults());
settings.setUseLlamaService(settingsComponent.getSelectedService() == ServiceType.LLAMA_CPP);
settings.setSelectedService(settingsComponent.getSelectedService());

var llamaModelPreferencesForm = serviceSelectionForm.getLlamaModelPreferencesForm();
llamaSettings.setCustomLlamaModelPath(llamaModelPreferencesForm.getCustomLlamaModelPath());
Expand All @@ -116,12 +104,14 @@ public void apply() {

openAISettings.apply(serviceSelectionForm);
azureSettings.apply(serviceSelectionForm);
YouSettingsState.getInstance()
.setDisplayWebSearchResults(serviceSelectionForm.isDisplayWebSearchResults());

if (serviceChanged || modelChanged) {
resetActiveTab();
if (serviceChanged) {
TelemetryAction.SETTINGS_CHANGED.createActionMessage()
.property("service", getServiceCode())
.property("service", settingsComponent.getSelectedService().getCode().toLowerCase())
.send();
}
}
Expand All @@ -137,20 +127,8 @@ public void reset() {

// settingsComponent.setEmail(settings.getEmail());
settingsComponent.setDisplayName(settings.getDisplayName());
settingsComponent.setSelectedService(settings.getSelectedService());

// TODO
if (settings.isUseOpenAIService()) {
settingsComponent.setSelectedService(OPENAI);
}
if (settings.isUseAzureService()) {
settingsComponent.setSelectedService(ServiceType.AZURE);
}
if (settings.isUseYouService()) {
settingsComponent.setSelectedService(ServiceType.YOU);
}
if (settings.isUseLlamaService()) {
settingsComponent.setSelectedService(ServiceType.LLAMA_CPP);
}
var llamaModelPreferencesForm = serviceSelectionForm.getLlamaModelPreferencesForm();
llamaModelPreferencesForm.setSelectedModel(llamaSettings.getHuggingFaceModel());
llamaModelPreferencesForm.setCustomLlamaModelPath(llamaSettings.getCustomLlamaModelPath());
Expand All @@ -174,10 +152,7 @@ public void disposeUIResources() {
}

private boolean isServiceChanged(SettingsState settings) {
return (settingsComponent.getSelectedService() == OPENAI) != settings.isUseOpenAIService() ||
(settingsComponent.getSelectedService() == AZURE) != settings.isUseAzureService() ||
(settingsComponent.getSelectedService() == YOU) != settings.isUseYouService() ||
(settingsComponent.getSelectedService() == LLAMA_CPP) != settings.isUseLlamaService();
return settingsComponent.getSelectedService() != settings.getSelectedService();
}

private void resetActiveTab() {
Expand All @@ -189,20 +164,4 @@ private void resetActiveTab() {

project.getService(StandardChatToolWindowContentManager.class).resetActiveTab();
}

private String getServiceCode() {
if (settingsComponent.getSelectedService() == OPENAI) {
return "openai";
}
if (settingsComponent.getSelectedService() == AZURE) {
return "azure";
}
if (settingsComponent.getSelectedService() == YOU) {
return "you";
}
if (settingsComponent.getSelectedService() == LLAMA_CPP) {
return "llama.cpp";
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
import ee.carlrobert.codegpt.CodeGPTBundle;

public enum ServiceType {
OPENAI("OPENAI", CodeGPTBundle.get("service.openai.title")),
AZURE("AZURE", CodeGPTBundle.get("service.azure.title")),
YOU("YOU", CodeGPTBundle.get("service.you.title")),
LLAMA_CPP("LLAMA_CPP", CodeGPTBundle.get("service.llama.title"));
OPENAI("OPENAI", CodeGPTBundle.get("service.openai.title"), "chat.completion"),
AZURE("AZURE", CodeGPTBundle.get("service.azure.title"), "azure.chat.completion"),
YOU("YOU", CodeGPTBundle.get("service.you.title"), "you.chat.completion"),
LLAMA_CPP("LLAMA_CPP", CodeGPTBundle.get("service.llama.title"), "llama.chat.completion");

private final String code;
private final String label;
private final String completionCode;

ServiceType(String code, String label) {
ServiceType(String code, String label, String completionCode) {
this.code = code;
this.label = label;
this.completionCode = completionCode;
}

public String getCode() {
Expand All @@ -24,6 +26,10 @@ public String getLabel() {
return label;
}

public String getCompletionCode() {
return completionCode;
}

@Override
public String toString() {
return label;
Expand Down
Loading

0 comments on commit cfa5ff7

Please sign in to comment.