Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Message drafts #3044

Merged
merged 29 commits into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2950e15
Created non-functioning draft UI component
someone13574 May 3, 2023
e77f26e
Draft message inference
someone13574 May 4, 2023
fed6939
Draft message selection
someone13574 May 4, 2023
77e5023
Draft regeneration and prompt editing
someone13574 May 6, 2023
ca9ea2a
Save last viewed message thread
someone13574 May 12, 2023
06f877e
Save inferior draft message data to inference-db
someone13574 May 12, 2023
4ae4df7
Markdown and plugin rendering for drafts
someone13574 May 12, 2023
611d75b
Use i18 translations for draft messages
someone13574 May 13, 2023
1927ded
Merge branch 'main' into message-drafts
someone13574 May 13, 2023
81a76d5
Combine alembic revisions
someone13574 May 13, 2023
3ab1354
Add ENABLE_DRAFTS_FOR_PLUGINS environment variable
someone13574 May 13, 2023
1221f02
Move last viewed thread from message to chat data
someone13574 May 14, 2023
1f64640
Move draft evaluation data to seperate table
someone13574 May 14, 2023
08cf226
Merge branch 'main' into message-drafts
someone13574 May 27, 2023
8b13b31
Add setting for number of drafts to generate
someone13574 May 27, 2023
4cbfc62
Add draft message environment variables to .env
someone13574 May 27, 2023
886ab0d
Remove useless key in ChatAssistantDraftViewer.tsx
someone13574 May 27, 2023
b8d0f5c
Move paging logic to ChatAssistantDraftPager.tsx
someone13574 May 28, 2023
113206b
Use put to update active message instead of SWR
someone13574 May 28, 2023
df45c6b
Fix flashing queue info for message drafts
someone13574 May 28, 2023
6f2fc3b
Change key in draft viewer from message-$index to draft-message-${index}
someone13574 May 28, 2023
93a9a3e
fix overflow
notmd May 28, 2023
bd2649c
Add NUM_GENERATED_DRAFTS <= 1 shutoff everywhere
someone13574 May 29, 2023
70a8e6a
Rename ENABLE_DRAFTS_FOR_PLUGINS and add comment
someone13574 May 29, 2023
66977b2
Add missing translations for draft toast messages
someone13574 May 29, 2023
c97f5ea
Merge branch 'LAION-AI:main' into message-drafts
someone13574 May 29, 2023
37f65aa
Add null onPluginIntermediateResponse for drafts
someone13574 May 29, 2023
7619642
active_message_id -> active_thread_tail_message_id
someone13574 May 29, 2023
f6413f6
Only show drafts on message regeneration
someone13574 May 29, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/deploy-to-node.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ jobs:
BACKEND_CORS_ORIGINS: ${{ vars.BACKEND_CORS_ORIGINS }}
WEB_INFERENCE_SERVER_HOST: ${{ vars.WEB_INFERENCE_SERVER_HOST }}
WEB_ENABLE_CHAT: ${{ vars.WEB_ENABLE_CHAT }}
WEB_ENABLE_DRAFTS_FOR_PLUGINS: ${{ vars.WEB_ENABLE_DRAFTS_FOR_PLUGINS }}
WEB_NUM_GENERATED_DRAFTS: ${{ vars.WEB_NUM_GENERATED_DRAFTS }}
WEB_CURRENT_ANNOUNCEMENT: ${{ vars.WEB_CURRENT_ANNOUNCEMENT }}
WEB_INFERENCE_SERVER_API_KEY: ${{secrets.WEB_INFERENCE_SERVER_API_KEY}}
INFERENCE_POSTGRES_PASSWORD: ${{secrets.INFERENCE_POSTGRES_PASSWORD}}
Expand Down
4 changes: 4 additions & 0 deletions ansible/deploy-to-node.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,10 @@
INFERENCE_SERVER_API_KEY:
"{{ lookup('ansible.builtin.env', 'WEB_INFERENCE_SERVER_API_KEY') }}"
ENABLE_CHAT: "{{ lookup('ansible.builtin.env', 'WEB_ENABLE_CHAT') }}"
ENABLE_DRAFTS_FOR_PLUGINS:
"{{ lookup('ansible.builtin.env', 'WEB_ENABLE_DRAFTS_FOR_PLUGINS')}}"
NUM_GENERATED_DRAFTS:
"{{ lookup('ansible.builtin.env', 'WEB_NUM_GENERATED_DRAFTS') }}"
CURRENT_ANNOUNCEMENT:
"{{ lookup('ansible.builtin.env', 'WEB_CURRENT_ANNOUNCEMENT') }}"
ports:
Expand Down
2 changes: 2 additions & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ services:
- DEBUG_LOGIN=true
- INFERENCE_SERVER_HOST=http://inference-server:8000
- ENABLE_CHAT=true
- ENABLE_DRAFTS_FOR_PLUGINS=false
- NUM_GENERATED_DRAFTS=3
depends_on:
webdb:
condition: service_healthy
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""add_active_messsage_id_and_message_eval

Revision ID: 2d67fbdc5b46
Revises: 5b4211625a9f
Create Date: 2023-05-18 17:48:54.902294

"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = "2d67fbdc5b46"
down_revision = "5b4211625a9f"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"message_evaluation",
sa.Column("inferior_message_ids", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("chat_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("user_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("selected_message_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.ForeignKeyConstraint(
["chat_id"],
["chat.id"],
),
sa.ForeignKeyConstraint(
["selected_message_id"],
["message.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_message_evaluation_chat_id"), "message_evaluation", ["chat_id"], unique=False)
op.create_index(op.f("ix_message_evaluation_user_id"), "message_evaluation", ["user_id"], unique=False)
op.add_column("chat", sa.Column("active_message_id", postgresql.JSONB(astext_type=sa.Text()), nullable=True))
someone13574 marked this conversation as resolved.
Show resolved Hide resolved
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("chat", "active_message_id")
op.drop_index(op.f("ix_message_evaluation_user_id"), table_name="message_evaluation")
op.drop_index(op.f("ix_message_evaluation_chat_id"), table_name="message_evaluation")
op.drop_table("message_evaluation")
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from .chat import DbChat, DbMessage, DbReport
from .chat import DbChat, DbMessage, DbMessageEval, DbReport
from .user import DbRefreshToken, DbUser
from .worker import DbWorker, DbWorkerComplianceCheck, DbWorkerEvent, WorkerEventType

__all__ = [
"DbChat",
"DbMessage",
"DbMessageEval",
"DbReport",
"DbRefreshToken",
"DbUser",
Expand Down
12 changes: 12 additions & 0 deletions inference/server/oasst_inference_server/models/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class DbChat(SQLModel, table=True):
title: str | None = Field(None)

messages: list[DbMessage] = Relationship(back_populates="chat")
active_message_id: str | None = Field(None, sa_column=sa.Column(pg.JSONB))

hidden: bool = Field(False, sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.false()))

Expand All @@ -109,6 +110,7 @@ def to_read(self) -> chat_schema.ChatRead:
messages=[m.to_read() for m in self.messages],
hidden=self.hidden,
allow_data_use=self.allow_data_use,
active_message_id=self.active_message_id,
)

def get_msg_dict(self) -> dict[str, DbMessage]:
Expand All @@ -126,3 +128,13 @@ class DbReport(SQLModel, table=True):

def to_read(self) -> inference.Report:
return inference.Report(id=self.id, report_type=self.report_type, reason=self.reason)


class DbMessageEval(SQLModel, table=True):
__tablename__ = "message_evaluation"

id: str = Field(default_factory=uuid7str, primary_key=True)
chat_id: str = Field(..., foreign_key="chat.id", index=True)
user_id: str = Field(..., foreign_key="user.id", index=True)
selected_message_id: str = Field(..., foreign_key="message.id")
inferior_message_ids: list[str] = Field(default_factory=list, sa_column=sa.Column(pg.JSONB))
17 changes: 17 additions & 0 deletions inference/server/oasst_inference_server/routes/chats.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,22 @@ async def handle_create_vote(
return fastapi.Response(status_code=500)


@router.post("/{chat_id}/messages/{message_id}/message_evals")
async def handle_create_message_eval(
message_id: str,
inferior_message_request: chat_schema.MessageEvalRequest,
ucr: deps.UserChatRepository = fastapi.Depends(deps.create_user_chat_repository),
) -> fastapi.Response:
try:
await ucr.add_message_eval(
message_id=message_id, inferior_message_ids=inferior_message_request.inferior_message_ids
)
return fastapi.Response(status_code=200)
except Exception:
logger.exception("Error setting messages as inferior")
return fastapi.Response(status_code=500)


@router.post("/{chat_id}/messages/{message_id}/reports")
async def handle_create_report(
message_id: str,
Expand Down Expand Up @@ -322,6 +338,7 @@ async def handle_update_chat(
title=request.title,
hidden=request.hidden,
allow_data_use=request.allow_data_use,
active_message_id=request.active_message_id,
)
except Exception:
logger.exception("Error when updating chat")
Expand Down
6 changes: 6 additions & 0 deletions inference/server/oasst_inference_server/schemas/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ class VoteRequest(pydantic.BaseModel):
score: int


class MessageEvalRequest(pydantic.BaseModel):
inferior_message_ids: list[str]


class ReportRequest(pydantic.BaseModel):
report_type: inference.ReportType
reason: str
Expand All @@ -72,6 +76,7 @@ class ChatListRead(pydantic.BaseModel):
title: str | None
hidden: bool = False
allow_data_use: bool = True
active_message_id: str | None


class ChatRead(ChatListRead):
Expand Down Expand Up @@ -100,3 +105,4 @@ class ChatUpdateRequest(pydantic.BaseModel):
title: pydantic.constr(max_length=100) | None = None
hidden: bool | None = None
allow_data_use: bool | None = None
active_message_id: str | None = None
35 changes: 27 additions & 8 deletions inference/server/oasst_inference_server/user_chat_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,7 @@ async def add_prompter_message(self, chat_id: str, parent_id: str | None, conten
if msg_dict[parent_id].state != inference.MessageState.complete:
raise fastapi.HTTPException(status_code=400, detail="Parent message is not complete")

message = models.DbMessage(
role="prompter",
chat_id=chat_id,
chat=chat,
parent_id=parent_id,
content=content,
)
message = models.DbMessage(role="prompter", chat_id=chat_id, chat=chat, parent_id=parent_id, content=content)
self.session.add(message)
chat.modified_at = message.created_at

Expand Down Expand Up @@ -176,6 +170,7 @@ async def initiate_assistant_message(
.where(
models.DbMessage.role == "assistant",
models.DbMessage.state == inference.MessageState.pending,
models.DbMessage.parent_id != parent_id, # Prevent draft messages from cancelling each other
)
.join(models.DbChat)
.where(
Expand Down Expand Up @@ -257,6 +252,25 @@ async def update_score(self, message_id: str, score: int) -> models.DbMessage:
await self.session.commit()
return message

async def add_message_eval(self, message_id: str, inferior_message_ids: list[str]):
logger.info(f"Adding message evaluation to {message_id=}: {inferior_message_ids=}")
query = (
sqlmodel.select(models.DbMessage)
.options(sqlalchemy.orm.selectinload(models.DbMessage.chat))
.where(models.DbMessage.id == message_id)
)
message: models.DbMessage = (await self.session.exec(query)).one()
if message.chat.user_id != self.user_id:
raise fastapi.HTTPException(status_code=400, detail="Message not found")
message_eval = models.DbMessageEval(
chat_id=message.chat_id,
user_id=message.chat.user_id,
selected_message_id=message.id,
inferior_message_ids=inferior_message_ids,
)
self.session.add(message_eval)
await self.session.commit()

async def add_report(self, message_id: str, reason: str, report_type: inference.ReportType) -> models.DbReport:
logger.info(f"Adding report to {message_id=}: {reason=}")
query = (
Expand All @@ -282,8 +296,9 @@ async def update_chat(
title: str | None = None,
hidden: bool | None = None,
allow_data_use: bool | None = None,
active_message_id: str | None = None,
) -> None:
logger.info(f"Updating chat {chat_id=}: {title=} {hidden=}")
logger.info(f"Updating chat {chat_id=}: {title=} {hidden=} {active_message_id=}")
chat = await self.get_chat_by_id(chat_id=chat_id, include_messages=False)

if title is not None:
Expand All @@ -298,4 +313,8 @@ async def update_chat(
logger.info(f"Updating allow_data_use of chat {chat_id=}: {allow_data_use=}")
chat.allow_data_use = allow_data_use

if active_message_id is not None:
logger.info(f"Updating active_message_id of chat {chat_id=}: {active_message_id=}")
chat.active_message_id = active_message_id

await self.session.commit()
2 changes: 2 additions & 0 deletions website/.env
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@ ENABLE_EMAIL_SIGNIN=true
INFERENCE_SERVER_HOST="http://localhost:8000"
INFERENCE_SERVER_API_KEY="6969"
ENABLE_CHAT=true
ENABLE_DRAFTS_FOR_PLUGINS=false
NUM_GENERATED_DRAFTS=3
1 change: 1 addition & 0 deletions website/public/locales/en/chat.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"delete_chat": "Delete chat",
"delete_confirmation": "Are you sure you want to delete this chat?",
"delete_confirmation_detail": "If you delete this chat, it won't be part of our data, and we won't be able to use it to improve our models. Please take the time to upvote and downvote responses in other chats to help us make Open Assistant better!",
"draft": "Draft",
"edit_plugin": "Edit Plugin",
"empty": "Untitled",
"input_placeholder": "Ask the assistant anything",
Expand Down
2 changes: 2 additions & 0 deletions website/public/locales/en/common.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
"review": "Review",
"save": "Save",
"send": "Send",
"show_less": "Show Less",
"show_more": "Show More",
"sign_in": "Sign In",
"sign_out": "Sign Out",
"skip": "Skip",
Expand Down
94 changes: 94 additions & 0 deletions website/src/components/Chat/ChatAssistantDraftPager.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import { Box, Flex } from "@chakra-ui/react";
import { ChevronLeft, ChevronRight, RotateCcw } from "lucide-react";
import { useCallback, useEffect, useState } from "react";
import { InferenceMessage } from "src/types/Chat";

import { BaseMessageEmojiButton } from "../Messages/MessageEmojiButton";
import { MessageInlineEmojiRow } from "../Messages/MessageInlineEmojiRow";
import { ChatAssistantDraftViewer } from "./ChatAssistantDraftViewer";

export type DraftPickedParams = { chatId: string; regenIndex: number; messageIndex: number };

type OnDraftPickedFn = (params: DraftPickedParams) => void;
type OnRetryFn = (params: { parentId: string; chatId: string }) => void;

type ChatAssistantDraftPagerProps = {
chatId: string;
streamedDrafts: string[];
draftMessageRegens: InferenceMessage[][];
onDraftPicked: OnDraftPickedFn;
onRetry: OnRetryFn;
};

export const ChatAssistantDraftPager = ({
chatId,
streamedDrafts,
draftMessageRegens,
onDraftPicked,
onRetry,
}: ChatAssistantDraftPagerProps) => {
const [isComplete, setIsComplete] = useState(false);
const [regenIndex, setRegenIndex] = useState<number>(0);

useEffect(() => {
const allMessagesComplete =
regenIndex < draftMessageRegens.length &&
draftMessageRegens[regenIndex]?.length !== 0 &&
draftMessageRegens[regenIndex].every((message) =>
["complete", "aborted_by_worker", "cancelled", "timeout"].includes(message.state)
);
setIsComplete(allMessagesComplete);
}, [regenIndex, draftMessageRegens]);

const handlePrevious = useCallback(() => {
setRegenIndex(regenIndex > 0 ? regenIndex - 1 : regenIndex);
}, [setRegenIndex, regenIndex]);

const handleNext = useCallback(() => {
setRegenIndex(regenIndex < draftMessageRegens.length - 1 ? regenIndex + 1 : regenIndex);
}, [setRegenIndex, regenIndex, draftMessageRegens]);

const handleRetry = useCallback(() => {
if (onRetry && regenIndex < draftMessageRegens.length && draftMessageRegens[regenIndex]?.length !== 0) {
onRetry({
parentId: draftMessageRegens[regenIndex][0].parent_id,
chatId: draftMessageRegens[regenIndex][0].chat_id,
});
setRegenIndex(draftMessageRegens.length);
}
}, [onRetry, regenIndex, draftMessageRegens, setRegenIndex]);

const handleDraftPicked = useCallback(
(messageIndex: number) => {
onDraftPicked({ chatId, regenIndex, messageIndex });
},
[chatId, regenIndex, onDraftPicked]
);

return (
<ChatAssistantDraftViewer
streamedDrafts={streamedDrafts}
isComplete={isComplete}
draftMessages={draftMessageRegens[Math.min(draftMessageRegens.length - 1, regenIndex)]}
onDraftPicked={handleDraftPicked}
pager={
isComplete ? (
<Flex justifyContent={"space-between"}>
<>
<MessageInlineEmojiRow gap="0.5">
<BaseMessageEmojiButton emoji={ChevronLeft} onClick={handlePrevious} isDisabled={regenIndex === 0} />
<Box fontSize="xs">{`${regenIndex + 1}/${draftMessageRegens.length}`}</Box>
<BaseMessageEmojiButton
emoji={ChevronRight}
onClick={handleNext}
isDisabled={regenIndex === draftMessageRegens.length - 1}
/>
</MessageInlineEmojiRow>
</>
<BaseMessageEmojiButton emoji={RotateCcw} onClick={handleRetry} />
</Flex>
) : undefined
}
/>
);
};
Loading