Skip to content

Commit

Permalink
Message drafts (#3044)
Browse files Browse the repository at this point in the history
closes #2931 (slightly changed goal based on advice from the discord,
generate full messages, not 'x' tokens. Full messages are more useful
data)

- [x] Create draft selection UI
- [x] Draft inference
- [x] Option to regenerate drafts and serve 3 new ones
- [x] Remember last viewed sibling message
- [x] Store selected draft training data for RLHF
- ~~[ ] Disable drafts when queue is too long / server is under load~~
(Suggested to leave to next PR in the discord)
- [x] Draft markdown rendering
- [x] 'Used plugin' UI for drafts
- [x] Resolve merge conflicts

---------

Co-authored-by: notmd <[email protected]>
Co-authored-by: notmd <[email protected]>
  • Loading branch information
3 people authored May 31, 2023
1 parent 6e593ec commit 70f30a6
Show file tree
Hide file tree
Showing 24 changed files with 685 additions and 42 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/deploy-to-node.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ jobs:
BACKEND_CORS_ORIGINS: ${{ vars.BACKEND_CORS_ORIGINS }}
WEB_INFERENCE_SERVER_HOST: ${{ vars.WEB_INFERENCE_SERVER_HOST }}
WEB_ENABLE_CHAT: ${{ vars.WEB_ENABLE_CHAT }}
WEB_ENABLE_DRAFTS_WITH_PLUGINS: ${{ vars.WEB_ENABLE_DRAFTS_WITH_PLUGINS }}
WEB_NUM_GENERATED_DRAFTS: ${{ vars.WEB_NUM_GENERATED_DRAFTS }}
WEB_CURRENT_ANNOUNCEMENT: ${{ vars.WEB_CURRENT_ANNOUNCEMENT }}
WEB_INFERENCE_SERVER_API_KEY: ${{secrets.WEB_INFERENCE_SERVER_API_KEY}}
INFERENCE_POSTGRES_PASSWORD: ${{secrets.INFERENCE_POSTGRES_PASSWORD}}
Expand Down
5 changes: 5 additions & 0 deletions ansible/deploy-to-node.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,11 @@
INFERENCE_SERVER_API_KEY:
"{{ lookup('ansible.builtin.env', 'WEB_INFERENCE_SERVER_API_KEY') }}"
ENABLE_CHAT: "{{ lookup('ansible.builtin.env', 'WEB_ENABLE_CHAT') }}"
ENABLE_DRAFTS_WITH_PLUGINS:
"{{ lookup('ansible.builtin.env',
'WEB_ENABLE_DRAFTS_WITH_PLUGINS')}}"
NUM_GENERATED_DRAFTS:
"{{ lookup('ansible.builtin.env', 'WEB_NUM_GENERATED_DRAFTS') }}"
CURRENT_ANNOUNCEMENT:
"{{ lookup('ansible.builtin.env', 'WEB_CURRENT_ANNOUNCEMENT') }}"
ports:
Expand Down
2 changes: 2 additions & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ services:
- DEBUG_LOGIN=true
- INFERENCE_SERVER_HOST=http://inference-server:8000
- ENABLE_CHAT=true
- ENABLE_DRAFTS_WITH_PLUGINS=false
- NUM_GENERATED_DRAFTS=3
depends_on:
webdb:
condition: service_healthy
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""add_active_thread_tail_messsage_id_and_message_eval
Revision ID: 5ed411a331f4
Revises: 5b4211625a9f
Create Date: 2023-05-29 15:51:41.857262
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = "5ed411a331f4"
down_revision = "5b4211625a9f"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"message_evaluation",
sa.Column("inferior_message_ids", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("chat_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("user_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("selected_message_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.ForeignKeyConstraint(
["chat_id"],
["chat.id"],
),
sa.ForeignKeyConstraint(
["selected_message_id"],
["message.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_message_evaluation_chat_id"), "message_evaluation", ["chat_id"], unique=False)
op.create_index(op.f("ix_message_evaluation_user_id"), "message_evaluation", ["user_id"], unique=False)
op.add_column("chat", sa.Column("active_thread_tail_message_id", sqlmodel.sql.sqltypes.AutoString(), nullable=True))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("chat", "active_thread_tail_message_id")
op.drop_index(op.f("ix_message_evaluation_user_id"), table_name="message_evaluation")
op.drop_index(op.f("ix_message_evaluation_chat_id"), table_name="message_evaluation")
op.drop_table("message_evaluation")
# ### end Alembic commands ###
3 changes: 2 additions & 1 deletion inference/server/oasst_inference_server/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from .chat import DbChat, DbMessage, DbReport
from .chat import DbChat, DbMessage, DbMessageEval, DbReport
from .user import DbRefreshToken, DbUser
from .worker import DbWorker, DbWorkerComplianceCheck, DbWorkerEvent, WorkerEventType

__all__ = [
"DbChat",
"DbMessage",
"DbMessageEval",
"DbReport",
"DbRefreshToken",
"DbUser",
Expand Down
12 changes: 12 additions & 0 deletions inference/server/oasst_inference_server/models/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class DbChat(SQLModel, table=True):
title: str | None = Field(None)

messages: list[DbMessage] = Relationship(back_populates="chat")
active_thread_tail_message_id: str | None = Field(None)

hidden: bool = Field(False, sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.false()))

Expand All @@ -109,6 +110,7 @@ def to_read(self) -> chat_schema.ChatRead:
messages=[m.to_read() for m in self.messages],
hidden=self.hidden,
allow_data_use=self.allow_data_use,
active_thread_tail_message_id=self.active_thread_tail_message_id,
)

def get_msg_dict(self) -> dict[str, DbMessage]:
Expand All @@ -126,3 +128,13 @@ class DbReport(SQLModel, table=True):

def to_read(self) -> inference.Report:
return inference.Report(id=self.id, report_type=self.report_type, reason=self.reason)


class DbMessageEval(SQLModel, table=True):
__tablename__ = "message_evaluation"

id: str = Field(default_factory=uuid7str, primary_key=True)
chat_id: str = Field(..., foreign_key="chat.id", index=True)
user_id: str = Field(..., foreign_key="user.id", index=True)
selected_message_id: str = Field(..., foreign_key="message.id")
inferior_message_ids: list[str] = Field(default_factory=list, sa_column=sa.Column(pg.JSONB))
17 changes: 17 additions & 0 deletions inference/server/oasst_inference_server/routes/chats.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,22 @@ async def handle_create_vote(
return fastapi.Response(status_code=500)


@router.post("/{chat_id}/messages/{message_id}/message_evals")
async def handle_create_message_eval(
message_id: str,
inferior_message_request: chat_schema.MessageEvalRequest,
ucr: deps.UserChatRepository = fastapi.Depends(deps.create_user_chat_repository),
) -> fastapi.Response:
try:
await ucr.add_message_eval(
message_id=message_id, inferior_message_ids=inferior_message_request.inferior_message_ids
)
return fastapi.Response(status_code=200)
except Exception:
logger.exception("Error setting messages as inferior")
return fastapi.Response(status_code=500)


@router.post("/{chat_id}/messages/{message_id}/reports")
async def handle_create_report(
message_id: str,
Expand Down Expand Up @@ -333,6 +349,7 @@ async def handle_update_chat(
title=request.title,
hidden=request.hidden,
allow_data_use=request.allow_data_use,
active_thread_tail_message_id=request.active_thread_tail_message_id,
)
except Exception:
logger.exception("Error when updating chat")
Expand Down
6 changes: 6 additions & 0 deletions inference/server/oasst_inference_server/schemas/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ class VoteRequest(pydantic.BaseModel):
score: int


class MessageEvalRequest(pydantic.BaseModel):
inferior_message_ids: list[str]


class ReportRequest(pydantic.BaseModel):
report_type: inference.ReportType
reason: str
Expand All @@ -87,6 +91,7 @@ class ChatListRead(pydantic.BaseModel):
title: str | None
hidden: bool = False
allow_data_use: bool = True
active_thread_tail_message_id: str | None


class ChatRead(ChatListRead):
Expand Down Expand Up @@ -115,3 +120,4 @@ class ChatUpdateRequest(pydantic.BaseModel):
title: pydantic.constr(max_length=100) | None = None
hidden: bool | None = None
allow_data_use: bool | None = None
active_thread_tail_message_id: str | None = None
35 changes: 27 additions & 8 deletions inference/server/oasst_inference_server/user_chat_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,7 @@ async def add_prompter_message(self, chat_id: str, parent_id: str | None, conten
if msg_dict[parent_id].state != inference.MessageState.complete:
raise fastapi.HTTPException(status_code=400, detail="Parent message is not complete")

message = models.DbMessage(
role="prompter",
chat_id=chat_id,
chat=chat,
parent_id=parent_id,
content=content,
)
message = models.DbMessage(role="prompter", chat_id=chat_id, chat=chat, parent_id=parent_id, content=content)
self.session.add(message)
chat.modified_at = message.created_at

Expand Down Expand Up @@ -176,6 +170,7 @@ async def initiate_assistant_message(
.where(
models.DbMessage.role == "assistant",
models.DbMessage.state == inference.MessageState.pending,
models.DbMessage.parent_id != parent_id, # Prevent draft messages from cancelling each other
)
.join(models.DbChat)
.where(
Expand Down Expand Up @@ -257,6 +252,25 @@ async def update_score(self, message_id: str, score: int) -> models.DbMessage:
await self.session.commit()
return message

async def add_message_eval(self, message_id: str, inferior_message_ids: list[str]):
logger.info(f"Adding message evaluation to {message_id=}: {inferior_message_ids=}")
query = (
sqlmodel.select(models.DbMessage)
.options(sqlalchemy.orm.selectinload(models.DbMessage.chat))
.where(models.DbMessage.id == message_id)
)
message: models.DbMessage = (await self.session.exec(query)).one()
if message.chat.user_id != self.user_id:
raise fastapi.HTTPException(status_code=400, detail="Message not found")
message_eval = models.DbMessageEval(
chat_id=message.chat_id,
user_id=message.chat.user_id,
selected_message_id=message.id,
inferior_message_ids=inferior_message_ids,
)
self.session.add(message_eval)
await self.session.commit()

async def add_report(self, message_id: str, reason: str, report_type: inference.ReportType) -> models.DbReport:
logger.info(f"Adding report to {message_id=}: {reason=}")
query = (
Expand All @@ -282,8 +296,9 @@ async def update_chat(
title: str | None = None,
hidden: bool | None = None,
allow_data_use: bool | None = None,
active_thread_tail_message_id: str | None = None,
) -> None:
logger.info(f"Updating chat {chat_id=}: {title=} {hidden=}")
logger.info(f"Updating chat {chat_id=}: {title=} {hidden=} {active_thread_tail_message_id=}")
chat = await self.get_chat_by_id(chat_id=chat_id, include_messages=False)

if title is not None:
Expand All @@ -298,4 +313,8 @@ async def update_chat(
logger.info(f"Updating allow_data_use of chat {chat_id=}: {allow_data_use=}")
chat.allow_data_use = allow_data_use

if active_thread_tail_message_id is not None:
logger.info(f"Updating active_thread_tail_message_id of chat {chat_id=}: {active_thread_tail_message_id=}")
chat.active_thread_tail_message_id = active_thread_tail_message_id

await self.session.commit()
2 changes: 2 additions & 0 deletions website/.env
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@ ENABLE_EMAIL_SIGNIN=true
INFERENCE_SERVER_HOST="http://localhost:8000"
INFERENCE_SERVER_API_KEY="6969"
ENABLE_CHAT=true
ENABLE_DRAFTS_FOR_PLUGINS=false
NUM_GENERATED_DRAFTS=3
3 changes: 3 additions & 0 deletions website/public/locales/en/chat.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
"delete_chat": "Delete chat",
"delete_confirmation": "Are you sure you want to delete this chat?",
"delete_confirmation_detail": "If you delete this chat, it won't be part of our data, and we won't be able to use it to improve our models. Please take the time to upvote and downvote responses in other chats to help us make Open Assistant better!",
"draft": "Draft",
"drafts_generating_notify": "Draft messages are still generating. Please wait.",
"edit_plugin": "Edit Plugin",
"empty": "Untitled",
"input_placeholder": "Ask the assistant anything",
Expand Down Expand Up @@ -36,6 +38,7 @@
"queue_info": "Your message is queued, you are at position {{ queuePosition, number, integer }} in the queue.",
"remove_plugin": "Remove Plugin",
"repetition_penalty": "Repetition penalty",
"select_chat_notify": "Please select a draft to continue.",
"sponsored_by": "Sponsored By",
"temperature": "Temperature",
"top_k": "Top K",
Expand Down
2 changes: 2 additions & 0 deletions website/public/locales/en/common.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
"review": "Review",
"save": "Save",
"send": "Send",
"show_less": "Show Less",
"show_more": "Show More",
"sign_in": "Sign In",
"sign_out": "Sign Out",
"skip": "Skip",
Expand Down
94 changes: 94 additions & 0 deletions website/src/components/Chat/ChatAssistantDraftPager.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import { Box, Flex } from "@chakra-ui/react";
import { ChevronLeft, ChevronRight, RotateCcw } from "lucide-react";
import { useCallback, useEffect, useState } from "react";
import { InferenceMessage } from "src/types/Chat";

import { BaseMessageEmojiButton } from "../Messages/MessageEmojiButton";
import { MessageInlineEmojiRow } from "../Messages/MessageInlineEmojiRow";
import { ChatAssistantDraftViewer } from "./ChatAssistantDraftViewer";

export type DraftPickedParams = { chatId: string; regenIndex: number; messageIndex: number };

type OnDraftPickedFn = (params: DraftPickedParams) => void;
type OnRetryFn = (params: { parentId: string; chatId: string }) => void;

type ChatAssistantDraftPagerProps = {
chatId: string;
streamedDrafts: string[];
draftMessageRegens: InferenceMessage[][];
onDraftPicked: OnDraftPickedFn;
onRetry: OnRetryFn;
};

export const ChatAssistantDraftPager = ({
chatId,
streamedDrafts,
draftMessageRegens,
onDraftPicked,
onRetry,
}: ChatAssistantDraftPagerProps) => {
const [isComplete, setIsComplete] = useState(false);
const [regenIndex, setRegenIndex] = useState<number>(0);

useEffect(() => {
const allMessagesComplete =
regenIndex < draftMessageRegens.length &&
draftMessageRegens[regenIndex]?.length !== 0 &&
draftMessageRegens[regenIndex].every((message) =>
["complete", "aborted_by_worker", "cancelled", "timeout"].includes(message.state)
);
setIsComplete(allMessagesComplete);
}, [regenIndex, draftMessageRegens]);

const handlePrevious = useCallback(() => {
setRegenIndex(regenIndex > 0 ? regenIndex - 1 : regenIndex);
}, [setRegenIndex, regenIndex]);

const handleNext = useCallback(() => {
setRegenIndex(regenIndex < draftMessageRegens.length - 1 ? regenIndex + 1 : regenIndex);
}, [setRegenIndex, regenIndex, draftMessageRegens]);

const handleRetry = useCallback(() => {
if (onRetry && regenIndex < draftMessageRegens.length && draftMessageRegens[regenIndex]?.length !== 0) {
onRetry({
parentId: draftMessageRegens[regenIndex][0].parent_id,
chatId: draftMessageRegens[regenIndex][0].chat_id,
});
setRegenIndex(draftMessageRegens.length);
}
}, [onRetry, regenIndex, draftMessageRegens, setRegenIndex]);

const handleDraftPicked = useCallback(
(messageIndex: number) => {
onDraftPicked({ chatId, regenIndex, messageIndex });
},
[chatId, regenIndex, onDraftPicked]
);

return (
<ChatAssistantDraftViewer
streamedDrafts={streamedDrafts}
isComplete={isComplete}
draftMessages={draftMessageRegens[Math.min(draftMessageRegens.length - 1, regenIndex)]}
onDraftPicked={handleDraftPicked}
pager={
isComplete ? (
<Flex justifyContent={"space-between"}>
<>
<MessageInlineEmojiRow gap="0.5">
<BaseMessageEmojiButton emoji={ChevronLeft} onClick={handlePrevious} isDisabled={regenIndex === 0} />
<Box fontSize="xs">{`${regenIndex + 1}/${draftMessageRegens.length}`}</Box>
<BaseMessageEmojiButton
emoji={ChevronRight}
onClick={handleNext}
isDisabled={regenIndex === draftMessageRegens.length - 1}
/>
</MessageInlineEmojiRow>
</>
<BaseMessageEmojiButton emoji={RotateCcw} onClick={handleRetry} />
</Flex>
) : undefined
}
/>
);
};
Loading

0 comments on commit 70f30a6

Please sign in to comment.