diff --git a/gateway/src/apicast/policy/request_unbuffered/apicast-policy.json b/gateway/src/apicast/policy/request_unbuffered/apicast-policy.json new file mode 100644 index 000000000..9b26e5ea7 --- /dev/null +++ b/gateway/src/apicast/policy/request_unbuffered/apicast-policy.json @@ -0,0 +1,13 @@ +{ + "$schema": "http://apicast.io/policy-v1/schema#manifest#", + "name": "Request Unbuffered", + "summary": "Disable request buffering", + "description": [ + "Disable request buffering. This is useful when proxying big payloads with HTTP/1.1 chunked encoding" + ], + "version": "builtin", + "configuration": { + "type": "object", + "properties": {} + } +} diff --git a/gateway/src/apicast/policy/request_unbuffered/init.lua b/gateway/src/apicast/policy/request_unbuffered/init.lua new file mode 100644 index 000000000..b5a678161 --- /dev/null +++ b/gateway/src/apicast/policy/request_unbuffered/init.lua @@ -0,0 +1 @@ +return require('request_unbuffered') diff --git a/gateway/src/apicast/policy/request_unbuffered/request_unbuffered.lua b/gateway/src/apicast/policy/request_unbuffered/request_unbuffered.lua new file mode 100644 index 000000000..a3113dcce --- /dev/null +++ b/gateway/src/apicast/policy/request_unbuffered/request_unbuffered.lua @@ -0,0 +1,22 @@ +-- Request Unbuffered policy +-- This policy will disable request buffering + +local policy = require('apicast.policy') +local _M = policy.new('request_unbuffered') + +local new = _M.new + +--- Initialize a buffering +-- @tparam[opt] table config Policy configuration. +function _M.new(config) + local self = new(config) + return self +end + +function _M:export() + return { + request_unbuffered = true, + } +end + +return _M diff --git a/gateway/src/apicast/upstream.lua b/gateway/src/apicast/upstream.lua index 23bbb2716..0aff47359 100644 --- a/gateway/src/apicast/upstream.lua +++ b/gateway/src/apicast/upstream.lua @@ -210,6 +210,15 @@ function _M:set_keepalive_key(context) end end +local function get_upstream_location_name(context) + if context.upstream_location_name then + return context.upstream_location_name + end + if context.request_unbuffered then + return "@upstream_request_unbuffered" + end +end + --- Execute the upstream. --- @tparam table context any table (policy context, ngx.ctx) to store the upstream for later use by balancer function _M:call(context) @@ -242,9 +251,9 @@ function _M:call(context) self:set_keepalive_key(context or {}) if not self.servers then self:resolve() end - if context.upstream_location_name then - self.location_name = context.upstream_location_name - end + + local upstream_location_name = get_upstream_location_name(context) + self:update_location(upstream_location_name) context[self.upstream_name] = self return exec(self) diff --git a/spec/upstream_spec.lua b/spec/upstream_spec.lua index 343f4bceb..6ab12e432 100644 --- a/spec/upstream_spec.lua +++ b/spec/upstream_spec.lua @@ -217,6 +217,22 @@ describe('Upstream', function() assert.spy(ngx.exec).was_called_with(upstream.location_name) end) + it('executes the upstream location when request_unbuffered provided in the context', function() + local contexts = { + ["buffered_request"] = {ctx={}, upstream_location="@upstream"}, + ["unbuffered_request"] = {ctx={request_unbuffered=true}, upstream_location="@upstream_request_unbuffered"}, + ["upstream_location and buffered_request"] = {ctx={upstream_location_name="@grpc", request_unbuffered=true}, upstream_location="@grpc"}, + ["upstream_location and unbuffered_request"] = {ctx={upstream_location_name="@grpc"}, upstream_location="@grpc"}, + } + + for _, value in pairs(contexts) do + local upstream = Upstream.new('http://localhost') + upstream:call(value.ctx) + + assert.spy(ngx.exec).was_called_with(value.upstream_location) + end + end) + it('skips executing the upstream location when missing', function() local upstream = Upstream.new('http://localhost') upstream.location_name = nil diff --git a/t/apicast-policy-request-unbuffered.t b/t/apicast-policy-request-unbuffered.t new file mode 100644 index 000000000..64d05ff9b --- /dev/null +++ b/t/apicast-policy-request-unbuffered.t @@ -0,0 +1,210 @@ +use lib 't'; +use Test::APIcast::Blackbox 'no_plan'; + +sub large_body { + my $res = ""; + for (my $i=0; $i <= 1024; $i++) { + $res = $res . "1111111 1111111 1111111 1111111\n"; + } + return $res; +} + +$ENV{'LARGE_BODY'} = large_body(); + +require("policies.pl"); + +run_tests(); + +__DATA__ + +=== TEST 1: request_unbuffered policy with big file +--- configuration +{ + "services": [ + { + "backend_version": 1, + "proxy": { + "api_backend": "http://test-upstream.lvh.me:$TEST_NGINX_SERVER_PORT/", + "proxy_rules": [ + { "pattern": "/", "http_method": "POST", "metric_system_name": "hits", "delta": 2 } + ], + "policy_chain": [ + { + "name": "request_unbuffered", + "version": "builtin", + "configuration": {} + }, + { + "name": "apicast", + "version": "builtin", + "configuration": {} + } + ] + } + } + ] +} +--- backend +location /transactions/authrep.xml { + content_by_lua_block { + ngx.exit(200) + } +} +--- upstream +server_name test-upstream.lvh.me; + location / { + echo_read_request_body; + echo_request_body; + } +--- request eval +"POST /?user_key= \n" . $ENV{LARGE_BODY} +--- response_body eval chomp +$ENV{LARGE_BODY} +--- error_code: 200 +--- grep_error_log +a client request body is buffered to a temporary file +--- grep_error_log_out +--- no_error_log +[error] + + + +=== TEST 2: with small chunked request +--- configuration +{ + "services": [ + { + "backend_version": 1, + "proxy": { + "api_backend": "http://test-upstream.lvh.me:$TEST_NGINX_SERVER_PORT/", + "proxy_rules": [ + { "pattern": "/", "http_method": "POST", "metric_system_name": "hits", "delta": 2 } + ], + "policy_chain": [ + { + "name": "request_unbuffered", + "version": "builtin", + "configuration": {} + }, + { + "name": "apicast", + "version": "builtin", + "configuration": {} + } + ] + } + } + ] +} +--- backend +location /transactions/authrep.xml { + content_by_lua_block { + ngx.exit(200) + } +} +--- upstream +server_name test-upstream.lvh.me; + location / { + access_by_lua_block { + assert = require('luassert') + ngx.say("yay, api backend") + + -- Nginx will read the entire body in one chunk, the upstream request will not be chunked + -- and Content-Length header will be added. + local content_length = ngx.req.get_headers()["Content-Length"] + local encoding = ngx.req.get_headers()["Transfer-Encoding"] + assert.equal('12', content_length) + assert.falsy(encoding, "chunked") + } + } +--- more_headers +Transfer-Encoding: chunked +--- request eval +"POST /test?user_key=value +7\r +hello, \r +5\r +world\r +0\r +\r +" +--- error_code: 200 +--- no_error_log +[error] + + + +=== TEST 3: With big chunked request +--- configuration +{ + "services": [ + { + "backend_version": 1, + "proxy": { + "api_backend": "http://test-upstream.lvh.me:$TEST_NGINX_SERVER_PORT/", + "proxy_rules": [ + { "pattern": "/", "http_method": "POST", "metric_system_name": "hits", "delta": 2 } + ], + "policy_chain": [ + { + "name": "request_unbuffered", + "version": "builtin", + "configuration": {} + }, + { + "name": "apicast", + "version": "builtin", + "configuration": {} + } + ] + } + } + ] +} +--- backend +location /transactions/authrep.xml { + content_by_lua_block { + ngx.exit(200) + } +} +--- upstream +server_name test-upstream.lvh.me; + location / { + access_by_lua_block { + assert = require('luassert') + local content_length = ngx.req.get_headers()["Content-Length"] + local encoding = ngx.req.get_headers()["Transfer-Encoding"] + assert.equal('chunked', encoding) + assert.falsy(content_length) + } + echo_read_request_body; + echo_request_body; + } +--- more_headers +Transfer-Encoding: chunked +--- request eval +$::data = ''; +for (my $i = 0; $i < 16384; $i++) { + my $c = chr int rand 128; + $::data .= $c; +} +my $s = "POST https://localhost/test?user_key=value +". +sprintf("%x\r\n", length $::data). +$::data +."\r +0\r +\r +"; +open my $out, '>/tmp/out.txt' or die $!; +print $out $s; +close $out; +$s +--- response_body eval +$::data +--- error_code: 200 +--- grep_error_log +a client request body is buffered to a temporary file +--- grep_error_log_out +--- no_error_log +[error]