Skip to content

Commit

Permalink
Fix llama in auto
Browse files Browse the repository at this point in the history
  • Loading branch information
dnakov committed Oct 22, 2024
1 parent 94820d1 commit f9dd4b0
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 17 deletions.
44 changes: 28 additions & 16 deletions r2ai/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from transformers import AutoTokenizer
from . import index
from .pipe import have_rlang, r2lang, get_r2_inst
from litellm import _should_retry, acompletion, utils
from litellm import _should_retry, acompletion, utils, ModelResponse
import asyncio
from r2ai.pipe import get_r2_inst
from .tools import r2cmd, run_python
Expand Down Expand Up @@ -41,7 +41,7 @@
"""

class ChatAuto:
def __init__(self, model, system=None, tools=None, messages=None, tool_choice='auto', cb=None ):
def __init__(self, model, system=None, tools=None, messages=None, tool_choice='auto', llama_instance=None, cb=None ):
self.functions = {}
self.tools = []
self.model = model
Expand All @@ -60,6 +60,7 @@ def __init__(self, model, system=None, tools=None, messages=None, tool_choice='a
self.tools.append({ "type": "function", "function": f })
self.functions[f['name']] = tool
self.tool_choice = tool_choice
self.llama_instance = llama_instance

#self.tool_end_message = '\nNOTE: The user saw this output, do not repeat it.'

Expand Down Expand Up @@ -130,24 +131,35 @@ async def process_streaming_response(self, resp):
self.messages.append({"role": "assistant", "content": response_message})
return response_message

async def attempt_completion(self):
args = {
"temperature": 0,
"tools": self.tools,
"tool_choice": self.tool_choice,
"stream": True
}
if self.llama_instance:
return self.llama_instance.create_chat_completion(self.messages, **args)

return await acompletion(
model=self.model,
messages=self.messages,
**args
)

async def get_completion(self):
if self.llama_instance:
response = await self.attempt_completion()
async def async_generator(response):
for item in response:
yield ModelResponse(stream=True, **item)
return await self.process_streaming_response(async_generator(response))
max_retries = 5
base_delay = 2

async def attempt_completion():
return await acompletion(
model=self.model,
messages=self.messages,
# max_tokens=4096,
temperature=0,
tools=self.tools,
tool_choice=self.tool_choice,
stream=True
)

for retry_count in range(max_retries):
try:
response = await attempt_completion()
response = await self.attempt_completion()
return await self.process_streaming_response(response)
except Exception as e:
print(e)
Expand Down Expand Up @@ -184,7 +196,7 @@ def cb(type, data):
def signal_handler(signum, frame):
raise KeyboardInterrupt

def chat(interpreter):
def chat(interpreter, llama_instance=None):
model = interpreter.model.replace(":", "/")
tools = [r2cmd, run_python]
messages = interpreter.messages
Expand All @@ -199,7 +211,7 @@ def chat(interpreter):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

chat_auto = ChatAuto(model, system=SYSTEM_PROMPT_AUTO, tools=tools, messages=messages, tool_choice=tool_choice, cb=cb)
chat_auto = ChatAuto(model, system=SYSTEM_PROMPT_AUTO, tools=tools, messages=messages, tool_choice=tool_choice, llama_instance=llama_instance, cb=cb)

original_handler = signal.getsignal(signal.SIGINT)

Expand Down
6 changes: 5 additions & 1 deletion r2ai/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,11 @@ def respond(self):
# builtins.print(prompt)
response = None
if self.auto_run:
response = auto.chat(self)
if(is_litellm_model(self.model)):
response = auto.chat(self)
else:
self.llama_instance = new_get_hf_llm(self, self.model, False, int(self.env["llm.window"]))
response = auto.chat(self, llama_instance=self.llama_instance)
return

elif self.model.startswith("kobaldcpp"):
Expand Down

0 comments on commit f9dd4b0

Please sign in to comment.