From 7d4a6480ef9c518960c71e035d2c2e56dacb4657 Mon Sep 17 00:00:00 2001 From: Nicholas Frechette Date: Sat, 10 Feb 2024 10:20:32 -0500 Subject: [PATCH] refactor(compression): treat bit rate optimization as a degree of freedom problem Fixes #494 --- .../acl/compression/impl/quantize.transform.h | 176 +++++++----------- .../impl/transform_bit_rate_permutations.h | 36 +++- tools/calc_local_bit_rates.py | 49 +++-- 3 files changed, 130 insertions(+), 131 deletions(-) diff --git a/includes/acl/compression/impl/quantize.transform.h b/includes/acl/compression/impl/quantize.transform.h index d8f05a84..2444baba 100644 --- a/includes/acl/compression/impl/quantize.transform.h +++ b/includes/acl/compression/impl/quantize.transform.h @@ -872,6 +872,19 @@ namespace acl // until our error is acceptable. // We try permutations from the lowest memory footprint to the highest. + const uint8_t* const bit_rate_permutations_per_dofs[] = + { + &acl_impl::k_local_bit_rate_permutations_1_dof[0][0], + &acl_impl::k_local_bit_rate_permutations_2_dof[0][0], + &acl_impl::k_local_bit_rate_permutations_3_dof[0][0], + }; + const size_t num_bit_rate_permutations_per_dofs[] = + { + get_array_size(acl_impl::k_local_bit_rate_permutations_1_dof), + get_array_size(acl_impl::k_local_bit_rate_permutations_2_dof), + get_array_size(acl_impl::k_local_bit_rate_permutations_3_dof), + }; + const uint32_t num_bones = context.num_bones; for (uint32_t bone_index = 0; bone_index < num_bones; ++bone_index) { @@ -897,134 +910,73 @@ namespace acl uint32_t prev_transform_size = ~0U; bool is_error_good_enough = false; - if (context.has_scale) - { - const size_t num_permutations = get_array_size(acl_impl::k_local_bit_rate_permutations); - for (size_t permutation_index = 0; permutation_index < num_permutations; ++permutation_index) - { - const uint8_t rotation_bit_rate = acl_impl::k_local_bit_rate_permutations[permutation_index][0]; - if (bone_bit_rates.rotation == 1) - { - if (rotation_bit_rate == 0) - continue; // Skip permutations we aren't interested in - } - else if (bone_bit_rates.rotation == k_invalid_bit_rate) - { - if (rotation_bit_rate != 0) - continue; // Skip permutations we aren't interested in - } + // Determine how many degrees of freedom we have to optimize our bit rates + uint32_t num_dof = 0; + num_dof += bone_bit_rates.rotation != k_invalid_bit_rate ? 1 : 0; + num_dof += bone_bit_rates.translation != k_invalid_bit_rate ? 1 : 0; + num_dof += bone_bit_rates.scale != k_invalid_bit_rate ? 1 : 0; - const uint8_t translation_bit_rate = acl_impl::k_local_bit_rate_permutations[permutation_index][1]; - if (bone_bit_rates.translation == 1) - { - if (translation_bit_rate == 0) - continue; // Skip permutations we aren't interested in - } - else if (bone_bit_rates.translation == k_invalid_bit_rate) - { - if (translation_bit_rate != 0) - continue; // Skip permutations we aren't interested in - } + const uint8_t* bit_rate_permutations_per_dof = bit_rate_permutations_per_dofs[num_dof - 1]; + const size_t num_bit_rate_permutations = num_bit_rate_permutations_per_dofs[num_dof - 1]; - const uint8_t scale_bit_rate = acl_impl::k_local_bit_rate_permutations[permutation_index][2]; - if (bone_bit_rates.scale == 1) - { - if (scale_bit_rate == 0) - continue; // Skip permutations we aren't interested in - } - else if (bone_bit_rates.scale == k_invalid_bit_rate) - { - if (scale_bit_rate != 0) - continue; // Skip permutations we aren't interested in - } - - const uint32_t rotation_size = get_num_bits_at_bit_rate(rotation_bit_rate); - const uint32_t translation_size = get_num_bits_at_bit_rate(translation_bit_rate); - const uint32_t scale_size = get_num_bits_at_bit_rate(scale_bit_rate); - const uint32_t transform_size = rotation_size + translation_size + scale_size; - - if (transform_size != prev_transform_size && is_error_good_enough) - { - // We already found the lowest transform size and we tried every permutation with that same size - break; - } - - prev_transform_size = transform_size; + // Our desired bit rates start with the initial value + transform_bit_rates desired_bit_rates = bone_bit_rates; - context.bit_rate_per_bone[bone_index].rotation = bone_bit_rates.rotation != k_invalid_bit_rate ? rotation_bit_rate : k_invalid_bit_rate; - context.bit_rate_per_bone[bone_index].translation = bone_bit_rates.translation != k_invalid_bit_rate ? translation_bit_rate : k_invalid_bit_rate; - context.bit_rate_per_bone[bone_index].scale = bone_bit_rates.scale != k_invalid_bit_rate ? scale_bit_rate : k_invalid_bit_rate; - - const float error = calculate_max_error_at_bit_rate_local(context, bone_index, error_scan_stop_condition::until_error_too_high); + size_t permutation_offset = 0; + for (size_t permutation_index = 0; permutation_index < num_bit_rate_permutations; ++permutation_index) + { + // If a bit rate is variable, grab a permutation for it + // We'll only consume as many bit rates as we have degrees of freedom -#if ACL_IMPL_DEBUG_VARIABLE_QUANTIZATION >= ACL_IMPL_DEBUG_LEVEL_VERBOSE_INFO - printf("%u: %u | %u | %u (%u) = %f\n", bone_index, rotation_bit_rate, translation_bit_rate, scale_bit_rate, transform_size, error); -#endif + uint32_t transform_size = 0; // In bits - if (error < best_error) - { - best_error = error; - best_bit_rates = context.bit_rate_per_bone[bone_index]; - is_error_good_enough = error < error_threshold; - } - } - } - else - { - const size_t num_permutations = get_array_size(acl_impl::k_local_bit_rate_permutations_no_scale); - for (size_t permutation_index = 0; permutation_index < num_permutations; ++permutation_index) + if (desired_bit_rates.rotation != k_invalid_bit_rate) { - const uint8_t rotation_bit_rate = acl_impl::k_local_bit_rate_permutations_no_scale[permutation_index][0]; - if (bone_bit_rates.rotation == 1) - { - if (rotation_bit_rate == 0) - continue; // Skip permutations we aren't interested in - } - else if (bone_bit_rates.rotation == k_invalid_bit_rate) - { - if (rotation_bit_rate != 0) - continue; // Skip permutations we aren't interested in - } + desired_bit_rates.rotation = bit_rate_permutations_per_dof[permutation_offset++]; + transform_size += get_num_bits_at_bit_rate(desired_bit_rates.rotation); + } - const uint8_t translation_bit_rate = acl_impl::k_local_bit_rate_permutations_no_scale[permutation_index][1]; - if (bone_bit_rates.translation == 1) - { - if (translation_bit_rate == 0) - continue; // Skip permutations we aren't interested in - } - else if (bone_bit_rates.translation == k_invalid_bit_rate) - { - if (translation_bit_rate != 0) - continue; // Skip permutations we aren't interested in - } + if (desired_bit_rates.translation != k_invalid_bit_rate) + { + desired_bit_rates.translation = bit_rate_permutations_per_dof[permutation_offset++]; + transform_size += get_num_bits_at_bit_rate(desired_bit_rates.translation); + } - const uint32_t rotation_size = get_num_bits_at_bit_rate(rotation_bit_rate); - const uint32_t translation_size = get_num_bits_at_bit_rate(translation_bit_rate); - const uint32_t transform_size = rotation_size + translation_size; + if (desired_bit_rates.scale != k_invalid_bit_rate) + { + desired_bit_rates.scale = bit_rate_permutations_per_dof[permutation_offset++]; + transform_size += get_num_bits_at_bit_rate(desired_bit_rates.scale); + } - if (transform_size != prev_transform_size && is_error_good_enough) - { - // We already found the lowest transform size and we tried every permutation with that same size - break; - } + // If our inputs aren't normalized per segment, we can't store them on 0 bits because we'll have no + // segment range information. This occurs when we have a single segment. Skip those permutations. + if (bone_bit_rates.rotation == k_lowest_bit_rate && desired_bit_rates.rotation == 0) + continue; + else if (bone_bit_rates.translation == k_lowest_bit_rate && desired_bit_rates.translation == 0) + continue; + else if (bone_bit_rates.scale == k_lowest_bit_rate && desired_bit_rates.scale == 0) + continue; + + // If we already found a permutation that is good enough, we test all the others + // that have the same size. Once the size changes, we stop. + if (is_error_good_enough && transform_size != prev_transform_size) + break; - prev_transform_size = transform_size; + prev_transform_size = transform_size; - context.bit_rate_per_bone[bone_index].rotation = bone_bit_rates.rotation != k_invalid_bit_rate ? rotation_bit_rate : k_invalid_bit_rate; - context.bit_rate_per_bone[bone_index].translation = bone_bit_rates.translation != k_invalid_bit_rate ? translation_bit_rate : k_invalid_bit_rate; + context.bit_rate_per_bone[bone_index] = desired_bit_rates; - const float error = calculate_max_error_at_bit_rate_local(context, bone_index, error_scan_stop_condition::until_error_too_high); + const float error = calculate_max_error_at_bit_rate_local(context, bone_index, error_scan_stop_condition::until_error_too_high); #if ACL_IMPL_DEBUG_VARIABLE_QUANTIZATION >= ACL_IMPL_DEBUG_LEVEL_VERBOSE_INFO - printf("%u: %u | %u | %u (%u) = %f\n", bone_index, rotation_bit_rate, translation_bit_rate, k_invalid_bit_rate, transform_size, error); + printf("%u: %u | %u | %u (%u) = %f\n", bone_index, desired_bit_rates.rotation, desired_bit_rates.translation, desired_bit_rates.scale, transform_size, error); #endif - if (error < best_error) - { - best_error = error; - best_bit_rates = context.bit_rate_per_bone[bone_index]; - is_error_good_enough = error < error_threshold; - } + if (error < best_error) + { + best_error = error; + best_bit_rates = desired_bit_rates; + is_error_good_enough = error < error_threshold; } } diff --git a/includes/acl/compression/impl/transform_bit_rate_permutations.h b/includes/acl/compression/impl/transform_bit_rate_permutations.h index f6be23a8..0f63adfe 100644 --- a/includes/acl/compression/impl/transform_bit_rate_permutations.h +++ b/includes/acl/compression/impl/transform_bit_rate_permutations.h @@ -39,7 +39,38 @@ namespace acl namespace acl_impl { - constexpr uint8_t k_local_bit_rate_permutations_no_scale[625][2] = + // Buffer size in bytes: 25 + constexpr uint8_t k_local_bit_rate_permutations_1_dof[25][1] = + { + { 0 }, // 0 bits per transform + { 1 }, // 3 bits per transform + { 2 }, // 6 bits per transform + { 3 }, // 9 bits per transform + { 4 }, // 12 bits per transform + { 5 }, // 15 bits per transform + { 6 }, // 18 bits per transform + { 7 }, // 21 bits per transform + { 8 }, // 24 bits per transform + { 9 }, // 27 bits per transform + { 10 }, // 30 bits per transform + { 11 }, // 33 bits per transform + { 12 }, // 36 bits per transform + { 13 }, // 39 bits per transform + { 14 }, // 42 bits per transform + { 15 }, // 45 bits per transform + { 16 }, // 48 bits per transform + { 17 }, // 51 bits per transform + { 18 }, // 54 bits per transform + { 19 }, // 57 bits per transform + { 20 }, // 60 bits per transform + { 21 }, // 63 bits per transform + { 22 }, // 66 bits per transform + { 23 }, // 69 bits per transform + { 24 }, // 96 bits per transform + }; + + // Buffer size in bytes: 1250 + constexpr uint8_t k_local_bit_rate_permutations_2_dof[625][2] = { { 0, 0 }, // 0 bits per transform { 0, 1 }, // 3 bits per transform @@ -668,7 +699,8 @@ namespace acl { 24, 24 }, // 192 bits per transform }; - constexpr uint8_t k_local_bit_rate_permutations[15625][3] = + // Buffer size in bytes: 46875 + constexpr uint8_t k_local_bit_rate_permutations_3_dof[15625][3] = { { 0, 0, 0 }, // 0 bits per transform { 0, 0, 1 }, // 3 bits per transform diff --git a/tools/calc_local_bit_rates.py b/tools/calc_local_bit_rates.py index 69834161..7e1c4798 100644 --- a/tools/calc_local_bit_rates.py +++ b/tools/calc_local_bit_rates.py @@ -17,30 +17,45 @@ print('Python 3.4 or higher needed to run this script') sys.exit(1) - permutation_tries = [] - permutation_tries_no_scale = [] + permutation_dof_1 = [] + permutation_dof_2 = [] + permutation_dof_3 = [] - for rotation_bit_rate in range(k_num_bit_rates): - for translation_bit_rate in range(k_num_bit_rates): - transform_size = k_bit_rate_num_bits[rotation_bit_rate] * 3 + k_bit_rate_num_bits[translation_bit_rate] * 3 - permutation_tries_no_scale.append((transform_size, rotation_bit_rate, translation_bit_rate)) + for dof_1 in range(k_num_bit_rates): + dof_1_size = k_bit_rate_num_bits[dof_1] * 3; + permutation_dof_1.append((dof_1_size, dof_1)) - for scale_bit_rate in range(k_num_bit_rates): - transform_size = k_bit_rate_num_bits[rotation_bit_rate] * 3 + k_bit_rate_num_bits[translation_bit_rate] * 3 + k_bit_rate_num_bits[scale_bit_rate] * 3 - permutation_tries.append((transform_size, rotation_bit_rate, translation_bit_rate, scale_bit_rate)) + for dof_2 in range(k_num_bit_rates): + dof_2_size = dof_1_size + k_bit_rate_num_bits[dof_2] * 3 + permutation_dof_2.append((dof_2_size, dof_1, dof_2)) + + for dof_3 in range(k_num_bit_rates): + dof_3_size = dof_2_size + k_bit_rate_num_bits[dof_3] * 3 + permutation_dof_3.append((dof_3_size, dof_1, dof_2, dof_3)) # Sort by transform size, then by each bit rate - permutation_tries.sort() - permutation_tries_no_scale.sort() + permutation_dof_1.sort() + permutation_dof_2.sort() + permutation_dof_3.sort() - print('constexpr uint8_t k_local_bit_rate_permutations_no_scale[{}][2] ='.format(len(permutation_tries_no_scale))) + print('// Buffer size in bytes: {}'.format(len(permutation_dof_1) * 1)); + print('constexpr uint8_t k_local_bit_rate_permutations_1_dof[{}][1] ='.format(len(permutation_dof_1))) print('{') - for transform_size, rotation_bit_rate, translation_bit_rate in permutation_tries_no_scale: - print('\t{{ {}, {} }},\t\t// {} bits per transform'.format(rotation_bit_rate, translation_bit_rate, transform_size)) + for transform_size, dof_1 in permutation_dof_1: + print('\t{{ {} }},\t\t// {} bits per transform'.format(dof_1, transform_size)) print('};') print() - print('constexpr uint8_t k_local_bit_rate_permutations[{}][3] ='.format(len(permutation_tries))) + print('// Buffer size in bytes: {}'.format(len(permutation_dof_2) * 2)); + print('constexpr uint8_t k_local_bit_rate_permutations_2_dof[{}][2] ='.format(len(permutation_dof_2))) print('{') - for transform_size, rotation_bit_rate, translation_bit_rate, scale_bit_rate in permutation_tries: - print('\t{{ {}, {}, {} }},\t\t// {} bits per transform'.format(rotation_bit_rate, translation_bit_rate, scale_bit_rate, transform_size)) + for transform_size, dof_1, dof_2 in permutation_dof_2: + print('\t{{ {}, {} }},\t\t// {} bits per transform'.format(dof_1, dof_2, transform_size)) print('};') + print() + print('// Buffer size in bytes: {}'.format(len(permutation_dof_3) * 3)); + print('constexpr uint8_t k_local_bit_rate_permutations_3_dof[{}][3] ='.format(len(permutation_dof_3))) + print('{') + for transform_size, dof_1, dof_2, dof_3 in permutation_dof_3: + print('\t{{ {}, {}, {} }},\t\t// {} bits per transform'.format(dof_1, dof_2, dof_3, transform_size)) + print('};') + print()