Skip to content

Commit

Permalink
Add custom model support for OpenAI translator
Browse files Browse the repository at this point in the history
  • Loading branch information
YiiGuxing committed Oct 22, 2024
1 parent ea22104 commit 95a7807
Show file tree
Hide file tree
Showing 12 changed files with 92 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
package cn.yiiguxing.plugin.translate.trans.openai

enum class AzureServiceVersion(val value: String) {
V2023_05_15("2023-05-15"),
V2024_06_01("2024-06-01"),
V2024_02_01("2024-02-01"),
V2024_05_01_PREVIEW("2024-05-01-preview");
V2023_05_15("2023-05-15"),
V2024_09_01_PREVIEW("2024-09-01-preview");

companion object {
fun previewVersions() = AzureServiceVersion.values().filter {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

package cn.yiiguxing.plugin.translate.trans.openai

import cn.yiiguxing.plugin.translate.message

/**
* See: [OpenAiService Models](https://platform.openai.com/docs/models)
*/
Expand All @@ -11,10 +13,12 @@ enum class OpenAiModel(val value: String, val modelName: String) {
GPT_4_TURBO("gpt-4-turbo", "GPT-4 Turbo"),
GPT_4("gpt-4", "GPT-4"),
GPT_3_5_TURBO("gpt-3.5-turbo", "GPT-3.5 Turbo"),
GPT_3_5_TURBO_1106("gpt-3.5-turbo-1106", "GPT-3.5 Turbo 1106"),

TTS_1("tts-1", "TTS-1"),
TTS_1_HD("tts-1-hd", "TTS-1 HD");
TTS_1_HD("tts-1-hd", "TTS-1 HD"),

/** Model customized by the user. */
CUSTOM("custom", message("openai.model.custom"));

companion object {
fun gptModels(): List<OpenAiModel> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@ interface OpenAiService {
sealed interface Options {
val endpoint: String?
val ttsModel: OpenAiModel
val customTTSModel: String?
val ttsVoice: OpenAiTtsVoice
val ttsSpeed: Int
}

interface OpenAIOptions : Options {
val model: OpenAiModel
val customModel: String?
}

interface AzureOptions : Options {
Expand Down Expand Up @@ -74,16 +76,24 @@ class OpenAI(private val options: OpenAiService.OpenAIOptions) : OpenAiService {
}

override fun chatCompletion(messages: List<ChatMessage>): ChatCompletion {
val model = when (options.model) {
OpenAiModel.CUSTOM -> options.customModel
else -> options.model.value
}
val request = chatCompletionRequest {
model = options.model.value
this.model = model
this.messages = messages
}
return OpenAiHttp.post(getApiUrl(OPEN_AI_API_PATH), request) { auth() }
}

override fun speech(text: String, indicator: ProgressIndicator?): ByteArray {
val model = when (options.ttsModel) {
OpenAiModel.CUSTOM -> options.customTTSModel
else -> options.ttsModel.value
}
val request = SpeechRequest(
module = options.ttsModel.value,
model = model,
input = text,
voice = options.ttsVoice.value,
speed = OpenAiTTSSpeed.get(options.ttsSpeed)
Expand All @@ -110,8 +120,12 @@ class Azure(private val options: OpenAiService.AzureOptions) : OpenAiService {
}

override fun speech(text: String, indicator: ProgressIndicator?): ByteArray {
val model = when (options.ttsModel) {
OpenAiModel.CUSTOM -> options.customTTSModel
else -> options.ttsModel.value
}
val request = SpeechRequest(
module = options.ttsModel.value,
model = model,
input = text,
voice = options.ttsVoice.value,
speed = OpenAiTTSSpeed.get(options.ttsSpeed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class OpenAiSettings : BaseState(), PersistentStateComponent<OpenAiSettings> {
@get:OptionTag("TTS_MODEL")
var ttsModel: OpenAiModel by enum(OpenAiModel.TTS_1)

@get:OptionTag("CUSTOM_TTS_MODEL")
var customTTSModel: String? by string()

@get:OptionTag("TTS_VOICE")
var ttsVoice: OpenAiTtsVoice by enum(OpenAiTtsVoice.ALLOY)

Expand All @@ -62,15 +65,18 @@ class OpenAiSettings : BaseState(), PersistentStateComponent<OpenAiSettings> {
class OpenAi : CommonState(), OpenAiService.OpenAIOptions {
@get:OptionTag("MODEL")
override var model: OpenAiModel by enum(OpenAiModel.GPT_4O_MINI)

@get:OptionTag("CUSTOM_MODEL")
override var customModel: String? by string()
}

@Tag("azure")
class Azure : CommonState(), OpenAiService.AzureOptions {
@get:OptionTag("API_VERSION")
override var apiVersion: AzureServiceVersion by enum(AzureServiceVersion.V2024_02_01)
override var apiVersion: AzureServiceVersion by enum(AzureServiceVersion.V2024_06_01)

@get:OptionTag("TTS_API_VERSION")
override var ttsApiVersion: AzureServiceVersion by enum(AzureServiceVersion.V2024_05_01_PREVIEW)
override var ttsApiVersion: AzureServiceVersion by enum(AzureServiceVersion.V2024_09_01_PREVIEW)

@get:OptionTag("DEPLOYMENT_ID")
override var deployment: String? by string()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import com.google.gson.annotations.SerializedName
* [Documentation](https://platform.openai.com/docs/api-reference/audio/createSpeech)
*/
data class SpeechRequest(
@SerializedName("model") val module: String,
@SerializedName("model") val model: String?,
@SerializedName("input") val input: String,
@SerializedName("voice") val voice: String,
@SerializedName("speed") val speed: Float,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,28 @@ class OpenAISettingsDialog(private val configType: ConfigType) : DialogWrapper(f
ConfigType.TRANSLATOR -> openAiState.model = model
ConfigType.TTS -> commonStates.ttsModel = model
}

with(ui.customModelField) {
val oldState = isVisible
isVisible = model == OpenAiModel.CUSTOM
if (oldState != isVisible) {
pack()
}
}
}
}

ui.customModelField.document.addDocumentListener(object : DocumentAdapter() {
override fun textChanged(e: DocumentEvent) {
val text = ui.customModelField.text.takeUnless { it.isNullOrBlank() }?.trim()
when (configType) {
ConfigType.TRANSLATOR -> openAiState.customModel = text
ConfigType.TTS -> commonStates.customTTSModel = text
}
verify(ui.customModelField)
}
})

ui.azureDeploymentField.document.addDocumentListener(object : DocumentAdapter() {
override fun textChanged(e: DocumentEvent) {
azureCommonState.deployment = ui.azureDeploymentField.text.takeUnless { it.isNullOrBlank() }?.trim()
Expand All @@ -157,6 +176,19 @@ class OpenAISettingsDialog(private val configType: ConfigType) : DialogWrapper(f
}

private fun initValidators() {
installValidator(ui.customModelField) {
val customModel = it.text
when {
ui.modelComboBox.selected == OpenAiModel.CUSTOM && customModel.isNullOrBlank() -> ValidationInfo(
message("openai.settings.dialog.error.missing.custom.model"),
it
)

else -> null
}
}


installValidator(ui.apiKeyField) {
val password = it.password
when {
Expand Down Expand Up @@ -217,17 +249,21 @@ class OpenAISettingsDialog(private val configType: ConfigType) : DialogWrapper(f
} else {
ui.apiEndpointField.emptyText.text = DEFAULT_OPEN_AI_API_ENDPOINT
}
ui.setOpenAiFormComponentsVisible(!isAzure)
ui.setAzureFormComponentsVisible(isAzure)

apiEndpoint = commonStates.endpoint
ui.apiKeyField.text = apiKeys[newProvider]
if (configType == ConfigType.TTS) {
if (configType == ConfigType.TRANSLATOR) {
ui.customModelField.text = openAiState.customModel
} else {
ui.customModelField.text = commonStates.customTTSModel
ui.ttsSpeedSlicer.value = commonStates.ttsSpeed
ui.modelComboBox.selected = commonStates.ttsModel
ui.ttsVoiceComboBox.selected = commonStates.ttsVoice
}

ui.setOpenAiFormComponentsVisible(!isAzure)
ui.setAzureFormComponentsVisible(isAzure)

invokeLater(expired = { isDisposed }) {
verify()
if (repack) {
Expand All @@ -253,7 +289,7 @@ class OpenAISettingsDialog(private val configType: ConfigType) : DialogWrapper(f
private fun verify(): Boolean {
var valid = true
var focusTarget: JComponent? = null
listOf(ui.apiKeyField, ui.apiEndpointField, ui.azureDeploymentField).forEach {
listOf(ui.customModelField, ui.apiKeyField, ui.apiEndpointField, ui.azureDeploymentField).forEach {
verify(it)?.let { info ->
// 校验不通过的聚焦优先级最高
if (valid && it.isShowing) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ internal interface OpenAISettingsUI {

val modelComboBox: ComboBox<OpenAiModel>

val customModelField: JBTextField

val azureApiVersionComboBox: ComboBox<AzureServiceVersion>

val ttsVoiceComboBox: ComboBox<OpenAiTtsVoice>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,12 @@ internal class OpenAISettingsUiImpl(private val type: ConfigType) : OpenAISettin
ConfigType.TRANSLATOR -> OpenAiModel.gptModels()
ConfigType.TTS -> OpenAiModel.ttsModels()
}
model = CollectionComboBoxModel(models)
model = CollectionComboBoxModel(models + OpenAiModel.CUSTOM)
renderer = SimpleListCellRenderer.create { label, model, _ ->
label.text = model.modelName
}
}
override val customModelField: JBTextField = JBTextField()

private val azureApiVersionLabel =
JLabel(message("openai.settings.dialog.label.api.version")).apply { isVisible = false }
Expand Down Expand Up @@ -153,7 +154,6 @@ internal class OpenAISettingsUiImpl(private val type: ConfigType) : OpenAISettin
private fun layout() {
val isTTS = type == ConfigType.TTS


val comboBoxCC = UI.wrap()
.sizeGroupX("combo-box")
.shrink(1f)
Expand All @@ -178,6 +178,7 @@ internal class OpenAISettingsUiImpl(private val type: ConfigType) : OpenAISettin
form.add(providerComboBox, comboBoxCC)
form.add(modelLabel, labelCC)
form.add(modelComboBox, comboBoxCC)
form.add(customModelField, UI.fillX().wrap().cell(1, 3))

form.add(azureDeploymentLabel, labelCC)
form.add(azureDeploymentField, UI.fillX())
Expand Down Expand Up @@ -212,7 +213,11 @@ internal class OpenAISettingsUiImpl(private val type: ConfigType) : OpenAISettin
if (type == ConfigType.TRANSLATOR) {
modelLabel.isVisible = visible
modelComboBox.isVisible = visible
} else {
modelLabel.isVisible = true
modelComboBox.isVisible = true
}
customModelField.isVisible = modelComboBox.isVisible && modelComboBox.selectedItem == OpenAiModel.CUSTOM
apiKeyHelpLabel.isVisible = visible
endpointHelpSpace.isVisible = visible
}
Expand Down
2 changes: 2 additions & 0 deletions src/main/resources/messages/TranslationBundle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ tooltip.wordbook.storage.path=You can now change the storage path for your word
and sync between multiple devices using sync services<br/>\
like iCloud Drive, Google Drive, Onedrive, and Dropbox,<br/>\
<a href="#SETTINGS">Configure now</a>.
openai.model.custom=Custom
openai.settings.dialog.title=OpenAI Translator Settings
openai.settings.dialog.title.tts=OpenAI TTS Settings
openai.settings.dialog.label.provider=Provider:
Expand All @@ -334,6 +335,7 @@ openai.settings.dialog.label.api.key=API key:
openai.settings.dialog.label.endpoint=Endpoint:
openai.settings.dialog.label.deployment=Deployment:
openai.settings.dialog.error.invalid.endpoint=Invalid endpoint.
openai.settings.dialog.error.missing.custom.model=Missing model.
openai.settings.dialog.error.missing.api.key=API key is missing.
openai.settings.dialog.error.missing.deployment=Deployment is missing.
openai.settings.dialog.error.missing.endpoint=Endpoint is missing.
Expand Down
2 changes: 2 additions & 0 deletions src/main/resources/messages/TranslationBundle_ja.properties
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ tooltip.wordbook.storage.path=ワードブックの保存場所を変更し、\
iCloud Drive、Google Drive、One Drive、Dropbox などの<br/>\
同期サービスを使用して複数のデバイス間で同期できるようになりました。\
<a href="#SETTINGS">今すぐ設定する</a>
openai.model.custom=カスタマイズ
openai.settings.dialog.title=OpenAI 翻訳設定
openai.settings.dialog.title.tts=OpenAI TTS 設定
openai.settings.dialog.label.provider=プロバイダー:
Expand All @@ -334,6 +335,7 @@ openai.settings.dialog.label.api.key=API キー:
openai.settings.dialog.label.endpoint=エンドポイント:
openai.settings.dialog.label.deployment=デプロイメント名:
openai.settings.dialog.error.invalid.endpoint=無効なエンドポイント。
openai.settings.dialog.error.missing.custom.model=ミッシングモデル
openai.settings.dialog.error.missing.api.key=API キーがありません。
openai.settings.dialog.error.missing.deployment=デプロイ名がありません。
openai.settings.dialog.error.missing.endpoint=エンドポイントがありません。
Expand Down
2 changes: 2 additions & 0 deletions src/main/resources/messages/TranslationBundle_ko.properties
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ tooltip.wordbook.storage.path=이제 단어장이 저장되는 위치를 변경
iCloud Drive, Google Drive, One Drive 및<br/>\
Dropbox 같은 동기화 서비스를 사용하여 여러 장치에서 동기화할 수 있습니다.<br/>\
<a href="#SETTINGS">즉시 설정</a>.
openai.model.custom=사용자 정의
openai.settings.dialog.title=OpenAI 번역 설정
openai.settings.dialog.title.tts=OpenAI TTS 설정
openai.settings.dialog.label.provider=공급자:
Expand All @@ -334,6 +335,7 @@ openai.settings.dialog.label.api.key=API 키:
openai.settings.dialog.label.endpoint=엔드포인트:
openai.settings.dialog.label.deployment=배포 이름:
openai.settings.dialog.error.invalid.endpoint=유효하지 않은 엔드포인트.
openai.settings.dialog.error.missing.custom.model=모델이 누락되었습니다.
openai.settings.dialog.error.missing.api.key=API 키가 없습니다.
openai.settings.dialog.error.missing.deployment=배포가 누락되었습니다.
openai.settings.dialog.error.missing.endpoint=엔드포인트가 없습니다.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ tooltip.wordbook.storage.path=现在,您可以更改单词本的存储路径
以及使用 iCloud Drive、Google Drive、<br/>\
One Drive 和 Dropbox 等同步服务在多个设备之间同步数据。\
<a href="#SETTINGS">去配置</a>
openai.model.custom=自定义
openai.settings.dialog.title=OpenAI 翻译配置
openai.settings.dialog.title.tts=OpenAI TTS 配置
openai.settings.dialog.label.provider=服务商:
Expand All @@ -334,6 +335,7 @@ openai.settings.dialog.label.api.key=API 密钥:
openai.settings.dialog.label.endpoint=API 端点:
openai.settings.dialog.label.deployment=部署名称:
openai.settings.dialog.error.invalid.endpoint=无效的 API 端点
openai.settings.dialog.error.missing.custom.model=缺少模型
openai.settings.dialog.error.missing.api.key=缺少 API 密钥
openai.settings.dialog.error.missing.deployment=缺少模型部署名称
openai.settings.dialog.error.missing.endpoint=缺少 API 端点
Expand Down

0 comments on commit 95a7807

Please sign in to comment.