Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: set cors allow origins by plugin metadata #6546

76 changes: 66 additions & 10 deletions apisix/plugins/cors.lua
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,34 @@
-- limitations under the License.
--
local core = require("apisix.core")
local plugin = require("apisix.plugin")
local ngx = ngx
local plugin_name = "cors"
local str_find = core.string.find
local re_gmatch = ngx.re.gmatch
local re_compile = require("resty.core.regex").re_match_compile
local re_find = ngx.re.find
local ipairs = ipairs
local origins_pattern = [[^(\*|\*\*|null|\w+://[^,]+(,\w+://[^,]+)*)$]]


local lrucache = core.lrucache.new({
type = "plugin",
})

local metadata_schema = {
type = "object",
properties = {
allow_origins = {
type = "object",
additionalProperties = {
type = "string",
pattern = origins_pattern
}
},
},
}

local schema = {
type = "object",
properties = {
Expand All @@ -37,7 +52,7 @@ local schema = {
"'**' to allow forcefully(it will bring some security risks, be carefully)," ..
"multiple origin use ',' to split. default: *.",
type = "string",
pattern = [[^(\*|\*\*|null|\w+://[^,]+(,\w+://[^,]+)*)$]],
pattern = origins_pattern,
default = "*"
},
allow_methods = {
Expand Down Expand Up @@ -92,6 +107,18 @@ local schema = {
minItems = 1,
uniqueItems = true,
},
allow_origins_by_metadata = {
type = "array",
description =
"set allowed origins by referencing origins in plugin metadata",
items = {
type = "string",
minLength = 1,
maxLength = 4096,
},
minItems = 1,
uniqueItems = true,
},
}
}

Expand All @@ -100,15 +127,16 @@ local _M = {
priority = 4000,
name = plugin_name,
schema = schema,
metadata_schema = metadata_schema,
}


local function create_multiple_origin_cache(conf)
if not str_find(conf.allow_origins, ",") then
local function create_multiple_origin_cache(allow_origins)
if not str_find(allow_origins, ",") then
return nil
end
local origin_cache = {}
local iterator, err = re_gmatch(conf.allow_origins, "([^,]+)", "jiox")
local iterator, err = re_gmatch(allow_origins, "([^,]+)", "jiox")
if not iterator then
core.log.error("match origins failed: ", err)
return nil
Expand All @@ -128,7 +156,10 @@ local function create_multiple_origin_cache(conf)
end


function _M.check_schema(conf)
function _M.check_schema(conf, schema_type)
if schema_type == core.schema.TYPE_METADATA then
return core.schema.check(metadata_schema, conf)
end
local ok, err = core.schema.check(schema, conf)
if not ok then
return false, err
Expand Down Expand Up @@ -177,13 +208,16 @@ local function set_cors_headers(conf, ctx)
end
end

local function process_with_allow_origins(conf, ctx, req_origin)
local allow_origins = conf.allow_origins
local function process_with_allow_origins(allow_origins_conf, ctx, req_origin, cache_key, cache_version)
local allow_origins = allow_origins_conf
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to declare allow_origins_conf as local allow_origins here, just change the allow_origins_conf -> allow_origins in parameter list is ok?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, changed that

if allow_origins == "**" then
allow_origins = req_origin or '*'
end
local multiple_origin, err = core.lrucache.plugin_ctx(lrucache, ctx, nil,
create_multiple_origin_cache, conf)

if not (cache_key and cache_version) then
cache_key, cache_version = core.lrucache.plugin_ctx_id(ctx)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plugin_ctx_id only returns an ID?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😂fixed

end
local multiple_origin, err = lrucache(cache_key, cache_version, create_multiple_origin_cache, allow_origins_conf)
if err then
return 500, {message = "get multiple origin cache failed: " .. err}
end
Expand Down Expand Up @@ -225,6 +259,25 @@ local function match_origins(req_origin, allow_origins)
return req_origin == allow_origins or allow_origins == '*'
end

local function process_with_allow_origins_by_metadata(allow_origins_by_metadata, ctx, req_origin)
if allow_origins_by_metadata == nil then
return
end

local metadata = plugin.plugin_metadata(plugin_name)
if metadata and metadata.value.allow_origins then
local allow_origins_map = metadata.value.allow_origins
for _, key in ipairs(allow_origins_by_metadata) do
local allow_origins_conf = allow_origins_map[key]
local allow_origins = process_with_allow_origins(allow_origins_conf, ctx, req_origin,
plugin_name .. "#" .. key, metadata.modifiedIndex)
if match_origins(req_origin, allow_origins) then
return req_origin
end
end
end
end


function _M.rewrite(conf, ctx)
-- save the original request origin as it may be changed at other phase
Expand All @@ -239,10 +292,13 @@ function _M.header_filter(conf, ctx)
local req_origin = ctx.original_request_origin
-- Try allow_origins first, if mismatched, try allow_origins_by_regex.
local allow_origins
allow_origins = process_with_allow_origins(conf, ctx, req_origin)
allow_origins = process_with_allow_origins(conf.allow_origins, ctx, req_origin)
if not match_origins(req_origin, allow_origins) then
allow_origins = process_with_allow_origins_by_regex(conf, ctx, req_origin)
end
if not match_origins(req_origin, allow_origins) then
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not match_origins(req_origin, allow_origins) then
if not allow_origins then

allow_origins = process_with_allow_origins_by_metadata(conf.allow_origins_by_metadata, ctx, req_origin)
end
if allow_origins then
ctx.cors_allow_origins = allow_origins
set_cors_headers(conf, ctx)
Expand Down
Loading