Skip to content

Commit

Permalink
Merge branch 'main' of github.com:brainlid/langchain
Browse files Browse the repository at this point in the history
* 'main' of github.com:brainlid/langchain:
  Add AWS Bedrock support to ChatAnthropic (#154)
  Handle functions with no parameters for Google AI (#183)
  Handle missing token usage fields for Google AI (#184)
  Handle empty text parts from GoogleAI responses (#181)
  Support system instructions for Google AI (#182)
  feat: add OpenAI's new structured output API (#180)
  Support strict mode for tools (#173)
  Do not duplicate tool call parameters if they are identical (#174)
  🐛 cast tool_calls arguments correctly inside message_deltas (#175)
  • Loading branch information
brainlid committed Oct 28, 2024
2 parents 6e9f1b2 + 1981fbd commit fa17272
Show file tree
Hide file tree
Showing 26 changed files with 1,488 additions and 345 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/elixir.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ env:
OPENAI_API_KEY: invalid
ANTHROPIC_API_KEY: invalid
GOOGLE_API_KEY: invalid
AWS_ACCESS_KEY_ID: invalid
AWS_SECRET_ACCESS_KEY: invalid

permissions:
contents: read
Expand Down
111 changes: 79 additions & 32 deletions lib/chat_models/chat_anthropic.ex
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ defmodule LangChain.ChatModels.ChatAnthropic do
alias LangChain.FunctionParam
alias LangChain.Utils
alias LangChain.Callbacks
alias LangChain.Utils.BedrockStreamDecoder
alias LangChain.Utils.BedrockConfig

@behaviour ChatModel

Expand All @@ -67,6 +69,9 @@ defmodule LangChain.ChatModels.ChatAnthropic do
# API endpoint to use. Defaults to Anthropic's API
field :endpoint, :string, default: "https://api.anthropic.com/v1/messages"

# Configuration for AWS Bedrock. Configure this instead of endpoint & api_key if you want to use Bedrock.
embeds_one :bedrock, BedrockConfig

# API key for Anthropic. If not set, will use global api key. Allows for usage
# of a different API key per-call if desired. For instance, allowing a
# customer to provide their own.
Expand Down Expand Up @@ -131,19 +136,14 @@ defmodule LangChain.ChatModels.ChatAnthropic do
]
@required_fields [:endpoint, :model]

@spec get_api_key(t()) :: String.t()
defp get_api_key(%ChatAnthropic{api_key: api_key}) do
# if no API key is set default to `""` which will raise an error
api_key || Config.resolve(:anthropic_key, "")
end

@doc """
Setup a ChatAnthropic client configuration.
"""
@spec new(attrs :: map()) :: {:ok, t} | {:error, Ecto.Changeset.t()}
def new(%{} = attrs \\ %{}) do
%ChatAnthropic{}
|> cast(attrs, @create_fields)
|> cast_embed(:bedrock)
|> common_validation()
|> apply_action(:insert)
end
Expand Down Expand Up @@ -175,7 +175,8 @@ defmodule LangChain.ChatModels.ChatAnthropic do
@spec for_api(t, message :: [map()], ChatModel.tools()) :: %{atom() => any()}
def for_api(%ChatAnthropic{} = anthropic, messages, tools) do
# separate the system message from the rest. Handled separately.
{system, messages} = split_system_message(messages)
{system, messages} =
Utils.split_system_message(messages, "Anthropic only supports a single System message")

system_text =
case system do
Expand Down Expand Up @@ -203,6 +204,15 @@ defmodule LangChain.ChatModels.ChatAnthropic do
|> Utils.conditionally_add_to_map(:max_tokens, anthropic.max_tokens)
|> Utils.conditionally_add_to_map(:top_p, anthropic.top_p)
|> Utils.conditionally_add_to_map(:top_k, anthropic.top_k)
|> maybe_transform_for_bedrock(anthropic.bedrock)
end

defp maybe_transform_for_bedrock(body, nil), do: body

defp maybe_transform_for_bedrock(body, %BedrockConfig{} = bedrock) do
body
|> Map.put(:anthropic_version, bedrock.anthropic_version)
|> Map.drop([:model, :stream])
end

defp get_tools_for_api(nil), do: []
Expand All @@ -214,21 +224,6 @@ defmodule LangChain.ChatModels.ChatAnthropic do
end)
end

# Unlike OpenAI, Anthropic only supports one system message.
@doc false
@spec split_system_message([Message.t()]) :: {nil | Message.t(), [Message.t()]} | no_return()
def split_system_message(messages) do
# split the messages into "system" and "other". Error if more than 1 system
# message. Return the other messages as a separate list.
{system, other} = Enum.split_with(messages, &(&1.role == :system))

if length(system) > 1 do
raise LangChainError, "Anthropic only supports a single System message"
end

{List.first(system), other}
end

@doc """
Calls the Anthropic API passing the ChatAnthropic struct with configuration, plus
either a simple message or the list of messages to act as the prompt.
Expand Down Expand Up @@ -301,13 +296,14 @@ defmodule LangChain.ChatModels.ChatAnthropic do
) do
req =
Req.new(
url: anthropic.endpoint,
url: url(anthropic),
json: for_api(anthropic, messages, tools),
headers: headers(get_api_key(anthropic), anthropic.api_version),
headers: headers(anthropic),
receive_timeout: anthropic.receive_timeout,
retry: :transient,
max_retries: 3,
retry_delay: fn attempt -> 300 * attempt end
retry_delay: fn attempt -> 300 * attempt end,
aws_sigv4: aws_sigv4_opts(anthropic.bedrock)
)

req
Expand Down Expand Up @@ -355,14 +351,19 @@ defmodule LangChain.ChatModels.ChatAnthropic do
retry_count
) do
Req.new(
url: anthropic.endpoint,
url: url(anthropic),
json: for_api(anthropic, messages, tools),
headers: headers(get_api_key(anthropic), anthropic.api_version),
receive_timeout: anthropic.receive_timeout
headers: headers(anthropic),
receive_timeout: anthropic.receive_timeout,
aws_sigv4: aws_sigv4_opts(anthropic.bedrock)
)
|> Req.post(
into:
Utils.handle_stream_fn(anthropic, &decode_stream/1, &do_process_response(anthropic, &1))
Utils.handle_stream_fn(
anthropic,
&decode_stream(anthropic, &1),
&do_process_response(anthropic, &1)
)
)
|> case do
{:ok, %Req.Response{body: data} = response} ->
Expand Down Expand Up @@ -393,16 +394,40 @@ defmodule LangChain.ChatModels.ChatAnthropic do
end
end

defp headers(api_key, api_version) do
defp aws_sigv4_opts(nil), do: nil
defp aws_sigv4_opts(%BedrockConfig{} = bedrock), do: BedrockConfig.aws_sigv4_opts(bedrock)

@spec get_api_key(binary() | nil) :: String.t()
defp get_api_key(api_key) do
# if no API key is set default to `""` which will raise an error
api_key || Config.resolve(:anthropic_key, "")
end

defp headers(%ChatAnthropic{bedrock: nil, api_key: api_key, api_version: api_version}) do
%{
"x-api-key" => api_key,
"x-api-key" => get_api_key(api_key),
"content-type" => "application/json",
"anthropic-version" => api_version,
# https://docs.anthropic.com/claude/docs/tool-use - requires this header during beta
"anthropic-beta" => "tools-2024-04-04"
}
end

defp headers(%ChatAnthropic{bedrock: %BedrockConfig{}}) do
%{
"content-type" => "application/json",
"accept" => "application/json"
}
end

defp url(%ChatAnthropic{bedrock: nil} = anthropic) do
anthropic.endpoint
end

defp url(%ChatAnthropic{bedrock: %BedrockConfig{} = bedrock, stream: stream} = anthropic) do
BedrockConfig.url(bedrock, model: anthropic.model, stream: stream)
end

# Parse a new message response
@doc false
@spec do_process_response(t(), data :: %{String.t() => any()} | {:error, any()}) ::
Expand Down Expand Up @@ -527,6 +552,16 @@ defmodule LangChain.ChatModels.ChatAnthropic do
{:error, error_message}
end

def do_process_response(%ChatAnthropic{bedrock: %BedrockConfig{}}, %{"message" => message}) do
{:error, "Received error from API: #{message}"}
end

def do_process_response(%ChatAnthropic{bedrock: %BedrockConfig{}}, %{
bedrock_exception: exceptions
}) do
{:error, "Stream exception received: #{inspect(exceptions)}"}
end

def do_process_response(_model, other) do
Logger.error("Trying to process an unexpected response. #{inspect(other)}")
{:error, "Unexpected response"}
Expand Down Expand Up @@ -597,7 +632,7 @@ defmodule LangChain.ChatModels.ChatAnthropic do
end

@doc false
def decode_stream({chunk, buffer}) do
def decode_stream(%ChatAnthropic{bedrock: nil}, {chunk, buffer}) do
# Combine the incoming data with the buffered incomplete data
combined_data = buffer <> chunk
# Split data by double newline to find complete messages
Expand Down Expand Up @@ -665,6 +700,18 @@ defmodule LangChain.ChatModels.ChatAnthropic do
# assumed the response is JSON. Return as-is
defp extract_data(json), do: json

@doc false
def decode_stream(%ChatAnthropic{bedrock: %BedrockConfig{}}, {chunk, buffer}, chunks \\ []) do
{chunks, remaining} = BedrockStreamDecoder.decode_stream({chunk, buffer}, chunks)

chunks =
Enum.filter(chunks, fn chunk ->
Map.has_key?(chunk, :bedrock_exception) || relevant_event?("event: #{chunk["type"]}\n")
end)

{chunks, remaining}
end

@doc """
Convert a LangChain structure to the expected map of data for the OpenAI API.
"""
Expand Down
73 changes: 48 additions & 25 deletions lib/chat_models/chat_google_ai.ex
Original file line number Diff line number Diff line change
Expand Up @@ -138,28 +138,42 @@ defmodule LangChain.ChatModels.ChatGoogleAI do
end

def for_api(%ChatGoogleAI{} = google_ai, messages, functions) do
{system, messages} =
Utils.split_system_message(messages, "Google AI only supports a single System message")

system_instruction =
case system do
nil ->
nil

%Message{role: :system, content: content} ->
%{"parts" => [%{"text" => content}]}
end

messages_for_api =
messages
|> Enum.map(&for_api/1)
|> List.flatten()
|> List.wrap()

req = %{
"contents" => messages_for_api,
"generationConfig" => %{
"temperature" => google_ai.temperature,
"topP" => google_ai.top_p,
"topK" => google_ai.top_k
req =
%{
"contents" => messages_for_api,
"generationConfig" => %{
"temperature" => google_ai.temperature,
"topP" => google_ai.top_p,
"topK" => google_ai.top_k
}
}
}
|> LangChain.Utils.conditionally_add_to_map("system_instruction", system_instruction)

if functions && not Enum.empty?(functions) do
req
|> Map.put("tools", [
%{
# Google AI functions use an OpenAI compatible format.
# See: https://ai.google.dev/docs/function_calling#how_it_works
"functionDeclarations" => Enum.map(functions, &ChatOpenAI.for_api/1)
"functionDeclarations" => Enum.map(functions, &for_api/1)
}
])
else
Expand Down Expand Up @@ -188,21 +202,6 @@ defmodule LangChain.ChatModels.ChatGoogleAI do
}
end

def for_api(%Message{role: :system} = message) do
# No system messages support means we need to fake a prompt and response
# to pretend like it worked.
[
%{
"role" => :user,
"parts" => [%{"text" => message.content}]
},
%{
"role" => :model,
"parts" => [%{"text" => ""}]
}
]
end

def for_api(%Message{} = message) do
%{
"role" => map_role(message.role),
Expand Down Expand Up @@ -249,6 +248,18 @@ defmodule LangChain.ChatModels.ChatGoogleAI do
}
end

def for_api(%Function{} = function) do
encoded = ChatOpenAI.for_api(function)

# For functions with no parameters, Google AI needs the parameters field removing, otherwise it will error
# with "* GenerateContentRequest.tools[0].function_declarations[0].parameters.properties: should be non-empty for OBJECT type\n"
if encoded["parameters"] == %{"properties" => %{}, "type" => "object"} do
Map.delete(encoded, "parameters")
else
encoded
end
end

@doc """
Calls the Google AI API passing the ChatGoogleAI struct with configuration, plus
either a simple message or the list of messages to act as the prompt.
Expand Down Expand Up @@ -426,6 +437,7 @@ defmodule LangChain.ChatModels.ChatGoogleAI do
text_part =
parts
|> filter_parts_for_types(["text"])
|> filter_text_parts()
|> Enum.map(fn part ->
ContentPart.new!(%{type: :text, content: part["text"]})
end)
Expand Down Expand Up @@ -479,6 +491,7 @@ defmodule LangChain.ChatModels.ChatGoogleAI do

parts
|> filter_parts_for_types(["text"])
|> filter_text_parts()
|> Enum.map(fn part ->
ContentPart.new!(%{type: :text, content: part["text"]})
end)
Expand Down Expand Up @@ -597,6 +610,16 @@ defmodule LangChain.ChatModels.ChatGoogleAI do
end)
end

@doc false
def filter_text_parts(parts) when is_list(parts) do
Enum.filter(parts, fn p ->
case p do
%{"text" => text} -> text && text != ""
_ -> false
end
end)
end

@doc """
Return the content parts for the message.
"""
Expand Down Expand Up @@ -660,8 +683,8 @@ defmodule LangChain.ChatModels.ChatGoogleAI do
defp get_token_usage(%{"usageMetadata" => usage} = _response_body) do
# extract out the reported response token usage
TokenUsage.new!(%{
input: Map.get(usage, "promptTokenCount"),
output: Map.get(usage, "candidatesTokenCount")
input: Map.get(usage, "promptTokenCount", 0),
output: Map.get(usage, "candidatesTokenCount", 0)
})
end

Expand Down
Loading

0 comments on commit fa17272

Please sign in to comment.