diff --git a/src/Processors/Formats/Impl/ArrowFieldIndexUtil.h b/src/Processors/Formats/Impl/ArrowFieldIndexUtil.h new file mode 100644 index 000000000000..e49c6cd1b2b6 --- /dev/null +++ b/src/Processors/Formats/Impl/ArrowFieldIndexUtil.h @@ -0,0 +1,179 @@ +#pragma once +#include "config_formats.h" +#if USE_PARQUET || USE_ORC +#include +#include +#include +#include +#include +#include +#include +#include +namespace arrow +{ +class Schema; +class DataType; +class Field; +} +namespace DB +{ +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} +/// For ORC format, index_nested_type = true, a nested type takes one index count. And the +/// the start index for ORC format should be 1, since index 0 indicates to select all columns. +template +class ArrowFieldIndexUtil +{ +public: + explicit ArrowFieldIndexUtil(bool ignore_case_, bool allow_missing_columns_) + : ignore_case(ignore_case_) + , allow_missing_columns(allow_missing_columns_) + { + } + + /// Recursively count every field indices. Return a map + /// - key: field name with full path. eg. a struct field's name is like a.x.i + /// - value: a pair, first value refers to this field's start index, second value refers to how many + /// indices this field take. eg. + /// For a parquet schema {x: int , y: {i: int, j: int}}, the return will be + /// - x: (0, 1) + /// - y: (1, 2) + /// - y.i: (1, 1) + /// - y.j: (2, 1) + std::unordered_map> + calculateFieldIndices(const arrow::Schema & schema) + { + std::unordered_map> result; + // For format like ORC, index = 0 indicates to select all columns, so we skip 0 and start + // from 1. + int index_start = index_nested_type; + for (int i = 0; i < schema.num_fields(); ++i) + { + const auto & field = schema.field(i); + calculateFieldIndices(*field, field->name(), index_start, result); + } + return result; + } + + /// Only collect the required fields' indices. Eg. when just read a field of a struct, + /// don't need to collect the whole indices in this struct. + std::vector findRequiredIndices(const Block & header, const arrow::Schema & schema) + { + std::vector required_indices; + std::unordered_set added_indices; + /// Flat all named fields' index information into a map. + auto fields_indices = calculateFieldIndices(schema); + for (size_t i = 0; i < header.columns(); ++i) + { + const auto & named_col = header.getByPosition(i); + std::string col_name = named_col.name; + if (ignore_case) + boost::to_lower(col_name); + /// Since all named fields are flatten into a map, we should find the column by name + /// in this map. + auto it = fields_indices.find(col_name); + + if (it == fields_indices.end()) + { + if (!allow_missing_columns) + throw Exception( + ErrorCodes::LOGICAL_ERROR, "Not found field({}) in arrow schema:{}.", named_col.name, schema.ToString()); + else + continue; + } + for (int j = 0; j < it->second.second; ++j) + { + auto index = it->second.first + j; + if (added_indices.insert(index).second) + required_indices.emplace_back(index); + } + } + return required_indices; + } + + /// Count the number of indices for types. + /// For orc format, index_nested_type is true, a complex type takes one index. + size_t countIndicesForType(std::shared_ptr type) + { + if (type->id() == arrow::Type::LIST) + { + return countIndicesForType(static_cast(type.get())->value_type()) + index_nested_type; + } + + if (type->id() == arrow::Type::STRUCT) + { + int indices = index_nested_type; + auto * struct_type = static_cast(type.get()); + for (int i = 0; i != struct_type->num_fields(); ++i) + indices += countIndicesForType(struct_type->field(i)->type()); + return indices; + } + + if (type->id() == arrow::Type::MAP) + { + auto * map_type = static_cast(type.get()); + return countIndicesForType(map_type->key_type()) + countIndicesForType(map_type->item_type()) + index_nested_type; + } + + return 1; + } + +private: + bool ignore_case; + bool allow_missing_columns; + void calculateFieldIndices(const arrow::Field & field, + std::string field_name, + int & current_start_index, + std::unordered_map> & result, const std::string & name_prefix = "") + { + const auto & field_type = field.type(); + if (field_name.empty()) + { + current_start_index += countIndicesForType(field_type); + return; + } + if (ignore_case) + { + boost::to_lower(field_name); + } + + std::string full_path_name = name_prefix.empty() ? field_name : name_prefix + "." + field_name; + auto & index_info = result[full_path_name]; + index_info.first = current_start_index; + if (field_type->id() == arrow::Type::STRUCT) + { + current_start_index += index_nested_type; + + auto * struct_type = static_cast(field_type.get()); + for (int i = 0, n = struct_type->num_fields(); i < n; ++i) + { + const auto & sub_field = struct_type->field(i); + calculateFieldIndices(*sub_field, sub_field->name(), current_start_index, result, full_path_name); + } + } + else if ( + field_type->id() == arrow::Type::LIST + && static_cast(field_type.get())->value_type()->id() == arrow::Type::STRUCT) + { + // It is a nested table. + const auto * list_type = static_cast(field_type.get()); + const auto value_field = list_type->value_field(); + auto index_snapshot = current_start_index; + current_start_index += index_nested_type; + calculateFieldIndices(*value_field, field_name, current_start_index, result, name_prefix); + // The nested struct field has the same name as this list field. + // rewrite it back to the original value. + index_info.first = index_snapshot; + } + else + { + current_start_index += countIndicesForType(field_type); + } + index_info.second = current_start_index - index_info.first; + } +}; +} +#endif + diff --git a/src/Processors/Formats/Impl/ORCBlockInputFormat.cpp b/src/Processors/Formats/Impl/ORCBlockInputFormat.cpp index 5745ccc6ac8d..fb0f54a74f6c 100644 --- a/src/Processors/Formats/Impl/ORCBlockInputFormat.cpp +++ b/src/Processors/Formats/Impl/ORCBlockInputFormat.cpp @@ -9,6 +9,7 @@ #include #include "ArrowBufferedStreams.h" #include "ArrowColumnToCHColumn.h" +#include "ArrowFieldIndexUtil.h" #include namespace DB @@ -79,29 +80,6 @@ const BlockMissingValues & ORCBlockInputFormat::getMissingValues() const return block_missing_values; } -static size_t countIndicesForType(std::shared_ptr type) -{ - if (type->id() == arrow::Type::LIST) - return countIndicesForType(static_cast(type.get())->value_type()) + 1; - - if (type->id() == arrow::Type::STRUCT) - { - int indices = 1; - auto * struct_type = static_cast(type.get()); - for (int i = 0; i != struct_type->num_fields(); ++i) - indices += countIndicesForType(struct_type->field(i)->type()); - return indices; - } - - if (type->id() == arrow::Type::MAP) - { - auto * map_type = static_cast(type.get()); - return countIndicesForType(map_type->key_type()) + countIndicesForType(map_type->item_type()); - } - - return 1; -} - static void getFileReaderAndSchema( ReadBuffer & in, std::unique_ptr & file_reader, @@ -136,28 +114,10 @@ void ORCBlockInputFormat::prepareReader() arrow_column_to_ch_column = std::make_unique(getPort().getHeader(), schema, "ORC", format_settings); - const bool ignore_case = format_settings.orc.case_insensitive_column_matching; - std::unordered_set nested_table_names; - if (format_settings.orc.import_nested) - nested_table_names = Nested::getAllTableNames(getPort().getHeader(), ignore_case); - - /// In ReadStripe column indices should be started from 1, - /// because 0 indicates to select all columns. - int index = 1; - for (int i = 0; i < schema->num_fields(); ++i) - { - /// LIST type require 2 indices, STRUCT - the number of elements + 1, - /// so we should recursively count the number of indices we need for this type. - int indexes_count = countIndicesForType(schema->field(i)->type()); - const auto & name = schema->field(i)->name(); - if (getPort().getHeader().has(name, ignore_case) || nested_table_names.contains(ignore_case ? boost::to_lower_copy(name) : name)) - { - for (int j = 0; j != indexes_count; ++j) - include_indices.push_back(index + j); - } - - index += indexes_count; - } + ArrowFieldIndexUtil field_util( + format_settings.orc.case_insensitive_column_matching, + format_settings.orc.allow_missing_columns); + include_indices = field_util.findRequiredIndices(getPort().getHeader(), *schema); } ORCSchemaReader::ORCSchemaReader(ReadBuffer & in_, const FormatSettings & format_settings_) diff --git a/src/Processors/Formats/Impl/ParquetBlockInputFormat.cpp b/src/Processors/Formats/Impl/ParquetBlockInputFormat.cpp index e11dc1e5f79f..e8d1c49a9ce0 100644 --- a/src/Processors/Formats/Impl/ParquetBlockInputFormat.cpp +++ b/src/Processors/Formats/Impl/ParquetBlockInputFormat.cpp @@ -14,6 +14,7 @@ #include #include "ArrowBufferedStreams.h" #include "ArrowColumnToCHColumn.h" +#include "ArrowFieldIndexUtil.h" #include namespace DB @@ -80,29 +81,6 @@ const BlockMissingValues & ParquetBlockInputFormat::getMissingValues() const return block_missing_values; } -static size_t countIndicesForType(std::shared_ptr type) -{ - if (type->id() == arrow::Type::LIST) - return countIndicesForType(static_cast(type.get())->value_type()); - - if (type->id() == arrow::Type::STRUCT) - { - int indices = 0; - auto * struct_type = static_cast(type.get()); - for (int i = 0; i != struct_type->num_fields(); ++i) - indices += countIndicesForType(struct_type->field(i)->type()); - return indices; - } - - if (type->id() == arrow::Type::MAP) - { - auto * map_type = static_cast(type.get()); - return countIndicesForType(map_type->key_type()) + countIndicesForType(map_type->item_type()); - } - - return 1; -} - static void getFileReaderAndSchema( ReadBuffer & in, std::unique_ptr & file_reader, @@ -129,28 +107,10 @@ void ParquetBlockInputFormat::prepareReader() arrow_column_to_ch_column = std::make_unique(getPort().getHeader(), schema, "Parquet", format_settings); - const bool ignore_case = format_settings.parquet.case_insensitive_column_matching; - std::unordered_set nested_table_names; - if (format_settings.parquet.import_nested) - nested_table_names = Nested::getAllTableNames(getPort().getHeader(), ignore_case); - - int index = 0; - for (int i = 0; i < schema->num_fields(); ++i) - { - /// STRUCT type require the number of indexes equal to the number of - /// nested elements, so we should recursively - /// count the number of indices we need for this type. - int indexes_count = countIndicesForType(schema->field(i)->type()); - const auto & name = schema->field(i)->name(); - - if (getPort().getHeader().has(name, ignore_case) || nested_table_names.contains(ignore_case ? boost::to_lower_copy(name) : name)) - { - for (int j = 0; j != indexes_count; ++j) - column_indices.push_back(index + j); - } - - index += indexes_count; - } + ArrowFieldIndexUtil field_util( + format_settings.parquet.case_insensitive_column_matching, + format_settings.parquet.allow_missing_columns); + column_indices = field_util.findRequiredIndices(getPort().getHeader(), *schema); } ParquetSchemaReader::ParquetSchemaReader(ReadBuffer & in_, const FormatSettings & format_settings_) diff --git a/tests/queries/0_stateless/00900_orc_arrow_parquet_nested.sh b/tests/queries/0_stateless/00900_orc_arrow_parquet_nested.sh index e07c8fcff096..dbb5b698e69d 100755 --- a/tests/queries/0_stateless/00900_orc_arrow_parquet_nested.sh +++ b/tests/queries/0_stateless/00900_orc_arrow_parquet_nested.sh @@ -26,7 +26,7 @@ for ((i = 0; i < 3; i++)) do ${CLICKHOUSE_CLIENT} --query="TRUNCATE TABLE nested_nested_table" - cat $CUR_DIR/data_orc_arrow_parquet_nested/nested_nested_table.${format_files[i]} | ${CLICKHOUSE_CLIENT} -q "INSERT INTO nested_nested_table SETTINGS input_format_${format_files[i]}_import_nested = 1 FORMAT ${formats[i]}" + cat $CUR_DIR/data_orc_arrow_parquet_nested/nested_nested_table.${format_files[i]} | ${CLICKHOUSE_CLIENT} -q "INSERT INTO nested_nested_table SETTINGS input_format_${format_files[i]}_import_nested = 1, input_format_${format_files[i]}_case_insensitive_column_matching = 1 FORMAT ${formats[i]}" ${CLICKHOUSE_CLIENT} --query="SELECT * FROM nested_nested_table" done