Skip to content

Commit

Permalink
Add per-message ZSTD compression (ros2#418)
Browse files Browse the repository at this point in the history
* Add per-message ZSTD compression

This implements the per-messages compression and decompression
functions for the ZSTD compressor and also adds unit tests
for them.

Distro A, OPSEC #2893

Signed-off-by: P. J. Reed <[email protected]>
  • Loading branch information
pjreed authored Jul 13, 2020
1 parent c06cb47 commit 9856f5b
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 15 deletions.
4 changes: 2 additions & 2 deletions ros2bag/ros2bag/verb/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def add_arguments(self, parser, cli_name): # noqa: D102
)
parser.add_argument(
'--compression-mode', type=str, default='none',
choices=['none', 'file'],
help='Determine whether to compress bag files. Default is "none".'
choices=['none', 'file', 'message'],
help="Determine whether to compress by file or message. Default is 'none'."
)
parser.add_argument(
'--compression-format', type=str, default='', choices=['zstd'],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

namespace rosbag2_compression
{

SequentialCompressionReader::SequentialCompressionReader(
std::unique_ptr<rosbag2_compression::CompressionFactory> compression_factory,
std::unique_ptr<rosbag2_storage::StorageFactoryInterface> storage_factory,
Expand All @@ -47,9 +46,12 @@ void SequentialCompressionReader::setup_decompression()
compression_mode_ = rosbag2_compression::compression_mode_from_string(metadata_.compression_mode);
if (compression_mode_ != rosbag2_compression::CompressionMode::NONE) {
decompressor_ = compression_factory_->create_decompressor(metadata_.compression_format);
// Decompress the first file so that it is readable.
ROSBAG2_COMPRESSION_LOG_DEBUG_STREAM("Decompressing " << get_current_file().c_str());
*current_file_iterator_ = decompressor_->decompress_uri(get_current_file());
// Decompress the first file so that it is readable; don't need to do anything for
// per-message encryption.
if (compression_mode_ == CompressionMode::FILE) {
ROSBAG2_COMPRESSION_LOG_DEBUG_STREAM("Decompressing " << get_current_file().c_str());
*current_file_iterator_ = decompressor_->decompress_uri(get_current_file());
}
} else {
throw std::invalid_argument{
"SequentialCompressionReader requires a CompressionMode that is not NONE!"};
Expand Down
62 changes: 60 additions & 2 deletions rosbag2_compression/src/rosbag2_compression/zstd_compressor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>
#include <chrono>
#include <cstdio>
#include <memory>
Expand Down Expand Up @@ -149,6 +150,35 @@ void write_output_buffer(
fclose(file_pointer);
}

/**
* Checks rcutils array resizing and throws a runtime_error if there was an error resizing.
* \param rcutils_ret_t Result of calling rcutils
*/
void throw_on_rcutils_resize_error(const rcutils_ret_t resize_result)
{
if (resize_result == RCUTILS_RET_OK) {
return;
}

std::stringstream error;
error << "rcutils_uint8_array_resize error: ";
switch (resize_result) {
case RCUTILS_RET_INVALID_ARGUMENT:
error << "Invalid Argument";
break;
case RCUTILS_RET_BAD_ALLOC:
error << "Bad Alloc";
break;
case RCUTILS_RET_ERROR:
error << "Ret Error";
break;
default:
error << "Unexpected Result";
break;
}
throw std::runtime_error(error.str());
}

/**
* Checks compression_result and throws a runtime_error if there was a ZSTD error.
* \param compression_result is the return value of ZSTD_compress.
Expand Down Expand Up @@ -218,9 +248,37 @@ std::string ZstdCompressor::compress_uri(const std::string & uri)
}

void ZstdCompressor::compress_serialized_bag_message(
rosbag2_storage::SerializedBagMessage *)
rosbag2_storage::SerializedBagMessage * message)
{
throw std::logic_error{"Not implemented"};
const auto start = std::chrono::high_resolution_clock::now();
// Allocate based on compression bound and compress
const auto uncompressed_buffer_length =
ZSTD_compressBound(message->serialized_data->buffer_length);
std::vector<uint8_t> compressed_buffer(uncompressed_buffer_length);

// Perform compression and check.
// compression_result is either the actual compressed size or an error code.
const auto compression_result = ZSTD_compress(
compressed_buffer.data(), compressed_buffer.size(),
message->serialized_data->buffer, message->serialized_data->buffer_length,
kDefaultZstdCompressionLevel);
throw_on_zstd_error(compression_result);

// Compression_buffer_length might be larger than the actual compression size
// Resize compressed_buffer so its size is the actual compression size.
compressed_buffer.resize(compression_result);

const auto resize_result =
rcutils_uint8_array_resize(message->serialized_data.get(), compression_result);
throw_on_rcutils_resize_error(resize_result);

// Note that rcutils_uint8_array_resize changes buffer_capacity but not buffer_length, we
// have to do that manually.
message->serialized_data->buffer_length = compression_result;
std::copy(compressed_buffer.begin(), compressed_buffer.end(), message->serialized_data->buffer);

const auto end = std::chrono::high_resolution_clock::now();
print_compression_statistics(start, end, uncompressed_buffer_length, compression_result);
}

std::string ZstdCompressor::get_compression_identifier() const
Expand Down
63 changes: 61 additions & 2 deletions rosbag2_compression/src/rosbag2_compression/zstd_decompressor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>
#include <chrono>
#include <cstdio>
#include <sstream>
Expand Down Expand Up @@ -175,6 +176,35 @@ void throw_on_zstd_error(const ZstdDecompressReturnType compression_result)
}
}

/**
* Checks rcutils array resizing and throws a runtime_error if there was an error resizing.
* \param rcutils_ret_t Result of calling rcutils
*/
void throw_on_rcutils_resize_error(const rcutils_ret_t resize_result)
{
if (resize_result == RCUTILS_RET_OK) {
return;
}

std::stringstream error;
error << "rcutils_uint8_array_resize error: ";
switch (resize_result) {
case RCUTILS_RET_INVALID_ARGUMENT:
error << "Invalid Argument";
break;
case RCUTILS_RET_BAD_ALLOC:
error << "Bad Alloc";
break;
case RCUTILS_RET_ERROR:
error << "Ret Error";
break;
default:
error << "Unexpected Result";
break;
}
throw std::runtime_error(error.str());
}

/**
* Checks frame_content and throws a runtime_error if there was a ZSTD error
* or frame_content is invalid.
Expand Down Expand Up @@ -257,9 +287,38 @@ std::string ZstdDecompressor::decompress_uri(const std::string & uri)
}

void ZstdDecompressor::decompress_serialized_bag_message(
rosbag2_storage::SerializedBagMessage *)
rosbag2_storage::SerializedBagMessage * message)
{
throw std::logic_error{"Not implemented"};
const auto start = std::chrono::high_resolution_clock::now();
const auto compressed_buffer_length = message->serialized_data->buffer_length;

const auto decompressed_buffer_length =
ZSTD_getFrameContentSize(message->serialized_data->buffer, compressed_buffer_length);

throw_on_invalid_frame_content(decompressed_buffer_length);

// Initializes decompressed_buffer with size = decompressed_buffer_length.
// Uniform initialization cannot be used here since it will choose
// the initializer list constructor instead.
std::vector<uint8_t> decompressed_buffer(decompressed_buffer_length);

const auto decompression_result = ZSTD_decompress(
decompressed_buffer.data(), decompressed_buffer_length,
message->serialized_data->buffer, compressed_buffer_length);

throw_on_zstd_error(decompression_result);

const auto resize_result =
rcutils_uint8_array_resize(message->serialized_data.get(), decompression_result);
throw_on_rcutils_resize_error(resize_result);

message->serialized_data->buffer_length = decompression_result;
std::copy(
decompressed_buffer.begin(), decompressed_buffer.end(),
message->serialized_data->buffer);

const auto end = std::chrono::high_resolution_clock::now();
print_decompression_statistics(start, end, decompression_result, compressed_buffer_length);
}

std::string ZstdDecompressor::get_decompression_identifier() const
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>
#include <cstdio>
#include <fstream>
#include <memory>
#include <string>
#include <vector>

Expand All @@ -24,6 +26,8 @@
#include "rosbag2_compression/zstd_compressor.hpp"
#include "rosbag2_compression/zstd_decompressor.hpp"

#include "rosbag2_storage/ros_helper.hpp"

#include "rosbag2_test_common/temporary_directory_fixture.hpp"

#include "gmock/gmock.h"
Expand All @@ -32,6 +36,23 @@ namespace
{
constexpr const char kGarbageStatement[] = "garbage";
constexpr const int kDefaultGarbageFileSize = 10; // MiB
constexpr const size_t kExpectedCompressedDataSize = 976; // manually calculated, could change
// if compression params change

/**
* Writes 1M * size garbage data to a stream.
* \param out The stream to write to.
* \param size The number of times to write.
*/
void write_garbage_stream(std::ostream & out, int size = kDefaultGarbageFileSize)
{
const auto output_size = size * 1024 * 1024;
const auto num_iterations = output_size / static_cast<int>(strlen(kGarbageStatement));

for (int i = 0; i < num_iterations; i++) {
out << kGarbageStatement;
}
}

/**
* Creates a text file of a certain size.
Expand All @@ -43,12 +64,19 @@ void create_garbage_file(const std::string & uri, int size = kDefaultGarbageFile
auto out = std::ofstream{uri};
out.exceptions(std::ifstream::failbit | std::ifstream::badbit);

const auto file_size = size * 1024 * 1024;
const auto num_iterations = file_size / static_cast<int>(strlen(kGarbageStatement));
write_garbage_stream(out, size);
}

for (int i = 0; i < num_iterations; i++) {
out << kGarbageStatement;
}
/**
* Creates a string of a certain size.
* \param size Size of the string in MiB.
* \return The string.
*/
std::string create_garbage_string(int size = kDefaultGarbageFileSize)
{
std::stringstream output;
write_garbage_stream(output, size);
return output.str();
}

std::vector<char> read_file(const std::string & uri)
Expand Down Expand Up @@ -76,13 +104,32 @@ class CompressionHelperFixture : public rosbag2_test_common::TemporaryDirectoryF

void SetUp() override
{
allocator_ = rcutils_get_default_allocator();
message_ = create_garbage_string();

rclcpp::init(0, nullptr);
}

void TearDown() override
{
rclcpp::shutdown();
}

std::string deserialize_message(std::shared_ptr<rcutils_uint8_array_t> serialized_message)
{
std::unique_ptr<uint8_t[]> copied(new uint8_t[serialized_message->buffer_length + 1]);
std::copy(
serialized_message->buffer,
serialized_message->buffer + serialized_message->buffer_length,
copied.get());
copied.get()[serialized_message->buffer_length] = '\0';
std::string message_content(reinterpret_cast<char *>(copied.get()));
return message_content;
}

rcutils_allocator_t allocator_;
std::string message_;
size_t compressed_length_{kExpectedCompressedDataSize};
};

TEST_F(CompressionHelperFixture, zstd_compress_file_uri)
Expand Down Expand Up @@ -218,3 +265,35 @@ TEST_F(CompressionHelperFixture, zstd_decompress_fails_on_bad_uri)
EXPECT_THROW(decompressor.decompress_uri(bad_uri), std::runtime_error) <<
"Expected decompress_uri(\"" << bad_uri << "\") to fail!";
}

TEST_F(CompressionHelperFixture, zstd_compress_serialized_bag_message)
{
auto msg = std::make_unique<rosbag2_storage::SerializedBagMessage>();
msg->serialized_data.reset(new rcutils_uint8_array_t);
msg->serialized_data = rosbag2_storage::make_serialized_message(
message_.data(), message_.length());

rosbag2_compression::ZstdCompressor compressor;
compressor.compress_serialized_bag_message(msg.get());

ASSERT_EQ(compressed_length_, msg->serialized_data->buffer_length);
}

TEST_F(CompressionHelperFixture, zstd_decompress_serialized_bag_message)
{
auto msg = std::make_unique<rosbag2_storage::SerializedBagMessage>();
msg->serialized_data.reset(new rcutils_uint8_array_t);
msg->serialized_data = rosbag2_storage::make_serialized_message(
message_.data(), message_.length());

rosbag2_compression::ZstdCompressor compressor;
compressor.compress_serialized_bag_message(msg.get());

const auto compressed_length = msg->serialized_data->buffer_length;
EXPECT_EQ(compressed_length_, compressed_length);

rosbag2_compression::ZstdDecompressor decompressor;
EXPECT_NO_THROW(decompressor.decompress_serialized_bag_message(msg.get()));
std::string new_msg = deserialize_message(msg->serialized_data);
EXPECT_EQ(new_msg, message_);
}

0 comments on commit 9856f5b

Please sign in to comment.