Skip to content

Commit

Permalink
better import path for ModelResponse
Browse files Browse the repository at this point in the history
  • Loading branch information
masci committed Jun 26, 2024
1 parent 96b3b21 commit 9fa6a34
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions src/banks/extensions/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from jinja2 import nodes
from jinja2.ext import Extension
from litellm import ModelResponse, acompletion, completion
from litellm import acompletion, completion
from litellm.types.utils import ModelResponse

from banks.errors import CanaryWordError
from banks.prompt import Prompt
Expand Down Expand Up @@ -49,7 +50,9 @@ def parse(self, parser):
args.append(nodes.Const(None))

if parser.environment.is_async:
return nodes.Output([self.call_method("_agenerate", args)]).set_lineno(lineno)
return nodes.Output([self.call_method("_agenerate", args)]).set_lineno(
lineno
)
return nodes.Output([self.call_method("_generate", args)]).set_lineno(lineno)

def _generate(self, text, model_name=DEFAULT_MODEL):
Expand All @@ -62,7 +65,9 @@ def _generate(self, text, model_name=DEFAULT_MODEL):
{"role": "system", "content": SYSTEM_PROMPT.text()},
{"role": "user", "content": text},
]
response: ModelResponse = cast(ModelResponse, completion(model=model_name, messages=messages))
response: ModelResponse = cast(
ModelResponse, completion(model=model_name, messages=messages)
)
return self._get_content(response)

async def _agenerate(self, text, model_name=DEFAULT_MODEL):
Expand All @@ -75,7 +80,9 @@ async def _agenerate(self, text, model_name=DEFAULT_MODEL):
{"role": "system", "content": SYSTEM_PROMPT.text()},
{"role": "user", "content": text},
]
response: ModelResponse = cast(ModelResponse, await acompletion(model=model_name, messages=messages))
response: ModelResponse = cast(
ModelResponse, await acompletion(model=model_name, messages=messages)
)
return self._get_content(response)

def _get_content(self, response: ModelResponse) -> str:
Expand Down

0 comments on commit 9fa6a34

Please sign in to comment.