Skip to content

Commit

Permalink
refactor(compression): treat bit rate optimization as a degree of fre…
Browse files Browse the repository at this point in the history
…edom problem

Fixes #494
  • Loading branch information
nfrechette committed Feb 10, 2024
1 parent 27a5186 commit 7d4a648
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 131 deletions.
176 changes: 64 additions & 112 deletions includes/acl/compression/impl/quantize.transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,19 @@ namespace acl
// until our error is acceptable.
// We try permutations from the lowest memory footprint to the highest.

const uint8_t* const bit_rate_permutations_per_dofs[] =
{
&acl_impl::k_local_bit_rate_permutations_1_dof[0][0],
&acl_impl::k_local_bit_rate_permutations_2_dof[0][0],
&acl_impl::k_local_bit_rate_permutations_3_dof[0][0],
};
const size_t num_bit_rate_permutations_per_dofs[] =
{
get_array_size(acl_impl::k_local_bit_rate_permutations_1_dof),
get_array_size(acl_impl::k_local_bit_rate_permutations_2_dof),
get_array_size(acl_impl::k_local_bit_rate_permutations_3_dof),
};

const uint32_t num_bones = context.num_bones;
for (uint32_t bone_index = 0; bone_index < num_bones; ++bone_index)
{
Expand All @@ -897,134 +910,73 @@ namespace acl
uint32_t prev_transform_size = ~0U;
bool is_error_good_enough = false;

if (context.has_scale)
{
const size_t num_permutations = get_array_size(acl_impl::k_local_bit_rate_permutations);
for (size_t permutation_index = 0; permutation_index < num_permutations; ++permutation_index)
{
const uint8_t rotation_bit_rate = acl_impl::k_local_bit_rate_permutations[permutation_index][0];
if (bone_bit_rates.rotation == 1)
{
if (rotation_bit_rate == 0)
continue; // Skip permutations we aren't interested in
}
else if (bone_bit_rates.rotation == k_invalid_bit_rate)
{
if (rotation_bit_rate != 0)
continue; // Skip permutations we aren't interested in
}
// Determine how many degrees of freedom we have to optimize our bit rates
uint32_t num_dof = 0;
num_dof += bone_bit_rates.rotation != k_invalid_bit_rate ? 1 : 0;
num_dof += bone_bit_rates.translation != k_invalid_bit_rate ? 1 : 0;
num_dof += bone_bit_rates.scale != k_invalid_bit_rate ? 1 : 0;

const uint8_t translation_bit_rate = acl_impl::k_local_bit_rate_permutations[permutation_index][1];
if (bone_bit_rates.translation == 1)
{
if (translation_bit_rate == 0)
continue; // Skip permutations we aren't interested in
}
else if (bone_bit_rates.translation == k_invalid_bit_rate)
{
if (translation_bit_rate != 0)
continue; // Skip permutations we aren't interested in
}
const uint8_t* bit_rate_permutations_per_dof = bit_rate_permutations_per_dofs[num_dof - 1];
const size_t num_bit_rate_permutations = num_bit_rate_permutations_per_dofs[num_dof - 1];

const uint8_t scale_bit_rate = acl_impl::k_local_bit_rate_permutations[permutation_index][2];
if (bone_bit_rates.scale == 1)
{
if (scale_bit_rate == 0)
continue; // Skip permutations we aren't interested in
}
else if (bone_bit_rates.scale == k_invalid_bit_rate)
{
if (scale_bit_rate != 0)
continue; // Skip permutations we aren't interested in
}

const uint32_t rotation_size = get_num_bits_at_bit_rate(rotation_bit_rate);
const uint32_t translation_size = get_num_bits_at_bit_rate(translation_bit_rate);
const uint32_t scale_size = get_num_bits_at_bit_rate(scale_bit_rate);
const uint32_t transform_size = rotation_size + translation_size + scale_size;

if (transform_size != prev_transform_size && is_error_good_enough)
{
// We already found the lowest transform size and we tried every permutation with that same size
break;
}

prev_transform_size = transform_size;
// Our desired bit rates start with the initial value
transform_bit_rates desired_bit_rates = bone_bit_rates;

context.bit_rate_per_bone[bone_index].rotation = bone_bit_rates.rotation != k_invalid_bit_rate ? rotation_bit_rate : k_invalid_bit_rate;
context.bit_rate_per_bone[bone_index].translation = bone_bit_rates.translation != k_invalid_bit_rate ? translation_bit_rate : k_invalid_bit_rate;
context.bit_rate_per_bone[bone_index].scale = bone_bit_rates.scale != k_invalid_bit_rate ? scale_bit_rate : k_invalid_bit_rate;

const float error = calculate_max_error_at_bit_rate_local(context, bone_index, error_scan_stop_condition::until_error_too_high);
size_t permutation_offset = 0;
for (size_t permutation_index = 0; permutation_index < num_bit_rate_permutations; ++permutation_index)
{
// If a bit rate is variable, grab a permutation for it
// We'll only consume as many bit rates as we have degrees of freedom

#if ACL_IMPL_DEBUG_VARIABLE_QUANTIZATION >= ACL_IMPL_DEBUG_LEVEL_VERBOSE_INFO
printf("%u: %u | %u | %u (%u) = %f\n", bone_index, rotation_bit_rate, translation_bit_rate, scale_bit_rate, transform_size, error);
#endif
uint32_t transform_size = 0; // In bits

if (error < best_error)
{
best_error = error;
best_bit_rates = context.bit_rate_per_bone[bone_index];
is_error_good_enough = error < error_threshold;
}
}
}
else
{
const size_t num_permutations = get_array_size(acl_impl::k_local_bit_rate_permutations_no_scale);
for (size_t permutation_index = 0; permutation_index < num_permutations; ++permutation_index)
if (desired_bit_rates.rotation != k_invalid_bit_rate)
{
const uint8_t rotation_bit_rate = acl_impl::k_local_bit_rate_permutations_no_scale[permutation_index][0];
if (bone_bit_rates.rotation == 1)
{
if (rotation_bit_rate == 0)
continue; // Skip permutations we aren't interested in
}
else if (bone_bit_rates.rotation == k_invalid_bit_rate)
{
if (rotation_bit_rate != 0)
continue; // Skip permutations we aren't interested in
}
desired_bit_rates.rotation = bit_rate_permutations_per_dof[permutation_offset++];
transform_size += get_num_bits_at_bit_rate(desired_bit_rates.rotation);
}

const uint8_t translation_bit_rate = acl_impl::k_local_bit_rate_permutations_no_scale[permutation_index][1];
if (bone_bit_rates.translation == 1)
{
if (translation_bit_rate == 0)
continue; // Skip permutations we aren't interested in
}
else if (bone_bit_rates.translation == k_invalid_bit_rate)
{
if (translation_bit_rate != 0)
continue; // Skip permutations we aren't interested in
}
if (desired_bit_rates.translation != k_invalid_bit_rate)
{
desired_bit_rates.translation = bit_rate_permutations_per_dof[permutation_offset++];
transform_size += get_num_bits_at_bit_rate(desired_bit_rates.translation);
}

const uint32_t rotation_size = get_num_bits_at_bit_rate(rotation_bit_rate);
const uint32_t translation_size = get_num_bits_at_bit_rate(translation_bit_rate);
const uint32_t transform_size = rotation_size + translation_size;
if (desired_bit_rates.scale != k_invalid_bit_rate)
{
desired_bit_rates.scale = bit_rate_permutations_per_dof[permutation_offset++];
transform_size += get_num_bits_at_bit_rate(desired_bit_rates.scale);
}

if (transform_size != prev_transform_size && is_error_good_enough)
{
// We already found the lowest transform size and we tried every permutation with that same size
break;
}
// If our inputs aren't normalized per segment, we can't store them on 0 bits because we'll have no
// segment range information. This occurs when we have a single segment. Skip those permutations.
if (bone_bit_rates.rotation == k_lowest_bit_rate && desired_bit_rates.rotation == 0)
continue;
else if (bone_bit_rates.translation == k_lowest_bit_rate && desired_bit_rates.translation == 0)
continue;
else if (bone_bit_rates.scale == k_lowest_bit_rate && desired_bit_rates.scale == 0)
continue;

// If we already found a permutation that is good enough, we test all the others
// that have the same size. Once the size changes, we stop.
if (is_error_good_enough && transform_size != prev_transform_size)
break;

prev_transform_size = transform_size;
prev_transform_size = transform_size;

context.bit_rate_per_bone[bone_index].rotation = bone_bit_rates.rotation != k_invalid_bit_rate ? rotation_bit_rate : k_invalid_bit_rate;
context.bit_rate_per_bone[bone_index].translation = bone_bit_rates.translation != k_invalid_bit_rate ? translation_bit_rate : k_invalid_bit_rate;
context.bit_rate_per_bone[bone_index] = desired_bit_rates;

const float error = calculate_max_error_at_bit_rate_local(context, bone_index, error_scan_stop_condition::until_error_too_high);
const float error = calculate_max_error_at_bit_rate_local(context, bone_index, error_scan_stop_condition::until_error_too_high);

#if ACL_IMPL_DEBUG_VARIABLE_QUANTIZATION >= ACL_IMPL_DEBUG_LEVEL_VERBOSE_INFO
printf("%u: %u | %u | %u (%u) = %f\n", bone_index, rotation_bit_rate, translation_bit_rate, k_invalid_bit_rate, transform_size, error);
printf("%u: %u | %u | %u (%u) = %f\n", bone_index, desired_bit_rates.rotation, desired_bit_rates.translation, desired_bit_rates.scale, transform_size, error);
#endif

if (error < best_error)
{
best_error = error;
best_bit_rates = context.bit_rate_per_bone[bone_index];
is_error_good_enough = error < error_threshold;
}
if (error < best_error)
{
best_error = error;
best_bit_rates = desired_bit_rates;
is_error_good_enough = error < error_threshold;
}
}

Expand Down
36 changes: 34 additions & 2 deletions includes/acl/compression/impl/transform_bit_rate_permutations.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,38 @@ namespace acl

namespace acl_impl
{
constexpr uint8_t k_local_bit_rate_permutations_no_scale[625][2] =
// Buffer size in bytes: 25
constexpr uint8_t k_local_bit_rate_permutations_1_dof[25][1] =
{
{ 0 }, // 0 bits per transform
{ 1 }, // 3 bits per transform
{ 2 }, // 6 bits per transform
{ 3 }, // 9 bits per transform
{ 4 }, // 12 bits per transform
{ 5 }, // 15 bits per transform
{ 6 }, // 18 bits per transform
{ 7 }, // 21 bits per transform
{ 8 }, // 24 bits per transform
{ 9 }, // 27 bits per transform
{ 10 }, // 30 bits per transform
{ 11 }, // 33 bits per transform
{ 12 }, // 36 bits per transform
{ 13 }, // 39 bits per transform
{ 14 }, // 42 bits per transform
{ 15 }, // 45 bits per transform
{ 16 }, // 48 bits per transform
{ 17 }, // 51 bits per transform
{ 18 }, // 54 bits per transform
{ 19 }, // 57 bits per transform
{ 20 }, // 60 bits per transform
{ 21 }, // 63 bits per transform
{ 22 }, // 66 bits per transform
{ 23 }, // 69 bits per transform
{ 24 }, // 96 bits per transform
};

// Buffer size in bytes: 1250
constexpr uint8_t k_local_bit_rate_permutations_2_dof[625][2] =
{
{ 0, 0 }, // 0 bits per transform
{ 0, 1 }, // 3 bits per transform
Expand Down Expand Up @@ -668,7 +699,8 @@ namespace acl
{ 24, 24 }, // 192 bits per transform
};

constexpr uint8_t k_local_bit_rate_permutations[15625][3] =
// Buffer size in bytes: 46875
constexpr uint8_t k_local_bit_rate_permutations_3_dof[15625][3] =
{
{ 0, 0, 0 }, // 0 bits per transform
{ 0, 0, 1 }, // 3 bits per transform
Expand Down
49 changes: 32 additions & 17 deletions tools/calc_local_bit_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,45 @@
print('Python 3.4 or higher needed to run this script')
sys.exit(1)

permutation_tries = []
permutation_tries_no_scale = []
permutation_dof_1 = []
permutation_dof_2 = []
permutation_dof_3 = []

for rotation_bit_rate in range(k_num_bit_rates):
for translation_bit_rate in range(k_num_bit_rates):
transform_size = k_bit_rate_num_bits[rotation_bit_rate] * 3 + k_bit_rate_num_bits[translation_bit_rate] * 3
permutation_tries_no_scale.append((transform_size, rotation_bit_rate, translation_bit_rate))
for dof_1 in range(k_num_bit_rates):
dof_1_size = k_bit_rate_num_bits[dof_1] * 3;
permutation_dof_1.append((dof_1_size, dof_1))

for scale_bit_rate in range(k_num_bit_rates):
transform_size = k_bit_rate_num_bits[rotation_bit_rate] * 3 + k_bit_rate_num_bits[translation_bit_rate] * 3 + k_bit_rate_num_bits[scale_bit_rate] * 3
permutation_tries.append((transform_size, rotation_bit_rate, translation_bit_rate, scale_bit_rate))
for dof_2 in range(k_num_bit_rates):
dof_2_size = dof_1_size + k_bit_rate_num_bits[dof_2] * 3
permutation_dof_2.append((dof_2_size, dof_1, dof_2))

for dof_3 in range(k_num_bit_rates):
dof_3_size = dof_2_size + k_bit_rate_num_bits[dof_3] * 3
permutation_dof_3.append((dof_3_size, dof_1, dof_2, dof_3))

# Sort by transform size, then by each bit rate
permutation_tries.sort()
permutation_tries_no_scale.sort()
permutation_dof_1.sort()
permutation_dof_2.sort()
permutation_dof_3.sort()

print('constexpr uint8_t k_local_bit_rate_permutations_no_scale[{}][2] ='.format(len(permutation_tries_no_scale)))
print('// Buffer size in bytes: {}'.format(len(permutation_dof_1) * 1));
print('constexpr uint8_t k_local_bit_rate_permutations_1_dof[{}][1] ='.format(len(permutation_dof_1)))
print('{')
for transform_size, rotation_bit_rate, translation_bit_rate in permutation_tries_no_scale:
print('\t{{ {}, {} }},\t\t// {} bits per transform'.format(rotation_bit_rate, translation_bit_rate, transform_size))
for transform_size, dof_1 in permutation_dof_1:
print('\t{{ {} }},\t\t// {} bits per transform'.format(dof_1, transform_size))
print('};')
print()
print('constexpr uint8_t k_local_bit_rate_permutations[{}][3] ='.format(len(permutation_tries)))
print('// Buffer size in bytes: {}'.format(len(permutation_dof_2) * 2));
print('constexpr uint8_t k_local_bit_rate_permutations_2_dof[{}][2] ='.format(len(permutation_dof_2)))
print('{')
for transform_size, rotation_bit_rate, translation_bit_rate, scale_bit_rate in permutation_tries:
print('\t{{ {}, {}, {} }},\t\t// {} bits per transform'.format(rotation_bit_rate, translation_bit_rate, scale_bit_rate, transform_size))
for transform_size, dof_1, dof_2 in permutation_dof_2:
print('\t{{ {}, {} }},\t\t// {} bits per transform'.format(dof_1, dof_2, transform_size))
print('};')
print()
print('// Buffer size in bytes: {}'.format(len(permutation_dof_3) * 3));
print('constexpr uint8_t k_local_bit_rate_permutations_3_dof[{}][3] ='.format(len(permutation_dof_3)))
print('{')
for transform_size, dof_1, dof_2, dof_3 in permutation_dof_3:
print('\t{{ {}, {}, {} }},\t\t// {} bits per transform'.format(dof_1, dof_2, dof_3, transform_size))
print('};')
print()

0 comments on commit 7d4a648

Please sign in to comment.