diff --git a/src/banks/extensions/generate.py b/src/banks/extensions/generate.py index 65845c7..37c99d8 100644 --- a/src/banks/extensions/generate.py +++ b/src/banks/extensions/generate.py @@ -62,10 +62,7 @@ def _generate(self, text, model_name=DEFAULT_MODEL): {"role": "user", "content": text}, ] response: ModelResponse = cast(ModelResponse, completion(model=model_name, messages=messages)) - content: str = response["choices"][0]["message"]["content"] - if SYSTEM_PROMPT.canary_leaked(content): - msg = "The system prompt has leaked into the response, possible prompt injection!" - raise CanaryWordError(msg) + return self._get_content(response) async def _agenerate(self, text, model_name=DEFAULT_MODEL): """ @@ -78,4 +75,11 @@ async def _agenerate(self, text, model_name=DEFAULT_MODEL): {"role": "user", "content": text}, ] response: ModelResponse = cast(ModelResponse, await acompletion(model=model_name, messages=messages)) - return response["choices"][0]["message"]["content"] + return self._get_content(response) + + def _get_content(self, response: ModelResponse) -> str: + content = response["choices"][0]["message"]["content"] + if SYSTEM_PROMPT.canary_leaked(content): + msg = "The system prompt has leaked into the response, possible prompt injection!" + raise CanaryWordError(msg) + return content diff --git a/src/banks/filters/lemmatize.py b/src/banks/filters/lemmatize.py index 93f74cf..f603560 100644 --- a/src/banks/filters/lemmatize.py +++ b/src/banks/filters/lemmatize.py @@ -4,7 +4,7 @@ from banks.errors import MissingDependencyError try: - from simplemma import text_lemmatizer + from simplemma.simplemma import text_lemmatizer simplemma_avail = True except ImportError: diff --git a/src/banks/templates/generate_tweet.jinja b/src/banks/templates/generate_tweet.jinja index 81e0ceb..056146a 100644 --- a/src/banks/templates/generate_tweet.jinja +++ b/src/banks/templates/generate_tweet.jinja @@ -6,6 +6,6 @@ Generate a tweet about the topic {{ topic }} with a positive sentiment. #} Examples: {% for number in range(3) %} -- {% generate "write a tweet with positive sentiment" "gpt-3.5-turbo" %} +- {% generate "write a tweet with positive sentiment", "gpt-3.5-turbo" %} {% endfor %}