Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for UUID keys comparison #86

Merged
merged 14 commits into from
Dec 7, 2020
Merged
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

## [Unreleased]

### Added

* Support for UUID field types and UUID values

## [0.4.0] - 2020-12-02

### Fixed
Expand Down
34 changes: 27 additions & 7 deletions crud/common/utils.lua
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
local errors = require('errors')
local ffi = require('ffi')

local dev_checks = require('crud.common.dev_checks')

Expand Down Expand Up @@ -106,12 +107,9 @@ function utils.merge_primary_key_parts(key_parts, pk_parts)
return merged_parts
end

local __tarantool_supports_fieldpaths
local function tarantool_supports_fieldpaths()
if __tarantool_supports_fieldpaths ~= nil then
return __tarantool_supports_fieldpaths
end
local enabled_tarantool_features = {}

local function determine_enabled_features()
local major_minor_patch = _G._TARANTOOL:split('-', 1)[1]
local major_minor_patch_parts = major_minor_patch:split('.', 2)

Expand All @@ -120,9 +118,26 @@ local function tarantool_supports_fieldpaths()
local patch = tonumber(major_minor_patch_parts[3])

-- since Tarantool 2.3
__tarantool_supports_fieldpaths = major >= 2 and (minor > 3 or minor == 3 and patch >= 1)
enabled_tarantool_features.fieldpaths = major >= 2 and (minor > 3 or minor == 3 and patch >= 1)

-- since Tarantool 2.4
enabled_tarantool_features.uuids = major >= 2 and (minor > 4 or minor == 4 and patch >= 1)
end

local function tarantool_supports_fieldpaths()
if enabled_tarantool_features.fieldpaths == nil then
determine_enabled_features()
end

return __tarantool_supports_fieldpaths
return enabled_tarantool_features.fieldpaths
end

function utils.tarantool_supports_uuids()
if enabled_tarantool_features.uuids == nil then
determine_enabled_features()
end

return enabled_tarantool_features.uuids
end

function utils.convert_operations(user_operations, space_format)
Expand Down Expand Up @@ -206,4 +221,9 @@ function utils.get_bucket_id_fieldno(space, shard_index_name)
return bucket_id_index.parts[1].fieldno
end

local uuid_t = ffi.typeof('struct tt_uuid')
function utils.is_uuid(value)
return ffi.istype(uuid_t, value)
end

return utils
77 changes: 4 additions & 73 deletions crud/select/comparators.lua
Original file line number Diff line number Diff line change
@@ -1,74 +1,15 @@
local errors = require('errors')

local collations = require('crud.common.collations')
local select_conditions = require('crud.select.conditions')
local type_comparators = require('crud.select.type_comparators')
local operators = select_conditions.operators

local utils = require('crud.common.utils')

local LessThenError = errors.new_class('LessThenError')
local GenFuncError = errors.new_class('GenFuncError')
local ComparatorsError = errors.new_class('ComparatorsError')

local comparators = {}

local function eq(lhs, rhs)
return lhs == rhs
end

local function eq_unicode(lhs, rhs)
if type(lhs) == 'string' and type(rhs) == 'string' then
return utf8.cmp(lhs, rhs) == 0
end

return eq(lhs)
end

local function eq_unicode_ci(lhs, rhs)
if type(lhs) == 'string' and type(rhs) == 'string' then
return utf8.casecmp(lhs, rhs) == 0
end

return lhs == rhs
end

local function lt(lhs, rhs)
if lhs == nil and rhs ~= nil then
return true
elseif rhs == nil then
return false
end

-- boolean compare
local lhs_is_boolean = type(lhs) == 'boolean'
local rhs_is_boolean = type(rhs) == 'boolean'

if lhs_is_boolean and rhs_is_boolean then
return (not lhs) and rhs
elseif lhs_is_boolean or rhs_is_boolean then
LessThenError:assert(false, 'Could not compare boolean and not boolean')
end

-- general compare
return lhs < rhs
end

local function lt_unicode(lhs, rhs)
if type(lhs) == 'string' and type(rhs) == 'string' then
return utf8.cmp(lhs, rhs) == -1
end

return lt(lhs, rhs)
end

local function lt_unicode_ci(lhs, rhs)
if type(lhs) == 'string' and type(rhs) == 'string' then
return utf8.casecmp(lhs, rhs) == -1
end

return lt(lhs, rhs)
end

local function array_eq(lhs, rhs, len, _, eq_funcs)
for i = 1, len do
if not eq_funcs[i](lhs[i], rhs[i]) then
Expand Down Expand Up @@ -132,19 +73,9 @@ local function gen_array_cmp_func(target, key_parts)
local eq_funcs = {}

for _, part in ipairs(key_parts) do
local collation = collations.get(part)
if collations.is_default(collation) then
table.insert(lt_funcs, lt)
table.insert(eq_funcs, eq)
elseif collation == collations.UNICODE then
table.insert(lt_funcs, lt_unicode)
table.insert(eq_funcs, eq_unicode)
elseif collation == collations.UNICODE_CI then
table.insert(lt_funcs, lt_unicode_ci)
table.insert(eq_funcs, eq_unicode_ci)
else
return nil, GenFuncError:new('Unsupported Tarantool collation %q', collation)
end
local lt_func, eq_func = type_comparators.get_comparators_by_type(part)
table.insert(lt_funcs, lt_func)
table.insert(eq_funcs, eq_func)
end

return function(lhs, rhs)
Expand Down
63 changes: 56 additions & 7 deletions crud/select/filters.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
local json = require('json')
local errors = require('errors')

local utils = require('crud.common.utils')
local dev_checks = require('crud.common.dev_checks')
local collations = require('crud.common.collations')
local select_conditions = require('crud.select.conditions')
Expand Down Expand Up @@ -163,10 +164,12 @@ local function format_value(value)
return ("%q"):format(value)
elseif type(value) == 'number' then
return tostring(value)
elseif type(value) == 'cdata' then
return tostring(value)
elseif type(value) == 'boolean' then
return tostring(value)
elseif utils.is_uuid(value) then
return ("%q"):format(value)
elseif type(value) == 'cdata' then
akudiyar marked this conversation as resolved.
Show resolved Hide resolved
return tostring(value)
end
assert(false, ('Unexpected value %s (type %s)'):format(value, type(value)))
end
Expand Down Expand Up @@ -258,13 +261,18 @@ local function format_eq(cond)
for j = 1, #cond.values do
local fieldno = cond.fieldnos[j]
local value = cond.values[j]
local value_type = cond.types[j]
local value_opts = values_opts[j] or {}

local func_name = 'eq'
func_name = add_collation_postfix(func_name, value_opts)

if collations.is_unicode(value_opts.collation) then
func_name = add_strict_postfix(func_name, value_opts)
if value_type == 'string' then
func_name = add_collation_postfix('eq', value_opts)
if collations.is_unicode(value_opts.collation) then
func_name = add_strict_postfix(func_name, value_opts)
end
elseif value_type == 'uuid' then
func_name = 'eq_uuid'
end

table.insert(cond_strings, format_comp_with_value(fieldno, func_name, value))
Expand All @@ -283,8 +291,15 @@ local function format_lt(cond)
local value_type = cond.types[j]
local value_opts = values_opts[j] or {}

local func_name = value_type ~= 'boolean' and 'lt' or 'lt_boolean'
func_name = add_collation_postfix(func_name, value_opts)
local func_name = 'lt'

if value_type == 'boolean' then
func_name = 'lt_boolean'
elseif value_type == 'string' then
func_name = add_collation_postfix('lt', value_opts)
elseif value_type == 'uuid' then
func_name = 'lt_uuid'
end
func_name = add_strict_postfix(func_name, value_opts)

table.insert(cond_strings, format_comp_with_value(fieldno, func_name, value))
Expand Down Expand Up @@ -491,6 +506,22 @@ local function lt_boolean_strict(lhs, rhs)
return (not lhs) and rhs
end

local function lt_uuid_nullable(lhs, rhs)
if lhs == nil and rhs ~= nil then
return true
elseif rhs == nil then
return false
end
return tostring(lhs) < tostring(rhs)
end

local function lt_uuid_strict(lhs, rhs)
if rhs == nil then
return false
end
return tostring(lhs) < tostring(rhs)
end

local function lt_unicode_ci_nullable(lhs, rhs)
if lhs == nil and rhs ~= nil then
return true
Expand All @@ -511,6 +542,20 @@ local function eq(lhs, rhs)
return lhs == rhs
end

local function eq_uuid(lhs, rhs)
if lhs == nil then
return rhs == nil
end
return tostring(lhs) == tostring(rhs)
end

local function eq_uuid_strict(lhs, rhs)
if rhs == nil then
return false
end
return tostring(lhs) == tostring(rhs)
end

local function eq_unicode_nullable(lhs, rhs)
if lhs == nil and rhs == nil then
return true
Expand Down Expand Up @@ -546,6 +591,8 @@ end
local library = {
-- EQ
eq = eq,
eq_uuid = eq_uuid,
eq_uuid_strict = eq_uuid_strict,
-- nullable
eq_unicode = eq_unicode_nullable,
eq_unicode_ci = eq_unicode_ci_nullable,
Expand All @@ -559,11 +606,13 @@ local library = {
lt_unicode = lt_unicode_nullable,
lt_unicode_ci = lt_unicode_ci_nullable,
lt_boolean = lt_boolean_nullable,
lt_uuid = lt_uuid_nullable,
-- strict
lt_strict = lt_strict,
lt_unicode_strict = lt_unicode_strict,
lt_unicode_ci_strict = lt_unicode_ci_strict,
lt_boolean_strict = lt_boolean_strict,
lt_uuid_strict = lt_uuid_strict,

utf8 = utf8,

Expand Down
Loading