diff --git a/lib/grpc/client/adapter.ex b/lib/grpc/client/adapter.ex index 6cec2b6f..f623be62 100644 --- a/lib/grpc/client/adapter.ex +++ b/lib/grpc/client/adapter.ex @@ -9,15 +9,21 @@ defmodule GRPC.Client.Adapter do @typedoc "Determines if the headers have finished being read." @type fin :: :fin | :nofin - @callback connect(Channel.t(), map()) :: {:ok, Channel.t()} | {:error, any()} + @callback connect(channel :: Channel.t(), opts :: keyword()) :: + {:ok, Channel.t()} | {:error, any()} - @callback disconnect(Channel.t()) :: {:ok, Channel.t()} | {:error, any()} + @callback disconnect(channel :: Channel.t()) :: {:ok, Channel.t()} | {:error, any()} - @callback send_request(Stream.t(), binary(), map()) :: Stream.t() + @callback send_request(stream :: Stream.t(), contents :: binary(), opts :: keyword()) :: + Stream.t() - @callback recv_headers(map(), map(), map()) :: + @callback recv_headers(stream :: map(), headers :: map(), opts :: keyword()) :: {:ok, %{String.t() => String.t()}, fin()} | {:error, GRPC.RPCError.t()} - @callback recv_data_or_trailers(map(), map(), map()) :: + @callback recv_data_or_trailers( + stream :: map(), + trailers_or_metadata :: map(), + opts :: keyword() + ) :: {:data, binary()} | {:trailers, binary()} | {:error, GRPC.RPCError.t()} end diff --git a/lib/grpc/client/adapters/gun.ex b/lib/grpc/client/adapters/gun.ex index 1dd6de38..088abeed 100644 --- a/lib/grpc/client/adapters/gun.ex +++ b/lib/grpc/client/adapters/gun.ex @@ -11,9 +11,15 @@ defmodule GRPC.Client.Adapters.Gun do @max_retries 100 @impl true - def connect(channel, opts \\ %{}) - def connect(%{scheme: "https"} = channel, opts), do: connect_securely(channel, opts) - def connect(channel, opts), do: connect_insecurely(channel, opts) + def connect(channel, opts) when is_list(opts) do + # handle opts as a map due to :gun.open + opts = Map.new(opts) + + case channel do + %{scheme: "https"} -> connect_securely(channel, opts) + _ -> connect_insecurely(channel, opts) + end + end defp connect_securely(%{cred: %{ssl: ssl}} = channel, opts) do transport_opts = Map.get(opts, :transport_opts) || [] diff --git a/lib/grpc/message.ex b/lib/grpc/message.ex index 2ba3b227..0f4d782b 100644 --- a/lib/grpc/message.ex +++ b/lib/grpc/message.ex @@ -39,7 +39,7 @@ defmodule GRPC.Message do {:error, "Encoded message is too large (9 bytes)"} """ - @spec to_data(iodata, keyword() | map()) :: + @spec to_data(iodata, keyword()) :: {:ok, iodata, non_neg_integer} | {:error, String.t()} def to_data(message, opts \\ []) do compressor = opts[:compressor] diff --git a/lib/grpc/server/adapter.ex b/lib/grpc/server/adapter.ex index f2f30093..9884ac9f 100644 --- a/lib/grpc/server/adapter.ex +++ b/lib/grpc/server/adapter.ex @@ -11,12 +11,17 @@ defmodule GRPC.Server.Adapter do pending_reader: nil } - @callback start(atom(), %{String.t() => [module()]}, non_neg_integer(), Keyword.t()) :: + @callback start( + atom(), + %{String.t() => [module()]}, + port :: non_neg_integer(), + opts :: keyword() + ) :: {atom(), any(), non_neg_integer()} @callback stop(atom(), %{String.t() => [module()]}) :: :ok | {:error, :not_found} - @callback send_reply(state(), binary(), Keyword.t()) :: any() + @callback send_reply(state, content :: binary(), opts :: keyword()) :: any() - @callback send_headers(state(), map()) :: any() + @callback send_headers(state, headers :: map()) :: any() end diff --git a/lib/grpc/server/adapters/cowboy/handler.ex b/lib/grpc/server/adapters/cowboy/handler.ex index 433b6fda..ec87568c 100644 --- a/lib/grpc/server/adapters/cowboy/handler.ex +++ b/lib/grpc/server/adapters/cowboy/handler.ex @@ -11,7 +11,10 @@ defmodule GRPC.Server.Adapters.Cowboy.Handler do @adapter GRPC.Server.Adapters.Cowboy @default_trailers HTTP2.server_trailers() - @spec init(map(), {atom(), %{String.t() => [module()]}, map()}) :: {:cowboy_loop, map(), map()} + @spec init( + map(), + state :: {endpoint :: atom(), servers :: %{String.t() => [module()]}, opts :: keyword()} + ) :: {:cowboy_loop, map(), map()} def init(req, {endpoint, servers, opts} = state) do path = :cowboy_req.path(req) @@ -472,8 +475,8 @@ defmodule GRPC.Server.Adapters.Cowboy.Handler do end defp async_read_body(req, opts) do - length = Map.get(opts, :length, 8_000_000) - period = Map.get(opts, :period, 15000) + length = opts[:length] || 8_000_000 + period = opts[:period] || 15000 ref = make_ref() :cowboy_req.cast({:read_body, self(), ref, length, period}, req) diff --git a/lib/grpc/stub.ex b/lib/grpc/stub.ex index eb5c7f09..848c39c0 100644 --- a/lib/grpc/stub.ex +++ b/lib/grpc/stub.ex @@ -125,7 +125,7 @@ defmodule GRPC.Stub do * `:accepted_compressors` - tell servers accepted compressors, this can be used without `:compressor` * `:headers` - headers to attach to each request """ - @spec connect(String.t(), Keyword.t()) :: {:ok, Channel.t()} | {:error, any()} + @spec connect(String.t(), keyword()) :: {:ok, Channel.t()} | {:error, any()} def connect(addr, opts \\ []) when is_binary(addr) and is_list(opts) do {host, port} = case String.split(addr, ":") do @@ -136,7 +136,7 @@ defmodule GRPC.Stub do connect(host, port, opts) end - @spec connect(String.t(), binary() | non_neg_integer(), Keyword.t()) :: + @spec connect(String.t(), binary() | non_neg_integer(), keyword()) :: {:ok, Channel.t()} | {:error, any()} def connect(host, port, opts) when is_binary(port) do connect(host, String.to_integer(port), opts) @@ -162,6 +162,12 @@ defmodule GRPC.Stub do accepted_compressors end + adapter_opts = opts[:adapter_opts] || [] + + unless is_list(adapter_opts) do + raise ArgumentError, ":adapter_opts must be a keyword list if present" + end + %Channel{ host: host, port: port, @@ -174,7 +180,7 @@ defmodule GRPC.Stub do accepted_compressors: accepted_compressors, headers: headers } - |> adapter.connect(opts[:adapter_opts] || %{}) + |> adapter.connect(adapter_opts) end def retry_timeout(curr) when curr < 11 do @@ -231,7 +237,7 @@ defmodule GRPC.Stub do with the last elem being a map of headers `%{headers: headers, trailers: trailers}`(unary) or `%{headers: headers}`(server streaming) """ - @spec call(atom(), tuple(), GRPC.Client.Stream.t(), struct() | nil, Keyword.t()) :: rpc_return + @spec call(atom(), tuple(), GRPC.Client.Stream.t(), struct() | nil, keyword()) :: rpc_return def call(_service_mod, rpc, %{channel: channel} = stream, request, opts) do {_, {req_mod, req_stream}, {res_mod, response_stream}} = rpc @@ -239,13 +245,17 @@ defmodule GRPC.Stub do opts = if req_stream || response_stream do - parse_req_opts([{:timeout, :infinity} | opts]) + opts + |> parse_req_opts() + |> Keyword.put_new(:timeout, :infinity) else - parse_req_opts([{:timeout, @default_timeout} | opts]) + opts + |> parse_req_opts() + |> Keyword.put_new(:timeout, @default_timeout) end - compressor = Map.get(opts, :compressor, channel.compressor) - accepted_compressors = Map.get(opts, :accepted_compressors, []) + compressor = Keyword.get(opts, :compressor, channel.compressor) + accepted_compressors = Keyword.get(opts, :accepted_compressors, []) accepted_compressors = if compressor do @@ -256,8 +266,8 @@ defmodule GRPC.Stub do stream = %{ stream - | codec: Map.get(opts, :codec, channel.codec), - compressor: Map.get(opts, :compressor, channel.compressor), + | codec: Keyword.get(opts, :codec, channel.codec), + compressor: Keyword.get(opts, :compressor, channel.compressor), accepted_compressors: accepted_compressors } @@ -272,7 +282,7 @@ defmodule GRPC.Stub do ) do last = fn %{codec: codec, compressor: compressor} = s, _ -> message = codec.encode(request) - opts = Map.put(opts, :compressor, compressor) + opts = Keyword.put(opts, :compressor, compressor) s |> channel.adapter.send_request(message, opts) @@ -319,7 +329,7 @@ defmodule GRPC.Stub do * `:end_stream` - indicates it's the last one request, then the stream will be in half_closed state. Default is false. """ - @spec send_request(GRPC.Client.Stream.t(), struct, Keyword.t()) :: GRPC.Client.Stream.t() + @spec send_request(GRPC.Client.Stream.t(), struct, keyword()) :: GRPC.Client.Stream.t() def send_request(%{__interface__: interface} = stream, request, opts \\ []) do interface[:send_request].(stream, request, opts) end @@ -376,7 +386,7 @@ defmodule GRPC.Stub do * `:deadline` - when the request is timeout, will override timeout * `:return_headers` - when true, headers will be returned. """ - @spec recv(GRPC.Client.Stream.t(), Keyword.t() | map()) :: + @spec recv(GRPC.Client.Stream.t(), keyword()) :: {:ok, struct()} | {:ok, struct(), map()} | {:ok, Enumerable.t()} @@ -391,7 +401,7 @@ defmodule GRPC.Stub do def recv(%{__interface__: interface} = stream, opts) do opts = if is_list(opts) do - parse_recv_opts(opts) + parse_recv_opts(Keyword.put_new(opts, :timeout, @default_timeout)) else opts end @@ -588,72 +598,39 @@ defmodule GRPC.Stub do end end - defp parse_req_opts(list) when is_list(list) do - parse_req_opts(list, %{}) - end - - defp parse_req_opts([{:timeout, timeout} | t], acc) do - parse_req_opts(t, Map.put(acc, :timeout, timeout)) - end - - defp parse_req_opts([{:deadline, deadline} | t], acc) do - parse_req_opts(t, Map.put(acc, :timeout, GRPC.TimeUtils.to_relative(deadline))) - end - - defp parse_req_opts([{:compressor, compressor} | t], acc) do - parse_req_opts(t, Map.put(acc, :compressor, compressor)) - end - - defp parse_req_opts([{:accepted_compressors, compressors} | t], acc) do - parse_req_opts(t, Map.put(acc, :accepted_compressors, compressors)) - end - - defp parse_req_opts([{:grpc_encoding, grpc_encoding} | t], acc) do - parse_req_opts(t, Map.put(acc, :grpc_encoding, grpc_encoding)) - end - - defp parse_req_opts([{:metadata, metadata} | t], acc) do - parse_req_opts(t, Map.put(acc, :metadata, metadata)) - end - - defp parse_req_opts([{:content_type, content_type} | t], acc) do - Logger.warn(":content_type has been deprecated, please use :codec") - parse_req_opts(t, Map.put(acc, :content_type, content_type)) - end - - defp parse_req_opts([{:codec, codec} | t], acc) do - parse_req_opts(t, Map.put(acc, :codec, codec)) - end - - defp parse_req_opts([{:return_headers, return_headers} | t], acc) do - parse_req_opts(t, Map.put(acc, :return_headers, return_headers)) - end - - defp parse_req_opts([{key, _} | _], _) do - raise ArgumentError, "option #{inspect(key)} is not supported" + @valid_req_opts [ + :timeout, + :deadline, + :compressor, + :accepted_compressors, + :grpc_encoding, + :metadata, + :codec, + :return_headers + ] + defp parse_req_opts(opts) when is_list(opts) do + Enum.map(opts, fn + {:deadline, deadline} -> + {:timeout, GRPC.TimeUtils.to_relative(deadline)} + + {key, value} when key in @valid_req_opts -> + {key, value} + + {key, _} -> + raise ArgumentError, "option #{inspect(key)} is not supported" + end) end - defp parse_req_opts(_, acc), do: acc - defp parse_recv_opts(list) when is_list(list) do - parse_recv_opts(list, %{timeout: @default_timeout}) - end - - defp parse_recv_opts([{:timeout, timeout} | t], acc) do - parse_recv_opts(t, Map.put(acc, :timeout, timeout)) - end - - defp parse_recv_opts([{:deadline, deadline} | t], acc) do - parse_recv_opts(t, Map.put(acc, :deadline, GRPC.TimeUtils.to_relative(deadline))) - end + Enum.map(list, fn + {:deadline, deadline} -> + {:deadline, GRPC.TimeUtils.to_relative(deadline)} - defp parse_recv_opts([{:return_headers, return_headers} | t], acc) do - parse_recv_opts(t, Map.put(acc, :return_headers, return_headers)) - end + {key, _} when key not in @valid_req_opts -> + raise ArgumentError, "option #{inspect(key)} is not supported" - defp parse_recv_opts([{key, _} | _], _) do - raise ArgumentError, "option #{inspect(key)} is not supported" + kv -> + kv + end) end - - defp parse_recv_opts(_, acc), do: acc end diff --git a/lib/grpc/transport/http2.ex b/lib/grpc/transport/http2.ex index 189830fc..353a0ba4 100644 --- a/lib/grpc/transport/http2.ex +++ b/lib/grpc/transport/http2.ex @@ -23,8 +23,8 @@ defmodule GRPC.Transport.HTTP2 do @doc """ Now we may not need this because gun already handles the pseudo headers. """ - @spec client_headers(GRPC.Client.Stream.t(), map) :: [{String.t(), String.t()}] - def client_headers(%{channel: channel, path: path} = s, opts \\ %{}) do + @spec client_headers(GRPC.Client.Stream.t(), keyword()) :: [{String.t(), String.t()}] + def client_headers(%{channel: channel, path: path} = s, opts \\ []) do [ {":method", "POST"}, {":scheme", channel.scheme}, @@ -33,8 +33,10 @@ defmodule GRPC.Transport.HTTP2 do ] ++ client_headers_without_reserved(s, opts) end - @spec client_headers_without_reserved(GRPC.Client.Stream.t(), map) :: [{String.t(), String.t()}] - def client_headers_without_reserved(%{codec: codec} = stream, opts \\ %{}) do + @spec client_headers_without_reserved(GRPC.Client.Stream.t(), keyword()) :: [ + {String.t(), String.t()} + ] + def client_headers_without_reserved(%{codec: codec} = stream, opts \\ []) do [ # It seems only gRPC implemenations only support "application/grpc", so we support :content_type now. {"content-type", content_type(opts[:content_type], codec)}, diff --git a/test/grpc/adapter/gun_test.exs b/test/grpc/adapter/gun_test.exs index 04840eb5..b3d089e0 100644 --- a/test/grpc/adapter/gun_test.exs +++ b/test/grpc/adapter/gun_test.exs @@ -26,7 +26,7 @@ defmodule GRPC.Client.Adapters.GunTest do test "connects insecurely (default options)", %{port: port, credential: credential} do channel = build(:channel, port: port, host: "localhost", cred: credential) - assert {:ok, result} = Gun.connect(channel) + assert {:ok, result} = Gun.connect(channel, []) assert %{channel | adapter_payload: %{conn_pid: result.adapter_payload.conn_pid}} == result end @@ -35,12 +35,12 @@ defmodule GRPC.Client.Adapters.GunTest do channel = build(:channel, port: port, host: "localhost", cred: credential) # Ensure that it works - assert {:ok, result} = Gun.connect(channel, %{transport_opts: [ip: :loopback]}) + assert {:ok, result} = Gun.connect(channel, transport_opts: [ip: :loopback]) assert %{channel | adapter_payload: %{conn_pid: result.adapter_payload.conn_pid}} == result # Ensure that changing one of the options breaks things assert {:error, {:down, :badarg}} == - Gun.connect(channel, %{transport_opts: [ip: "256.0.0.0"]}) + Gun.connect(channel, transport_opts: [ip: "256.0.0.0"]) end test "connects securely (default options)", %{port: port, credential: credential} do @@ -52,7 +52,7 @@ defmodule GRPC.Client.Adapters.GunTest do cred: credential ) - assert {:ok, result} = Gun.connect(channel, %{tls_opts: channel.cred.ssl}) + assert {:ok, result} = Gun.connect(channel, tls_opts: channel.cred.ssl) assert %{channel | adapter_payload: %{conn_pid: result.adapter_payload.conn_pid}} == result end @@ -68,20 +68,20 @@ defmodule GRPC.Client.Adapters.GunTest do # Ensure that it works assert {:ok, result} = - Gun.connect(channel, %{ + Gun.connect(channel, transport_opts: [certfile: credential.ssl[:certfile], ip: :loopback] - }) + ) assert %{channel | adapter_payload: %{conn_pid: result.adapter_payload.conn_pid}} == result # Ensure that changing one of the options breaks things assert {:error, :timeout} == - Gun.connect(channel, %{ + Gun.connect(channel, transport_opts: [ certfile: credential.ssl[:certfile] <> "invalidsuffix", ip: :loopback ] - }) + ) end end end diff --git a/test/grpc/integration/connection_test.exs b/test/grpc/integration/connection_test.exs index bdd25cfc..fe68d005 100644 --- a/test/grpc/integration/connection_test.exs +++ b/test/grpc/integration/connection_test.exs @@ -9,7 +9,7 @@ defmodule GRPC.Integration.ConnectionTest do server = FeatureServer {:ok, _, port} = GRPC.Server.start(server, 0) point = Routeguide.Point.new(latitude: 409_146_138, longitude: -746_188_906) - {:ok, channel} = GRPC.Stub.connect("localhost:#{port}", adapter_opts: %{retry_timeout: 10}) + {:ok, channel} = GRPC.Stub.connect("localhost:#{port}", adapter_opts: [retry_timeout: 10]) assert {:ok, _} = channel |> Routeguide.RouteGuide.Stub.get_feature(point) :ok = GRPC.Server.stop(server) {:ok, _, _} = reconnect_server(server, port) @@ -23,7 +23,7 @@ defmodule GRPC.Integration.ConnectionTest do File.rm(socket_path) {:ok, _, _} = GRPC.Server.start(server, 0, ip: {:local, socket_path}) - {:ok, channel} = GRPC.Stub.connect(socket_path, adapter_opts: %{retry_timeout: 10}) + {:ok, channel} = GRPC.Stub.connect(socket_path, adapter_opts: [retry_timeout: 10]) point = Routeguide.Point.new(latitude: 409_146_138, longitude: -746_188_906) assert {:ok, _} = channel |> Routeguide.RouteGuide.Stub.get_feature(point)