Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added configuration to adjust how contexts are displayed #620

Merged
merged 7 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 28 additions & 17 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,30 +695,41 @@ async def aquery( # noqa: PLR0912
answer.add_tokens(pre)
pre_str = pre.text

# sort by first score, then name
filtered_contexts = sorted(
contexts,
key=lambda x: x.score,
reverse=True,
key=lambda x: (-x.score, x.text.name),
)[: answer_config.answer_max_sources]
# remove any contexts with a score of 0
filtered_contexts = [c for c in filtered_contexts if c.score > 0]

context_str = "\n\n".join(
[
f"{c.text.name}: {c.context}"
+ "".join([f"\n{k}: {v}" for k, v in (c.model_extra or {}).items()])
+ (
f"\nFrom {c.text.doc.citation}"
if answer_config.evidence_detailed_citations
else ""
)
for c in filtered_contexts
]
+ ([f"Extra background information: {pre_str}"] if pre_str else [])
)
# shim deprecated flag
# TODO: remove in v6
context_inner_prompt = prompt_config.context_inner
if (
not answer_config.evidence_detailed_citations
and "\nFrom {citation}" in context_inner_prompt
):
context_inner_prompt = context_inner_prompt.replace("\nFrom {citation}", "")

inner_context_strs = [
context_inner_prompt.format(
name=c.text.name,
text=c.context,
citation=c.text.doc.citation,
**(c.model_extra or {}),
)
for c in filtered_contexts
]
if pre_str:
inner_context_strs += (
[f"Extra background information: {pre_str}"] if pre_str else []
)

valid_names = [c.text.name for c in filtered_contexts]
context_str += "\n\nValid keys: " + ", ".join(valid_names)
context_str = prompt_config.context_outer.format(
context_str="\n\n".join(inner_context_strs),
valid_keys=", ".join([c.text.name for c in filtered_contexts]),
)

bib = {}
if len(context_str) < 10: # noqa: PLR2004
Expand Down
3 changes: 3 additions & 0 deletions paperqa/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,6 @@
"\n\n{qa_answer}"
"\n\nSingle Letter Answer:"
)

CONTEXT_OUTER_PROMPT = "{context_str}\n\nValid Keys: {valid_keys}"
CONTEXT_INNER_PROMPT = "{name}: {text}\nFrom {citation}"
43 changes: 40 additions & 3 deletions paperqa/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
model_validator,
)
from pydantic_settings import BaseSettings, CliSettingsSource, SettingsConfigDict
from typing_extensions import deprecated
whitead marked this conversation as resolved.
Show resolved Hide resolved

try:
from ldp.agent import (
Expand All @@ -42,6 +43,8 @@

from paperqa.llms import EmbeddingModel, LiteLLMModel, embedding_model_factory
from paperqa.prompts import (
CONTEXT_INNER_PROMPT,
CONTEXT_OUTER_PROMPT,
citation_prompt,
default_system_prompt,
qa_prompt,
Expand All @@ -65,7 +68,11 @@ class AnswerSettings(BaseModel):
default=10, description="Number of evidence pieces to retrieve"
)
evidence_detailed_citations: bool = Field(
default=True, description="Whether to include detailed citations in summaries"
default=True,
description="Whether to include detailed citations in summaries",
deprecated=deprecated(
"Set the context_inner prompt directly to have a citation key. This flag will be removed in v6"
),
)
evidence_retrieval: bool = Field(
default=True,
Expand Down Expand Up @@ -210,7 +217,9 @@ def get_formatted_variables(s: str) -> set[str]:


class PromptSettings(BaseModel):
model_config = ConfigDict(extra="forbid")
model_config = ConfigDict(extra="forbid", validate_assignment=True)

EXAMPLE_CITATION: ClassVar[str] = "(Example2012Example pages 3-4)"

summary: str = summary_prompt
qa: str = qa_prompt
Expand All @@ -233,7 +242,15 @@ class PromptSettings(BaseModel):
# to get JSON
summary_json: str = summary_json_prompt
summary_json_system: str = summary_json_system_prompt
EXAMPLE_CITATION: ClassVar[str] = "(Example2012Example pages 3-4)"
context_outer: str = Field(
default=CONTEXT_OUTER_PROMPT,
description="Prompt for how to format all contexts in generate answer.",
)
context_inner: str = Field(
default=CONTEXT_INNER_PROMPT,
description="Prompt for how to format a single context in generate answer. "
"This should at least contain key and name.",
)

@field_validator("summary")
@classmethod
Expand Down Expand Up @@ -288,6 +305,26 @@ def check_post(cls, v: str | None) -> str | None:
raise ValueError(f"Post prompt must have input variables: {attrs}")
return v

@field_validator("context_outer")
@classmethod
def check_context_outer(cls, v: str) -> str:
if not get_formatted_variables(v).issubset(
get_formatted_variables(CONTEXT_OUTER_PROMPT)
):
raise ValueError(
"Context outer prompt can only have variables:"
f" {get_formatted_variables(CONTEXT_OUTER_PROMPT)}"
)
return v

@field_validator("context_inner")
@classmethod
def check_context_inner(cls, v: str) -> str:
fvars = get_formatted_variables(v)
if "name" not in fvars or "text" not in fvars:
raise ValueError("Context inner prompt must have name and text")
return v


class IndexSettings(BaseModel):
model_config = ConfigDict(extra="forbid")
Expand Down
46 changes: 46 additions & 0 deletions tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,3 +1083,49 @@ def test_external_doc_index(stub_data_dir: Path) -> None:
docs2 = Docs(texts_index=docs.texts_index)
assert not docs2.docs
assert docs2.get_evidence("What is the date of flag day?").contexts


def test_context_inner_prompt(stub_data_dir: Path) -> None:

prompt_settings = Settings()

# try bogus prompt
with pytest.raises(ValueError, match="Context inner prompt can only"):
prompt_settings.prompts.context_inner = "A:"

# make sure prompt gets used
settings = Settings.from_name("fast")
settings.prompts.context_inner = "{name} @@@@@ {text}\nFrom: {citation}"
docs = Docs()
docs.add(stub_data_dir / "bates.txt", "WikiMedia Foundation, 2023, Accessed now")
response = docs.query("What country is Bates from?", settings=settings)
assert "@@@@@" in response.context
assert "WikiMedia Foundation, 2023" in response.context


def test_evidence_detailed_citations_shim(stub_data_dir: Path) -> None:

# TODO: delete in v6
whitead marked this conversation as resolved.
Show resolved Hide resolved
settings = Settings.from_name("fast")
settings.answer.evidence_detailed_citations = False
docs = Docs()
docs.add(stub_data_dir / "bates.txt", "WikiMedia Foundation, 2023, Accessed now")
response = docs.query("What country is Bates from?", settings=settings)
assert "WikiMedia Foundation, 2023, Accessed now" not in response.context


def test_context_outer_prompt(stub_data_dir: Path) -> None:
whitead marked this conversation as resolved.
Show resolved Hide resolved

prompt_settings = Settings()

# try bogus prompt
with pytest.raises(ValueError, match="Context outer prompt can only"):
prompt_settings.prompts.context_outer = "{foo}"

# make sure can delete keys
settings = Settings.from_name("fast")
settings.prompts.context_outer = "{context_str}"
docs = Docs()
docs.add(stub_data_dir / "bates.txt", "WikiMedia Foundation, 2023, Accessed now")
response = docs.query("What country is Bates from?", settings=settings)
assert "Valid Keys" not in response.context
Loading