Skip to content

Commit

Permalink
Add option to define new global prompt constants
Browse files Browse the repository at this point in the history
  • Loading branch information
s-jse committed Jul 25, 2024
1 parent 8f4d18e commit 9da1693
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 21 deletions.
2 changes: 2 additions & 0 deletions chainlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
write_prompt_logs_to_file,
string_to_json,
)
from .load_prompt import register_prompt_constants
from .utils import get_logger


Expand All @@ -24,4 +25,5 @@
"chain",
"get_all_configured_engines",
"string_to_json",
"register_prompt_constants",
]
57 changes: 39 additions & 18 deletions chainlite/load_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from zoneinfo import ZoneInfo # Python 3.9 and later

from jinja2 import Environment, FileSystemLoader
from langchain_core.prompts import (AIMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate)
from langchain_core.prompts import (
AIMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)

jinja2_comment_pattern = re.compile(r"{#.*?#}", re.DOTALL)

Expand Down Expand Up @@ -61,24 +63,41 @@ def load_template_file(template_file: str, keep_indentation: bool) -> str:
return raw_template


def add_template_constants(
added_template_constants = {}


def register_prompt_constants(constant_name_to_value_map: dict) -> None:
"""
Make constant values available to all prompt templates.
By default, current_year, today and location are set, and you can overwrite them or add new constants using this method.
Args:
constant_name_to_value_map (dict): A dictionary where keys are constant names and values are the corresponding constant values.
Returns:
None
"""
for k, v in constant_name_to_value_map.items():
added_template_constants[k] = v


def add_constants_to_template(
chat_prompt_template: ChatPromptTemplate,
) -> ChatPromptTemplate:
# always make these useful constants available in a template
# make a new function call each time since the date might change during a long-term server deployment
pacific_zone = ZoneInfo("America/Los_Angeles")
today = datetime.now(pacific_zone).date()

current_year = today.year
today = today.strftime("%B %d, %Y") # e.g. May 30, 2024
location = "the U.S."
chatbot_name = "WikiChat"
chat_prompt_template = chat_prompt_template.partial(
today=today,
current_year=current_year,
location=location,
chatbot_name=chatbot_name,
)
template_constants = {
"current_year": today.year,
"today": today.strftime("%B %d, %Y"), # e.g. May 30, 2024
"location": "the U.S.",
}
for k, v in added_template_constants.items():
template_constants[k] = v

chat_prompt_template = chat_prompt_template.partial(**template_constants)

return chat_prompt_template

Expand Down Expand Up @@ -166,11 +185,13 @@ def _prompt_blocks_to_chat_messages(
# only keep the distillation_instruction and the last input
assert distillation_instruction is not None
message_prompt_templates = [
SystemMessagePromptTemplate.from_template(distillation_instruction, template_format="jinja2"),
SystemMessagePromptTemplate.from_template(
distillation_instruction, template_format="jinja2"
),
message_prompt_templates[-1],
]
chat_prompt_template = ChatPromptTemplate.from_messages(message_prompt_templates)
chat_prompt_template = add_template_constants(chat_prompt_template)
chat_prompt_template = add_constants_to_template(chat_prompt_template)
if distillation_instruction is None:
# if distillation instruction is not provided, will default to instruction
block_type, distillation_instruction = tuple(
Expand All @@ -180,7 +201,7 @@ def _prompt_blocks_to_chat_messages(

distillation_instruction = (
(
add_template_constants(
add_constants_to_template(
ChatPromptTemplate.from_template(
distillation_instruction, template_format="jinja2"
)
Expand Down
6 changes: 6 additions & 0 deletions tests/constants.prompt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# instruction
Today's date is {{ today }}.
The current year is {{ current_year }}.

# input
{{ question }}
34 changes: 31 additions & 3 deletions tests/test_llm_generate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import datetime
from zoneinfo import ZoneInfo
import pytest
from langchain_core.runnables import RunnableLambda

Expand All @@ -6,7 +8,8 @@
llm_generation_chain,
load_config_from_file,
write_prompt_logs_to_file,
get_all_configured_engines
get_all_configured_engines,
register_prompt_constants,
)
from chainlite.llm_config import GlobalVars

Expand All @@ -22,11 +25,13 @@
{"topic": "Rabbits"},
]

test_engine="gpt-4o-mini"
test_engine = "gpt-4o-mini"


@pytest.mark.asyncio(scope="session")
async def test_llm_generate():
print(get_all_configured_engines())
logger.info("All registered engines: %s", str(get_all_configured_engines()))

# Check that the config file has been loaded properly
assert GlobalVars.all_llm_endpoints
assert GlobalVars.prompt_dirs
Expand Down Expand Up @@ -58,6 +63,29 @@ async def test_readme_example():
).ainvoke({"topic": "Life as a PhD student"})


@pytest.mark.asyncio(scope="session")
async def test_constants():
pacific_zone = ZoneInfo("America/Los_Angeles")
today = datetime.now(pacific_zone).date().strftime("%B %d, %Y") # e.g. May 30, 2024
response = await llm_generation_chain(
template_file="tests/constants.prompt",
engine=test_engine,
max_tokens=10,
temperature=0,
).ainvoke({"question": "what is today's date?"})
assert today in response

# overwrite "today"
register_prompt_constants({"today": "Thursday"})
response = await llm_generation_chain(
template_file="tests/constants.prompt",
engine=test_engine,
max_tokens=10,
temperature=0,
).ainvoke({"question": "what day of the week is today?"})
assert "thursday" in response.lower()


@pytest.mark.asyncio(scope="session")
async def test_batching():
response = await llm_generation_chain(
Expand Down

0 comments on commit 9da1693

Please sign in to comment.