Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
masci committed Jun 26, 2024
1 parent faf141b commit 96b3b21
Showing 1 changed file with 3 additions and 9 deletions.
12 changes: 3 additions & 9 deletions src/banks/extensions/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ def parse(self, parser):
args.append(nodes.Const(None))

if parser.environment.is_async:
return nodes.Output([self.call_method("_agenerate", args)]).set_lineno(
lineno
)
return nodes.Output([self.call_method("_agenerate", args)]).set_lineno(lineno)
return nodes.Output([self.call_method("_generate", args)]).set_lineno(lineno)

def _generate(self, text, model_name=DEFAULT_MODEL):
Expand All @@ -64,9 +62,7 @@ def _generate(self, text, model_name=DEFAULT_MODEL):
{"role": "system", "content": SYSTEM_PROMPT.text()},
{"role": "user", "content": text},
]
response: ModelResponse = cast(
ModelResponse, completion(model=model_name, messages=messages)
)
response: ModelResponse = cast(ModelResponse, completion(model=model_name, messages=messages))
return self._get_content(response)

async def _agenerate(self, text, model_name=DEFAULT_MODEL):
Expand All @@ -79,9 +75,7 @@ async def _agenerate(self, text, model_name=DEFAULT_MODEL):
{"role": "system", "content": SYSTEM_PROMPT.text()},
{"role": "user", "content": text},
]
response: ModelResponse = cast(
ModelResponse, await acompletion(model=model_name, messages=messages)
)
response: ModelResponse = cast(ModelResponse, await acompletion(model=model_name, messages=messages))
return self._get_content(response)

def _get_content(self, response: ModelResponse) -> str:
Expand Down

0 comments on commit 96b3b21

Please sign in to comment.