Skip to content

Commit

Permalink
Fix llama in auto
Browse files Browse the repository at this point in the history
  • Loading branch information
dnakov authored and trufae committed Oct 22, 2024
1 parent 94820d1 commit c6987e5
Show file tree
Hide file tree
Showing 10 changed files with 33 additions and 2,034 deletions.
153 changes: 0 additions & 153 deletions r2ai/anthropic.py

This file was deleted.

44 changes: 28 additions & 16 deletions r2ai/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from transformers import AutoTokenizer
from . import index
from .pipe import have_rlang, r2lang, get_r2_inst
from litellm import _should_retry, acompletion, utils
from litellm import _should_retry, acompletion, utils, ModelResponse
import asyncio
from r2ai.pipe import get_r2_inst
from .tools import r2cmd, run_python
Expand Down Expand Up @@ -41,7 +41,7 @@
"""

class ChatAuto:
def __init__(self, model, system=None, tools=None, messages=None, tool_choice='auto', cb=None ):
def __init__(self, model, system=None, tools=None, messages=None, tool_choice='auto', llama_instance=None, cb=None ):
self.functions = {}
self.tools = []
self.model = model
Expand All @@ -60,6 +60,7 @@ def __init__(self, model, system=None, tools=None, messages=None, tool_choice='a
self.tools.append({ "type": "function", "function": f })
self.functions[f['name']] = tool
self.tool_choice = tool_choice
self.llama_instance = llama_instance

#self.tool_end_message = '\nNOTE: The user saw this output, do not repeat it.'

Expand Down Expand Up @@ -130,24 +131,35 @@ async def process_streaming_response(self, resp):
self.messages.append({"role": "assistant", "content": response_message})
return response_message

async def attempt_completion(self):
args = {
"temperature": 0,
"tools": self.tools,
"tool_choice": self.tool_choice,
"stream": True
}
if self.llama_instance:
return self.llama_instance.create_chat_completion(self.messages, **args)

return await acompletion(
model=self.model,
messages=self.messages,
**args
)

async def get_completion(self):
if self.llama_instance:
response = await self.attempt_completion()
async def async_generator(response):
for item in response:
yield ModelResponse(stream=True, **item)
return await self.process_streaming_response(async_generator(response))
max_retries = 5
base_delay = 2

async def attempt_completion():
return await acompletion(
model=self.model,
messages=self.messages,
# max_tokens=4096,
temperature=0,
tools=self.tools,
tool_choice=self.tool_choice,
stream=True
)

for retry_count in range(max_retries):
try:
response = await attempt_completion()
response = await self.attempt_completion()
return await self.process_streaming_response(response)
except Exception as e:
print(e)
Expand Down Expand Up @@ -184,7 +196,7 @@ def cb(type, data):
def signal_handler(signum, frame):
raise KeyboardInterrupt

def chat(interpreter):
def chat(interpreter, llama_instance=None):
model = interpreter.model.replace(":", "/")
tools = [r2cmd, run_python]
messages = interpreter.messages
Expand All @@ -199,7 +211,7 @@ def chat(interpreter):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

chat_auto = ChatAuto(model, system=SYSTEM_PROMPT_AUTO, tools=tools, messages=messages, tool_choice=tool_choice, cb=cb)
chat_auto = ChatAuto(model, system=SYSTEM_PROMPT_AUTO, tools=tools, messages=messages, tool_choice=tool_choice, llama_instance=llama_instance, cb=cb)

original_handler = signal.getsignal(signal.SIGINT)

Expand Down
Empty file removed r2ai/functionary/__init__.py
Empty file.
103 changes: 0 additions & 103 deletions r2ai/functionary/openai_types.py

This file was deleted.

44 changes: 0 additions & 44 deletions r2ai/functionary/prompt_template/__init__.py

This file was deleted.

Loading

0 comments on commit c6987e5

Please sign in to comment.