Skip to content

Commit

Permalink
Merge pull request #4536 from BerriAI/litellm_anthropic_tool_calling_…
Browse files Browse the repository at this point in the history
…streaming_fix

*real* Anthropic tool calling + streaming support
  • Loading branch information
krrishdholakia authored Jul 4, 2024
2 parents e712277 + 5e47970 commit 17869fc
Show file tree
Hide file tree
Showing 9 changed files with 606 additions and 419 deletions.
480 changes: 325 additions & 155 deletions litellm/llms/anthropic.py

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion litellm/llms/prompt_templates/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -1283,7 +1283,9 @@ def anthropic_messages_pt(messages: list):
)
else:
raise Exception(
"Invalid first message. Should always start with 'role'='user' for Anthropic. System prompt is sent separately for Anthropic. set 'litellm.modify_params = True' or 'litellm_settings:modify_params = True' on proxy, to insert a placeholder user message - '.' as the first message, "
"Invalid first message={}. Should always start with 'role'='user' for Anthropic. System prompt is sent separately for Anthropic. set 'litellm.modify_params = True' or 'litellm_settings:modify_params = True' on proxy, to insert a placeholder user message - '.' as the first message, ".format(
new_messages
)
)

if new_messages[-1]["role"] == "assistant":
Expand Down
208 changes: 31 additions & 177 deletions litellm/llms/vertex_ai_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,190 +235,44 @@ def completion(
if k not in optional_params:
optional_params[k] = v

## Format Prompt
_is_function_call = False
_is_json_schema = False
messages = copy.deepcopy(messages)
optional_params = copy.deepcopy(optional_params)
# Separate system prompt from rest of message
system_prompt_indices = []
system_prompt = ""
for idx, message in enumerate(messages):
if message["role"] == "system":
system_prompt += message["content"]
system_prompt_indices.append(idx)
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
if len(system_prompt) > 0:
optional_params["system"] = system_prompt
# Checks for 'response_schema' support - if passed in
if "response_format" in optional_params:
response_format_chunk = ResponseFormatChunk(
**optional_params["response_format"] # type: ignore
)
supports_response_schema = litellm.supports_response_schema(
model=model, custom_llm_provider="vertex_ai"
)
if (
supports_response_schema is False
and response_format_chunk["type"] == "json_object"
and "response_schema" in response_format_chunk
):
_is_json_schema = True
user_response_schema_message = response_schema_prompt(
model=model,
response_schema=response_format_chunk["response_schema"],
)
messages.append(
{"role": "user", "content": user_response_schema_message}
)
messages.append({"role": "assistant", "content": "{"})
optional_params.pop("response_format")
# Format rest of message according to anthropic guidelines
try:
messages = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic_xml"
)
except Exception as e:
raise VertexAIError(status_code=400, message=str(e))

## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
tool_calling_system_prompt = construct_tool_use_system_prompt(
tools=optional_params["tools"]
)
optional_params["system"] = (
optional_params.get("system", "\n") + tool_calling_system_prompt
) # add the anthropic tool calling prompt to the system prompt
optional_params.pop("tools")

stream = optional_params.pop("stream", None)

data = {
"model": model,
"messages": messages,
**optional_params,
}
print_verbose(f"_is_function_call: {_is_function_call}")

## Completion Call

print_verbose(
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}; vertex_credentials={vertex_credentials}"
## CONSTRUCT API BASE
stream = optional_params.get("stream", False)

api_base = create_vertex_anthropic_url(
vertex_location=vertex_location or "us-central1",
vertex_project=vertex_project or project_id,
model=model,
stream=stream,
)

if acompletion == True:
"""
- async streaming
- async completion
"""
if stream is not None and stream == True:
return async_streaming(
model=model,
messages=messages,
data=data,
print_verbose=print_verbose,
model_response=model_response,
logging_obj=logging_obj,
vertex_project=vertex_project,
vertex_location=vertex_location,
optional_params=optional_params,
client=client,
access_token=access_token,
)
else:
return async_completion(
model=model,
messages=messages,
data=data,
print_verbose=print_verbose,
model_response=model_response,
logging_obj=logging_obj,
vertex_project=vertex_project,
vertex_location=vertex_location,
optional_params=optional_params,
client=client,
access_token=access_token,
)
if stream is not None and stream == True:
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
},
)
response = vertex_ai_client.messages.create(**data, stream=True) # type: ignore
return response

## LOGGING
logging_obj.pre_call(
input=messages,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
},
)
if headers is not None:
vertex_headers = headers
else:
vertex_headers = {}

message = vertex_ai_client.messages.create(**data) # type: ignore
vertex_headers.update({"Authorization": "Bearer {}".format(access_token)})

## LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=message,
additional_args={"complete_input_dict": data},
optional_params.update(
{"anthropic_version": "vertex-2023-10-16", "is_vertex_request": True}
)

text_content: str = message.content[0].text
## TOOL CALLING - OUTPUT PARSE
if text_content is not None and contains_tag("invoke", text_content):
function_name = extract_between_tags("tool_name", text_content)[0]
function_arguments_str = extract_between_tags("invoke", text_content)[
0
].strip()
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
function_arguments = parse_xml_params(function_arguments_str)
_message = litellm.Message(
tool_calls=[
{
"id": f"call_{uuid.uuid4()}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(function_arguments),
},
}
],
content=None,
)
model_response.choices[0].message = _message # type: ignore
else:
if (
_is_json_schema
): # follows https://github.com/anthropics/anthropic-cookbook/blob/main/misc/how_to_enable_json_mode.ipynb
json_response = "{" + text_content[: text_content.rfind("}") + 1]
model_response.choices[0].message.content = json_response # type: ignore
else:
model_response.choices[0].message.content = text_content # type: ignore
model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason)

## CALCULATING USAGE
prompt_tokens = message.usage.input_tokens
completion_tokens = message.usage.output_tokens

model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
return anthropic_chat_completions.completion(
model=model,
messages=messages,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=access_token,
logging_obj=logging_obj,
optional_params=optional_params,
acompletion=acompletion,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=vertex_headers,
)
setattr(model_response, "usage", usage)
return model_response

except Exception as e:
raise VertexAIError(status_code=500, message=str(e))

Expand Down
37 changes: 23 additions & 14 deletions litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2026,18 +2026,18 @@ def completion(
acompletion=acompletion,
)

if (
"stream" in optional_params
and optional_params["stream"] == True
and acompletion == False
):
response = CustomStreamWrapper(
model_response,
model,
custom_llm_provider="vertex_ai",
logging_obj=logging,
)
return response
if (
"stream" in optional_params
and optional_params["stream"] == True
and acompletion == False
):
response = CustomStreamWrapper(
model_response,
model,
custom_llm_provider="vertex_ai",
logging_obj=logging,
)
return response
response = model_response
elif custom_llm_provider == "predibase":
tenant_id = (
Expand Down Expand Up @@ -4944,14 +4944,23 @@ def stream_chunk_builder(
else:
completion_output = ""
# # Update usage information if needed
prompt_tokens = 0
completion_tokens = 0
for chunk in chunks:
if "usage" in chunk:
if "prompt_tokens" in chunk["usage"]:
prompt_tokens += chunk["usage"].get("prompt_tokens", 0) or 0
if "completion_tokens" in chunk["usage"]:
completion_tokens += chunk["usage"].get("completion_tokens", 0) or 0

try:
response["usage"]["prompt_tokens"] = token_counter(
response["usage"]["prompt_tokens"] = prompt_tokens or token_counter(
model=model, messages=messages
)
except: # don't allow this failing to block a complete streaming response from being returned
print_verbose(f"token_counter failed, assuming prompt tokens is 0")
response["usage"]["prompt_tokens"] = 0
response["usage"]["completion_tokens"] = token_counter(
response["usage"]["completion_tokens"] = completion_tokens or token_counter(
model=model,
text=completion_output,
count_response_tokens=True, # count_response_tokens is a Flag to tell token counter this is a response, No need to add extra tokens we do for input messages
Expand Down
3 changes: 1 addition & 2 deletions litellm/proxy/_super_secret_config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
model_list:
- model_name: claude-3-5-sonnet
litellm_params:
model: anthropic/claude-3-5-sonnet
model: claude-3-haiku-20240307
# - model_name: gemini-1.5-flash-gemini
# litellm_params:
# model: vertex_ai_beta/gemini-1.5-flash
Expand All @@ -18,7 +18,6 @@ model_list:
- model_name: fake-openai-endpoint
litellm_params:
model: predibase/llama-3-8b-instruct
api_base: "http://0.0.0.0:8081"
api_key: os.environ/PREDIBASE_API_KEY
tenant_id: os.environ/PREDIBASE_TENANT_ID
max_new_tokens: 256
Expand Down
9 changes: 6 additions & 3 deletions litellm/tests/test_amazing_vertex_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def test_vertex_ai_anthropic():
# )
def test_vertex_ai_anthropic_streaming():
try:
# load_vertex_ai_credentials()
load_vertex_ai_credentials()

# litellm.set_verbose = True

Expand All @@ -223,8 +223,9 @@ def test_vertex_ai_anthropic_streaming():
stream=True,
)
# print("\nModel Response", response)
for chunk in response:
for idx, chunk in enumerate(response):
print(f"chunk: {chunk}")
streaming_format_tests(idx=idx, chunk=chunk)

# raise Exception("it worked!")
except litellm.RateLimitError as e:
Expand Down Expand Up @@ -294,8 +295,10 @@ async def test_vertex_ai_anthropic_async_streaming():
stream=True,
)

idx = 0
async for chunk in response:
print(f"chunk: {chunk}")
streaming_format_tests(idx=idx, chunk=chunk)
idx += 1
except litellm.RateLimitError as e:
pass
except Exception as e:
Expand Down
Loading

0 comments on commit 17869fc

Please sign in to comment.