diff --git a/r2ai/anthropic.py b/r2ai/anthropic.py
deleted file mode 100644
index caa6970..0000000
--- a/r2ai/anthropic.py
+++ /dev/null
@@ -1,153 +0,0 @@
-import re
-import random
-import string
-
-def get_random_tool_call_id():
- return "call_" + "".join(
- [random.choice(string.ascii_letters + string.digits) for _ in range(24)]
- )
-
-def construct_tool_parameters_prompt(parameters):
- prompt = ""
- props = parameters["properties"]
- for name in props:
- parameter = props[name]
- prompt += (
- "\n"
- f"{name}\n"
- f"{parameter['description']}\n"
- f"{parameter['type']}\n"
- "\n"
- )
- return prompt
-
-def construct_tool_prompt(func):
- tool = func['function']
- prompt = (
- "\n"
- f"{tool['name']}\n"
- "\n"
- f"{tool['description']}\n"
- "\n"
- "\n"
- f"{construct_tool_parameters_prompt(tool['parameters'])}\n"
- "\n"
- ""
- )
- return prompt
-
-def construct_tool_use_system_prompt(tools):
- tool_use_system_prompt = (
- "In this environment you have access to a set of tools "
- "you can use to answer the user's question.\n\n"
- "You may call them like this:\n"
- "\n"
- "\n"
- "$TOOL_NAME\n"
- "\n"
- "<$PARAMETER_NAME>$PARAMETER_VALUE$PARAMETER_NAME>\n"
- "...\n"
- "\n"
- "\n"
- "\n"
- "\n"
- "Here are the tools available:\n"
- "\n"
- + '\n'.join([construct_tool_prompt(tool) for tool in tools]) +
- "\n"
- )
- return tool_use_system_prompt
-
-TAGS = r'|||||||'
-
-def parse_tags(invoke_string):
- tool_name = re.findall(r'.*?', invoke_string, re.DOTALL)
- if not tool_name:
- raise Exception("Missing tags inside of tags.")
- if len(tool_name) > 1:
- raise Exception("More than one tool_name specified inside single set of tags.")
-
- parameters = re.findall(r'.*?', invoke_string, re.DOTALL)
- if not parameters:
- raise Exception("Missing tags inside of tags.")
- if len(parameters) > 1:
- raise Exception("More than one set of tags specified inside single set of tags.")
- # Check for balanced tags inside parameters
- # TODO: This will fail if the parameter value contains <> pattern
- # TODO: or if there is a parameter called parameters. Fix that issue.
- tags = re.findall(r'<.*?>', parameters[0].replace('', '').replace('', ''), re.DOTALL)
- if len(tags) % 2 != 0:
- raise Exception("Imbalanced tags inside tags.")
- return tool_name, parameters, tags
-
-def _function_calls_valid_format_and_invoke_extraction(last_completion):
- """Check if the function call follows a valid format and extract the
- attempted function calls if so. Does not check if the tools actually
- exist or if they are called with the requisite params."""
- # Check if there are any of the relevant XML tags present that would
- # indicate an attempted function call.
- function_call_tags = re.findall(TAGS, last_completion, re.DOTALL)
- if not function_call_tags:
- # TODO: Should we return something in the text to claude indicating
- # that it did not do anything to indicate an attempted function call
- # (in case it was in fact trying to and we missed it)?
- return {"status": True, "invokes": []}
- # Extract content between tags. If there are multiple we
- # will only parse the first and ignore the rest, regardless of their correctness.
- match = re.search(r'(.*)', last_completion, re.DOTALL)
- if not match:
- return {"status": False, "reason": "No valid tags present in your query."}
- func_calls = match.group(1)
-
- prefix_match = re.search(r'^(.*?)', last_completion, re.DOTALL)
- if prefix_match:
- func_call_prefix_content = prefix_match.group(1)
- # Check for invoke tags
- # TODO: Is this faster or slower than bundling with the next check?
- invoke_regex = r'.*?'
- if not re.search(invoke_regex, func_calls, re.DOTALL):
- return {"status": False, "reason": "Missing tags inside of tags."}
- # Check each invoke contains tool name and parameters
- invoke_strings = re.findall(invoke_regex, func_calls, re.DOTALL)
- invokes = []
- for invoke_string in invoke_strings:
- try:
- tool_name, parameters, tags = parse_tags(invoke_string)
- except Exception as e:
- return {"status": False, "reason": e}
-
- # Loop through the tags and check if each even-indexed tag matches the
- # tag in the position after it (with the / of course). If valid store
- # their content for later use.
- # TODO: Add a check to make sure there aren't duplicates provided of a given parameter.
- arguments = {}
- for i in range(0, len(tags), 2):
- opening_tag = tags[i]
- closing_tag = tags[i+1]
- closing_tag_without_second_char = closing_tag[:1] + closing_tag[2:]
- if closing_tag[1] != '/' or opening_tag != closing_tag_without_second_char:
- return {"status": False, "reason": "Non-matching opening and closing tags inside tags."}
- arguments[opening_tag[1:-1]] = re.search(rf'{opening_tag}(.*?){closing_tag}', parameters[0], re.DOTALL).group(1)
- # Parse out the full function call
- invokes.append({
- "function": {
- "name": tool_name[0].replace('', '').replace('', ''),
- "arguments": arguments,
- },
- "id": get_random_tool_call_id()
- })
- return {"status": True, "invokes": invokes, "prefix_content": func_call_prefix_content}
-
-def extract_claude_tool_calls(interpreter, stream):
- msg = ''
- res = None
- for event in stream:
- if event.type == "content_block_delta":
- delta = event.delta
- msg += delta.text
- res = _function_calls_valid_format_and_invoke_extraction(msg)
- if res["status"] is True and "invokes" in res and len(res["invokes"]) > 0:
- interpreter.messages.append({ "role": "assistant", "content": msg})
- return res["invokes"], res["prefix_content"]
- interpreter.messages.append({ "role": "assistant", "content": msg})
- return [], re.sub(r'.*', '', msg)
diff --git a/r2ai/auto.py b/r2ai/auto.py
index ce212e8..8188a90 100644
--- a/r2ai/auto.py
+++ b/r2ai/auto.py
@@ -8,7 +8,7 @@
from transformers import AutoTokenizer
from . import index
from .pipe import have_rlang, r2lang, get_r2_inst
-from litellm import _should_retry, acompletion, utils
+from litellm import _should_retry, acompletion, utils, ModelResponse
import asyncio
from r2ai.pipe import get_r2_inst
from .tools import r2cmd, run_python
@@ -41,7 +41,7 @@
"""
class ChatAuto:
- def __init__(self, model, system=None, tools=None, messages=None, tool_choice='auto', cb=None ):
+ def __init__(self, model, system=None, tools=None, messages=None, tool_choice='auto', llama_instance=None, cb=None ):
self.functions = {}
self.tools = []
self.model = model
@@ -60,6 +60,7 @@ def __init__(self, model, system=None, tools=None, messages=None, tool_choice='a
self.tools.append({ "type": "function", "function": f })
self.functions[f['name']] = tool
self.tool_choice = tool_choice
+ self.llama_instance = llama_instance
#self.tool_end_message = '\nNOTE: The user saw this output, do not repeat it.'
@@ -130,24 +131,35 @@ async def process_streaming_response(self, resp):
self.messages.append({"role": "assistant", "content": response_message})
return response_message
+ async def attempt_completion(self):
+ args = {
+ "temperature": 0,
+ "tools": self.tools,
+ "tool_choice": self.tool_choice,
+ "stream": True
+ }
+ if self.llama_instance:
+ return self.llama_instance.create_chat_completion(self.messages, **args)
+
+ return await acompletion(
+ model=self.model,
+ messages=self.messages,
+ **args
+ )
+
async def get_completion(self):
+ if self.llama_instance:
+ response = await self.attempt_completion()
+ async def async_generator(response):
+ for item in response:
+ yield ModelResponse(stream=True, **item)
+ return await self.process_streaming_response(async_generator(response))
max_retries = 5
base_delay = 2
- async def attempt_completion():
- return await acompletion(
- model=self.model,
- messages=self.messages,
- # max_tokens=4096,
- temperature=0,
- tools=self.tools,
- tool_choice=self.tool_choice,
- stream=True
- )
-
for retry_count in range(max_retries):
try:
- response = await attempt_completion()
+ response = await self.attempt_completion()
return await self.process_streaming_response(response)
except Exception as e:
print(e)
@@ -184,7 +196,7 @@ def cb(type, data):
def signal_handler(signum, frame):
raise KeyboardInterrupt
-def chat(interpreter):
+def chat(interpreter, llama_instance=None):
model = interpreter.model.replace(":", "/")
tools = [r2cmd, run_python]
messages = interpreter.messages
@@ -199,7 +211,7 @@ def chat(interpreter):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
- chat_auto = ChatAuto(model, system=SYSTEM_PROMPT_AUTO, tools=tools, messages=messages, tool_choice=tool_choice, cb=cb)
+ chat_auto = ChatAuto(model, system=SYSTEM_PROMPT_AUTO, tools=tools, messages=messages, tool_choice=tool_choice, llama_instance=llama_instance, cb=cb)
original_handler = signal.getsignal(signal.SIGINT)
diff --git a/r2ai/functionary/__init__.py b/r2ai/functionary/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/r2ai/functionary/openai_types.py b/r2ai/functionary/openai_types.py
deleted file mode 100644
index a3abb22..0000000
--- a/r2ai/functionary/openai_types.py
+++ /dev/null
@@ -1,103 +0,0 @@
-import time
-from typing import List, Literal, Optional
-
-from pydantic import BaseModel, Field
-
-
-class FunctionCall(BaseModel):
- name: Optional[str] = None
- arguments: str
-
-
-class ToolCall(BaseModel):
- index: Optional[int] = None
- id: Optional[str] = None
- function: FunctionCall
- type: Optional[str] = "function"
-
-
-class Function(BaseModel):
- name: str
- description: Optional[str] = Field(default="")
- parameters: Optional[dict] = None
-
-
-class Tool(BaseModel):
- type: Literal["function", "code_interpreter"] = "function"
- function: Optional[Function] = None
-
-
-class ChatMessage(BaseModel):
- role: Optional[str] = None
- tool_call_id: Optional[str] = None
- content: Optional[str] = None
- name: Optional[str] = None
- function_call: Optional[FunctionCall] = None
- tool_calls: Optional[List[ToolCall]] = None
-
- def __str__(self) -> str:
- if self.role == "system":
- return f"system:\n{self.content}\n"
-
- elif self.role == "function":
- return f"function name={self.name}:\n{self.content}\n"
-
- elif self.role == "user":
- if self.content is None:
- return "user:\n"
- else:
- return f"user:\n{self.content}\n"
-
- elif self.role == "assistant":
- if self.content is not None and self.function_call is not None:
- return f"assistant:\n{self.content}\nassistant to={self.function_call.name}:\n{self.function_call.arguments}"
-
- elif self.function_call is not None:
- return f"assistant to={self.function_call.name}:\n{self.function_call.arguments}"
-
- elif self.content is None:
- return "assistant"
-
- else:
- return f"assistant:\n{self.content}\n"
-
- else:
- raise ValueError(f"Unsupported role: {self.role}")
-
-
-class ChatInput(BaseModel):
- messages: List[ChatMessage]
- functions: Optional[List[Function]] = None
- tools: Optional[List[Tool]] = None
- temperature: float = 0.9
- stream: bool = False
-
-
-class Choice(BaseModel):
- message: ChatMessage
- finish_reason: str = "stop"
- index: int = 0
-
- @classmethod
- def from_message(cls, message: ChatMessage, finish_reason: str):
- return cls(message=message, finish_reason=finish_reason)
-
-
-class ChatCompletion(BaseModel):
- id: str
- object: str = "chat.completion"
- created: float = Field(default_factory=time.time)
- choices: List[Choice]
-
-
-class StreamChoice(BaseModel):
- delta: ChatMessage
- finish_reason: Optional[str] = "stop"
- index: int = 0
-
-
-class ChatCompletionChunk(BaseModel):
- id: str
- object: str = "chat.completion.chunk"
- created: float = Field(default_factory=time.time)
- choices: List[StreamChoice]
diff --git a/r2ai/functionary/prompt_template/__init__.py b/r2ai/functionary/prompt_template/__init__.py
deleted file mode 100644
index 338d8b2..0000000
--- a/r2ai/functionary/prompt_template/__init__.py
+++ /dev/null
@@ -1,44 +0,0 @@
-from typing import Any
-
-from r2ai.functionary.prompt_template.base_template import (
- SYSTEM_MESSAGE,
- PredefinedFuncTypes,
- PromptTemplate,
-)
-from r2ai.functionary.prompt_template.prompt_template_v1 import PromptTemplateV1
-from r2ai.functionary.prompt_template.prompt_template_v2 import PromptTemplateV2
-
-
-def get_default_prompt_template() -> PromptTemplate:
- """Return default prompt template to be used
-
- Returns:
- _type_: _description_
- """
- return PromptTemplateV2.get_prompt_template()
-
-
-def get_prompt_template_by_version(version: str) -> PromptTemplate:
- if version == "v1":
- return PromptTemplateV1.get_prompt_template()
- return PromptTemplateV2.get_prompt_template()
-
-
-def get_prompt_template_from_tokenizer(tokenizer: Any) -> PromptTemplate:
- """This function will determine the prompt template based on tokenizer.
- Under the hood, this function will check if tokenizer contains some special tokens from template or not
-
- Args:
- tokenizer (Any): Tokenizer
-
- Returns:
- _type_: _description_
- """
- p1 = PromptTemplateV1.get_prompt_template()
- p2 = PromptTemplateV2.get_prompt_template()
- token_ids = tokenizer.encode(p1.start_function, add_special_tokens=False)
- if token_ids[0] in [29871, 28705]:
- token_ids = token_ids[1:]
- if len(token_ids) == 1:
- return p1
- return p2
diff --git a/r2ai/functionary/prompt_template/base_template.py b/r2ai/functionary/prompt_template/base_template.py
deleted file mode 100644
index 7a8a00b..0000000
--- a/r2ai/functionary/prompt_template/base_template.py
+++ /dev/null
@@ -1,583 +0,0 @@
-from __future__ import annotations
-
-import json
-import re
-from abc import abstractmethod
-from enum import Enum
-from typing import Any, Dict, List, Literal, Optional, Tuple, Union
-
-from r2ai.functionary.schema import generate_schema_from_functions
-
-SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""
-PYTHON_RUN_SYS_MSG = "When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files."
-
-
-class PredefinedFuncTypes(str, Enum):
- no_tool_call = "no-tool-call"
- code_interpreter = "code-interpreter"
-
-
-class PromptTemplate:
- _instance = None
-
- @abstractmethod
- def get_start_of_function_call_token(self) -> str:
- """returns a token that indicates the start of a function call in the prompt template
- Returns:
- str: a string token
- """
- raise NotImplementedError
-
- @abstractmethod
- def get_stop_token_for_function_parameter(
- self, stage: Literal["function", "parameter"]
- ) -> str:
- """returns a str token which stops function/parameter name generation
- e.g.: `"get_current_weather` with v1 prompt template -> returns id = 28747 (':' token)
- so the generation gets forced towards `"get_current_weather:\n{...`
- Args:
- stage (str): Whether to get function name or parameter name stopping token
- Returns:
- str: str token
- """
- raise NotImplementedError
-
- def get_predefined_function_names(self, function_types: Any) -> List[str]:
- """returns a list of predefined function names. Some prompt template versions may
- require a default/predefined function name to indicate for example, no function called.
- E.g.: in v2, 'all' is generated to indicate normal model response. In this case, the v2
- subclass will overwrite this base method.
- Args:
- function_types (Any): Either "all" or one of the function type in PredefinedFuncTypes enum class
- Returns:
- List[str]: list of predefined function names (default to [])
- """
- return []
-
- @abstractmethod
- def initialize_grammar_sampling_gen_state(self, tool_choice: Optional[Any]) -> Dict:
- """initializes and returns a new generation state. Each template version may be initialized
- at different starting stage
- Args:
- tool_choice (Optional[Any]): the tool_choice provided by the user, if any
- Returns:
- dict: the gen_state. It contains the following:
- - stage: one of the following:
- - pre-function: the generation prior to function name generation
- - function: when the model is generating a function name
- - pre-parameter: when the model is generating the part between function name and parameter
- - parameter-name: when the model is generating a parameter name
- - parameter-value: when the model is generating a parameter value
- - no-tool-call: when the model is generating content
- - curr_tokens: all the tokens for the current stage being generated
- - curr_text: curr_tokens but in string text form
- - func_name: the function name, if any
- - param_names: the parameters names, if any
- """
- raise NotImplementedError
-
- def update_grammar_sampling_gen_state(
- self,
- gen_state: Dict,
- new_token_id: int,
- options: Optional[List],
- tokenizer: Any,
- ) -> Dict:
- """Receives a generation state, updates and returns it. This is only used when
- grammar sampling is enabled in inference. This functions parses the generated
- tokens and identifies the stage of generation (pre-function, function, parameter-name,
- etc.)
- Args:
- gen_state (Dict): The current generation state. It contains the following:
- - stage: one of the following:
- - pre-function: the generation prior to function name generation
- - function: when the model is generating a function name
- - pre-parameter: when the model is generating the part between function name and parameter
- - parameter-name: when the model is generating a parameter name
- - parameter-value: when the model is generating a parameter value
- - no-tool-call: when the model is generating content
- - code-interpreter: when the model is generating code
- - curr_tokens: all the tokens for the current stage being generated
- - curr_text: curr_tokens but in string text form
- - func_name: the function name, if any
- - param_names: the parameters names, if any
- new_token_id (int): The token id of the newly sampled token
- options (List): All available function/param names depending on the stage of gen_state
- tokenizer (Any): The tokenizer class passed in from Transformers or vLLM
- Returns:
- dict: The updated gen_state
- """
- # Update curr_tokens and curr_text
- gen_state["curr_tokens"].append(new_token_id)
- gen_state["curr_text"] = tokenizer.decode(gen_state["curr_tokens"])
-
- # v1: "assistant:\n{content}\n{self.start_function}{function}:\n{arguments}\n"
- # v2: "{func_name}\n{param_names}\n<|from|> assistant\n<|recipient|>"
- if gen_state["stage"] == "pre-function":
- # Check if the new state is in "function" stage
- if gen_state["curr_text"].endswith(self.get_start_of_function_call_token()):
- gen_state = {
- "stage": "function",
- "curr_tokens": [],
- "curr_text": "",
- "func_name": "",
- "param_names": [],
- "add_predefined_fns": gen_state["add_predefined_fns"],
- }
- gen_state["stage"] = "function"
- elif gen_state["stage"] == "function":
- # Remove all unnecessary suffixes by checking whether stop token is in curr_text
- if (
- self.get_stop_token_for_function_parameter(stage="function")
- in gen_state["curr_text"]
- ):
- curr_text = gen_state["curr_text"].rstrip()
- while True:
- if any([curr_text == option for option in options]):
- break
- curr_text = curr_text[:-1]
- gen_state["func_name"] = curr_text
- else:
- gen_state["func_name"] = gen_state["curr_text"].rstrip()
-
- # Check if the new state is in "pre-parameter" stage
- if (
- sum([gen_state["func_name"] == option for option in options]) == 1
- and sum(
- [option.startswith(gen_state["func_name"]) for option in options]
- )
- == 1
- ):
- gen_state["stage"] = "pre-parameter"
-
- # Update curr_text and curr_tokens
- if (
- self.get_stop_token_for_function_parameter(stage="function")
- in gen_state["curr_text"]
- ):
- gen_state["curr_text"] = tokenizer.decode([new_token_id])
- gen_state["curr_tokens"] = [new_token_id]
- else:
- gen_state["curr_text"], gen_state["curr_tokens"] = "", []
- elif gen_state["stage"] == "pre-parameter":
- # Check if the new state is in "parameter" or "no-tool-call" or "code-interpreter" stage
- if self.fn_param_sep_token.rstrip("{").rstrip() in gen_state["curr_text"]:
- if gen_state["func_name"] in self.get_predefined_function_names(
- function_types=PredefinedFuncTypes.no_tool_call
- ):
- gen_state["stage"] = "no-tool-call"
- elif gen_state["func_name"] in self.get_predefined_function_names(
- function_types=PredefinedFuncTypes.code_interpreter
- ):
- gen_state["stage"] = "code-interpreter"
- # Either '{' or '{"' or '{}'
- elif self.fn_param_sep_token in gen_state["curr_text"]:
- # Check if no arguments are called and go straight to "pre-function"
- if "}" in gen_state["curr_text"]:
- gen_state["stage"] = "pre-function"
- elif '"' in gen_state["curr_text"]:
- gen_state["stage"] = "parameter-name"
- if gen_state["curr_text"].endswith('"'):
- gen_state["curr_text"], gen_state["curr_tokens"] = "", []
- else:
- gen_state["curr_tokens"] = [new_token_id]
- gen_state["curr_text"] = tokenizer.decode([new_token_id])
- elif gen_state["stage"] == "parameter-name":
- # Get the latest param
- latest_param_str = gen_state["curr_text"]
-
- # Remove unneccesary prefixes before the parameter-name part
- if len(gen_state["curr_tokens"]) > 0 and '"' in tokenizer.decode(
- [gen_state["curr_tokens"][0]]
- ):
- latest_param_str = latest_param_str[latest_param_str.find('"') + 1 :]
-
- # Check if the new state is in "parameter-value" stage
- stop_token = self.get_stop_token_for_function_parameter(stage="parameter")
- if stop_token in latest_param_str:
- pattern = stop_token + r".*$"
- match_res = re.search(pattern, latest_param_str, re.DOTALL)
- if bool(match_res):
- gen_state["param_names"].append(
- gen_state["curr_text"].removesuffix(match_res.group(0))
- )
- gen_state["stage"] = "parameter-value"
- gen_state["curr_text"] = match_res.group(0)
- new_tokens = []
- for token in gen_state["curr_tokens"][::-1]:
- new_tokens = [token] + new_tokens
- next_text = tokenizer.decode(new_tokens)
- if next_text.endswith(match_res.group(0)):
- gen_state["curr_tokens"] = new_tokens
- break
- elif gen_state["stage"] == "parameter-value":
- latest_param_val = gen_state["curr_text"]
- stop_token = self.get_stop_token_for_function_parameter(stage="parameter")
-
- # Remove unnecessary prefixes in latest_param_val
- if not latest_param_val.startswith(stop_token):
- latest_param_val = latest_param_val[latest_param_val.find(stop_token) :]
-
- # Check if the new state is in "pre-function" stage
- try:
- _ = json.loads('{"' + gen_state["param_names"][-1] + latest_param_val)
- gen_state["stage"] = "pre-function"
- except Exception:
- pass
-
- # Check if the current state can be converted to json, it means the
- # new state is back to "parameter-name" stage
- pattern = r',[\s]*"'
- match_res = re.findall(pattern, latest_param_val, re.DOTALL)
- if '"' in tokenizer.decode(new_token_id) and len(match_res) > 0:
- latest_match = match_res[-1]
- try:
- _ = json.loads(
- '{"'
- + gen_state["param_names"][-1]
- + latest_param_val[: latest_param_val.rfind(latest_match)]
- + "}"
- )
- gen_state["stage"] = "parameter-name"
- if latest_param_val.endswith('"'):
- gen_state["curr_text"], gen_state["curr_tokens"] = "", []
- else:
- gen_state["curr_tokens"] = [new_token_id]
- gen_state["curr_text"] = tokenizer.decode([new_token_id])
- except Exception:
- pass
- elif gen_state["stage"] in ["no-tool-call", "code-interpreter"]:
- # probability of stop token is not 100% at the end of no-tool-call
- # We still need to check if the stage will go to "function" by checking
- # for the presence of the start_of_function_call token
- if gen_state["curr_text"].endswith(self.get_start_of_function_call_token()):
- gen_state = {
- "stage": "function",
- "curr_tokens": [],
- "curr_text": "",
- "func_name": "",
- "param_names": [],
- "add_predefined_fns": gen_state["add_predefined_fns"],
- }
-
- return gen_state
-
- def grammar_sample(
- self,
- gen_state: Dict,
- tools_or_functions: List,
- delta_token_ids: List,
- model_sampled_token_id: int,
- tokenizer: Any,
- ) -> Tuple[int, str]:
- """Applies grammar-sampling to the token generation and returns a
- newly sampled token.
-
- This function checks whether the model-sampled token helps towards
- forming one of the function names or parameter names. It loops through
- a list of token ids sorted in descending order by the log probabilities.
- It replaces the output token if the grammar-sampled token is different
- from the model-sampled token
- Args:
- gen_state (Dict): The current generation state
- options (List): The list of available function/parameter names depending on gen_state["stage"]
- delta_token_ids (List): The list of delta token ids sorted in descending order by log probabilities
- model_sampled_token_id (int): The token id of the token sampled by model
- tokenizer (Any): The tokenizer object passed in from Transformers, vLLM, etc.
- Returns:
- Tuple[int, str]: Tuple of grammar-sampled token id and grammar-sampled token in str format
- """
- grammar_sampled_token_id, grammar_sampled_token = None, None
-
- # Form the functions/parameters options
- options = []
- if gen_state["stage"] in ["pre-function", "function"]:
- options = [tool_or_func["name"] for tool_or_func in tools_or_functions]
- elif gen_state["stage"] == "pre-parameter":
- options = [self.fn_param_sep_token]
- else:
- func_name = gen_state["func_name"]
- for tool_or_func in tools_or_functions:
- if tool_or_func["name"] == func_name:
- options = list(tool_or_func["parameters"]["properties"].keys())
- break
- # Assume prompt template versions > 1 have "all" in function options
- # Subjected to changes in future versions
- # Concatenate the list of predefined function names in the respective prompt
- # template version. For e.g., v2 returns ["all"]
- if gen_state["stage"] == "function" and gen_state["add_predefined_fns"] is True:
- options += self.get_predefined_function_names(function_types="all")
-
- # No grammar sampling needed if gen_state not in "function" or "pre-parameter"
- # or "parameter-name" stages. Just return the model_sampled_token_id
- if gen_state["stage"] not in ["function", "pre-parameter", "parameter-name"]:
- grammar_sampled_token_id = model_sampled_token_id
- grammar_sampled_token = tokenizer.decode([model_sampled_token_id])
-
- # Loop through the list of token ids sorted in descending order. Form a mask made
- # up of booleans where the index of the mask == index of function/parameter name
- # in function/parameter options. The element is True if the sampled_token
- # helps in forming the function/parameter name. Else, False.
- if grammar_sampled_token_id is None:
- for i, sampled_token_ind in enumerate(delta_token_ids):
- sampled_token = tokenizer.decode(
- [sampled_token_ind], add_special_tokens=False
- )
- # Form the function name with the current sampled token id
- new_curr_tokens_id = gen_state["curr_tokens"] + [sampled_token_ind]
- new_curr_tokens = tokenizer.decode(new_curr_tokens_id)
-
- if gen_state["stage"] == "function":
- options_mask = [
- (
- True
- if option.startswith(new_curr_tokens.lstrip(" "))
- or new_curr_tokens.lstrip(" ").startswith(option)
- else False
- )
- for option in options
- ]
-
- # - In case of two fns having common prefixes (e.g.: get_weather and
- # get_weather_and_time), we need to iterate until parts of the
- # fn_param_sep_token is present in new_curr_tokens to know if the
- # shorter or longer function name is preferred by the model.
- # - Reject the whitespace (" ") and empty ("") tokens
- if any(options_mask) and sampled_token.strip(" ") != "":
- grammar_sampled_token_id = sampled_token_ind
- grammar_sampled_token = sampled_token
- break
- elif gen_state["stage"] == "pre-parameter":
- # Get the suffix after fn_param_sep_token and check if crit_char is in it
- if self.fn_param_sep_token in new_curr_tokens:
- suffix = new_curr_tokens[
- new_curr_tokens.index(self.fn_param_sep_token)
- + len(self.fn_param_sep_token) :
- ]
- else:
- suffix = new_curr_tokens
- crit_bool = any([crit_char in suffix for crit_char in ['"', "}"]])
-
- options_mask = []
- for option in options:
- if option.startswith(new_curr_tokens.lstrip(" ")) or crit_bool:
- options_mask.append(True)
- else:
- options_mask.append(False)
-
- # We just need to check if the option (fn_param_sep_token) is True
- # or fn_param_sep_token + one of ['}', '"'] is present
- if any(options_mask) and sampled_token.strip(" ") != "":
- grammar_sampled_token_id = sampled_token_ind
- grammar_sampled_token = sampled_token
- break
- else:
- # Mask away those wellformed parameter names while creating options_mask
- wellformed_params = gen_state["param_names"]
-
- # Remove unneccesary prefixes before the parameter-name part
- if len(gen_state["curr_tokens"]) > 0 and '"' in tokenizer.decode(
- [gen_state["curr_tokens"][0]]
- ):
- new_curr_tokens = new_curr_tokens[
- new_curr_tokens.find('"') + 1 :
- ]
-
- options_mask = []
- for option in options:
- if option not in wellformed_params and option.startswith(
- new_curr_tokens
- ):
- options_mask.append(True)
- else:
- options_mask.append(False)
-
- # Same logic as function name, except that we check whether the token
- # is a stopping token for parameter name generation.
- if (
- (
- self.get_stop_token_for_function_parameter(
- stage="parameter"
- )
- in new_curr_tokens
- )
- or any(options_mask)
- and sampled_token.strip(" ") != ""
- ):
- grammar_sampled_token_id = sampled_token_ind
- grammar_sampled_token = sampled_token
- break
-
- # Update gen_state
- return (
- grammar_sampled_token_id,
- grammar_sampled_token,
- self.update_grammar_sampling_gen_state(
- gen_state=gen_state,
- new_token_id=grammar_sampled_token_id,
- options=options,
- tokenizer=tokenizer,
- ),
- )
-
- @abstractmethod
- def get_additional_tokens(self) -> List[str]:
- """return list of added tokens if using this template
- Returns:
- List[str]: list of tokens, each token is a string
- """
- raise NotImplementedError
-
- @abstractmethod
- def convert_message_to_prompt(self, message: Dict) -> str:
- """Return the prompt of this message
-
- Args:
- message (Dict): Dictionary of openAI format
-
- Returns:
- str: prompt of this message
- """
- raise NotImplementedError
-
- @abstractmethod
- def get_stop_tokens_for_generation(self) -> List[str]:
- """Function to get list of stop tokens in generation
-
- Returns:
- List[str]: list of stop tokens
- """
- raise NotImplementedError
-
- @abstractmethod
- def get_assistant_prefixes(self) -> List[str]:
- """Return the assistant prefixs in the final prompt, this is used for masking the labels
- in unmasking labels, the system will unmask chunks that start with assistant prefixs and end with stop tokens.
- For example, assistant_prefixes might be: "<|from|>assistant\n<|recipient|>"
- In this case unmasked chunks in labels would be tokens in ... of: <|from|>assistant\n<|recipient|> ... <|stop|>
- Returns:
- List[str]: list of possible assistant prefixs
- """
- raise NotImplementedError
-
- def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Dict]:
- """This function is used if we need to process messages before doing inference.
- This is used when the messages in training and inference are different.
- For example, in training we have no: tool_call_id, but in inference, we have tool_call_id to know the order of function calls.
- This function woule be called to convert inference messages to the format of training messages.
- Args:
- messages (List[Dict]): list of input messages
-
- Returns:
- List[Dict]: list of output messages
- """
- return messages
-
- def get_prompt_from_messages(
- self,
- messages: List[Dict],
- tools_or_functions: Optional[List[Dict]] = None,
- ) -> str:
- """This function is used to get the complete prompt for list of messages
-
- Args:
- messages (List[Dict]): List of messages
- tools_or_functions (Optional[List[Dict]], optional): List of tools or functions. Defaults to None.
-
- Returns:
- str: the prompt for inference/training
- """
- messages_clone = messages.copy() # To avoid modifying the original list
-
- functions = []
- is_code_interpreter = False
- if tools_or_functions is not None:
- for item in tools_or_functions:
- if (
- "function" in item and item["function"] is not None
- ): # new data format: tools: [{"type": xx, "function": xxx}]
- functions.append(item["function"])
- elif "type" in item and item["type"] == "code_interpreter":
- is_code_interpreter = True
- else:
- functions.append(item) # old format
-
- messages_clone.insert(
- 0, {"role": "system", "content": generate_schema_from_functions(functions)}
- )
- if is_code_interpreter:
- messages_clone.insert(1, {"role": "system", "content": PYTHON_RUN_SYS_MSG})
- else:
- messages_clone.insert(1, {"role": "system", "content": SYSTEM_MESSAGE})
-
- full_text = ""
- for message in messages_clone:
- full_text += self.convert_message_to_prompt(message)
- return full_text.strip()
-
- def get_end_token_to_token_id(self, tokenizer: Any) -> Dict[str, int]:
- """return a dictionary mapping from end_token --> token_id
- Args:
- tokenizer (Any): tokenizer in transformers
-
- Returns:
- Dict[int, EndToken]: the mapping from token_id --> end_token
- """
- result = {}
- for item in self.get_stop_tokens_for_generation():
- tok_ids = tokenizer.encode(item, add_special_tokens=False)
- assert len(tok_ids) <= 2, ""
- if len(tok_ids) == 2:
- assert tok_ids[0] in [
- 29871,
- 28705,
- ] # Llama tokenizer adds this token intentionally
- result[item] = tok_ids[-1]
- return result
-
- @abstractmethod
- def parse_assistant_response(
- self, llm_output: str, tool_choice: Optional[Any]
- ) -> Dict:
- """This function is used to parse llm_output to the Message of OpenAI ({"role": xxx, "content": xxx, ...})
- this is used in inference.
- Args:
- llm_output (str): The generated content from Model
- tool_choice (Optional[Any]): Any choice of tool provided by the user
-
- Returns:
- Dict: Dictionary of OpenAI message format
- """
- raise NotImplementedError
-
- @abstractmethod
- def update_response_state_from_delta_text(
- self,
- *,
- current_state: Dict[str, Any],
- delta_text: str,
- finish_reason: Optional[str],
- ) -> Tuple[Dict[str, Any], Union[None, Dict, List[Dict]]]:
- """This function is used for streaming
-
- Args:
- current_state (Dict[str, Any]): a dictionary containing the state of the streaming: such as current function_name,
- delta_text: new token generated
- finish_reason: if finished or not
-
- Returns:
- Tuple[Dict[str, Any], Optional[Dict]]: updated state, response: can be None, a dictionary: {} or a list of dictionary: [{}, ..., {}]
- """
- raise NotImplementedError
-
- @abstractmethod
- def get_chat_template_jinja(self):
- """Return chat_template in jinja format"""
- raise NotImplementedError
-
- @classmethod
- def get_prompt_template(cls):
- if cls._instance is None:
- cls._instance = cls()
- return cls._instance
- return cls._instance
diff --git a/r2ai/functionary/prompt_template/prompt_template_v1.py b/r2ai/functionary/prompt_template/prompt_template_v1.py
deleted file mode 100644
index 24e2092..0000000
--- a/r2ai/functionary/prompt_template/prompt_template_v1.py
+++ /dev/null
@@ -1,262 +0,0 @@
-import json
-from typing import Any, Dict, List, Literal, Optional, Tuple, Union
-
-from r2ai.functionary.prompt_template.base_template import PromptTemplate
-
-
-class PromptTemplateV1(PromptTemplate):
- start_function = "<|START_OF_FUNCTION_CALL|>"
- end_system = "<|END_OF_SYSTEM|>"
- end_user = "<|END_OF_USER|>"
- end_assistant = "<|END_OF_ASSISTANT|>"
- end_function = "<|END_OF_FUNCTION_RESULT|>"
- end_function_call = "<|END_OF_FUNCTION_CALL|>"
- version = "v1"
- # This token splits between function name and parameters
- fn_param_sep_token = ":\n{"
-
- def get_end_token_from_message(self, message: Dict) -> str:
- """this function is used for getting the end token for each message.
- For example, if message["role"] == "user" --> return EndToken.user
- if message["role"] == "assistant" and "function_call" in message --> EndTOken.function_call
-
- Args:
- message (Dict): A dictionary containing: role, content, function_call(optional)
-
- Returns:
- EndToken: End Token for this message, this will be appended to the end of the prompt for this message
- """
- role = message["role"]
- if role == "user":
- return self.end_user
- elif role == "system":
- return self.end_system
- elif role == "function":
- return self.end_function
- else: # role = assistant
- if message.get("function_call", None) is not None:
- # if "function_call" in message and message["function_call"] is not None:
- return self.end_function_call
- else:
- return self.end_assistant
-
- def get_start_of_function_call_token(self) -> str:
- return self.start_function
-
- def get_stop_token_for_function_parameter(
- self, stage: Literal["function", "parameter"]
- ) -> int:
- if stage == "function":
- return ":" # 28747
- else:
- return '":' # 1264
-
- def initialize_grammar_sampling_gen_state(self) -> Dict:
- return {
- "stage": "pre-function",
- "curr_tokens": [],
- "curr_text": "",
- "func_name": "",
- "param_names": [],
- }
-
- def get_additional_tokens(self) -> List[str]:
- return [
- self.start_function,
- self.end_system,
- self.end_user,
- self.end_assistant,
- self.end_function,
- self.end_function_call,
- ]
-
- def convert_message_to_prompt(self, message: Dict) -> str:
- """convert a message to a string to be included in the prompt
- Args:
- message (Dict): A dictionary in OpenAI format (containing: role, content, function_call (optional))
-
- Returns:
- str: the string used in the final prompt of this message
- """
- end_token = self.get_end_token_from_message(message)
- content = message.get("content", None)
-
- if message["role"] == "system":
- text = f"system:\n{content}{end_token}\n"
-
- elif message["role"] in ["function", "tool"]:
- func_name = message.get("name", "")
- text = f"function name={func_name}:\n{content}{end_token}\n"
-
- elif message["role"] == "user" and content is None:
- text = "user:\n"
-
- elif message["role"] == "user":
- text = f"user:\n{content}{end_token}\n"
-
- elif message["role"] == "assistant":
- if (
- message.get("function_call", None) is not None
- ): # format of openai: {"role": assistant, "function_call": {"name": xxx, "arguments": xxx}}
- function = message["function_call"]["name"]
- arguments = message["function_call"]["arguments"] + end_token
- if content is not None:
- text = f"assistant:\n{content}\n{self.start_function}{function}:\n{arguments}\n"
- else:
- text = (
- f"assistant:\n{self.start_function}{function}:\n{arguments}\n"
- )
- elif content is not None: # this is text content
- text = f"assistant:\n{content}{end_token}\n"
- else: # if no function call and content is None --> this is used at inference
- text = "assistant:"
-
- return text
-
- def get_stop_tokens_for_generation(self) -> List[str]:
- return [self.end_assistant, self.end_function_call]
-
- def get_assistant_prefixes(self) -> List[str]:
- result = []
- for item in [self.end_user, self.end_function]:
- prefix = f"{item}\nassistant:"
- result.append(prefix)
- return result
-
- def parse_assistant_response(
- self, llm_output: str, tool_choice: Optional[Any] = None
- ) -> Dict:
- generated_content = llm_output.strip()
-
- for endtoken in self.get_stop_tokens_for_generation():
- if generated_content.endswith(endtoken):
- generated_content = generated_content[: -len(endtoken)].strip()
-
- # First we need to check if llm_output contains start_token or not
- start_function_index = generated_content.find(self.start_function)
- text_content = generated_content
- result = {"role": "assistant", "content": None}
-
- if start_function_index >= 0:
- func_info = generated_content[
- start_function_index + len(self.start_function) :
- ].strip()
- index = func_info.find(":")
- func_name = func_info[:index].strip()
- arguments = func_info[index + 1 :].strip()
-
- text_content = generated_content[:start_function_index].strip()
- result["function_call"] = {
- "name": func_name,
- "arguments": arguments,
- } # FunctionCall(name=func_name, arguments=arguments)
- if len(text_content) > 0:
- result["content"] = text_content
- return result
-
- def update_response_state_from_delta_text(
- self,
- *,
- current_state: Dict[str, Any],
- delta_text: str,
- finish_reason: Optional[str],
- ) -> Tuple[Dict[str, Any], Optional[Dict]]:
- if len(current_state) == 0:
- current_state = {
- "response_type": None, # the type of current response text (text_response)/function (function_call)
- "func_name": None, # if response_type=function, this is the function_name
- "current_text": "", # the concatenation of generated tokens so far
- }
- current_state["current_text"] += delta_text
- cur_text = current_state["current_text"]
-
- response: Optional[Dict[str, Any]] = None
- if current_state["response_type"] is None:
- if cur_text.strip().startswith(self.start_function): # if function_call
- if cur_text.endswith(":"):
- f_index = cur_text.find(self.start_function)
- func_name = cur_text[
- f_index + len(self.start_function) : -1
- ].strip()
- response = {
- "delta": {
- "role": "assistant",
- "content": None,
- "function_call": {"arguments": "", "name": func_name},
- },
- "finish_reason": None,
- }
- current_state["response_type"] = "function"
- else: # if text_response
- current_state["response_type"] = "text"
- response = {
- "delta": {"content": "", "role": "assistant"},
- "finish_reason": None,
- "index": 0,
- }
-
- elif current_state["response_type"] == "function":
- if finish_reason is None:
- response = {
- "delta": {
- "role": "assistant",
- "function_call": {"arguments": delta_text},
- }, # format of openAI at the second return, don't need to add function_name
- "finish_reason": None,
- "index": 0,
- }
- else:
- response = {
- "delta": {},
- "finish_reason": "function_call",
- "index": 0,
- } # format of openAI at the end, delta must be empty
-
- elif current_state["response_type"] == "text":
- if finish_reason is None:
- # need to check if call a function or not
- if cur_text.endswith(self.start_function): # if call another function
- print("call another function in the mean time")
- cur_text = self.start_function
- current_state["current_text"] = self.start_function
- current_state["response_type"] = None
- else:
- response = {
- "delta": {"content": delta_text, "role": "assistant"},
- "finish_reason": None,
- "index": 0,
- }
- else: # finish generating
- response = {
- "delta": {},
- "finish_reason": finish_reason,
- "index": 0,
- } # format of openAI at the end, delta must be empty
- return current_state, response
-
- def get_chat_template_jinja(self) -> str:
- chat_template = """{% for message in messages %}
- {% if message['role'] == 'user' %}
- {{ message['role'] + ':\n' + message['content'] + '<|END_OF_USER|>' + '\n' }}
- {% elif message['role'] == 'system' %}
- {{ message['role'] + ':\n' + message['content'] + '<|END_OF_SYSTEM|>' + '\n' }}
- {% elif message['role'] == 'function' %}
- {{ 'function name=' + message['name'] + ':\n' + message['content']+ '<|END_OF_FUNCTION_RESULT|>\n' }}
- {% elif message['role'] == 'assistant' %}
- {% if 'function_call' in message and message['function_call'] is not none %}
- {% if message['content'] is not none %}
- {{ 'assistant:\n' + message['content'] + '\n<|START_OF_FUNCTION_CALL|>' + message['function_call']['name'] + ':\n' + message['function_call']['arguments'] + '<|END_OF_FUNCTION_CALL|>\n' }}
- {% else %}
- {{ 'assistant:\n<|START_OF_FUNCTION_CALL|>' + message['function_call']['name'] + ':\n' + message['function_call']['arguments'] + '<|END_OF_FUNCTION_CALL|>\n' }}
- {% endif %}
- {% else %}
- {{ 'assistant:\n' + message['content'] + '<|END_OF_ASSISTANT|>' + '\n' }}
- {% endif %}
- {% endif %}
- {% endfor %}
- {% if add_generation_prompt %}{{ 'assistant:' }}{% endif %}
- """
- chat_template = chat_template.replace(" ", "")
- chat_template = chat_template.replace("
\n", "")
- chat_template = chat_template.strip()
- return chat_template
diff --git a/r2ai/functionary/prompt_template/prompt_template_v2.py b/r2ai/functionary/prompt_template/prompt_template_v2.py
deleted file mode 100644
index 32043f4..0000000
--- a/r2ai/functionary/prompt_template/prompt_template_v2.py
+++ /dev/null
@@ -1,413 +0,0 @@
-import json
-import random
-import string
-from typing import Any, Dict, List, Literal, Optional, Tuple, Union
-
-from r2ai.functionary.openai_types import Tool
-from r2ai.functionary.prompt_template.base_template import (
- PredefinedFuncTypes,
- PromptTemplate,
-)
-
-
-class PromptTemplateV2(PromptTemplate):
- from_token = "<|from|>"
- recipient_token = "<|recipient|>"
- content_token = "<|content|>"
- stop_token = "<|stop|>"
- version = "v2"
- # This token splits between function name and parameters
- fn_param_sep_token = "\n<|content|> {"
- # This maps the predefined function type to its str name
- predefined_func_names = {
- PredefinedFuncTypes.no_tool_call: "all",
- PredefinedFuncTypes.code_interpreter: "python",
- }
-
- def get_start_of_function_call_token(self) -> str:
- return self.recipient_token
-
- def get_stop_token_for_function_parameter(
- self, stage: Literal["function", "parameter"]
- ) -> int:
- if stage == "function":
- return "\n" # 13
- else:
- return '":' # 1264
-
- def get_predefined_function_names(self, function_types: Any) -> List[str]:
- if function_types == "all":
- return [func_name for func_name in self.predefined_func_names.values()]
-
- if not isinstance(function_types, list):
- function_types = [function_types]
-
- predefined_function_names = []
- for function_type in function_types:
- predefined_function_names.append(self.predefined_func_names[function_type])
-
- return predefined_function_names
-
- def initialize_grammar_sampling_gen_state(
- self, tool_choice: str, curr_text: str, curr_tokens: List[int]
- ) -> Dict:
- if tool_choice != "":
- add_predefined_fns = False
- stage = "pre-parameter"
- else:
- add_predefined_fns = True
- stage = "function"
-
- return {
- "stage": stage,
- "curr_tokens": curr_tokens,
- "curr_text": curr_text,
- "func_name": tool_choice,
- "param_names": [],
- "add_predefined_fns": add_predefined_fns,
- }
-
- def get_additional_tokens(self) -> List[str]:
- return [
- self.from_token,
- self.recipient_token,
- self.content_token,
- self.stop_token,
- ]
-
- def convert_message_to_prompt(self, message: Dict) -> str:
- role = message["role"]
- content = message.get("content", None)
-
- if role in [
- "system",
- "user",
- ]: # <|from|>system\n<|recipient|>all\n<|content|>xxx
- return f"{self.from_token}{role}\n{self.recipient_token}all\n{self.content_token}{content}\n"
-
- if role == "tool": # <|from|>tool_name\n<|recipient|>all\n<|content|>xxx
- tool_name = message["name"]
- return f"{self.from_token}{tool_name}\n{self.recipient_token}all\n{self.content_token}{content}\n"
-
- assert role == "assistant"
- tool_calls = message.get("tool_calls", [])
- if tool_calls is None:
- tool_calls = []
- if (
- len(tool_calls) == 0 and content is None
- ): # for inference: <|from|> assistant\n<|recipient|>
- return f"{self.from_token}{role}\n{self.recipient_token}"
-
- if len(tool_calls) == 0: # <|from|>assistant\n<|recipient|>all\n<|content|>xxx
- return f"{self.from_token}{role}\n{self.recipient_token}all\n{self.content_token}{content}{self.stop_token}\n"
-
- result = ""
- if content is not None: # both text-response and function_call
- result += f"{self.from_token}{role}\n{self.recipient_token}all\n{self.content_token}{content}\n"
-
- for tool in tool_calls:
- func_name = tool["function"]["name"]
- arguments = tool["function"]["arguments"]
- # <|from|>assistant\n<|recipient|>func_name\n<|content|>xxxx
- result += f"{self.from_token}{role}\n{self.recipient_token}{func_name}\n{self.content_token}{arguments}\n"
-
- result = result.strip() + f"{self.stop_token}\n"
- return result
-
- def get_stop_tokens_for_generation(self) -> List[str]:
- return [self.stop_token]
-
- def get_assistant_prefixes(self) -> List[str]:
- return [f"{self.from_token}assistant\n{self.recipient_token}"]
-
- def parse_assistant_response(
- self, llm_output: str, tool_choice: Optional[Any] = None
- ) -> Dict:
- for stop in self.get_stop_tokens_for_generation():
- if llm_output.endswith(stop):
- llm_output = llm_output[: -len(stop)]
-
- recipient_to_fill = ""
- if tool_choice is not None:
- if tool_choice == "none":
- recipient_to_fill = self.get_predefined_function_names(
- function_types=PredefinedFuncTypes.no_tool_call
- )[0] + self.get_stop_token_for_function_parameter(stage="function")
- elif isinstance(tool_choice, Tool):
- recipient_to_fill = (
- tool_choice.function.name
- + self.get_stop_token_for_function_parameter(stage="function")
- )
-
- llm_output = (
- f"{self.from_token}assistant\n{self.recipient_token}"
- + recipient_to_fill
- + llm_output
- )
- responses = llm_output.split(self.from_token)
- responses = [response.strip() for response in responses]
-
- tool_calls = []
- text_response = None
- for response in responses:
- if len(response) == 0:
- continue
- # response = assistant<|recipient|>xxxx\n<|content|>yyy
- recipient_index = response.find(self.recipient_token)
- content_index = response.find(self.content_token)
- recipient = response[
- recipient_index + len(self.recipient_token) : content_index
- ].strip()
- content = response[content_index + len(self.content_token) :].strip()
- # print(f"recipient: {recipient}, content={content}")
- if recipient == "all":
- text_response = content
- else:
- tool_calls.append(
- {
- "function": {"name": recipient, "arguments": content},
- "id": get_random_tool_call_id(),
- "type": "function",
- }
- )
-
- return {"role": "assistant", "content": text_response, "tool_calls": tool_calls}
-
- def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Dict]:
- """re-order the messages where role = tool to match the order in tool_calls by tool_call_id
- Args:
- messages (List[Dict]): list of messages containing: tool_call_id
-
- Returns:
- List[Dict]: _description_
- """
- result = []
- index = 0
- while index < len(messages):
- message = messages[index]
- tool_calls = message.get("tool_calls", None)
-
- result.append(message)
- if message["role"] == "assistant" and tool_calls:
- num_calls = len(tool_calls)
- if (
- tool_calls[0].get("id", None) is not None
- ): # if tool_call contains "id" for mapping
- tool_call_ids = [item["id"] for item in tool_calls]
-
- tool_messages = [messages[index + 1 + j] for j in range(num_calls)]
- id_2_tool_messages = {
- item["tool_call_id"]: item for item in tool_messages
- }
- new_messages = [id_2_tool_messages[cid] for cid in tool_call_ids]
-
- result.extend(new_messages)
- index += num_calls + 1
- else:
- index += 1
- else:
- index += 1
- return result
-
- def get_function_delta_response(
- self,
- current_state: Dict,
- delta_text: str,
- first_call: bool,
- return_role: bool,
- finish_reason: Optional[str],
- ) -> Dict:
- """Return delta for tool_call in streaming
-
- Args:
- current_state (Dict): _description_
- delta_text (str): _description_
- first_call (bool): _description_
- return_role (bool): _description_
- finish_reason (Optional[str]): _description_
-
- Returns:
- Dict: _description_
- """
- return {
- "delta": {
- "content": None,
- "function_call": None,
- "role": None if not return_role else "assistant",
- "tool_calls": [
- {
- "index": current_state["func_index"],
- "id": current_state["call_id"]
- if first_call
- else None, # only return call_id at the first time
- "function": {
- "arguments": delta_text,
- "name": current_state["func_name"] if first_call else None,
- },
- "type": "function" if first_call else None,
- }
- ],
- },
- "finish_reason": finish_reason,
- "index": 0,
- }
-
- def get_text_delta_response(
- self, delta_text: Optional[str], return_role: bool, finish_reason: Optional[str]
- ) -> Dict:
- """Return delta for text_response in streaming
-
- Args:
- delta_text (Optional[str]): _description_
- return_role (bool): _description_
- finish_reason (Optional[str]): _description_
-
- Returns:
- Dict: _description_
- """
- return {
- "delta": {
- "content": delta_text,
- "function_call": None,
- "role": None if not return_role else "assistant",
- "tool_calls": None,
- },
- "finish_reason": finish_reason,
- "index": 0,
- }
-
- def get_recipient(self, current_text: str) -> str:
- """Get recipient from the llm_output
-
- Args:
- current_text (str): _description_
-
- Returns:
- str: _description_
- """
- recipient_index = current_text.find(self.recipient_token)
- start_index = 0
- if recipient_index >= 0:
- start_index = recipient_index + len(self.recipient_token)
-
- end_index = current_text.find(f"\n{self.content_token}")
- return current_text[start_index:end_index].strip()
-
- def get_chat_template_jinja(self) -> str:
- chat_template = """{% for message in messages %}
- {% if message['role'] == 'user' or message['role'] == 'system' %}
- {{ '<|from|>' + message['role'] + '\n<|recipient|>all\n<|content|>' + message['content'] + '\n' }}
- {% elif message['role'] == 'tool' %}
- {{ '<|from|>' + message['name'] + '\n<|recipient|>all\n<|content|>' + message['content'] + '\n' }}
- {% else %}
- {% set contain_content='no'%}
- {% if message['content'] is not none %}
- {{ '<|from|>assistant\n<|recipient|>all\n<|content|>' + message['content'] }}
- {% set contain_content='yes'%}
- {% endif %}
- {% if 'tool_calls' in message and message['tool_calls'] is not none %}
- {% for tool_call in message['tool_calls'] %}
- {% set prompt='<|from|>assistant\n<|recipient|>' + tool_call['function']['name'] + '\n<|content|>' + tool_call['function']['arguments'] %}
- {% if loop.index == 1 and contain_content == "no" %}
- {{ prompt }}
- {% else %}
- {{ '\n' + prompt}}
- {% endif %}
- {% endfor %}
- {% endif %}
- {{ '<|stop|>\n' }}
- {% endif %}
- {% endfor %}
- {% if add_generation_prompt %}{{ '<|from|>assistant\n<|recipient|>' }}{% endif %}
- """
- chat_template = chat_template.replace(" ", "")
- chat_template = chat_template.replace("
\n", "")
- chat_template = chat_template.strip()
- return chat_template
-
- def update_response_state_from_delta_text(
- self,
- *,
- current_state: Dict[str, Any],
- delta_text: str,
- finish_reason: Optional[str],
- ) -> Tuple[Dict[str, Any], Union[None, Dict, List[Dict]]]:
- if len(current_state) == 0: # empty dict, at the first_time
- current_state = {
- "current_text": "", # the concatenation of all tokens so far
- "func_name": None, # function_name of the current tool, if the response requires to use tool
- "response_type": None, # response_type=text(text response)/function (using tool)
- "func_index": -1, # index of the tool in tool_calls
- "call_id": None, # call_id of the current tool
- # skip_until_reach we skip new tokens until we reach certain token. This is used when we hit special tokens
- "skip_until_reach": self.content_token, # at first we will skip until reach <|content|>
- "first_time": True, # if first_time we return an tempty delta with role=assistant
- }
- current_state["current_text"] += delta_text
-
- if finish_reason is not None:
- if current_state["response_type"] == "function":
- finish_reason = "tool_calls"
- return current_state, self.get_text_delta_response(
- None, False, finish_reason
- )
-
- skip_until_reach = current_state.get("skip_until_reach", "")
- if skip_until_reach:
- if delta_text != skip_until_reach:
- return current_state, None
- else:
- current_state["skip_until_reach"] = "" # once hit, no need to skip
- recipient = self.get_recipient(current_state["current_text"])
- first_time = current_state["first_time"]
- current_state["first_time"] = False
-
- if recipient == "all":
- current_state["response_type"] = "text"
- return current_state, self.get_text_delta_response(
- "", True, finish_reason
- )
- else:
- current_state["response_type"] = "function"
- current_state["func_name"] = recipient
- current_state["call_id"] = get_random_tool_call_id()
- current_state["func_index"] += 1
-
- responses = []
- if (
- first_time
- ): # first chunk of function_call is a message where all fields are None, except role
- responses.append(
- self.get_text_delta_response(None, True, finish_reason)
- )
- responses.append(
- self.get_function_delta_response(
- current_state, "", True, False, finish_reason
- )
- )
- return current_state, responses
- else:
- assert current_state["response_type"] is not None
- if (
- delta_text == self.from_token
- ): # skip until reach to check type of response
- current_state["current_text"] = ""
- current_state["skip_until_reach"] = self.content_token
- current_state["response_type"] = None
- return current_state, None
-
- else:
- if current_state["response_type"] == "function":
- return current_state, self.get_function_delta_response(
- current_state, delta_text, False, False, finish_reason
- )
- else: # response_type=text
- return current_state, self.get_text_delta_response(
- delta_text, True, finish_reason
- )
-
-
-def get_random_tool_call_id():
- return "call_" + "".join(
- [random.choice(string.ascii_letters + string.digits) for _ in range(24)]
- )
diff --git a/r2ai/functionary/schema.py b/r2ai/functionary/schema.py
deleted file mode 100644
index 4ab08a6..0000000
--- a/r2ai/functionary/schema.py
+++ /dev/null
@@ -1,459 +0,0 @@
-import pdb
-from copy import deepcopy
-from typing import Any, Dict, List, Optional
-
-import jsonref
-import requests
-import yaml
-
-from r2ai.functionary.openai_types import Function
-
-
-def convert_data_type(param_type: str) -> str:
- """convert data_type to typescript data type
-
- Args:
- param_type (str): param_type
-
- Returns:
- str: param type in typescript
- """
- if param_type == "integer" or param_type == "float":
- return "number"
- return param_type
-
-
-def get_param_type(param: Dict) -> str:
- """get param_type of parameter
-
- Args:
- param (Dict): param dict in properties
-
- Returns:
- str: _description_
- """
- param_type = "any"
- if "type" in param:
- raw_param_type = param["type"]
- if type(raw_param_type) is list:
- param_type = " | ".join(raw_param_type)
- else:
- param_type = raw_param_type
-
- else: # in many cases, the json schema contains: oneOf instead of "type"
- if "oneOf" in param:
- one_of_types = []
- for item in param["oneOf"]:
- if "type" in item:
- one_of_types.append(convert_data_type(item["type"]))
- one_of_types = list(set(one_of_types))
- param_type = " | ".join(one_of_types)
- return convert_data_type(param_type)
-
-
-def get_format_param(param: Dict) -> Optional[str]:
- """Get "format" from param. There are cases where format is not directly in param but in oneOf
-
- Args:
- param (Dict): _description_
-
- Returns:
- Optional[str]: _description_
- """
- if "format" in param:
- return param["format"]
- if "oneOf" in param:
- formats = []
- for item in param["oneOf"]:
- if "format" in item:
- formats.append(item["format"])
- if len(formats) > 0:
- return " or ".join(formats)
- return None
-
-
-def get_param_info(param: Dict) -> Optional[str]:
- """get additional information about parameter such as: format, default value, min, max, ...
-
- Args:
- param (Dict): _description_
-
- Returns:
- Optional[str]: _description_
- """
- param_type = param.get("type", "any")
- info_list = []
- if "description" in param:
- desc = param["description"]
- if not desc.endswith("."):
- desc += "."
- info_list.append(desc)
-
- if "default" in param:
- default_value = param["default"]
- if param_type == "string":
- default_value = f'"{default_value}"' # if string --> add ""
- info_list.append(f"Default={default_value}.")
-
- format_param = get_format_param(param)
- if format_param is not None:
- info_list.append("Format=" + format_param)
-
- for field, field_name in [
- ("maximum", "Maximum"),
- ("minimum", "Minimum"),
- ("maxLength", "Maximum length"),
- ("minLength", "Minimum length"),
- ]:
- if field in param:
- info_list.append(f"{field_name}=" + str(param[field]))
-
- if len(info_list) > 0:
- result = "// " + " ".join(info_list)
- result = result.replace("\n", " ")
- return result
- return None
-
-
-def append_new_param_info(
- info_list: List[str],
- param_declaration: str,
- comment_info: Optional[str],
- depth: int,
-):
- """Append a new parameter with comment to the info_list
-
- Args:
- info_lines (List[str]): current info_list
- param_declaration (str): param: type
- comment_info (Optional[str]): information of comment
- depth (int): level of nested param
- """
- offset = ""
- if depth >= 1:
- offset = "".join([" " for _ in range(depth)])
- if comment_info is not None:
- # if depth == 0: # format: //comment\nparam: type
- info_list.append(f"{offset}{comment_info}")
- info_list.append(f"{offset}{param_declaration}")
- # else: # format: param: type // comment
- # info_list.append(f"{offset}{param_declaration} {comment_info}")
- else:
- info_list.append(f"{offset}{param_declaration}")
-
-
-def get_enum_option_str(enum_options: List) -> str:
- """get enum option separated by: "|"
-
- Args:
- enum_options (List): list of options
-
- Returns:
- _type_: concatenation of options separated by "|"
- """
- # if each option is string --> add quote
- return " | ".join([f'"{v}"' if type(v) is str else str(v) for v in enum_options])
-
-
-def get_array_typescript(
- param_name: Optional[str], param_dic: dict, depth: int = 0
-) -> str:
- """recursive implementation for generating type script of array
-
- Args:
- param_name (Optional[str]): name of param, optional
- param_dic (dict): param_dic
- depth (int, optional): nested level. Defaults to 0.
-
- Returns:
- _type_: typescript of array
- """
- offset = ""
- if depth >= 1:
- offset = "".join([" " for _ in range(depth)])
- items_info = param_dic.get("items", {})
-
- if len(items_info) == 0:
- if param_name is not None:
- return f"{offset}{param_name}: []"
- else:
- return "[]"
- array_type = get_param_type(items_info)
- if array_type == "object":
- info_lines = []
- child_lines = get_parameter_typescript(
- items_info.get("properties", {}), items_info.get("required", []), depth + 1
- )
- # if comment_info is not None:
- # info_lines.append(f"{offset}{comment_info}")
- if param_name is not None:
- info_lines.append(f"{offset}{param_name}" + ": {")
- else:
- info_lines.append(f"{offset}" + "{")
- info_lines.extend(child_lines)
- info_lines.append(f"{offset}" + "}[]")
- return "\n".join(info_lines)
-
- elif array_type == "array":
- item_info = get_array_typescript(None, items_info, depth + 1)
- if param_name is None:
- return f"{item_info}[]"
- return f"{offset}{param_name}: {item_info.strip()}[]"
-
- else:
- if "enum" in items_info:
- item_type = get_enum_option_str(items_info["enum"])
- if param_name is None:
- return f"({item_type})[]"
- else:
- return f"{offset}{param_name}: ({item_type})[]"
- else:
- if param_name is None:
- return f"{array_type}[]"
- else:
- return f"{offset}{param_name}: {array_type}[],"
-
-
-def get_parameter_typescript(properties, required_params, depth=0) -> List[str]:
- """Recursion, returning the information about parameters including data type, description and other information
- These kinds of information will be put into the prompt
-
- Args:
- properties (_type_): properties in parameters
- required_params (_type_): List of required parameters
- depth (int, optional): the depth of params (nested level). Defaults to 0.
-
- Returns:
- _type_: list of lines containing information about all parameters
- """
- tp_lines = []
- for param_name, param in properties.items():
- # Sometimes properties have "required" field as a list of string.
- # Even though its supposed to be not under properties. So we skip it
- if not isinstance(param, dict):
- continue
- # Param Description
- comment_info = get_param_info(param)
- # Param Name declaration
- param_declaration = f"{param_name}"
- if isinstance(required_params, list):
- if param_name not in required_params:
- param_declaration += "?"
- param_type = get_param_type(param)
-
- offset = ""
- if depth >= 1:
- offset = "".join([" " for _ in range(depth)])
-
- if param_type == "object": # param_type is object
- child_lines = get_parameter_typescript(
- param.get("properties", {}), param.get("required", []), depth + 1
- )
- if comment_info is not None:
- tp_lines.append(f"{offset}{comment_info}")
-
- param_declaration += ": {"
- tp_lines.append(f"{offset}{param_declaration}")
- tp_lines.extend(child_lines)
- tp_lines.append(f"{offset}" + "},")
-
- elif param_type == "array": # param_type is an array
- item_info = param.get("items", {})
- if "type" not in item_info: # don't know type of array
- param_declaration += ": [],"
- append_new_param_info(tp_lines, param_declaration, comment_info, depth)
- else:
- array_declaration = get_array_typescript(
- param_declaration, param, depth
- )
- if not array_declaration.endswith(","):
- array_declaration += ","
- if comment_info is not None:
- tp_lines.append(f"{offset}{comment_info}")
- tp_lines.append(array_declaration)
- else:
- if "enum" in param:
- param_type = get_enum_option_str(param["enum"])
- # param_type = " | ".join([f'"{v}"' for v in param["enum"]])
- param_declaration += f": {param_type},"
- append_new_param_info(tp_lines, param_declaration, comment_info, depth)
-
- return tp_lines
-
-
-def generate_schema_from_functions(
- functions: List[Function], namespace="functions"
-) -> str:
- """
- Convert functions schema to a schema that language models can understand.
- """
-
- schema = "// Supported function definitions that should be called when necessary.\n"
- schema += f"namespace {namespace} {{\n\n"
-
- for function in functions:
- # Convert a Function object to dict, if necessary
- if not isinstance(function, dict):
- function = function.model_dump()
- function_name = function.get("name", None)
- if function_name is None:
- continue
-
- description = function.get("description", "")
- schema += f"// {description}\n"
- schema += f"type {function_name}"
-
- parameters = function.get("parameters", None)
- if parameters is not None and parameters.get("properties") is not None:
- parameters = deepcopy(jsonref.JsonRef.replace_refs(parameters))
- schema += " = (_: {\n"
- required_params = parameters.get("required", [])
- tp_lines = get_parameter_typescript(
- parameters.get("properties"), required_params, 0
- )
- schema += "\n".join(tp_lines)
- schema += "\n}) => any;\n\n"
- else:
- # Doesn't have any parameters
- schema += " = () => any;\n\n"
-
- schema += f"}} // namespace {namespace}"
-
- return schema
-
-
-def generate_schema_from_openapi(
- specification: Dict[str, Any], description: str, namespace: str
-) -> str:
- """
- Convert OpenAPI specification object to a schema that language models can understand.
-
- Input:
- specification: can be obtained by json.loads of any OpanAPI json spec, or yaml.safe_load for yaml OpenAPI specs
-
- Example output:
-
- // General Description
- namespace functions {
-
- // Simple GET endpoint
- type getEndpoint = (_: {
- // This is a string parameter
- param_string: string,
- param_integer: number,
- param_boolean?: boolean,
- param_enum: "value1" | "value2" | "value3",
- }) => any;
-
- } // namespace functions
- """
-
- description_clean = description.replace("\n", "")
-
- schema = f"// {description_clean}\n"
- schema += f"namespace {namespace} {{\n\n"
-
- for path_name, paths in specification.get("paths", {}).items():
- for method_name, method_info in paths.items():
- operationId = method_info.get("operationId", None)
- if operationId is None:
- continue
- description = method_info.get("description", method_info.get("summary", ""))
- schema += f"// {description}\n"
- schema += f"type {operationId}"
-
- if ("requestBody" in method_info) or (
- method_info.get("parameters") is not None
- ):
- schema += f" = (_: {{\n"
- # Body
- if "requestBody" in method_info:
- try:
- body_schema = (
- method_info.get("requestBody", {})
- .get("content", {})
- .get("application/json", {})
- .get("schema", {})
- )
- except AttributeError:
- body_schema = {}
- for param_name, param in body_schema.get("properties", {}).items():
- # Param Description
- description = param.get("description")
- if description is not None:
- schema += f"// {description}\n"
-
- # Param Name
- schema += f"{param_name}"
- if (
- (not param.get("required", False))
- or (param.get("nullable", False))
- or (param_name in body_schema.get("required", []))
- ):
- schema += "?"
-
- # Param Type
- param_type = param.get("type", "any")
- if param_type == "integer":
- param_type = "number"
- if "enum" in param:
- param_type = " | ".join([f'"{v}"' for v in param["enum"]])
- schema += f": {param_type},\n"
-
- # URL
- for param in method_info.get("parameters", []):
- # Param Description
- if description := param.get("description"):
- schema += f"// {description}\n"
-
- # Param Name
- schema += f"{param['name']}"
- if (not param.get("required", False)) or (
- param.get("nullable", False)
- ):
- schema += "?"
- if param.get("schema") is None:
- continue
- # Param Type
- param_type = param["schema"].get("type", "any")
- if param_type == "integer":
- param_type = "number"
- if "enum" in param["schema"]:
- param_type = " | ".join(
- [f'"{v}"' for v in param["schema"]["enum"]]
- )
- schema += f": {param_type},\n"
-
- schema += f"}}) => any;\n\n"
- else:
- # Doesn't have any parameters
- schema += f" = () => any;\n\n"
-
- schema += f"}} // namespace {namespace}"
-
- return schema
-
-
-def generate_specification_from_openapi_url(
- openapi_url: str, proxies: dict = None
-) -> str:
- # Make Request
- headers = {"Accept": "application/x-yaml, text/yaml, text/x-yaml, application/json"}
- response = requests.get(
- openapi_url, verify=False, headers=headers, timeout=60, proxies=proxies
- )
-
- if response.status_code == 200:
- # Trust content-type first
- if response.headers.get("Content-Type") is not None:
- if "application/json" in response.headers.get("Content-Type"):
- specification = response.json()
- else:
- specification = yaml.safe_load(response.text)
- elif response.url.endswith(".json"):
- specification = response.json()
- else:
- specification = yaml.safe_load(response.text)
- # Resolve references
- specification = deepcopy(jsonref.JsonRef.replace_refs(specification))
- return specification
diff --git a/r2ai/interpreter.py b/r2ai/interpreter.py
index 8fa9f96..44d283e 100644
--- a/r2ai/interpreter.py
+++ b/r2ai/interpreter.py
@@ -857,7 +857,11 @@ def respond(self):
# builtins.print(prompt)
response = None
if self.auto_run:
- response = auto.chat(self)
+ if(is_litellm_model(self.model)):
+ response = auto.chat(self)
+ else:
+ self.llama_instance = new_get_hf_llm(self, self.model, False, int(self.env["llm.window"]))
+ response = auto.chat(self, llama_instance=self.llama_instance)
return
elif self.model.startswith("kobaldcpp"):