Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/custom oauth2 header #2928

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions kong/plugins/oauth2/access.lua
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,9 @@ local function authorize(conf)
})
end

local function retrieve_client_credentials(parameters)
local function retrieve_client_credentials(parameters, conf)
local client_id, client_secret, from_authorization_header
local authorization_header = ngx.req.get_headers()["authorization"]
local authorization_header = ngx.req.get_headers()[conf.auth_header_name]
if parameters[CLIENT_ID] and parameters[CLIENT_SECRET] then
client_id = parameters[CLIENT_ID]
client_secret = parameters[CLIENT_SECRET]
Expand Down Expand Up @@ -271,7 +271,7 @@ local function issue_token(conf)
response_params = {[ERROR] = "unsupported_grant_type", error_description = "Invalid " .. GRANT_TYPE}
end

local client_id, client_secret, from_authorization_header = retrieve_client_credentials(parameters)
local client_id, client_secret, from_authorization_header = retrieve_client_credentials(parameters, conf)

-- Check client_id and redirect_uri
local allowed_redirect_uris, client = get_redirect_uri(client_id)
Expand Down Expand Up @@ -405,7 +405,7 @@ local function parse_access_token(conf)
local found_in = {}
local result = retrieve_parameters()["access_token"]
if not result then
local authorization = ngx.req.get_headers()["authorization"]
local authorization = ngx.req.get_headers()[conf.auth_header_name]
if authorization then
local parts = {}
for v in authorization:gmatch("%S+") do -- Split by space
Expand All @@ -420,7 +420,7 @@ local function parse_access_token(conf)

if conf.hide_credentials then
if found_in.authorization_header then
ngx.req.clear_header("authorization")
ngx.req.clear_header(conf.auth_header_name)
else
-- Remove from querystring
local parameters = ngx.req.get_uri_args()
Expand Down
22 changes: 21 additions & 1 deletion kong/plugins/oauth2/migrations/cassandra.lua
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
local plugin_config_iterator = require("kong.dao.migrations.helpers").plugin_config_iterator

return {
{
name = "2015-08-03-132400_init_oauth2",
Expand Down Expand Up @@ -151,5 +153,23 @@ return {
end
end
end
}
},
{
name = "2017-10-19-set_auth_header_name_default",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to lexically sort, so please add the time in hhmmss format to the key (see examples above)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm, just saw the older ones did this neither.

up = function(_, _, dao)
for ok, config, update in plugin_config_iterator(dao, "oauth2") do
if not ok then
return config
end
if config.auth_header_name == nil then
config.auth_header_name = "authorization"
local _, err = update(config)
if err then
return err
end
end
end
end,
down = function(_, _, dao) end -- not implemented
},
}
22 changes: 21 additions & 1 deletion kong/plugins/oauth2/migrations/postgres.lua
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
local plugin_config_iterator = require("kong.dao.migrations.helpers").plugin_config_iterator

return {
{
name = "2015-08-03-132400_init_oauth2",
Expand Down Expand Up @@ -164,5 +166,23 @@ return {
down = [[
ALTER TABLE oauth2_credentials ADD CONSTRAINT oauth2_credentials_client_secret_key UNIQUE(client_secret);
]],
}
},
{
name = "2017-10-19-set_auth_header_name_default",
up = function(_, _, dao)
for ok, config, update in plugin_config_iterator(dao, "oauth2") do
if not ok then
return config
end
if config.auth_header_name == nil then
config.auth_header_name = "authorization"
local _, err = update(config)
if err then
return err
end
end
end
end,
down = function(_, _, dao) end -- not implemented
},
}
1 change: 1 addition & 0 deletions kong/plugins/oauth2/schema.lua
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ return {
accept_http_if_already_terminated = { required = false, type = "boolean", default = false },
anonymous = {type = "string", default = "", func = check_user},
global_credentials = {type = "boolean", default = false},
auth_header_name = {required = false, type = "string", default = "authorization"},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will require a migration. Existing records in the datastore will not have this field, nor the default. (The default is not applied when loading from the datastore, but only when adding the plugin config)

This also means that this PR cannot go into master but has to go into next (migrations go in the next major release, which is the next branch)

As an example, look at the migrations in this PR: #2883
(you need to rebase on next to get those files)

},
self_check = function(schema, plugin_t, dao, is_update)
if not plugin_t.enable_authorization_code and not plugin_t.enable_implicit_grant
Expand Down
20 changes: 19 additions & 1 deletion spec/03-plugins/26-oauth2/01-schema_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,25 @@ describe("Plugin: oauth2 (schema)", function()
assert.truthy(t.provision_key)
assert.equal("hello", t.provision_key)
end)

it("sets default `auth_header_name` when not given", function()
local t = {enable_authorization_code = true, mandatory_scope = true, scopes = {"email", "info"}}
local ok, errors = validate_entity(t, oauth2_schema)
assert.True(ok)
assert.is_nil(errors)
assert.truthy(t.provision_key)
assert.equal(32, t.provision_key:len())
assert.equal("authorization", t.auth_header_name)
end)
it("does not set default value for `auth_header_name` when it is given", function()
local t = {enable_authorization_code = true, mandatory_scope = true, scopes = {"email", "info"}, provision_key = "hello",
auth_header_name="custom_header_name"}
local ok, errors = validate_entity(t, oauth2_schema)
assert.True(ok)
assert.is_nil(errors)
assert.truthy(t.provision_key)
assert.equal("hello", t.provision_key)
assert.equal("custom_header_name", t.auth_header_name)
end)
describe("errors", function()
it("requires at least one flow", function()
local ok, _, err = validate_entity({}, oauth2_schema)
Expand Down
130 changes: 129 additions & 1 deletion spec/03-plugins/26-oauth2/03-access_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ describe("Plugin: oauth2 (access)", function()
name = "testapp3",
consumer_id = consumer.id
})
assert(helpers.dao.oauth2_credentials:insert {
client_id = "clientid1011",
client_secret = "secret1011",
redirect_uri = "http://google.com/kong",
name = "testapp31",
consumer_id = consumer.id
})

local api1 = assert(helpers.dao.apis:insert {
name = "api-1",
Expand Down Expand Up @@ -302,7 +309,45 @@ describe("Plugin: oauth2 (access)", function()
anonymous = utils.uuid(), -- a non existing consumer
},
})

local api11 = assert(helpers.dao.apis:insert {
name = "oauth2_11.com",
hosts = { "oauth2_11.com" },
upstream_url = helpers.mock_upstream_url,
})
assert(helpers.dao.plugins:insert {
name = "oauth2",
api_id = api11.id,
config = {
scopes = { "email", "profile", "user.email" },
enable_authorization_code = true,
mandatory_scope = true,
provision_key = "provision123",
token_expiration = 7,
enable_implicit_grant = true,
global_credentials = true,
auth_header_name = "custom_header_name",
},
})
local api12 = assert(helpers.dao.apis:insert {
name = "oauth2_12.com",
hosts = { "oauth2_12.com" },
upstream_url = helpers.mock_upstream_url,
})
assert(helpers.dao.plugins:insert {
name = "oauth2",
api_id = api12.id,
config = {
scopes = { "email", "profile", "user.email" },
enable_authorization_code = true,
mandatory_scope = true,
provision_key = "provision123",
token_expiration = 7,
enable_implicit_grant = true,
global_credentials = true,
auth_header_name = "custom_header_name",
hide_credentials = true,
},
})
assert(helpers.start_kong({
trusted_ips = "127.0.0.1",
nginx_conf = "spec/fixtures/custom_nginx.template",
Expand Down Expand Up @@ -805,6 +850,35 @@ describe("Plugin: oauth2 (access)", function()
assert.are.equal(5, data[1].expires_in)
assert.falsy(data[1].refresh_token)
end)
it("returns success and the token should have the right expiration when a custom header is passed", function()
local res = assert(proxy_ssl_client:send {
method = "POST",
path = "/oauth2/authorize",
body = {
provision_key = "provision123",
authenticated_userid = "id123",
client_id = "clientid1011",
scope = "email",
response_type = "token"
},
headers = {
["Host"] = "oauth2_11.com",
["Content-Type"] = "application/json"
}
})
local body = cjson.decode(assert.res_status(200, res))
assert.is_table(ngx.re.match(body.redirect_uri, "^http://google\\.com/kong\\#access_token=[\\w]{32,32}&expires_in=[\\d]+&token_type=bearer$"))

local iterator, err = ngx.re.gmatch(body.redirect_uri, "^http://google\\.com/kong\\#access_token=([\\w]{32,32})&expires_in=[\\d]+&token_type=bearer$")
assert.is_nil(err)
local m, err = iterator()
assert.is_nil(err)
local data = helpers.dao.oauth2_tokens:find_all {access_token = m[1]}
assert.are.equal(1, #data)
assert.are.equal(m[1], data[1].access_token)
assert.are.equal(7, data[1].expires_in)
assert.falsy(data[1].refresh_token)
end)
it("returns success and store authenticated user properties", function()
local res = assert(proxy_ssl_client:send {
method = "POST",
Expand Down Expand Up @@ -1695,6 +1769,32 @@ describe("Plugin: oauth2 (access)", function()
})
assert.res_status(200, res)
end)
it("work when a correct access_token is being sent in the custom header", function()
local token = provision_token("oauth2_11.com",nil,"clientid1011","secret1011")

local res = assert(proxy_ssl_client:send {
method = "GET",
path = "/request",
headers = {
["Host"] = "oauth2_11.com",
["custom_header_name"] = "bearer " .. token.access_token,
}
})
assert.res_status(200, res)
end)
it("fail when a correct access_token is being sent in the wrong header", function()
local token = provision_token("oauth2_11.com",nil,"clientid1011","secret1011")

local res = assert(proxy_ssl_client:send {
method = "GET",
path = "/request",
headers = {
["Host"] = "oauth2_11.com",
["authorization"] = "bearer " .. token.access_token,
}
})
assert.res_status(401, res)
end)
it("does not work when requesting a different API", function()
local token = provision_token()

Expand Down Expand Up @@ -2124,6 +2224,19 @@ describe("Plugin: oauth2 (access)", function()
local body = cjson.decode(assert.res_status(200, res))
assert.is_nil(body.uri_args.access_token)
end)
it("hides credentials in the querystring for api with custom header", function()
local token = provision_token("oauth2_12.com",nil,"clientid1011","secret1011")

local res = assert(proxy_client:send {
method = "GET",
path = "/request?access_token=" .. token.access_token,
headers = {
["Host"] = "oauth2_12.com"
}
})
local body = cjson.decode(assert.res_status(200, res))
assert.is_nil(body.uri_args.access_token)
end)
it("does not hide credentials in the header", function()
local token = provision_token()

Expand Down Expand Up @@ -2152,6 +2265,21 @@ describe("Plugin: oauth2 (access)", function()
local body = cjson.decode(assert.res_status(200, res))
assert.is_nil(body.headers.authorization)
end)
it("hides credentials in the custom header", function()
local token = provision_token("oauth2_12.com",nil,"clientid1011","secret1011")

local res = assert(proxy_client:send {
method = "GET",
path = "/request",
headers = {
["Host"] = "oauth2_12.com",
["custom_header_name"] = "bearer " .. token.access_token
}
})
local body = cjson.decode(assert.res_status(200, res))
assert.is_nil(body.headers.authorization)
assert.is_nil(body.headers.custom_header_name)
end)
it("does not abort when the request body is a multipart form upload", function()
local token = provision_token("oauth2_3.com")

Expand Down