Skip to content

Commit

Permalink
[improve](ip-function)improve is_ip_address_in_range for inverted ind…
Browse files Browse the repository at this point in the history
…ex speed (apache#41768)

speed up is_ip_address_in_range with inverted index
  • Loading branch information
amorynan committed Oct 25, 2024
1 parent b88d4db commit e545a77
Show file tree
Hide file tree
Showing 6 changed files with 397 additions and 94 deletions.
2 changes: 2 additions & 0 deletions be/src/olap/rowset/segment_v2/inverted_index_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,8 @@ class InvertedIndexQueryParamFactory {
M(PrimitiveType::TYPE_STRING)
M(PrimitiveType::TYPE_DATEV2)
M(PrimitiveType::TYPE_DATETIMEV2)
M(PrimitiveType::TYPE_IPV4)
M(PrimitiveType::TYPE_IPV6)
#undef M
default:
return Status::NotSupported("Unsupported primitive type {} for inverted index reader",
Expand Down
252 changes: 158 additions & 94 deletions be/src/vec/functions/function_ip.h
Original file line number Diff line number Diff line change
Expand Up @@ -652,91 +652,182 @@ class FunctionIsIPAddressInRange : public IFunction {
size_t get_number_of_arguments() const override { return 2; }

DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
if (arguments.size() != 2) {
throw Exception(
ErrorCode::INVALID_ARGUMENT,
"Number of arguments for function {} doesn't match: passed {}, should be 2",
get_name(), arguments.size());
}
const auto& addr_type = arguments[0];
const auto& cidr_type = arguments[1];
if (!is_string(remove_nullable(addr_type)) || !is_string(remove_nullable(cidr_type))) {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"The arguments of function {} must be String", get_name());
}
return std::make_shared<DataTypeUInt8>();
}

bool use_default_implementation_for_nulls() const override { return false; }

template <PrimitiveType PT, typename ColumnType>
void execute_impl_with_ip(size_t input_rows_count, bool addr_const, bool cidr_const,
const ColumnString* str_cidr_column, const ColumnPtr addr_column,
ColumnUInt8* col_res) const {
auto& col_res_data = col_res->get_data();
const auto& ip_data = assert_cast<const ColumnType*>(addr_column.get())->get_data();
for (size_t i = 0; i < input_rows_count; ++i) {
auto addr_idx = index_check_const(i, addr_const);
auto cidr_idx = index_check_const(i, cidr_const);
const auto cidr =
parse_ip_with_cidr(str_cidr_column->get_data_at(cidr_idx).to_string_view());
if constexpr (PT == PrimitiveType::TYPE_IPV4) {
if (cidr._address.as_v4()) {
col_res_data[i] = match_ipv4_subnet(ip_data[addr_idx], cidr._address.as_v4(),
cidr._prefix)
? 1
: 0;
} else {
col_res_data[i] = 0;
}
} else if constexpr (PT == PrimitiveType::TYPE_IPV6) {
if (cidr._address.as_v6()) {
col_res_data[i] = match_ipv6_subnet((uint8*)(&ip_data[addr_idx]),
cidr._address.as_v6(), cidr._prefix)
? 1
: 0;
} else {
col_res_data[i] = 0;
}
}
}
}

Status evaluate_inverted_index(
const ColumnsWithTypeAndName& arguments,
const std::vector<vectorized::IndexFieldNameAndTypePair>& data_type_with_names,
std::vector<segment_v2::InvertedIndexIterator*> iterators, uint32_t num_rows,
segment_v2::InvertedIndexResultBitmap& bitmap_result) const override {
DCHECK(arguments.size() == 1);
DCHECK(data_type_with_names.size() == 1);
DCHECK(iterators.size() == 1);
auto* iter = iterators[0];
auto data_type_with_name = data_type_with_names[0];
if (iter == nullptr) {
return Status::OK();
}

if (iter->get_inverted_index_reader_type() != segment_v2::InvertedIndexReaderType::BKD) {
// Not support only bkd index
return Status::Error<ErrorCode::INVERTED_INDEX_EVALUATE_SKIPPED>(
"Inverted index evaluate skipped, ip range reader can only support by bkd "
"reader");
}
// Get the is_ip_address_in_range from the arguments: cidr
const auto& cidr_column_with_type_and_name = arguments[0];
// in is_ip_address_in_range param is const Field
ColumnPtr arg_column = cidr_column_with_type_and_name.column;
DataTypePtr arg_type = cidr_column_with_type_and_name.type;
if ((is_column_nullable(*arg_column) && !is_column_const(*remove_nullable(arg_column))) ||
(!is_column_nullable(*arg_column) && !is_column_const(*arg_column))) {
// if not we should skip inverted index and evaluate in expression
return Status::Error<ErrorCode::INVERTED_INDEX_EVALUATE_SKIPPED>(
"Inverted index evaluate skipped, is_ip_address_in_range only support const "
"value");
}
// check param type is string
if (!WhichDataType(*arg_type).is_string()) {
return Status::Error<ErrorCode::INVERTED_INDEX_EVALUATE_SKIPPED>(
"Inverted index evaluate skipped, is_ip_address_in_range only support string "
"type");
}
// min && max ip address
Field min_ip, max_ip;
IPAddressCIDR cidr = parse_ip_with_cidr(arg_column->get_data_at(0));
if (WhichDataType(remove_nullable(data_type_with_name.second)).is_ipv4() &&
cidr._address.as_v4()) {
auto range = apply_cidr_mask(cidr._address.as_v4(), cidr._prefix);
min_ip = range.first;
max_ip = range.second;
} else if (WhichDataType(remove_nullable(data_type_with_name.second)).is_ipv6() &&
cidr._address.as_v6()) {
auto cidr_range_ipv6_col = ColumnIPv6::create(2, 0);
auto& cidr_range_ipv6_data = cidr_range_ipv6_col->get_data();
apply_cidr_mask(reinterpret_cast<const char*>(cidr._address.as_v6()),
reinterpret_cast<char*>(&cidr_range_ipv6_data[0]),
reinterpret_cast<char*>(&cidr_range_ipv6_data[1]), cidr._prefix);
min_ip = cidr_range_ipv6_data[0];
max_ip = cidr_range_ipv6_data[1];
}
// apply for inverted index
std::shared_ptr<roaring::Roaring> res_roaring = std::make_shared<roaring::Roaring>();
std::shared_ptr<roaring::Roaring> max_roaring = std::make_shared<roaring::Roaring>();
std::shared_ptr<roaring::Roaring> null_bitmap = std::make_shared<roaring::Roaring>();

auto param_type = data_type_with_name.second->get_type_as_type_descriptor().type;
std::unique_ptr<segment_v2::InvertedIndexQueryParamFactory> query_param = nullptr;
// >= min ip
RETURN_IF_ERROR(segment_v2::InvertedIndexQueryParamFactory::create_query_value(
param_type, &min_ip, query_param));
RETURN_IF_ERROR(iter->read_from_inverted_index(
data_type_with_name.first, query_param->get_value(),
segment_v2::InvertedIndexQueryType::GREATER_EQUAL_QUERY, num_rows, res_roaring));
// <= max ip
RETURN_IF_ERROR(segment_v2::InvertedIndexQueryParamFactory::create_query_value(
param_type, &max_ip, query_param));
RETURN_IF_ERROR(iter->read_from_inverted_index(
data_type_with_name.first, query_param->get_value(),
segment_v2::InvertedIndexQueryType::LESS_EQUAL_QUERY, num_rows, max_roaring));

DBUG_EXECUTE_IF("ip.inverted_index_filtered", {
auto req_id = DebugPoints::instance()->get_debug_param_or_default<int32_t>(
"ip.inverted_index_filtered", "req_id", 0);
LOG(INFO) << "execute inverted index req_id: " << req_id
<< " min: " << res_roaring->cardinality();
});
*res_roaring &= *max_roaring;
DBUG_EXECUTE_IF("ip.inverted_index_filtered", {
auto req_id = DebugPoints::instance()->get_debug_param_or_default<int32_t>(
"ip.inverted_index_filtered", "req_id", 0);
LOG(INFO) << "execute inverted index req_id: " << req_id
<< " max: " << max_roaring->cardinality()
<< " result: " << res_roaring->cardinality();
});
segment_v2::InvertedIndexResultBitmap result(res_roaring, null_bitmap);
bitmap_result = result;
bitmap_result.mask_out_null();
return Status::OK();
}

Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) const override {
DBUG_EXECUTE_IF("ip.inverted_index_filtered", {
auto req_id = DebugPoints::instance()->get_debug_param_or_default<int32_t>(
"ip.inverted_index_filtered", "req_id", 0);
return Status::Error<ErrorCode::INTERNAL_ERROR>(
"{} has already execute inverted index req_id {} , should not execute expr "
"with rows: {}",
get_name(), req_id, input_rows_count);
});
const auto& addr_column_with_type_and_name = block.get_by_position(arguments[0]);
const auto& cidr_column_with_type_and_name = block.get_by_position(arguments[1]);
WhichDataType addr_type(addr_column_with_type_and_name.type);
WhichDataType cidr_type(cidr_column_with_type_and_name.type);
const auto& [addr_column, addr_const] =
unpack_if_const(addr_column_with_type_and_name.column);
const auto& [cidr_column, cidr_const] =
unpack_if_const(cidr_column_with_type_and_name.column);
const ColumnString* str_addr_column = nullptr;
const ColumnString* str_cidr_column = nullptr;
const NullMap* null_map_addr = nullptr;
const NullMap* null_map_cidr = nullptr;

if (addr_type.is_nullable()) {
const auto* addr_column_nullable =
assert_cast<const ColumnNullable*>(addr_column.get());
str_addr_column =
check_and_get_column<ColumnString>(addr_column_nullable->get_nested_column());
null_map_addr = &addr_column_nullable->get_null_map_data();
} else {
str_addr_column = check_and_get_column<ColumnString>(addr_column.get());
}

if (cidr_type.is_nullable()) {
const auto* cidr_column_nullable =
assert_cast<const ColumnNullable*>(cidr_column.get());
str_cidr_column =
check_and_get_column<ColumnString>(cidr_column_nullable->get_nested_column());
null_map_cidr = &cidr_column_nullable->get_null_map_data();
} else {
str_cidr_column = check_and_get_column<ColumnString>(cidr_column.get());
}

if (!str_addr_column) {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"Illegal column {} of argument of function {}, expected String",
addr_column->get_name(), get_name());
}

if (!str_cidr_column) {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"Illegal column {} of argument of function {}, expected String",
cidr_column->get_name(), get_name());
}

auto col_res = ColumnUInt8::create(input_rows_count, 0);
auto& col_res_data = col_res->get_data();

for (size_t i = 0; i < input_rows_count; ++i) {
auto addr_idx = index_check_const(i, addr_const);
auto cidr_idx = index_check_const(i, cidr_const);
if (null_map_addr && (*null_map_addr)[addr_idx]) {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"The arguments of function {} must be String, not NULL",
get_name());
}
if (null_map_cidr && (*null_map_cidr)[cidr_idx]) {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"The arguments of function {} must be String, not NULL",
get_name());
if (is_ipv4(addr_column_with_type_and_name.type)) {
execute_impl_with_ip<PrimitiveType::TYPE_IPV4, ColumnIPv4>(
input_rows_count, addr_const, cidr_const,
assert_cast<const ColumnString*>(cidr_column.get()), addr_column, col_res);
} else if (is_ipv6(addr_column_with_type_and_name.type)) {
execute_impl_with_ip<PrimitiveType::TYPE_IPV6, ColumnIPv6>(
input_rows_count, addr_const, cidr_const,
assert_cast<const ColumnString*>(cidr_column.get()), addr_column, col_res);
} else {
const auto* str_addr_column = assert_cast<const ColumnString*>(addr_column.get());
const auto* str_cidr_column = assert_cast<const ColumnString*>(cidr_column.get());

for (size_t i = 0; i < input_rows_count; ++i) {
auto addr_idx = index_check_const(i, addr_const);
auto cidr_idx = index_check_const(i, cidr_const);

const auto addr =
IPAddressVariant(str_addr_column->get_data_at(addr_idx).to_string_view());
const auto cidr =
parse_ip_with_cidr(str_cidr_column->get_data_at(cidr_idx).to_string_view());
col_res_data[i] = is_address_in_range(addr, cidr) ? 1 : 0;
}
const auto addr =
IPAddressVariant(str_addr_column->get_data_at(addr_idx).to_string_view());
const auto cidr =
parse_ip_with_cidr(str_cidr_column->get_data_at(cidr_idx).to_string_view());
col_res_data[i] = is_address_in_range(addr, cidr) ? 1 : 0;
}

block.replace_by_position(result, std::move(col_res));
Expand Down Expand Up @@ -839,21 +930,6 @@ class FunctionIPv4CIDRToRange : public IFunction {
std::move(col_upper_range_output)}));
return Status::OK();
}

private:
static inline std::pair<UInt32, UInt32> apply_cidr_mask(UInt32 src, UInt8 bits_to_keep) {
if (bits_to_keep >= 8 * sizeof(UInt32)) {
return {src, src};
}
if (bits_to_keep == 0) {
return {static_cast<UInt32>(0), static_cast<UInt32>(-1)};
}
UInt32 mask = static_cast<UInt32>(-1) << (8 * sizeof(UInt32) - bits_to_keep);
UInt32 lower = src & mask;
UInt32 upper = lower | ~mask;

return {lower, upper};
}
};

class FunctionIPv6CIDRToRange : public IFunction {
Expand Down Expand Up @@ -991,18 +1067,6 @@ class FunctionIPv6CIDRToRange : public IFunction {
return ColumnStruct::create(
Columns {std::move(col_res_lower_range), std::move(col_res_upper_range)});
}

private:
static void apply_cidr_mask(const char* __restrict src, char* __restrict dst_lower,
char* __restrict dst_upper, UInt8 bits_to_keep) {
// little-endian mask
const auto& mask = get_cidr_mask_ipv6(bits_to_keep);

for (int8_t i = IPV6_BINARY_LENGTH - 1; i >= 0; --i) {
dst_lower[i] = src[i] & mask[i];
dst_upper[i] = dst_lower[i] | ~mask[i];
}
}
};

class FunctionIsIPv4Compat : public IFunction {
Expand Down
27 changes: 27 additions & 0 deletions be/src/vec/runtime/ip_address_cidr.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,33 @@

namespace doris {

namespace vectorized {
static inline std::pair<UInt32, UInt32> apply_cidr_mask(UInt32 src, UInt8 bits_to_keep) {
if (bits_to_keep >= 8 * sizeof(UInt32)) {
return {src, src};
}
if (bits_to_keep == 0) {
return {static_cast<UInt32>(0), static_cast<UInt32>(-1)};
}
UInt32 mask = static_cast<UInt32>(-1) << (8 * sizeof(UInt32) - bits_to_keep);
UInt32 lower = src & mask;
UInt32 upper = lower | ~mask;

return {lower, upper};
}

static inline void apply_cidr_mask(const char* __restrict src, char* __restrict dst_lower,
char* __restrict dst_upper, UInt8 bits_to_keep) {
// little-endian mask
const auto& mask = get_cidr_mask_ipv6(bits_to_keep);

for (int8_t i = IPV6_BINARY_LENGTH - 1; i >= 0; --i) {
dst_lower[i] = src[i] & mask[i];
dst_upper[i] = dst_lower[i] | ~mask[i];
}
}
} // namespace vectorized

class IPAddressVariant {
public:
explicit IPAddressVariant(std::string_view address_str) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.IPv4Type;
import org.apache.doris.nereids.types.IPv6Type;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.VarcharType;

Expand All @@ -39,6 +41,10 @@ public class IsIpAddressInRange extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BooleanType.INSTANCE).args(IPv4Type.INSTANCE, VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(BooleanType.INSTANCE).args(IPv4Type.INSTANCE, StringType.INSTANCE),
FunctionSignature.ret(BooleanType.INSTANCE).args(IPv6Type.INSTANCE, VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(BooleanType.INSTANCE).args(IPv6Type.INSTANCE, StringType.INSTANCE),
FunctionSignature.ret(BooleanType.INSTANCE).args(VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(BooleanType.INSTANCE).args(StringType.INSTANCE, StringType.INSTANCE));

Expand Down
Loading

0 comments on commit e545a77

Please sign in to comment.