Skip to content

Commit

Permalink
feat: support virtual message model designation (#129)
Browse files Browse the repository at this point in the history
  • Loading branch information
zmh-program committed Mar 16, 2024
1 parent eafdf00 commit 642fc46
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 14 deletions.
15 changes: 8 additions & 7 deletions adapter/midjourney/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,29 +133,30 @@ func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, c
return err
}

return c.CallbackActions(form, callback)
return c.CallbackActions(props, form, callback)
}

func toVirtualMessage(message string) string {
return "https://chatnio.virtual" + strings.Replace(message, " ", "-", -1)
func toVirtualMessage(message string, model string) string {
prompt := strings.Replace(message, " ", "-", -1)
return fmt.Sprintf("https://chatnio.virtual%s::%s", prompt, model)
}

func (c *ChatInstance) CallbackActions(form *StorageForm, callback globals.Hook) error {
func (c *ChatInstance) CallbackActions(props *adaptercommon.ChatProps, form *StorageForm, callback globals.Hook) error {
if form.Action == UpscaleAction {
return nil
}

actions := utils.Range(1, maxActions+1)

upscale := strings.Join(utils.Each(actions, func(index int) string {
return fmt.Sprintf("[U%d](%s)", index, toVirtualMessage(fmt.Sprintf("/UPSCALE %s %d", form.Task, index)))
return fmt.Sprintf("[U%d](%s)", index, toVirtualMessage(fmt.Sprintf("/UPSCALE %s %d", form.Task, index), props.OriginalModel))
}), " ")

variation := strings.Join(utils.Each(actions, func(index int) string {
return fmt.Sprintf("[V%d](%s)", index, toVirtualMessage(fmt.Sprintf("/VARIATION %s %d", form.Task, index)))
return fmt.Sprintf("[V%d](%s)", index, toVirtualMessage(fmt.Sprintf("/VARIATION %s %d", form.Task, index), props.OriginalModel))
}), " ")

reroll := fmt.Sprintf("[REROLL](%s)", toVirtualMessage(fmt.Sprintf("/REROLL %s", form.Task)))
reroll := fmt.Sprintf("[REROLL](%s)", toVirtualMessage(fmt.Sprintf("/REROLL %s", form.Task), props.OriginalModel))

return callback(&globals.Chunk{
Content: fmt.Sprintf("\n\n%s\n\n%s\n\n%s\n", upscale, variation, reroll),
Expand Down
4 changes: 2 additions & 2 deletions app/src/components/markdown/Link.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ export default function ({ href, children }: LinkProps) {
const url: string = href?.toString() || "";

if (url.startsWith("https://chatnio.virtual")) {
const message = url.slice(23).replace(/-/g, " ");
const prefix = message.split(" ")[0];
const message = url.slice(23);
const prefix = message.split("-")[0];

return (
<VirtualMessage message={message} prefix={prefix}>
Expand Down
22 changes: 19 additions & 3 deletions app/src/components/markdown/VirtualMessage.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import React, { useState } from "react";
import { useTranslation } from "react-i18next";
import { useMessageActions, useWorking } from "@/store/chat.ts";
import {
useConversationActions,
useMessageActions,
useWorking,
} from "@/store/chat.ts";
import {
Dialog,
DialogContent,
Expand Down Expand Up @@ -87,15 +91,24 @@ type VirtualMessageProps = {
children: React.ReactNode;
};

function parseMessage(message: string): { prompt: string; model: string } {
const [prompt, ...rest] = message.split("::");
const model = rest.join(" ");
return { prompt: prompt.replace(/-/g, " "), model };
}

export function VirtualMessage({
message,
prefix,
children,
}: VirtualMessageProps) {
const { t } = useTranslation();
const { selected } = useConversationActions();
const { send: sendAction } = useMessageActions();
const working = useWorking();

const { prompt, model } = parseMessage(message);

return (
<Dialog>
<DialogTrigger asChild>
Expand All @@ -113,7 +126,7 @@ export function VirtualMessage({
<DialogDescription className={`pb-2`}>
{t("chat.send-message-desc")}
</DialogDescription>
<VirtualPrompt message={message} prefix={prefix}>
<VirtualPrompt message={prompt} prefix={prefix}>
{children}
</VirtualPrompt>
</DialogHeader>
Expand All @@ -123,7 +136,10 @@ export function VirtualMessage({
</DialogClose>
<DialogClose
disabled={working}
onClick={async () => await sendAction(message)}
onClick={async () => {
selected(model);
await sendAction(prompt, model);
}}
asChild
>
<Button variant={`default`}>
Expand Down
4 changes: 2 additions & 2 deletions app/src/store/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ export function useMessageActions() {
const repetition_penalty = useSelector(repetitionPenaltySelector);

return {
send: async (message: string) => {
send: async (message: string, using_model?: string) => {
if (current === -1 && conversations[-1].messages.length === 0) {
// preflight history if it's a new conversation
dispatch(preflightHistory(message));
Expand All @@ -533,7 +533,7 @@ export function useMessageActions() {
type: "chat",
message,
web,
model,
model: using_model || model,
context: history,
ignore_context: !context,
max_tokens,
Expand Down

0 comments on commit 642fc46

Please sign in to comment.