From 6bb99f04551c168ba39c966b113e84fcb3b0d188 Mon Sep 17 00:00:00 2001 From: Hao Gao Date: Tue, 10 Nov 2020 09:56:46 -0800 Subject: [PATCH 1/5] Select hash functions in hash_partition --- cpp/include/cudf/detail/hashing.hpp | 12 +- .../cudf/detail/utilities/hash_functions.cuh | 13 +- cpp/include/cudf/partitioning.hpp | 4 +- cpp/src/hash/hashing.cu | 599 ------------------ cpp/src/partitioning/partitioning.cu | 30 +- .../partitioning/hash_partition_test.cpp | 13 +- 6 files changed, 50 insertions(+), 621 deletions(-) diff --git a/cpp/include/cudf/detail/hashing.hpp b/cpp/include/cudf/detail/hashing.hpp index c5600f0af18..35107a99fd9 100644 --- a/cpp/include/cudf/detail/hashing.hpp +++ b/cpp/include/cudf/detail/hashing.hpp @@ -16,20 +16,10 @@ #pragma once #include +#include namespace cudf { namespace detail { -/** - * @copydoc cudf::hash_partition - * - * @param stream CUDA stream used for device memory operations and kernel launches. - */ -std::pair, std::vector> hash_partition( - table_view const& input, - std::vector const& columns_to_hash, - int num_partitions, - rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource(), - cudaStream_t stream = 0); /** * @copydoc cudf::hash diff --git a/cpp/include/cudf/detail/utilities/hash_functions.cuh b/cpp/include/cudf/detail/utilities/hash_functions.cuh index 31c884c8320..9804dec634c 100644 --- a/cpp/include/cudf/detail/utilities/hash_functions.cuh +++ b/cpp/include/cudf/detail/utilities/hash_functions.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include @@ -578,7 +579,17 @@ struct IdentityHash { return combined; } - CUDA_HOST_DEVICE_CALLABLE result_type operator()(const Key& key) const + template + CUDA_HOST_DEVICE_CALLABLE std::enable_if_t::value, return_type> + operator()(const Key& key) const + { + release_assert(false && "IdentityHash does not support this data type"); + return 0; + } + + template + CUDA_HOST_DEVICE_CALLABLE std::enable_if_t::value, return_type> + operator()(const Key& key) const { return static_cast(key); } diff --git a/cpp/include/cudf/partitioning.hpp b/cpp/include/cudf/partitioning.hpp index d84f272760e..ac636db72bd 100644 --- a/cpp/include/cudf/partitioning.hpp +++ b/cpp/include/cudf/partitioning.hpp @@ -88,7 +88,9 @@ std::pair, std::vector> hash_partition( table_view const& input, std::vector const& columns_to_hash, int num_partitions, - rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + hash_id hash_function = hash_id::HASH_MURMUR3, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource(), + cudaStream_t stream = 0); /** * @brief Round-robin partition. diff --git a/cpp/src/hash/hashing.cu b/cpp/src/hash/hashing.cu index 63401ad823a..bd0bd0a4a78 100644 --- a/cpp/src/hash/hashing.cu +++ b/cpp/src/hash/hashing.cu @@ -29,583 +29,6 @@ namespace cudf { namespace { -// Launch configuration for optimized hash partition -constexpr size_type OPTIMIZED_BLOCK_SIZE = 512; -constexpr size_type OPTIMIZED_ROWS_PER_THREAD = 8; -constexpr size_type ELEMENTS_PER_THREAD = 2; -constexpr size_type THRESHOLD_FOR_OPTIMIZED_PARTITION_KERNEL = 1024; - -// Launch configuration for fallback hash partition -constexpr size_type FALLBACK_BLOCK_SIZE = 256; -constexpr size_type FALLBACK_ROWS_PER_THREAD = 1; - -/** - * @brief Functor to map a hash value to a particular 'bin' or partition number - * that uses the modulo operation. - */ -template -class modulo_partitioner { - public: - modulo_partitioner(size_type num_partitions) : divisor{num_partitions} {} - - __device__ size_type operator()(hash_value_t hash_value) const { return hash_value % divisor; } - - private: - const size_type divisor; -}; - -template -bool is_power_two(T number) -{ - return (0 == (number & (number - 1))); -} - -/** - * @brief Functor to map a hash value to a particular 'bin' or partition number - * that uses a bitwise mask. Only works when num_partitions is a power of 2. - * - * For n % d, if d is a power of two, then it can be computed more efficiently via - * a single bitwise AND as: - * n & (d - 1) - */ -template -class bitwise_partitioner { - public: - bitwise_partitioner(size_type num_partitions) : mask{(num_partitions - 1)} - { - assert(is_power_two(num_partitions)); - } - - __device__ size_type operator()(hash_value_t hash_value) const - { - return hash_value & mask; // hash_value & (num_partitions - 1) - } - - private: - const size_type mask; -}; - -/* --------------------------------------------------------------------------*/ -/** - * @brief Computes which partition each row of a device_table will belong to based - on hashing each row, and applying a partition function to the hash value. - Records the size of each partition for each thread block as well as the global - size of each partition across all thread blocks. - * - * @param[in] the_table The table whose rows will be partitioned - * @param[in] num_rows The number of rows in the table - * @param[in] num_partitions The number of partitions to divide the rows into - * @param[in] the_partitioner The functor that maps a rows hash value to a partition number - * @param[out] row_partition_numbers Array that holds which partition each row belongs to - * @param[out] row_partition_offset Array that holds the offset of each row in its partition of - * the thread block - * @param[out] block_partition_sizes Array that holds the size of each partition for each block, - * i.e., { {block0 partition0 size, block1 partition0 size, ...}, - {block0 partition1 size, block1 partition1 size, ...}, - ... - {block0 partition(num_partitions-1) size, block1 partition(num_partitions -1) size, ...} } - * @param[out] global_partition_sizes The number of rows in each partition. - */ -/* ----------------------------------------------------------------------------*/ -template -__global__ void compute_row_partition_numbers(row_hasher_t the_hasher, - const size_type num_rows, - const size_type num_partitions, - const partitioner_type the_partitioner, - size_type* __restrict__ row_partition_numbers, - size_type* __restrict__ row_partition_offset, - size_type* __restrict__ block_partition_sizes, - size_type* __restrict__ global_partition_sizes) -{ - // Accumulate histogram of the size of each partition in shared memory - extern __shared__ size_type shared_partition_sizes[]; - - size_type row_number = threadIdx.x + blockIdx.x * blockDim.x; - - // Initialize local histogram - size_type partition_number = threadIdx.x; - while (partition_number < num_partitions) { - shared_partition_sizes[partition_number] = 0; - partition_number += blockDim.x; - } - - __syncthreads(); - - // Compute the hash value for each row, store it to the array of hash values - // and compute the partition to which the hash value belongs and increment - // the shared memory counter for that partition - while (row_number < num_rows) { - const hash_value_type row_hash_value = the_hasher(row_number); - - const size_type partition_number = the_partitioner(row_hash_value); - - row_partition_numbers[row_number] = partition_number; - - row_partition_offset[row_number] = - atomicAdd(&(shared_partition_sizes[partition_number]), size_type(1)); - - row_number += blockDim.x * gridDim.x; - } - - __syncthreads(); - - // Flush shared memory histogram to global memory - partition_number = threadIdx.x; - while (partition_number < num_partitions) { - const size_type block_partition_size = shared_partition_sizes[partition_number]; - - // Update global size of each partition - atomicAdd(&global_partition_sizes[partition_number], block_partition_size); - - // Record the size of this partition in this block - const size_type write_location = partition_number * gridDim.x + blockIdx.x; - block_partition_sizes[write_location] = block_partition_size; - partition_number += blockDim.x; - } -} - -/* --------------------------------------------------------------------------*/ -/** - * @brief Given an array of partition numbers, computes the final output location - for each element in the output such that all rows with the same partition are - contiguous in memory. - * - * @param row_partition_numbers The array that records the partition number for each row - * @param num_rows The number of rows - * @param num_partitions THe number of partitions - * @param[out] block_partition_offsets Array that holds the offset of each partition for each thread - block, - * i.e., { {block0 partition0 offset, block1 partition0 offset, ...}, - {block0 partition1 offset, block1 partition1 offset, ...}, - ... - {block0 partition(num_partitions-1) offset, block1 partition(num_partitions -1) offset, - ...} } - */ -/* ----------------------------------------------------------------------------*/ -__global__ void compute_row_output_locations(size_type* __restrict__ row_partition_numbers, - const size_type num_rows, - const size_type num_partitions, - size_type* __restrict__ block_partition_offsets) -{ - // Shared array that holds the offset of this blocks partitions in - // global memory - extern __shared__ size_type shared_partition_offsets[]; - - // Initialize array of this blocks offsets from global array - size_type partition_number = threadIdx.x; - while (partition_number < num_partitions) { - shared_partition_offsets[partition_number] = - block_partition_offsets[partition_number * gridDim.x + blockIdx.x]; - partition_number += blockDim.x; - } - __syncthreads(); - - size_type row_number = threadIdx.x + blockIdx.x * blockDim.x; - - // Get each row's partition number, and get it's output location by - // incrementing block's offset counter for that partition number - // and store the row's output location in-place - while (row_number < num_rows) { - // Get partition number of this row - const size_type partition_number = row_partition_numbers[row_number]; - - // Get output location based on partition number by incrementing the corresponding - // partition offset for this block - const size_type row_output_location = - atomicAdd(&(shared_partition_offsets[partition_number]), size_type(1)); - - // Store the row's output location in-place - row_partition_numbers[row_number] = row_output_location; - - row_number += blockDim.x * gridDim.x; - } -} - -/* --------------------------------------------------------------------------*/ -/** - * @brief Move one column from the input table to the hashed table. - * - * @param[in] input_buf Data buffer of the column in the input table - * @param[out] output_buf Preallocated data buffer of the column in the output table - * @param[in] num_rows The number of rows in each column - * @param[in] num_partitions The number of partitions to divide the rows into - * @param[in] row_partition_numbers Array that holds which partition each row belongs to - * @param[in] row_partition_offset Array that holds the offset of each row in its partition of - * the thread block. - * @param[in] block_partition_sizes Array that holds the size of each partition for each block - * @param[in] scanned_block_partition_sizes The scan of block_partition_sizes - */ -/* ----------------------------------------------------------------------------*/ -template -__global__ void copy_block_partitions(InputIter input_iter, - DataType* __restrict__ output_buf, - const size_type num_rows, - const size_type num_partitions, - size_type const* __restrict__ row_partition_numbers, - size_type const* __restrict__ row_partition_offset, - size_type const* __restrict__ block_partition_sizes, - size_type const* __restrict__ scanned_block_partition_sizes) -{ - extern __shared__ char shared_memory[]; - auto block_output = reinterpret_cast(shared_memory); - auto partition_offset_shared = - reinterpret_cast(block_output + OPTIMIZED_BLOCK_SIZE * OPTIMIZED_ROWS_PER_THREAD); - auto partition_offset_global = - reinterpret_cast(partition_offset_shared + num_partitions + 1); - - typedef cub::BlockScan BlockScan; - __shared__ typename BlockScan::TempStorage temp_storage; - - // use ELEMENTS_PER_THREAD=2 to support up to 1024 partitions - size_type temp_histo[ELEMENTS_PER_THREAD]; - - for (int i = 0; i < ELEMENTS_PER_THREAD; ++i) { - if (ELEMENTS_PER_THREAD * threadIdx.x + i < num_partitions) { - temp_histo[i] = - block_partition_sizes[blockIdx.x + (ELEMENTS_PER_THREAD * threadIdx.x + i) * gridDim.x]; - } else { - temp_histo[i] = 0; - } - } - - __syncthreads(); - - BlockScan(temp_storage).InclusiveSum(temp_histo, temp_histo); - - __syncthreads(); - - if (threadIdx.x == 0) { partition_offset_shared[0] = 0; } - - // Calculate the offset in shared memory of each partition in this thread block - for (int i = 0; i < ELEMENTS_PER_THREAD; ++i) { - if (ELEMENTS_PER_THREAD * threadIdx.x + i < num_partitions) { - partition_offset_shared[ELEMENTS_PER_THREAD * threadIdx.x + i + 1] = temp_histo[i]; - } - } - - // Fetch the offset in the output buffer of each partition in this thread block - for (size_type ipartition = threadIdx.x; ipartition < num_partitions; ipartition += blockDim.x) { - partition_offset_global[ipartition] = - scanned_block_partition_sizes[ipartition * gridDim.x + blockIdx.x]; - } - - __syncthreads(); - - // Fetch the input data to shared memory - for (size_type row_number = threadIdx.x + blockIdx.x * blockDim.x; row_number < num_rows; - row_number += blockDim.x * gridDim.x) { - size_type const ipartition = row_partition_numbers[row_number]; - - block_output[partition_offset_shared[ipartition] + row_partition_offset[row_number]] = - input_iter[row_number]; - } - - __syncthreads(); - - // Copy data from shared memory to output using 32 threads for each partition - constexpr int nthreads_partition = 32; - static_assert(OPTIMIZED_BLOCK_SIZE % nthreads_partition == 0, - "BLOCK_SIZE must be divisible by number of threads"); - - for (size_type ipartition = threadIdx.x / nthreads_partition; ipartition < num_partitions; - ipartition += OPTIMIZED_BLOCK_SIZE / nthreads_partition) { - size_type const nelements_partition = - partition_offset_shared[ipartition + 1] - partition_offset_shared[ipartition]; - - for (size_type row_offset = threadIdx.x % nthreads_partition; row_offset < nelements_partition; - row_offset += nthreads_partition) { - output_buf[partition_offset_global[ipartition] + row_offset] = - block_output[partition_offset_shared[ipartition] + row_offset]; - } - } -} - -template -void copy_block_partitions_impl(InputIter const input, - OutputIter output, - size_type num_rows, - size_type num_partitions, - size_type const* row_partition_numbers, - size_type const* row_partition_offset, - size_type const* block_partition_sizes, - size_type const* scanned_block_partition_sizes, - size_type grid_size, - cudaStream_t stream) -{ - // We need 3 chunks of shared memory: - // 1. BLOCK_SIZE * ROWS_PER_THREAD elements of size_type for copying to output - // 2. num_partitions + 1 elements of size_type for per-block partition offsets - // 3. num_partitions + 1 elements of size_type for global partition offsets - int const smem = OPTIMIZED_BLOCK_SIZE * OPTIMIZED_ROWS_PER_THREAD * sizeof(*output) + - (num_partitions + 1) * sizeof(size_type) * 2; - - copy_block_partitions<<>>( - input, - output, - num_rows, - num_partitions, - row_partition_numbers, - row_partition_offset, - block_partition_sizes, - scanned_block_partition_sizes); -} - -rmm::device_vector compute_gather_map(size_type num_rows, - size_type num_partitions, - size_type const* row_partition_numbers, - size_type const* row_partition_offset, - size_type const* block_partition_sizes, - size_type const* scanned_block_partition_sizes, - size_type grid_size, - cudaStream_t stream) -{ - auto sequence = thrust::make_counting_iterator(0); - rmm::device_vector gather_map(num_rows); - - copy_block_partitions_impl(sequence, - gather_map.data().get(), - num_rows, - num_partitions, - row_partition_numbers, - row_partition_offset, - block_partition_sizes, - scanned_block_partition_sizes, - grid_size, - stream); - - return gather_map; -} - -struct copy_block_partitions_dispatcher { - template ()>* = nullptr> - std::unique_ptr operator()(column_view const& input, - const size_type num_partitions, - size_type const* row_partition_numbers, - size_type const* row_partition_offset, - size_type const* block_partition_sizes, - size_type const* scanned_block_partition_sizes, - size_type grid_size, - rmm::mr::device_memory_resource* mr, - cudaStream_t stream) - { - rmm::device_buffer output(input.size() * sizeof(DataType), stream, mr); - - copy_block_partitions_impl(input.data(), - static_cast(output.data()), - input.size(), - num_partitions, - row_partition_numbers, - row_partition_offset, - block_partition_sizes, - scanned_block_partition_sizes, - grid_size, - stream); - - return std::make_unique(input.type(), input.size(), std::move(output)); - } - - template ()>* = nullptr> - std::unique_ptr operator()(column_view const& input, - const size_type num_partitions, - size_type const* row_partition_numbers, - size_type const* row_partition_offset, - size_type const* block_partition_sizes, - size_type const* scanned_block_partition_sizes, - size_type grid_size, - rmm::mr::device_memory_resource* mr, - cudaStream_t stream) - { - // Use move_to_output_buffer to create an equivalent gather map - auto gather_map = compute_gather_map(input.size(), - num_partitions, - row_partition_numbers, - row_partition_offset, - block_partition_sizes, - scanned_block_partition_sizes, - grid_size, - stream); - - // Use gather instead for non-fixed width types - return type_dispatcher(input.type(), - detail::column_gatherer{}, - input, - gather_map.begin(), - gather_map.end(), - false, - stream, - mr); - } -}; - -// NOTE hash_has_nulls must be true if table_to_hash has nulls -template -std::pair, std::vector> hash_partition_table( - table_view const& input, - table_view const& table_to_hash, - size_type num_partitions, - rmm::mr::device_memory_resource* mr, - cudaStream_t stream) -{ - auto const num_rows = table_to_hash.num_rows(); - - bool const use_optimization{num_partitions <= THRESHOLD_FOR_OPTIMIZED_PARTITION_KERNEL}; - auto const block_size = use_optimization ? OPTIMIZED_BLOCK_SIZE : FALLBACK_BLOCK_SIZE; - auto const rows_per_thread = - use_optimization ? OPTIMIZED_ROWS_PER_THREAD : FALLBACK_ROWS_PER_THREAD; - auto const rows_per_block = block_size * rows_per_thread; - - // NOTE grid_size is non-const to workaround lambda capture bug in gcc 5.4 - auto grid_size = util::div_rounding_up_safe(num_rows, rows_per_block); - - // Allocate array to hold which partition each row belongs to - auto row_partition_numbers = rmm::device_vector(num_rows); - - // Array to hold the size of each partition computed by each block - // i.e., { {block0 partition0 size, block1 partition0 size, ...}, - // {block0 partition1 size, block1 partition1 size, ...}, - // ... - // {block0 partition(num_partitions-1) size, block1 partition(num_partitions -1) size, - // ...} } - auto block_partition_sizes = rmm::device_vector(grid_size * num_partitions); - - auto scanned_block_partition_sizes = rmm::device_vector(grid_size * num_partitions); - - // Holds the total number of rows in each partition - auto global_partition_sizes = rmm::device_vector(num_partitions, size_type{0}); - - auto row_partition_offset = rmm::device_vector(num_rows); - - auto const device_input = table_device_view::create(table_to_hash, stream); - auto const hasher = row_hasher(*device_input); - - // If the number of partitions is a power of two, we can compute the partition - // number of each row more efficiently with bitwise operations - if (is_power_two(num_partitions)) { - // Determines how the mapping between hash value and partition number is computed - using partitioner_type = bitwise_partitioner; - - // Computes which partition each row belongs to by hashing the row and performing - // a partitioning operator on the hash value. Also computes the number of - // rows in each partition both for each thread block as well as across all blocks - compute_row_partition_numbers<<>>(hasher, - num_rows, - num_partitions, - partitioner_type(num_partitions), - row_partition_numbers.data().get(), - row_partition_offset.data().get(), - block_partition_sizes.data().get(), - global_partition_sizes.data().get()); - } else { - // Determines how the mapping between hash value and partition number is computed - using partitioner_type = modulo_partitioner; - - // Computes which partition each row belongs to by hashing the row and performing - // a partitioning operator on the hash value. Also computes the number of - // rows in each partition both for each thread block as well as across all blocks - compute_row_partition_numbers<<>>(hasher, - num_rows, - num_partitions, - partitioner_type(num_partitions), - row_partition_numbers.data().get(), - row_partition_offset.data().get(), - block_partition_sizes.data().get(), - global_partition_sizes.data().get()); - } - - // Compute exclusive scan of all blocks' partition sizes in-place to determine - // the starting point for each blocks portion of each partition in the output - thrust::exclusive_scan(rmm::exec_policy(stream)->on(stream), - block_partition_sizes.begin(), - block_partition_sizes.end(), - scanned_block_partition_sizes.data().get()); - - // Compute exclusive scan of size of each partition to determine offset location - // of each partition in final output. - // TODO This can be done independently on a separate stream - size_type* scanned_global_partition_sizes{global_partition_sizes.data().get()}; - thrust::exclusive_scan(rmm::exec_policy(stream)->on(stream), - global_partition_sizes.begin(), - global_partition_sizes.end(), - scanned_global_partition_sizes); - - // Copy the result of the exclusive scan to the output offsets array - // to indicate the starting point for each partition in the output - std::vector partition_offsets(num_partitions); - CUDA_TRY(cudaMemcpyAsync(partition_offsets.data(), - scanned_global_partition_sizes, - num_partitions * sizeof(size_type), - cudaMemcpyDeviceToHost, - stream)); - - // When the number of partitions is less than a threshold, we can apply an - // optimization using shared memory to copy values to the output buffer. - // Otherwise, fallback to using scatter. - if (use_optimization) { - std::vector> output_cols(input.num_columns()); - - // NOTE these pointers are non-const to workaround lambda capture bug in gcc 5.4 - auto row_partition_numbers_ptr{row_partition_numbers.data().get()}; - auto row_partition_offset_ptr{row_partition_offset.data().get()}; - auto block_partition_sizes_ptr{block_partition_sizes.data().get()}; - auto scanned_block_partition_sizes_ptr{scanned_block_partition_sizes.data().get()}; - - // Copy input to output by partition per column - std::transform(input.begin(), input.end(), output_cols.begin(), [=](auto const& col) { - return cudf::type_dispatcher(col.type(), - copy_block_partitions_dispatcher{}, - col, - num_partitions, - row_partition_numbers_ptr, - row_partition_offset_ptr, - block_partition_sizes_ptr, - scanned_block_partition_sizes_ptr, - grid_size, - mr, - stream); - }); - - if (has_nulls(input)) { - // Use copy_block_partitions to compute a gather map - auto gather_map = compute_gather_map(num_rows, - num_partitions, - row_partition_numbers_ptr, - row_partition_offset_ptr, - block_partition_sizes_ptr, - scanned_block_partition_sizes_ptr, - grid_size, - stream); - - // Handle bitmask using gather to take advantage of ballot_sync - detail::gather_bitmask( - input, gather_map.begin(), output_cols, detail::gather_bitmask_op::DONT_CHECK, mr, stream); - } - - auto output{std::make_unique(std::move(output_cols))}; - return std::make_pair(std::move(output), std::move(partition_offsets)); - } else { - // Compute a scatter map from input to output such that the output rows are - // sorted by partition number - auto row_output_locations{row_partition_numbers.data().get()}; - auto scanned_block_partition_sizes_ptr{scanned_block_partition_sizes.data().get()}; - compute_row_output_locations<<>>( - row_output_locations, num_rows, num_partitions, scanned_block_partition_sizes_ptr); - - // Use the resulting scatter map to materialize the output - auto output = detail::scatter( - input, row_partition_numbers.begin(), row_partition_numbers.end(), input, false, mr, stream); - - return std::make_pair(std::move(output), std::move(partition_offsets)); - } -} // MD5 supported leaf data type check bool md5_type_check(data_type dt) @@ -616,28 +39,6 @@ bool md5_type_check(data_type dt) } // namespace namespace detail { -std::pair, std::vector> hash_partition( - table_view const& input, - std::vector const& columns_to_hash, - int num_partitions, - rmm::mr::device_memory_resource* mr, - cudaStream_t stream) -{ - CUDF_FUNC_RANGE(); - - auto table_to_hash = input.select(columns_to_hash); - - // Return empty result if there are no partitions or nothing to hash - if (num_partitions <= 0 || input.num_rows() == 0 || table_to_hash.num_columns() == 0) { - return std::make_pair(empty_like(input), std::vector{}); - } - - if (has_nulls(table_to_hash)) { - return hash_partition_table(input, table_to_hash, num_partitions, mr, stream); - } else { - return hash_partition_table(input, table_to_hash, num_partitions, mr, stream); - } -} std::unique_ptr hash(table_view const& input, hash_id hash_function, diff --git a/cpp/src/partitioning/partitioning.cu b/cpp/src/partitioning/partitioning.cu index b18c231b309..7a30aca5c50 100644 --- a/cpp/src/partitioning/partitioning.cu +++ b/cpp/src/partitioning/partitioning.cu @@ -26,6 +26,8 @@ #include #include +#include + namespace cudf { namespace { // Launch configuration for optimized hash partition @@ -446,7 +448,7 @@ struct copy_block_partitions_dispatcher { }; // NOTE hash_has_nulls must be true if table_to_hash has nulls -template +template