Skip to content

Commit

Permalink
Merge branch 'danswer-ai:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
colachg authored Feb 19, 2024
2 parents 36e7242 + c1d1651 commit 349f90f
Show file tree
Hide file tree
Showing 32 changed files with 832 additions and 244 deletions.
33 changes: 0 additions & 33 deletions backend/alembic/versions/dbaa756c2ccf_embedding_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,7 @@
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy import table, column, String, Integer, Boolean

from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import DOC_EMBEDDING_DIM
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
from danswer.db.models import IndexModelStatus

# revision identifiers, used by Alembic.
Expand All @@ -40,33 +34,6 @@ def upgrade() -> None:
),
sa.PrimaryKeyConstraint("id"),
)
EmbeddingModel = table(
"embedding_model",
column("id", Integer),
column("model_name", String),
column("model_dim", Integer),
column("normalize", Boolean),
column("query_prefix", String),
column("passage_prefix", String),
column("index_name", String),
column(
"status", sa.Enum(IndexModelStatus, name="indexmodelstatus", native=False)
),
)
op.bulk_insert(
EmbeddingModel,
[
{
"model_name": DOCUMENT_ENCODER_MODEL,
"model_dim": DOC_EMBEDDING_DIM,
"normalize": NORMALIZE_EMBEDDINGS,
"query_prefix": ASYM_QUERY_PREFIX,
"passage_prefix": ASYM_PASSAGE_PREFIX,
"index_name": "danswer_chunk",
"status": IndexModelStatus.PRESENT,
}
],
)
op.add_column(
"index_attempt",
sa.Column("embedding_model_id", sa.Integer(), nullable=True),
Expand Down
1 change: 1 addition & 0 deletions backend/danswer/background/indexing/run_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def _run_indexing(

db_session.refresh(index_attempt)
if index_attempt.status != IndexingStatus.IN_PROGRESS:
# Likely due to user manually disabling it or model swap
raise RuntimeError("Index Attempt was canceled")

logger.debug(
Expand Down
17 changes: 16 additions & 1 deletion backend/danswer/background/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.app_configs import LOG_LEVEL
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
Expand Down Expand Up @@ -69,8 +70,15 @@ def _should_create_new_indexing(
connector: Connector,
last_index: IndexAttempt | None,
model: EmbeddingModel,
secondary_index_building: bool,
db_session: Session,
) -> bool:
# User can still manually create single indexing attempts via the UI for the
# currently in use index
if DISABLE_INDEX_UPDATE_ON_SWAP:
if model.status == IndexModelStatus.PRESENT and secondary_index_building:
return False

# When switching over models, always index at least once
if model.status == IndexModelStatus.FUTURE and not last_index:
if connector.id == 0: # Ingestion API
Expand Down Expand Up @@ -186,7 +194,11 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
connector.id, credential.id, model.id, db_session
)
if not _should_create_new_indexing(
connector, last_attempt, model, db_session
connector=connector,
last_index=last_attempt,
model=model,
secondary_index_building=len(embedding_models) > 1,
db_session=db_session,
):
continue

Expand Down Expand Up @@ -255,6 +267,9 @@ def cleanup_indexing_jobs(
)
for index_attempt in in_progress_indexing_attempts:
if index_attempt.id in existing_jobs:
# If index attempt is canceled, stop the run
if index_attempt.status == IndexingStatus.FAILED:
existing_jobs[index_attempt.id].cancel()
# check to see if the job has been updated in last `timeout_hours` hours, if not
# assume it to frozen in some bad state and just mark it as failed. Note: this relies
# on the fact that the `time_updated` field is constantly updated every
Expand Down
56 changes: 35 additions & 21 deletions backend/danswer/chat/process_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from danswer.search.search_runner import inference_documents_from_ids
from danswer.secondary_llm_flows.choose_search import check_if_need_search
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
from danswer.utils.logger import setup_logger
Expand Down Expand Up @@ -153,8 +154,7 @@ def translate_citations(
return citation_to_saved_doc_id_map


@log_generator_function_time()
def stream_chat_message(
def stream_chat_message_objects(
new_msg_req: CreateChatMessageRequest,
user: User | None,
db_session: Session,
Expand All @@ -164,7 +164,14 @@ def stream_chat_message(
# For flow with search, don't include as many chunks as possible since we need to leave space
# for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
) -> Iterator[str]:
) -> Iterator[
StreamingError
| QADocsResponse
| LLMRelevanceFilterResponse
| ChatMessageDetail
| DanswerAnswerPiece
| CitationInfo
]:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
Expand Down Expand Up @@ -313,10 +320,8 @@ def stream_chat_message(
# only allow the final document to get truncated
# if more than that, then the user message is too long
if final_doc_ind != len(tokens_per_doc) - 1:
yield get_json_line(
StreamingError(
error="LLM context window exceeded. Please de-select some documents or shorten your query."
).dict()
yield StreamingError(
error="LLM context window exceeded. Please de-select some documents or shorten your query."
)
return

Expand Down Expand Up @@ -417,8 +422,8 @@ def stream_chat_message(
applied_source_filters=retrieval_request.filters.source_type,
applied_time_cutoff=time_cutoff,
recency_bias_multiplier=recency_bias_multiplier,
).dict()
yield get_json_line(initial_response)
)
yield initial_response

# Get the final ordering of chunks for the LLM call
llm_chunk_selection = cast(list[bool], next(documents_generator))
Expand All @@ -430,8 +435,8 @@ def stream_chat_message(
]
if run_llm_chunk_filter
else []
).dict()
yield get_json_line(llm_relevance_filtering_response)
)
yield llm_relevance_filtering_response

# Prep chunks to pass to LLM
num_llm_chunks = (
Expand Down Expand Up @@ -497,7 +502,7 @@ def stream_chat_message(
gen_ai_response_message
)

yield get_json_line(msg_detail_response.dict())
yield msg_detail_response

# Stop here after saving message details, the above still needs to be sent for the
# message id to send the next follow-up message
Expand Down Expand Up @@ -530,17 +535,13 @@ def stream_chat_message(
citations.append(packet)
continue

yield get_json_line(packet.dict())
yield packet
except Exception as e:
logger.exception(e)

# Frontend will erase whatever answer and show this instead
# This will be the issue 99% of the time
error_packet = StreamingError(
error="LLM failed to respond, have you set your API key?"
)

yield get_json_line(error_packet.dict())
yield StreamingError(error="LLM failed to respond, have you set your API key?")
return

# Post-LLM answer processing
Expand All @@ -564,11 +565,24 @@ def stream_chat_message(
gen_ai_response_message
)

yield get_json_line(msg_detail_response.dict())
yield msg_detail_response
except Exception as e:
logger.exception(e)

# Frontend will erase whatever answer and show this instead
error_packet = StreamingError(error="Failed to parse LLM output")
yield StreamingError(error="Failed to parse LLM output")

yield get_json_line(error_packet.dict())

@log_generator_function_time()
def stream_chat_message(
new_msg_req: CreateChatMessageRequest,
user: User | None,
db_session: Session,
) -> Iterator[str]:
objects = stream_chat_message_objects(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
)
for obj in objects:
yield get_json_line(obj.dict())
5 changes: 5 additions & 0 deletions backend/danswer/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@
CONTINUE_ON_CONNECTOR_FAILURE = os.environ.get(
"CONTINUE_ON_CONNECTOR_FAILURE", ""
).lower() not in ["false", ""]
# When swapping to a new embedding model, a secondary index is created in the background, to conserve
# resources, we pause updates on the primary index by default while the secondary index is created
DISABLE_INDEX_UPDATE_ON_SWAP = (
os.environ.get("DISABLE_INDEX_UPDATE_ON_SWAP", "").lower() == "true"
)
# Controls how many worker processes we spin up to index documents in the
# background. This is useful for speeding up indexing, but does require a
# fairly large amount of memory in order to increase substantially, since
Expand Down
27 changes: 17 additions & 10 deletions backend/danswer/configs/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,38 @@
# Inference/Indexing speed
# https://huggingface.co/DOCUMENT_ENCODER_MODEL
# The useable models configured as below must be SentenceTransformer compatible
# NOTE: DO NOT CHANGE SET THESE UNLESS YOU KNOW WHAT YOU ARE DOING
# IDEALLY, YOU SHOULD CHANGE EMBEDDING MODELS VIA THE UI
DEFAULT_DOCUMENT_ENCODER_MODEL = "intfloat/e5-base-v2"
DOCUMENT_ENCODER_MODEL = (
# This is not a good model anymore, but this default needs to be kept for not breaking existing
# deployments, will eventually be retired/swapped for a different default model
os.environ.get("DOCUMENT_ENCODER_MODEL")
or "thenlper/gte-small"
os.environ.get("DOCUMENT_ENCODER_MODEL") or DEFAULT_DOCUMENT_ENCODER_MODEL
)
# If the below is changed, Vespa deployment must also be changed
DOC_EMBEDDING_DIM = int(os.environ.get("DOC_EMBEDDING_DIM") or 384)
DOC_EMBEDDING_DIM = int(os.environ.get("DOC_EMBEDDING_DIM") or 768)
# Model should be chosen with 512 context size, ideally don't change this
DOC_EMBEDDING_CONTEXT_SIZE = 512
NORMALIZE_EMBEDDINGS = (
os.environ.get("NORMALIZE_EMBEDDINGS") or "False"
os.environ.get("NORMALIZE_EMBEDDINGS") or "true"
).lower() == "true"

# Old default model settings, which are needed for an automatic easy upgrade
OLD_DEFAULT_DOCUMENT_ENCODER_MODEL = "thenlper/gte-small"
OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM = 384
OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS = False

# These are only used if reranking is turned off, to normalize the direct retrieval scores for display
# Currently unused
SIM_SCORE_RANGE_LOW = float(os.environ.get("SIM_SCORE_RANGE_LOW") or 0.0)
SIM_SCORE_RANGE_HIGH = float(os.environ.get("SIM_SCORE_RANGE_HIGH") or 1.0)
# Certain models like e5, BGE, etc use a prefix for asymmetric retrievals (query generally shorter than docs)
ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "")
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "")
ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "query: ")
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "passage: ")
# Purely an optimization, memory limitation consideration
BATCH_SIZE_ENCODE_CHUNKS = 8
# This controls the minimum number of pytorch "threads" to allocate to the embedding
# model. If torch finds more threads on its own, this value is not used.
MIN_THREADS_ML_MODELS = int(os.environ.get("MIN_THREADS_ML_MODELS") or 1)


# Cross Encoder Settings
ENABLE_RERANKING_ASYNC_FLOW = (
os.environ.get("ENABLE_RERANKING_ASYNC_FLOW", "").lower() == "true"
Expand Down Expand Up @@ -78,7 +83,9 @@
# Set GEN_AI_MODEL_PROVIDER to "gpt4all" to use gpt4all models running locally
GEN_AI_MODEL_PROVIDER = os.environ.get("GEN_AI_MODEL_PROVIDER") or "openai"
# If using Azure, it's the engine name, for example: Danswer
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION") or "gpt-3.5-turbo-0125"
GEN_AI_MODEL_VERSION = (
os.environ.get("GEN_AI_MODEL_VERSION") or "gpt-3.5-turbo-16k-0613"
)
# For secondary flows like extracting filters or deciding if a chunk is useful, we don't need
# as powerful of a model as say GPT-4 so we can use an alternative that is faster and cheaper
FAST_GEN_AI_MODEL_VERSION = (
Expand Down
62 changes: 62 additions & 0 deletions backend/danswer/db/embedding_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
from sqlalchemy import select
from sqlalchemy.orm import Session

from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
from danswer.configs.model_configs import DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import DOC_EMBEDDING_DIM
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.models import EmbeddingModel
from danswer.db.models import IndexModelStatus
from danswer.indexing.models import EmbeddingModelDetail
Expand Down Expand Up @@ -65,3 +75,55 @@ def update_embedding_model_status(
) -> None:
embedding_model.status = new_status
db_session.commit()


def insert_initial_embedding_models(db_session: Session) -> None:
"""Should be called on startup to ensure that the initial
embedding model is present in the DB."""
existing_embedding_models = db_session.scalars(select(EmbeddingModel)).all()
if existing_embedding_models:
logger.error(
"Called `insert_initial_embedding_models` but models already exist in the DB. Skipping."
)
return

existing_cc_pairs = get_connector_credential_pairs(db_session)

# if the user is overriding the `DOCUMENT_ENCODER_MODEL`, then
# allow them to continue to use that model and do nothing fancy
# in the background OR if the user has no connectors, then we can
# also just use the new model immediately
can_skip_upgrade = (
DOCUMENT_ENCODER_MODEL != DEFAULT_DOCUMENT_ENCODER_MODEL
or not existing_cc_pairs
)

# if we need to automatically upgrade the user, then create
# an entry which will automatically be replaced by the
# below desired model
if not can_skip_upgrade:
embedding_model_to_upgrade = EmbeddingModel(
model_name=OLD_DEFAULT_DOCUMENT_ENCODER_MODEL,
model_dim=OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM,
normalize=OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS,
query_prefix="",
passage_prefix="",
status=IndexModelStatus.PRESENT,
index_name="danswer_chunk",
)
db_session.add(embedding_model_to_upgrade)

desired_embedding_model = EmbeddingModel(
model_name=DOCUMENT_ENCODER_MODEL,
model_dim=DOC_EMBEDDING_DIM,
normalize=NORMALIZE_EMBEDDINGS,
query_prefix=ASYM_QUERY_PREFIX,
passage_prefix=ASYM_PASSAGE_PREFIX,
status=IndexModelStatus.PRESENT
if can_skip_upgrade
else IndexModelStatus.FUTURE,
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
)
db_session.add(desired_embedding_model)

db_session.commit()
13 changes: 12 additions & 1 deletion backend/danswer/db/index_attempt.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,24 @@ def expire_index_attempts(
embedding_model_id: int,
db_session: Session,
) -> None:
delete_query = (
delete(IndexAttempt)
.where(IndexAttempt.embedding_model_id == embedding_model_id)
.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
)
db_session.execute(delete_query)

update_query = (
update(IndexAttempt)
.where(IndexAttempt.embedding_model_id == embedding_model_id)
.where(IndexAttempt.status != IndexingStatus.SUCCESS)
.values(status=IndexingStatus.FAILED, error_msg="Embedding model swapped")
.values(
status=IndexingStatus.FAILED,
error_msg="Canceled due to embedding model swap",
)
)
db_session.execute(update_query)

db_session.commit()


Expand Down
Loading

0 comments on commit 349f90f

Please sign in to comment.