Skip to content

Commit

Permalink
Add ability to add starter messages
Browse files Browse the repository at this point in the history
  • Loading branch information
Weves committed Mar 3, 2024
1 parent 9051ebf commit a8cc3d5
Show file tree
Hide file tree
Showing 14 changed files with 578 additions and 252 deletions.
31 changes: 31 additions & 0 deletions backend/alembic/versions/0a2b51deb0b8_add_starter_prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Add starter prompts
Revision ID: 0a2b51deb0b8
Revises: 5f4b8568a221
Create Date: 2024-03-02 23:23:49.960309
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = "0a2b51deb0b8"
down_revision = "5f4b8568a221"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column(
"persona",
sa.Column(
"starter_messages",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)


def downgrade() -> None:
op.drop_column("persona", "starter_messages")
1 change: 1 addition & 0 deletions backend/danswer/chat/load_yamls.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def load_personas_from_yaml(
if persona.get("num_chunks") is not None
else default_chunks,
llm_relevance_filter=persona.get("llm_relevance_filter"),
starter_messages=persona.get("starter_messages"),
llm_filter_extraction=persona.get("llm_filter_extraction"),
llm_model_version_override=None,
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
Expand Down
4 changes: 4 additions & 0 deletions backend/danswer/db/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from danswer.db.models import Prompt
from danswer.db.models import SearchDoc
from danswer.db.models import SearchDoc as DBSearchDoc
from danswer.db.models import StarterMessage
from danswer.search.models import RecencyBiasSetting
from danswer.search.models import RetrievalDocs
from danswer.search.models import SavedSearchDoc
Expand Down Expand Up @@ -465,6 +466,7 @@ def upsert_persona(
prompts: list[Prompt] | None,
document_sets: list[DBDocumentSet] | None,
llm_model_version_override: str | None,
starter_messages: list[StarterMessage] | None,
shared: bool,
db_session: Session,
persona_id: int | None = None,
Expand All @@ -490,6 +492,7 @@ def upsert_persona(
persona.recency_bias = recency_bias
persona.default_persona = default_persona
persona.llm_model_version_override = llm_model_version_override
persona.starter_messages = starter_messages
persona.deleted = False # Un-delete if previously deleted

# Do not delete any associations manually added unless
Expand All @@ -516,6 +519,7 @@ def upsert_persona(
prompts=prompts or [],
document_sets=document_sets or [],
llm_model_version_override=llm_model_version_override,
starter_messages=starter_messages,
)
db_session.add(persona)

Expand Down
12 changes: 12 additions & 0 deletions backend/danswer/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,15 @@ class Prompt(Base):
)


class StarterMessage(TypedDict):
"""NOTE: is a `TypedDict` so it can be used as a type hint for a JSONB column
in Postgres"""

name: str
description: str
message: str


class Persona(Base):
__tablename__ = "persona"

Expand Down Expand Up @@ -744,6 +753,9 @@ class Persona(Base):
llm_model_version_override: Mapped[str | None] = mapped_column(
String, nullable=True
)
starter_messages: Mapped[list[StarterMessage] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
# Default personas are configured via backend during deployment
# Treated specially (cannot be user edited etc.)
default_persona: Mapped[bool] = mapped_column(Boolean, default=False)
Expand Down
1 change: 1 addition & 0 deletions backend/danswer/db/slack_bot_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def create_slack_bot_persona(
prompts=None,
document_sets=document_sets,
llm_model_version_override=None,
starter_messages=None,
shared=True,
default_persona=False,
db_session=db_session,
Expand Down
1 change: 1 addition & 0 deletions backend/danswer/server/features/persona/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def create_update_persona(
prompts=prompts,
document_sets=document_sets,
llm_model_version_override=create_persona_request.llm_model_version_override,
starter_messages=create_persona_request.starter_messages,
shared=create_persona_request.shared,
db_session=db_session,
)
Expand Down
4 changes: 4 additions & 0 deletions backend/danswer/server/features/persona/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pydantic import BaseModel

from danswer.db.models import Persona
from danswer.db.models import StarterMessage
from danswer.search.models import RecencyBiasSetting
from danswer.server.features.document_set.models import DocumentSet
from danswer.server.features.prompt.models import PromptSnapshot
Expand All @@ -17,6 +18,7 @@ class CreatePersonaRequest(BaseModel):
prompt_ids: list[int]
document_set_ids: list[int]
llm_model_version_override: str | None = None
starter_messages: list[StarterMessage] | None = None


class PersonaSnapshot(BaseModel):
Expand All @@ -30,6 +32,7 @@ class PersonaSnapshot(BaseModel):
llm_relevance_filter: bool
llm_filter_extraction: bool
llm_model_version_override: str | None
starter_messages: list[StarterMessage] | None
default_persona: bool
prompts: list[PromptSnapshot]
document_sets: list[DocumentSet]
Expand All @@ -50,6 +53,7 @@ def from_model(cls, persona: Persona) -> "PersonaSnapshot":
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
llm_model_version_override=persona.llm_model_version_override,
starter_messages=persona.starter_messages,
default_persona=persona.default_persona,
prompts=[PromptSnapshot.from_model(prompt) for prompt in persona.prompts],
document_sets=[
Expand Down
50 changes: 50 additions & 0 deletions web/src/app/admin/personas/HidableSection.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import { useState } from "react";
import { FiChevronDown, FiChevronRight } from "react-icons/fi";

export function SectionHeader({
children,
includeMargin = true,
}: {
children: string | JSX.Element;
includeMargin?: boolean;
}) {
return (
<div
className={"font-bold text-xl my-auto" + (includeMargin ? " mb-4" : "")}
>
{children}
</div>
);
}

export function HidableSection({
children,
sectionTitle,
defaultHidden = false,
}: {
children: string | JSX.Element;
sectionTitle: string | JSX.Element;
defaultHidden?: boolean;
}) {
const [isHidden, setIsHidden] = useState(defaultHidden);

return (
<div>
<div
className="flex hover:bg-hover-light rounded cursor-pointer p-2"
onClick={() => setIsHidden(!isHidden)}
>
<SectionHeader includeMargin={false}>{sectionTitle}</SectionHeader>
<div className="my-auto ml-auto p-1">
{isHidden ? (
<FiChevronRight size={24} />
) : (
<FiChevronDown size={24} />
)}
</div>
</div>

{!isHidden && <div className="mx-2 mt-2">{children}</div>}
</div>
);
}
Loading

0 comments on commit a8cc3d5

Please sign in to comment.