-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
closes #2931 (slightly changed goal based on advice from the discord, generate full messages, not 'x' tokens. Full messages are more useful data) - [x] Create draft selection UI - [x] Draft inference - [x] Option to regenerate drafts and serve 3 new ones - [x] Remember last viewed sibling message - [x] Store selected draft training data for RLHF - ~~[ ] Disable drafts when queue is too long / server is under load~~ (Suggested to leave to next PR in the discord) - [x] Draft markdown rendering - [x] 'Used plugin' UI for drafts - [x] Resolve merge conflicts --------- Co-authored-by: notmd <[email protected]> Co-authored-by: notmd <[email protected]>
- Loading branch information
1 parent
6e593ec
commit 70f30a6
Showing
24 changed files
with
685 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
55 changes: 55 additions & 0 deletions
55
.../alembic/versions/2023_05_29_1551-5ed411a331f4_add_active_thread_tail_messsage_id_and_.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
"""add_active_thread_tail_messsage_id_and_message_eval | ||
Revision ID: 5ed411a331f4 | ||
Revises: 5b4211625a9f | ||
Create Date: 2023-05-29 15:51:41.857262 | ||
""" | ||
import sqlalchemy as sa | ||
import sqlmodel | ||
from alembic import op | ||
from sqlalchemy.dialects import postgresql | ||
|
||
# revision identifiers, used by Alembic. | ||
revision = "5ed411a331f4" | ||
down_revision = "5b4211625a9f" | ||
branch_labels = None | ||
depends_on = None | ||
|
||
|
||
def upgrade() -> None: | ||
# ### commands auto generated by Alembic - please adjust! ### | ||
op.create_table( | ||
"message_evaluation", | ||
sa.Column("inferior_message_ids", postgresql.JSONB(astext_type=sa.Text()), nullable=True), | ||
sa.Column("id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), | ||
sa.Column("chat_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), | ||
sa.Column("user_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), | ||
sa.Column("selected_message_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), | ||
sa.ForeignKeyConstraint( | ||
["chat_id"], | ||
["chat.id"], | ||
), | ||
sa.ForeignKeyConstraint( | ||
["selected_message_id"], | ||
["message.id"], | ||
), | ||
sa.ForeignKeyConstraint( | ||
["user_id"], | ||
["user.id"], | ||
), | ||
sa.PrimaryKeyConstraint("id"), | ||
) | ||
op.create_index(op.f("ix_message_evaluation_chat_id"), "message_evaluation", ["chat_id"], unique=False) | ||
op.create_index(op.f("ix_message_evaluation_user_id"), "message_evaluation", ["user_id"], unique=False) | ||
op.add_column("chat", sa.Column("active_thread_tail_message_id", sqlmodel.sql.sqltypes.AutoString(), nullable=True)) | ||
# ### end Alembic commands ### | ||
|
||
|
||
def downgrade() -> None: | ||
# ### commands auto generated by Alembic - please adjust! ### | ||
op.drop_column("chat", "active_thread_tail_message_id") | ||
op.drop_index(op.f("ix_message_evaluation_user_id"), table_name="message_evaluation") | ||
op.drop_index(op.f("ix_message_evaluation_chat_id"), table_name="message_evaluation") | ||
op.drop_table("message_evaluation") | ||
# ### end Alembic commands ### |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import { Box, Flex } from "@chakra-ui/react"; | ||
import { ChevronLeft, ChevronRight, RotateCcw } from "lucide-react"; | ||
import { useCallback, useEffect, useState } from "react"; | ||
import { InferenceMessage } from "src/types/Chat"; | ||
|
||
import { BaseMessageEmojiButton } from "../Messages/MessageEmojiButton"; | ||
import { MessageInlineEmojiRow } from "../Messages/MessageInlineEmojiRow"; | ||
import { ChatAssistantDraftViewer } from "./ChatAssistantDraftViewer"; | ||
|
||
export type DraftPickedParams = { chatId: string; regenIndex: number; messageIndex: number }; | ||
|
||
type OnDraftPickedFn = (params: DraftPickedParams) => void; | ||
type OnRetryFn = (params: { parentId: string; chatId: string }) => void; | ||
|
||
type ChatAssistantDraftPagerProps = { | ||
chatId: string; | ||
streamedDrafts: string[]; | ||
draftMessageRegens: InferenceMessage[][]; | ||
onDraftPicked: OnDraftPickedFn; | ||
onRetry: OnRetryFn; | ||
}; | ||
|
||
export const ChatAssistantDraftPager = ({ | ||
chatId, | ||
streamedDrafts, | ||
draftMessageRegens, | ||
onDraftPicked, | ||
onRetry, | ||
}: ChatAssistantDraftPagerProps) => { | ||
const [isComplete, setIsComplete] = useState(false); | ||
const [regenIndex, setRegenIndex] = useState<number>(0); | ||
|
||
useEffect(() => { | ||
const allMessagesComplete = | ||
regenIndex < draftMessageRegens.length && | ||
draftMessageRegens[regenIndex]?.length !== 0 && | ||
draftMessageRegens[regenIndex].every((message) => | ||
["complete", "aborted_by_worker", "cancelled", "timeout"].includes(message.state) | ||
); | ||
setIsComplete(allMessagesComplete); | ||
}, [regenIndex, draftMessageRegens]); | ||
|
||
const handlePrevious = useCallback(() => { | ||
setRegenIndex(regenIndex > 0 ? regenIndex - 1 : regenIndex); | ||
}, [setRegenIndex, regenIndex]); | ||
|
||
const handleNext = useCallback(() => { | ||
setRegenIndex(regenIndex < draftMessageRegens.length - 1 ? regenIndex + 1 : regenIndex); | ||
}, [setRegenIndex, regenIndex, draftMessageRegens]); | ||
|
||
const handleRetry = useCallback(() => { | ||
if (onRetry && regenIndex < draftMessageRegens.length && draftMessageRegens[regenIndex]?.length !== 0) { | ||
onRetry({ | ||
parentId: draftMessageRegens[regenIndex][0].parent_id, | ||
chatId: draftMessageRegens[regenIndex][0].chat_id, | ||
}); | ||
setRegenIndex(draftMessageRegens.length); | ||
} | ||
}, [onRetry, regenIndex, draftMessageRegens, setRegenIndex]); | ||
|
||
const handleDraftPicked = useCallback( | ||
(messageIndex: number) => { | ||
onDraftPicked({ chatId, regenIndex, messageIndex }); | ||
}, | ||
[chatId, regenIndex, onDraftPicked] | ||
); | ||
|
||
return ( | ||
<ChatAssistantDraftViewer | ||
streamedDrafts={streamedDrafts} | ||
isComplete={isComplete} | ||
draftMessages={draftMessageRegens[Math.min(draftMessageRegens.length - 1, regenIndex)]} | ||
onDraftPicked={handleDraftPicked} | ||
pager={ | ||
isComplete ? ( | ||
<Flex justifyContent={"space-between"}> | ||
<> | ||
<MessageInlineEmojiRow gap="0.5"> | ||
<BaseMessageEmojiButton emoji={ChevronLeft} onClick={handlePrevious} isDisabled={regenIndex === 0} /> | ||
<Box fontSize="xs">{`${regenIndex + 1}/${draftMessageRegens.length}`}</Box> | ||
<BaseMessageEmojiButton | ||
emoji={ChevronRight} | ||
onClick={handleNext} | ||
isDisabled={regenIndex === draftMessageRegens.length - 1} | ||
/> | ||
</MessageInlineEmojiRow> | ||
</> | ||
<BaseMessageEmojiButton emoji={RotateCcw} onClick={handleRetry} /> | ||
</Flex> | ||
) : undefined | ||
} | ||
/> | ||
); | ||
}; |
Oops, something went wrong.