diff --git a/src/banks/extensions/generate.py b/src/banks/extensions/generate.py index 9cfb847..f2a3134 100644 --- a/src/banks/extensions/generate.py +++ b/src/banks/extensions/generate.py @@ -49,9 +49,7 @@ def parse(self, parser): args.append(nodes.Const(None)) if parser.environment.is_async: - return nodes.Output([self.call_method("_agenerate", args)]).set_lineno( - lineno - ) + return nodes.Output([self.call_method("_agenerate", args)]).set_lineno(lineno) return nodes.Output([self.call_method("_generate", args)]).set_lineno(lineno) def _generate(self, text, model_name=DEFAULT_MODEL): @@ -64,9 +62,7 @@ def _generate(self, text, model_name=DEFAULT_MODEL): {"role": "system", "content": SYSTEM_PROMPT.text()}, {"role": "user", "content": text}, ] - response: ModelResponse = cast( - ModelResponse, completion(model=model_name, messages=messages) - ) + response: ModelResponse = cast(ModelResponse, completion(model=model_name, messages=messages)) return self._get_content(response) async def _agenerate(self, text, model_name=DEFAULT_MODEL): @@ -79,9 +75,7 @@ async def _agenerate(self, text, model_name=DEFAULT_MODEL): {"role": "system", "content": SYSTEM_PROMPT.text()}, {"role": "user", "content": text}, ] - response: ModelResponse = cast( - ModelResponse, await acompletion(model=model_name, messages=messages) - ) + response: ModelResponse = cast(ModelResponse, await acompletion(model=model_name, messages=messages)) return self._get_content(response) def _get_content(self, response: ModelResponse) -> str: