Skip to content

Commit

Permalink
Adds prompt templates (#53)
Browse files Browse the repository at this point in the history
* Modifies prompts to be less explicit, due to templates being present

* Adds prompt templates by format
  • Loading branch information
JasonWeill authored Apr 12, 2023
1 parent 2f4baeb commit 76e36e8
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
5 changes: 2 additions & 3 deletions examples/magics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@
],
"source": [
"%%ai chatgpt -f math\n",
"Generate the 2D heat equation in LaTeX surrounded by `$$`. Do not include an explanation."
"Generate the 2D heat equation."
]
},
{
Expand Down Expand Up @@ -487,8 +487,7 @@
],
"source": [
"%%ai j2-jumbo-instruct --format math\n",
"Write the 2d Laplace equation in polar coordinates in pure LaTeX, delimited by `$$`.\n",
"Do not include an explanation."
"Write the 2d Laplace equation in polar coordinates."
]
},
{
Expand Down
14 changes: 14 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@
"raw": None
}

MARKDOWN_PROMPT_TEMPLATE = '{prompt}\n\nProduce output in markdown format only.'

PROMPT_TEMPLATES_BY_FORMAT = {
"html": '{prompt}\n\nProduce output in HTML format only, with no markup before or afterward.',
"markdown": MARKDOWN_PROMPT_TEMPLATE,
"md": MARKDOWN_PROMPT_TEMPLATE,
"math": '{prompt}\n\nProduce output in LaTeX format only, with $$ at the beginning and end.',
"json": '{prompt}\n\nProduce output in JSON format only, with nothing before or after it.',
"raw": '{prompt}' # No customization
}

class FormatDict(dict):
"""Subclass of dict to be passed to str#format(). Suppresses KeyError and
leaves replacement field unchanged if replacement field is not associated
Expand Down Expand Up @@ -128,6 +139,9 @@ def ai(self, line, cell=None):
else:
prompt = cell

# Apply a prompt template.
prompt = PROMPT_TEMPLATES_BY_FORMAT[args.format].format(prompt = prompt)

# determine provider and local model IDs
provider_id, local_model_id = self._decompose_model_id(args.model_id)
Provider = self._get_provider(provider_id)
Expand Down

0 comments on commit 76e36e8

Please sign in to comment.