diff --git a/kong/dao/cassandra/apis.lua b/kong/dao/cassandra/apis.lua index 03a3465e12c7..36ef4d5ef757 100644 --- a/kong/dao/cassandra/apis.lua +++ b/kong/dao/cassandra/apis.lua @@ -1,5 +1,6 @@ local BaseDao = require "kong.dao.cassandra.base_dao" local constants = require "kong.constants" +local PluginsConfigurations = require "kong.dao.cassandra.plugins_configurations" local SCHEMA = { id = { type = constants.DATABASE_TYPES.ID }, @@ -50,4 +51,36 @@ function Apis:new(properties) Apis.super.new(self, properties) end +-- @override +function Apis:delete(api_id) + local ok, err = Apis.super.delete(self, api_id) + if not ok then + return err + end + + -- delete all related plugins configurations + local plugins_dao = PluginsConfigurations(self._properties) + local query, args_keys, errors = plugins_dao:_build_where_query(plugins_dao._queries.select.query, { + api_id = api_id + }) + if errors then + return nil, errors + end + + for _, rows, page, err in plugins_dao:_execute_kong_query({query=query, args_keys=args_keys}, {api_id=api_id}, {auto_paging=true}) do + if err then + return nil, err + end + + for _, row in ipairs(rows) do + local ok_del_plugin, err = plugins_dao:delete(row.id) + if not ok_del_plugin then + return nil, err + end + end + end + + return ok +end + return Apis diff --git a/kong/dao/cassandra/base_dao.lua b/kong/dao/cassandra/base_dao.lua index 5b5f0e43e8f4..65b2612bc535 100644 --- a/kong/dao/cassandra/base_dao.lua +++ b/kong/dao/cassandra/base_dao.lua @@ -223,6 +223,33 @@ local function encode_cassandra_args(schema, t, args_keys) return args_to_bind, errors end +function BaseDao:_build_where_query(query, t) + local args_keys = {} + local where_str = "" + local errors + + -- if t is an args_keys, compute a WHERE statement + if t and utils.table_size(t) > 0 then + local where = {} + for k, v in pairs(t) do + if self._schema[k] and self._schema[k].queryable or k == "id" then + table.insert(where, string.format("%s = ?", k)) + table.insert(args_keys, k) + else + errors = utils.add_error(errors, k, k.." is not queryable.") + end + end + + if errors then + return nil, nil, DaoError(errors, error_types.SCHEMA) + end + + where_str = "WHERE "..table.concat(where, " AND ").." ALLOW FILTERING" + end + + return string.format(query, where_str), args_keys +end + -- Get a statement from the cache or prepare it (and thus insert it in the cache). -- The cache key will be the plain string query representation. -- @param `kong_query` A kong query from the _queries property. @@ -269,6 +296,14 @@ function BaseDao:_execute(statement, args, options, keyspace) return nil, err end + if options and options.auto_paging then + local _, rows, page, err = session:execute(statement, args, options) + for i, row in ipairs(rows) do + rows[i] = self:_unmarshall(row) + end + return _, rows, page, err + end + local results, err = session:execute(statement, args, options) if err then err = DaoError(err, error_types.DATABASE) @@ -523,31 +558,12 @@ end -- @param `paging_state` Start page from given offset. See lua-resty-cassandra's :execute() option. -- @return _execute_kong_query() function BaseDao:find_by_keys(t, page_size, paging_state) - local where, keys = {}, {} - local where_str = "" - local errors - - -- if keys are passed, compute a WHERE statement - if t and utils.table_size(t) > 0 then - for k,v in pairs(t) do - if self._schema[k] and self._schema[k].queryable or k == "id" then - table.insert(where, string.format("%s = ?", k)) - table.insert(keys, k) - else - errors = utils.add_error(errors, k, k.." is not queryable.") - end - end - - if errors then - return nil, DaoError(errors, error_types.SCHEMA) - end - - where_str = "WHERE "..table.concat(where, " AND ").." ALLOW FILTERING" + local select_where_query, args_keys, errors = self:_build_where_query(self._queries.select.query, t) + if errors then + return nil, errors end - local select_query = string.format(self._queries.select.query, where_str) - - return self:_execute_kong_query({ query = select_query, args_keys = keys }, t, { + return self:_execute_kong_query({ query = select_where_query, args_keys = args_keys }, t, { page_size = page_size, paging_state = paging_state }) diff --git a/kong/dao/cassandra/consumers.lua b/kong/dao/cassandra/consumers.lua index 89e615ca5e6a..89e37508c214 100644 --- a/kong/dao/cassandra/consumers.lua +++ b/kong/dao/cassandra/consumers.lua @@ -1,6 +1,7 @@ local BaseDao = require "kong.dao.cassandra.base_dao" -local constants = require "kong.constants" local stringy = require "stringy" +local constants = require "kong.constants" +local PluginsConfigurations = require "kong.dao.cassandra.plugins_configurations" local function check_custom_id_and_username(value, consumer_t) if (consumer_t.custom_id == nil or stringy.strip(consumer_t.custom_id) == "") @@ -56,4 +57,36 @@ function Consumers:new(properties) Consumers.super.new(self, properties) end +-- @override +function Consumers:delete(consumer_id) + local ok, err = Consumers.super.delete(self, consumer_id) + if not ok then + return err + end + + -- delete all related plugins configurations + local plugins_dao = PluginsConfigurations(self._properties) + local query, args_keys, errors = plugins_dao:_build_where_query(plugins_dao._queries.select.query, { + consumer_id = consumer_id + }) + if errors then + return nil, errors + end + + for _, rows, page, err in plugins_dao:_execute_kong_query({query=query, args_keys=args_keys}, {consumer_id=consumer_id}, {auto_paging=true}) do + if err then + return nil, err + end + + for _, row in ipairs(rows) do + local ok_del_plugin, err = plugins_dao:delete(row.id) + if not ok_del_plugin then + return nil, err + end + end + end + + return ok +end + return Consumers diff --git a/kong/dao/cassandra/plugins_configurations.lua b/kong/dao/cassandra/plugins_configurations.lua index 4b2995040aef..ebefef2ef69d 100644 --- a/kong/dao/cassandra/plugins_configurations.lua +++ b/kong/dao/cassandra/plugins_configurations.lua @@ -121,7 +121,7 @@ function PluginsConfigurations:find_distinct() end local result = {} - for k,_ in pairs(distinct_names) do + for k, _ in pairs(distinct_names) do table.insert(result, k) end diff --git a/kong/tools/faker.lua b/kong/tools/faker.lua index fbf42b99127f..f6af6229497e 100644 --- a/kong/tools/faker.lua +++ b/kong/tools/faker.lua @@ -53,7 +53,7 @@ Faker.FIXTURES = { { name = "keyauth", value = { key_names = { "apikey" }}, __api = 1 }, { name = "tcplog", value = { host = "127.0.0.1", port = 7777 }, __api = 1 }, { name = "udplog", value = { host = "127.0.0.1", port = 8888 }, __api = 1 }, - { name = "filelog", value = { }, __api = 1 }, + { name = "filelog", value = {}, __api = 1 }, -- API 2 { name = "basicauth", value = {}, __api = 2 }, -- API 3 @@ -73,9 +73,9 @@ Faker.FIXTURES = { -- API 6 { name = "cors", value = {}, __api = 6 }, -- API 7 - { name = "cors", value = { origin = "example.com", - methods = "GET", - headers = "origin, type, accepts", + { name = "cors", value = { origin = "example.com", + methods = "GET", + headers = "origin, type, accepts", exposed_headers = "x-auth-token", max_age = 23, credentials = true }, __api = 7 } diff --git a/spec/integration/admin_api/admin_api_spec.lua b/spec/integration/admin_api/admin_api_spec.lua index 7e026a6a32e3..0a98d97058ee 100644 --- a/spec/integration/admin_api/admin_api_spec.lua +++ b/spec/integration/admin_api/admin_api_spec.lua @@ -106,106 +106,145 @@ describe("Admin API", function() end) - for i, v in ipairs(ENDPOINTS) do - describe("#"..v.collection.." entity", function() - - it("should not create on POST with invalid parameters", function() - if v.collection ~= "consumers" then - local response, status, headers = http_client.post(kWebURL.."/"..v.collection.."/", {}) - assert.are.equal(400, status) - assert.are.equal(v.error_message, response) - end - end) - - it("should create an entity from valid paremeters", function() - -- Replace the IDs - for k, p in pairs(v.entity) do - if type(p) == "function" then - v.entity[k] = p() + describe("POST", function() + for i, v in ipairs(ENDPOINTS) do + describe(v.collection.." entity", function() + + it("should not create with invalid parameters", function() + if v.collection ~= "consumers" then + local response, status, headers = http_client.post(kWebURL.."/"..v.collection.."/", {}) + assert.are.equal(400, status) + assert.are.equal(v.error_message, response) + end + end) + + it("should create an entity from valid paremeters", function() + -- Replace the IDs + for k, p in pairs(v.entity) do + if type(p) == "function" then + v.entity[k] = p() + end end - end - local response, status, headers = http_client.post(kWebURL.."/"..v.collection.."/", v.entity) - local body = cjson.decode(response) - assert.are.equal(201, status) - assert.truthy(body) + local response, status, headers = http_client.post(kWebURL.."/"..v.collection.."/", v.entity) + local body = cjson.decode(response) + assert.are.equal(201, status) + assert.truthy(body) - -- Save the ID for later use - created_ids[v.collection] = body.id - end) + -- Save the ID for later use + created_ids[v.collection] = body.id + end) - it("should GET all entities", function() - local response, status, headers = http_client.get(kWebURL.."/"..v.collection.."/") - local body = cjson.decode(response) - assert.are.equal(200, status) - assert.truthy(body.data) - --assert.truthy(body.total) - --assert.are.equal(v.total, body.total) - assert.are.equal(v.total, table.getn(body.data)) - end) + it("should not create when the content-type is wrong", function() + local response, status, headers = http_client.post(kWebURL.."/"..v.collection.."/", v.entity, { ["content-type"] = "application/json"}) + assert.are.equal(415, status) + assert.are.equal("{\"message\":\"Unsupported Content-Type. Use \\\"application\\/x-www-form-urlencoded\\\"\"}\n", response) + end) - it("should GET one entity", function() - local response, status, headers = http_client.get(kWebURL.."/"..v.collection.."/"..created_ids[v.collection]) - local body = cjson.decode(response) - assert.are.equal(200, status) - assert.truthy(body) - assert.are.equal(created_ids[v.collection], body.id) end) + end + end) + + describe("GET", function() + for i, v in ipairs(ENDPOINTS) do + describe(v.collection.." entity", function() + + it("should return not retrieve any entity with an invalid parameter", function() + local response, status, headers = http_client.get(kWebURL.."/"..v.collection.."/"..created_ids[v.collection].."blah") + local body = cjson.decode(response) + assert.are.equal(404, status) + assert.truthy(body) + assert.are.equal('{"id":"'..created_ids[v.collection]..'blah is an invalid uuid"}\n', response) + end) + + it("should retrieve all entities", function() + local response, status, headers = http_client.get(kWebURL.."/"..v.collection.."/") + local body = cjson.decode(response) + assert.are.equal(200, status) + assert.truthy(body.data) + --assert.truthy(body.total) + --assert.are.equal(v.total, body.total) + assert.are.equal(v.total, table.getn(body.data)) + end) + + it("should retrieve one entity", function() + local response, status, headers = http_client.get(kWebURL.."/"..v.collection.."/"..created_ids[v.collection]) + local body = cjson.decode(response) + assert.are.equal(200, status) + assert.truthy(body) + assert.are.equal(created_ids[v.collection], body.id) + end) - it("should return not found on GET", function() - local response, status, headers = http_client.get(kWebURL.."/"..v.collection.."/"..created_ids[v.collection].."blah") - local body = cjson.decode(response) - assert.are.equal(404, status) - assert.truthy(body) - assert.are.equal('{"id":"'..created_ids[v.collection]..'blah is an invalid uuid"}\n', response) end) + end + end) - it("should update a created entity on PUT", function() - local data = http_client.get(kWebURL.."/"..v.collection.."/"..created_ids[v.collection]) - local body = cjson.decode(data) + describe("PUT", function() + for i, v in ipairs(ENDPOINTS) do + describe(v.collection.." entity", function() - -- Create new body - for k,v in pairs(v.update_fields) do - body[k] = v - end + it("should not update when the content-type is wrong", function() + local response, status, headers = http_client.put(kWebURL.."/"..v.collection.."/"..created_ids[v.collection], body, { ["content-type"] = "application/x-www-form-urlencoded"}) + assert.are.equal(415, status) + assert.are.equal("{\"message\":\"Unsupported Content-Type. Use \\\"application\\/json\\\"\"}\n", response) + end) - local response, status, headers = http_client.put(kWebURL.."/"..v.collection.."/"..created_ids[v.collection], body) - local new_body = cjson.decode(response) - assert.are.equal(200, status) - assert.truthy(new_body) - assert.are.equal(created_ids[v.collection], new_body.id) + it("should update an entity if valid parameters", function() + local data = http_client.get(kWebURL.."/"..v.collection.."/"..created_ids[v.collection]) + local body = cjson.decode(data) - for k,v in pairs(v.update_fields) do - assert.are.equal(v, new_body[k]) - end + -- Create new body + for k,v in pairs(v.update_fields) do + body[k] = v + end - assert.are.same(body, new_body) - end) + local response, status, headers = http_client.put(kWebURL.."/"..v.collection.."/"..created_ids[v.collection], body) + local new_body = cjson.decode(response) + assert.are.equal(200, status) + assert.truthy(new_body) + assert.are.equal(created_ids[v.collection], new_body.id) + + for k,v in pairs(v.update_fields) do + assert.are.equal(v, new_body[k]) + end + + assert.are.same(body, new_body) + end) - it("should not update when the content-type is wrong", function() - local response, status, headers = http_client.put(kWebURL.."/"..v.collection.."/"..created_ids[v.collection], body, { ["content-type"] = "application/x-www-form-urlencoded"}) - assert.are.equal(415, status) - assert.are.equal("{\"message\":\"Unsupported Content-Type. Use \\\"application\\/json\\\"\"}\n", response) end) + end + end) - it("should not save when the content-type is wrong", function() - local response, status, headers = http_client.post(kWebURL.."/"..v.collection.."/", v.entity, { ["content-type"] = "application/json"}) - assert.are.equal(415, status) - assert.are.equal("{\"message\":\"Unsupported Content-Type. Use \\\"application\\/x-www-form-urlencoded\\\"\"}\n", response) + -- Tests on DELETE must run in that order: + -- 1. plugins_configurations + -- 2. APIs/Consumers + -- Since deleting APIs and Consumers delete related plugins_configurations. + describe("DELETE", function() + describe("plugins_configurations", function() + + it("should delete a plugin_configuration", function() + local response, status, headers = http_client.delete(kWebURL.."/plugins_configurations/"..created_ids.plugins_configurations) + assert.are.equal(204, status) end) end) - end - for i,v in ipairs(ENDPOINTS) do - describe("#"..v.collection, function() + describe("APIs", function() - it("should delete an entity on DELETE", function() - local response, status, headers = http_client.delete(kWebURL.."/"..v.collection.."/"..created_ids[v.collection]) + it("should delete an API", function() + local response, status, headers = http_client.delete(kWebURL.."/apis/"..created_ids.apis) assert.are.equal(204, status) end) end) - end + describe("Consumers", function() + + it("should delete a Consumer", function() + local response, status, headers = http_client.delete(kWebURL.."/consumers/"..created_ids.consumers) + assert.are.equal(204, status) + end) + + end) + end) end) diff --git a/spec/unit/dao/cassandra_spec.lua b/spec/unit/dao/cassandra_spec.lua index e9af3ed02d70..100de63d709d 100644 --- a/spec/unit/dao/cassandra_spec.lua +++ b/spec/unit/dao/cassandra_spec.lua @@ -481,9 +481,9 @@ describe("Cassandra DAO", function() assert.truthy(entities) assert.True(#entities > 0) - local success, err = dao_factory[collection]:delete(entities[1].id) + local ok, err = dao_factory[collection]:delete(entities[1].id) assert.falsy(err) - assert.True(success) + assert.True(ok) local entities, err = session:execute("SELECT * FROM "..collection.." WHERE id = "..entities[1].id ) assert.falsy(err) @@ -492,6 +492,162 @@ describe("Cassandra DAO", function() end) end) + + describe("APIs", function() + local api, untouched_api + + setup(function() + spec_helper.drop_db() + + -- Insert an API + local _, err + api, err = dao_factory.apis:insert { + name = "cascade delete test", + public_dns = "cascade.com", + target_url = "http://mockbin.com" + } + assert.falsy(err) + + -- Insert some plugins_configurations + _, err = dao_factory.plugins_configurations:insert { + name = "keyauth", value = { key_names = {"apikey"} }, api_id = api.id + } + assert.falsy(err) + + _, err = dao_factory.plugins_configurations:insert { + name = "ratelimiting", value = { period = "minute", limit = 6 }, api_id = api.id + } + assert.falsy(err) + + _, err = dao_factory.plugins_configurations:insert { + name = "filelog", value = {}, api_id = api.id + } + assert.falsy(err) + + -- Insert an unrelated API + plugin + untouched_api, err = dao_factory.apis:insert { + name = "untouched cascade test api", + public_dns = "untouched.com", + target_url = "http://mockbin.com" + } + assert.falsy(err) + + _, err = dao_factory.plugins_configurations:insert { + name = "filelog", value = {}, api_id = untouched_api.id + } + + -- Make sure we have 3 matches + local results, err = dao_factory.plugins_configurations:find_by_keys { + api_id = api.id + } + assert.falsy(err) + assert.are.same(3, #results) + end) + + teardown(function() + spec_helper.drop_db() + end) + + it("should delete all related plugins_configurations when deleting an API", function() + local ok, err = dao_factory.apis:delete(api.id) + assert.falsy(err) + assert.True(ok) + + -- Make sure we have 0 matches + local results, err = dao_factory.plugins_configurations:find_by_keys { + api_id = api.id + } + assert.falsy(err) + assert.are.same(0, #results) + + -- Make sure the untouched API still has its plugin + local results, err = dao_factory.plugins_configurations:find_by_keys { + api_id = untouched_api.id + } + assert.falsy(err) + assert.are.same(1, #results) + end) + + end) + + describe("Consumers", function() + local consumer, untouched_consumer + + setup(function() + spec_helper.drop_db() + + local _, err + + -- Insert a Consumer + consumer, err = dao_factory.consumers:insert { username = "king kong" } + assert.falsy(err) + + -- Insert an API + api, err = dao_factory.apis:insert { + name = "cascade delete test", + public_dns = "cascade.com", + target_url = "http://mockbin.com" + } + + -- Insert some plugins_configurations + _, err = dao_factory.plugins_configurations:insert { + name="keyauth", value = { key_names = {"apikey"} }, api_id = api.id, + consumer_id = consumer.id + } + assert.falsy(err) + + _, err = dao_factory.plugins_configurations:insert { + name = "ratelimiting", value = { period = "minute", limit = 6 }, api_id = api.id, + consumer_id = consumer.id + } + assert.falsy(err) + + _, err = dao_factory.plugins_configurations:insert { + name = "filelog", value = {}, api_id = api.id, + consumer_id = consumer.id + } + assert.falsy(err) + + -- Inser an untouched consumer + plugin + untouched_consumer, err = dao_factory.consumers:insert { username = "untouched consumer" } + assert.falsy(err) + + _, err = dao_factory.plugins_configurations:insert { + name = "filelog", value = {}, api_id = api.id, + consumer_id = untouched_consumer.id + } + + local results, err = dao_factory.plugins_configurations:find_by_keys { + consumer_id = consumer.id + } + assert.falsy(err) + assert.are.same(3, #results) + end) + + teardown(function() + spec_helper.drop_db() + end) + + it("should delete all related plugins_configurations when deleting an API", function() + local ok, err = dao_factory.consumers:delete(consumer.id) + assert.True(ok) + assert.falsy(err) + + local results, err = dao_factory.plugins_configurations:find_by_keys { + consumer_id = consumer.id + } + assert.falsy(err) + assert.are.same(0, #results) + + -- Make sure the untouched Consumer still has its plugin + local results, err = dao_factory.plugins_configurations:find_by_keys { + consumer_id = untouched_consumer.id + } + assert.falsy(err) + assert.are.same(1, #results) + end) + + end) end) -- describe :delete() describe(":find()", function()