Skip to content

Commit

Permalink
[#22301] docdb: Improve handling of large responses
Browse files Browse the repository at this point in the history
Summary:
The presence of large rows can result in very large RPC responses to read requests:
 - By default, we return `yb_fetch_row_limit` rows regardless of row size.
 - When `yb_fetch_size_limit` is set, we return responses exceeding `yb_fetch_size_limit`,
   by one row.
 - We always return at least one full row, so regardless of what `yb_fetch_size_limit` and
   `yb_fetch_row_limit` are set to, it is possible to have very large RPC responses.

When the RPC response is sufficiently large, we do not fail gracefully and do not properly return an
error back to the user:
 - When an RPC response exceeds `rpc_max_message_size`, we attempt to return an error back, but
   still attempt to use the oversized RPC response for the error, causing the error return to itself
   fail, resulting in a DFATAL that does not properly respond to the RPC and leaves clients hanging.
 - When an RPC response exceeds 2^32 bytes, we instead trigger a FATAL from `narrow_cast`, due to
   assumptions that RPC responses do not exceed 2^32 bytes.
 - We also consume (unbounded) large amounts of memory to generate and process this response,
   which may trigger FATAL from checked mallocs failing.

This diff makes the following changes:
 - Change `narrow_cast`s that may FATAL to either not use `narrow_cast`, or to cap the value
   appropriately before performing `narrow_cast`.
 - Catch large responses and error earlier to avoid some unnecessary allocations.
 - Do not attempt to send the sidecars with an error response, to avoid the error response itself
   failing due to being too large.
 - Impose a maximum of `rpc_max_message_size` on `yb_fetch_size_limit` (importantly, also in the
   case where it is set to its default value of `0` for unlimited).

This diff also changes protobuf_message_total_bytes_limit from int32 to uint32 (this is safe
because negative values made no sense and would have prevented any RPCs from being
sent) and adds gflag validators to enforce the following relationship:
  rpc_max_message_size < protobuf_message_total_bytes_limit < 512 MB

**Upgrade/Rollback safety:**
This diff only touches test only protos.

Jira: DB-11216

Test Plan:
Jenkins.

Added test cases:
- `./yb_build.sh --gtest_filter 'PgMiniTest.ReadHugeRow'` to test RPC too large error is returned up properly.
- `./yb_build.sh --gtest_filter TestRpc.MaxSizeResponse  --cxx-test rpc_rpc-test` to test max sized RPC, tested before and after changes.

No tests were added for the narrow_cast case due to memory limitations running unit tests on Jenkins, but a modified version of the above test that runs into the narrow_cast case was run locally to confirm its fix.

Reviewers: qhu, sergei

Reviewed By: sergei

Subscribers: rthallam, yyan, yql, ybase

Differential Revision: https://phorge.dev.yugabyte.com/D37548
  • Loading branch information
es1024 committed Sep 6, 2024
1 parent 6cc5f6a commit a28b3ec
Show file tree
Hide file tree
Showing 15 changed files with 200 additions and 13 deletions.
13 changes: 10 additions & 3 deletions src/yb/docdb/pgsql_operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ DEFINE_RUNTIME_AUTO_bool(ysql_skip_row_lock_for_update, kExternal, true, false,
"take finer column-level locks instead of locking the whole row. This may cause issues with "
"data integrity for operations with implicit dependencies between columns.");


DECLARE_uint64(rpc_max_message_size);

namespace yb::docdb {

using dockv::DocKey;
Expand Down Expand Up @@ -2067,11 +2070,15 @@ Result<size_t> PgsqlReadOperation::ExecuteScalar(
row_count_limit = request_.limit();
}

// We also limit the response's size.
auto response_size_limit = std::numeric_limits<std::size_t>::max();
// We also limit the response's size. Responses that exceed rpc_max_message_size will error
// anyways, so we use that as an upper bound for the limit. This limit only applies on the data
// in the response, and excludes headers, etc., but since we add rows until we *exceed*
// the limit, this already won't avoid hitting rpc max size and is just an effort to limit the
// damage.
auto response_size_limit = GetAtomicFlag(&FLAGS_rpc_max_message_size);

if (request_.has_size_limit() && request_.size_limit() > 0) {
response_size_limit = request_.size_limit();
response_size_limit = std::min(response_size_limit, request_.size_limit());
}

VLOG(4) << "Row count limit: " << row_count_limit << ", size limit: " << response_size_limit;
Expand Down
22 changes: 21 additions & 1 deletion src/yb/rpc/lightweight_message.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,27 @@ DEFINE_UNKNOWN_uint64(rpc_max_message_size, 255_MB,
"The maximum size of a message of any RPC that the server will accept. The sum of "
"consensus_max_batch_size_bytes and 1KB should be less than rpc_max_message_size");

DECLARE_int32(protobuf_message_total_bytes_limit);
DECLARE_uint32(protobuf_message_total_bytes_limit);

namespace {

bool RpcMaxMessageSizeValidator(const char* flag_name, uint64_t value) {
// This validation depends on the value of: protobuf_message_total_bytes_limit.
DELAY_FLAG_VALIDATION_ON_STARTUP(flag_name);

if (value >= FLAGS_protobuf_message_total_bytes_limit) {
LOG_FLAG_VALIDATION_ERROR(flag_name, value)
<< "Must be less than protobuf_message_total_bytes_limit "
<< FLAGS_protobuf_message_total_bytes_limit;
return false;
}

return true;
}

} // namespace

DEFINE_validator(rpc_max_message_size, &RpcMaxMessageSizeValidator);

using google::protobuf::internal::WireFormatLite;
using google::protobuf::io::CodedOutputStream;
Expand Down
14 changes: 9 additions & 5 deletions src/yb/rpc/outbound_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,14 @@ void OutboundCall::Serialize(ByteBlocks* output) {
buffer_consumption_ = ScopedTrackedConsumption();
}

size_t OutboundCall::HeaderTotalLength(size_t header_pb_len) {
return
kMsgLengthPrefixLength // Int prefix for the total length.
+ CodedOutputStream::VarintSize32(
narrow_cast<uint32_t>(header_pb_len)) // Varint delimiter for header PB.
+ header_pb_len; // Length for the header PB itself.
}

Status OutboundCall::SetRequestParam(
AnyMessageConstPtr req, std::unique_ptr<Sidecars> sidecars, const MemTrackerPtr& mem_tracker) {
auto req_size = req.SerializedSize();
Expand All @@ -320,11 +328,7 @@ Status OutboundCall::SetRequestParam(
encoded_sidecars_len = sidecar_offsets->size() * sizeof(uint32_t);
header_pb_len += 1 + Output::VarintSize64(encoded_sidecars_len) + encoded_sidecars_len;
}
size_t header_size =
kMsgLengthPrefixLength // Int prefix for the total length.
+ CodedOutputStream::VarintSize32(
narrow_cast<uint32_t>(header_pb_len)) // Varint delimiter for header PB.
+ header_pb_len; // Length for the header PB itself.
size_t header_size = HeaderTotalLength(header_pb_len);
size_t buffer_size = header_size + message_size;

buffer_ = RefCntBuffer(buffer_size);
Expand Down
2 changes: 2 additions & 0 deletions src/yb/rpc/outbound_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,8 @@ class OutboundCall : public RpcCall {
expires_at_.store(expires_at, std::memory_order_release);
}

static size_t HeaderTotalLength(size_t header_pb_len);

// ----------------------------------------------------------------------------------------------
// Getters
// ----------------------------------------------------------------------------------------------
Expand Down
23 changes: 23 additions & 0 deletions src/yb/rpc/rpc-test-base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ using yb::rpc_test::AddRequestPB;
using yb::rpc_test::AddResponsePB;
using yb::rpc_test::EchoRequestPB;
using yb::rpc_test::EchoResponsePB;
using yb::rpc_test::RepeatedEchoRequestPB;
using yb::rpc_test::RepeatedEchoResponsePB;
using yb::rpc_test::ForwardRequestPB;
using yb::rpc_test::ForwardResponsePB;
using yb::rpc_test::PanicRequestPB;
Expand Down Expand Up @@ -213,6 +215,21 @@ void GenericCalculatorService::DoEcho(InboundCall* incoming) {
down_cast<YBInboundCall*>(incoming)->RespondSuccess(AnyMessageConstPtr(&resp));
}

void GenericCalculatorService::DoRepeatedEcho(InboundCall* incoming) {
Slice param(incoming->serialized_request());
RepeatedEchoRequestPB req;
if (!req.ParseFromArray(param.data(), narrow_cast<int>(param.size()))) {
incoming->RespondFailure(
ErrorStatusPB::ERROR_INVALID_REQUEST,
STATUS(InvalidArgument, "Couldn't parse pb", req.InitializationErrorString()));
return;
}

RepeatedEchoResponsePB resp;
resp.set_data(std::string(req.count(), static_cast<char>(req.character())));
down_cast<YBInboundCall*>(incoming)->RespondSuccess(AnyMessageConstPtr(&resp));
}

namespace {

class CalculatorService: public CalculatorServiceIf {
Expand Down Expand Up @@ -270,6 +287,12 @@ class CalculatorService: public CalculatorServiceIf {
context.RespondSuccess();
}

void RepeatedEcho(const RepeatedEchoRequestPB* req, RepeatedEchoResponsePB* resp,
RpcContext context) override {
resp->set_data(std::string(req->count(), static_cast<char>(req->character())));
context.RespondSuccess();
}

void WhoAmI(const WhoAmIRequestPB* req, WhoAmIResponsePB* resp, RpcContext context) override {
LOG(INFO) << "Remote address: " << context.remote_address();
resp->set_address(yb::ToString(context.remote_address()));
Expand Down
8 changes: 8 additions & 0 deletions src/yb/rpc/rpc-test-base.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class CalculatorServiceMethods {
static const constexpr auto kAddMethodName = "Add";
static const constexpr auto kDisconnectMethodName = "Disconnect";
static const constexpr auto kEchoMethodName = "Echo";
static const constexpr auto kRepeatedEchoMethodName = "RepeatedEcho";
static const constexpr auto kSendStringsMethodName = "SendStrings";
static const constexpr auto kSleepMethodName = "Sleep";

Expand All @@ -85,6 +86,12 @@ class CalculatorServiceMethods {
return &method;
}

static RemoteMethod* RepeatedEchoMethod() {
static RemoteMethod method(
rpc_test::CalculatorServiceIf::static_service_name(), kRepeatedEchoMethodName);
return &method;
}

static RemoteMethod* SendStringsMethod() {
static RemoteMethod method(
rpc_test::CalculatorServiceIf::static_service_name(), kSendStringsMethodName);
Expand Down Expand Up @@ -124,6 +131,7 @@ class GenericCalculatorService : public ServiceIf {
void DoSendStrings(InboundCall* incoming);
void DoSleep(InboundCall *incoming);
void DoEcho(InboundCall *incoming);
void DoRepeatedEcho(InboundCall *incoming);
void AddMethodToMap(
const RpcServicePtr& service, RpcEndpointMap* map, const char* method_name, Method method);

Expand Down
49 changes: 49 additions & 0 deletions src/yb/rpc/rpc-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ DECLARE_int64(memory_limit_hard_bytes);
DECLARE_string(vmodule);
DECLARE_uint64(rpc_connection_timeout_ms);
DECLARE_uint64(rpc_read_buffer_size);
DECLARE_uint64(rpc_max_message_size);

using namespace std::chrono_literals;
using std::string;
Expand Down Expand Up @@ -1171,6 +1172,54 @@ TEST_F(TestRpc, YB_DISABLE_TEST_ON_MACOS(CantAllocateReadBuffer)) {
RunPlainTest(&TestCantAllocateReadBuffer, SetupServerForTestCantAllocateReadBuffer());
}

namespace {

void TestMaxSizeRpcResponse(CalculatorServiceProxy* proxy) {
using google::protobuf::io::CodedOutputStream;

const size_t rpc_max_size = FLAGS_rpc_max_message_size;
const size_t rpc_max_size_varint_size =
CodedOutputStream::VarintSize32(narrow_cast<uint32_t>(rpc_max_size));

ResponseHeader resp_header;
resp_header.set_call_id(1);
resp_header.set_is_error(false);

const size_t header_pb_len = resp_header.ByteSize();
const size_t header_tot_len = OutboundCall::HeaderTotalLength(header_pb_len);

// We assume that length of data field/entire message is close enough to rpc_max_size for
// simplicity; this should hold for the values we use for rpc_max_size.
const size_t msg_len_without_data = 1 // Field tag.
+ rpc_max_size_varint_size // Length of data field.
+ rpc_max_size_varint_size; // Length of entire message.

const size_t data_length = rpc_max_size - header_tot_len - msg_len_without_data;

RpcController controller;
controller.set_timeout(5s * kTimeMultiplier);

rpc_test::RepeatedEchoRequestPB req;
req.set_character('0');
req.set_count(data_length);

rpc_test::RepeatedEchoResponsePB resp;
ASSERT_OK(proxy->RepeatedEcho(req, &resp, &controller));

controller.Reset();

req.set_character('0');
req.set_count(data_length + 1);
ASSERT_NOK(proxy->RepeatedEcho(req, &resp, &controller));
}

} // namespace

TEST_F(TestRpc, MaxSizeResponse) {
ANNOTATE_UNPROTECTED_WRITE(FLAGS_rpc_max_message_size) = 10_MB;
RunPlainTest(&TestMaxSizeRpcResponse);
}

class TestRpcSecure : public RpcTestBase {
public:
void SetUp() override {
Expand Down
10 changes: 10 additions & 0 deletions src/yb/rpc/rtest.proto
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,15 @@ message EchoResponsePB {
required string data = 1;
}

message RepeatedEchoRequestPB {
required int32 character = 1;
required uint64 count = 2;
}

message RepeatedEchoResponsePB {
required string data = 1;
}

message WhoAmIRequestPB {
}

Expand Down Expand Up @@ -136,6 +145,7 @@ service CalculatorService {
rpc Add(AddRequestPB) returns(AddResponsePB);
rpc Sleep(SleepRequestPB) returns(SleepResponsePB);
rpc Echo(EchoRequestPB) returns(EchoResponsePB);
rpc RepeatedEcho(RepeatedEchoRequestPB) returns(RepeatedEchoResponsePB);
rpc WhoAmI(WhoAmIRequestPB) returns (WhoAmIResponsePB);
rpc TestArgumentsInDiffPackage(yb.rpc_test_diff_package.ReqDiffPackagePB)
returns(yb.rpc_test_diff_package.RespDiffPackagePB);
Expand Down
9 changes: 8 additions & 1 deletion src/yb/rpc/serialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ namespace rpc {

size_t SerializedMessageSize(size_t body_size, size_t additional_size) {
auto full_size = body_size + additional_size;
return body_size + CodedOutputStream::VarintSize32(narrow_cast<uint32_t>(full_size));
// VarintSize64 used to avoid casting errors. There is a separate constraint later enforced where
// size <= rpc_max_message_size <= protobuf_message_total_bytes_limit < 512 MB, so we never end up
// in a case where we actually send RPCs with size that doesn't fit in 32 bits.
return body_size + CodedOutputStream::VarintSize64(full_size);
}

Status SerializeMessage(
Expand Down Expand Up @@ -104,6 +107,10 @@ Status SerializeHeader(const MessageLite& header,
+ header_pb_len; // Length for the header PB itself.
size_t total_size = header_tot_len + param_len;

if (total_size > FLAGS_rpc_max_message_size) {
return STATUS_FORMAT(InvalidArgument, "Sending too long RPC message ($0 bytes)", total_size);
}

*header_buf = RefCntBuffer(header_tot_len + reserve_for_param);
if (header_size != nullptr) {
*header_size = header_tot_len;
Expand Down
4 changes: 4 additions & 0 deletions src/yb/rpc/yb_rpc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,10 @@ void YBInboundCall::RespondSuccess(AnyMessageConstPtr response) {
void YBInboundCall::RespondFailure(ErrorStatusPB::RpcErrorCodePB error_code,
const Status& status) {
TRACE_EVENT0("rpc", "InboundCall::RespondFailure");

// Release memory early and prevent building an oversized error response.
sidecars_.Reset();

ErrorStatusPB err;
err.set_message(status.ToString());
err.set_code(error_code);
Expand Down
2 changes: 1 addition & 1 deletion src/yb/tablet/preparer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ DEFINE_test_flag(double, simulate_skip_process_batch, 0.0,
"Probability that the preparer will skip invoking ProcessAndClearLeaderSideBatch "
"after processing an item.");

DECLARE_int32(protobuf_message_total_bytes_limit);
DECLARE_uint32(protobuf_message_total_bytes_limit);
DECLARE_uint64(rpc_max_message_size);

using namespace std::literals;
Expand Down
2 changes: 1 addition & 1 deletion src/yb/tablet/tablet_peer-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ DECLARE_uint64(initial_log_segment_size_bytes);
DECLARE_int32(log_min_seconds_to_retain);
DECLARE_uint64(log_segment_size_bytes);
DECLARE_uint64(max_group_replicate_batch_size);
DECLARE_int32(protobuf_message_total_bytes_limit);
DECLARE_uint32(protobuf_message_total_bytes_limit);
DECLARE_uint64(rpc_max_message_size);
DECLARE_int32(retryable_request_timeout_secs);

Expand Down
10 changes: 10 additions & 0 deletions src/yb/tserver/pg_client_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ DECLARE_bool(ysql_serializable_isolation_for_ddl_txn);
DECLARE_bool(ysql_yb_enable_ddl_atomicity_infra);
DECLARE_bool(yb_enable_cdc_consistent_snapshot_streams);

DECLARE_uint64(rpc_max_message_size);

namespace yb::tserver {
namespace {

Expand Down Expand Up @@ -395,8 +397,16 @@ struct PerformData {
if (status.ok()) {
status = ProcessResponse(used_read_time_applier ? &used_read_time : nullptr);
}

size_t max_size = GetAtomicFlag(&FLAGS_rpc_max_message_size);
if (status.ok() && sidecars.size() > max_size) {
status = STATUS_FORMAT(
InvalidArgument, "Sending too long RPC message ($0 bytes of data)", sidecars.size());
}

if (!status.ok()) {
StatusToPB(status, resp.mutable_status());
sidecars.Reset();
used_read_time = {};
}
if (cache_setter) {
Expand Down
20 changes: 19 additions & 1 deletion src/yb/util/pb_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,30 @@ COMPILE_ASSERT((arraysize(kPBContainerMagic) - 1) == kPBContainerMagicLen,
// To permit parsing of very large PB messages, we must use parse through a CodedInputStream and
// bump the byte limit. The SetTotalBytesLimit() docs say that 512MB is the shortest theoretical
// message length that may produce integer overflow warnings, so that's what we'll use.
DEFINE_UNKNOWN_int32(
DEFINE_UNKNOWN_uint32(
protobuf_message_total_bytes_limit, 511_MB,
"Limits single protobuf message size for deserialization.");
TAG_FLAG(protobuf_message_total_bytes_limit, advanced);
TAG_FLAG(protobuf_message_total_bytes_limit, hidden);

namespace {

bool ProtobufMessageTotalBytesLimitValidator(const char* flag_name, uint32_t value) {
constexpr uint32_t kMaxProtobufMessageTotalBytesLimit = 512_MB;

if (value >= kMaxProtobufMessageTotalBytesLimit) {
LOG_FLAG_VALIDATION_ERROR(flag_name, value) << "Must be less than "
<< kMaxProtobufMessageTotalBytesLimit;
return false;
}

return true;
}

} // namespace

DEFINE_validator(protobuf_message_total_bytes_limit, ProtobufMessageTotalBytesLimitValidator);

namespace yb {
namespace pb_util {

Expand Down
25 changes: 25 additions & 0 deletions src/yb/yql/pgwrapper/pg_mini-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2208,6 +2208,31 @@ TEST_F(PgMiniTest, NoWaitForRPCOnTermination) {
ASSERT_LT(termination_duration, RegularBuildVsDebugVsSanitizers(3000, 5000, 5000));
}

TEST_F(PgMiniTest, ReadHugeRow) {
constexpr size_t kNumColumns = 2;
constexpr size_t kColumnSize = 254000000;

std::string create_query = "CREATE TABLE test(pk INT PRIMARY KEY";
for (size_t i = 0; i < kNumColumns; ++i) {
create_query += Format(", text$0 TEXT", i);
}
create_query += ")";

auto conn = ASSERT_RESULT(Connect());
ASSERT_OK(conn.Execute(create_query));
ASSERT_OK(conn.Execute("INSERT INTO test(pk) VALUES(0)"));

for (size_t i = 0; i < kNumColumns; ++i) {
ASSERT_OK(conn.ExecuteFormat(
"UPDATE test SET text$0 = repeat('0', $1) WHERE pk = 0",
i, kColumnSize));
}

const auto res = conn.Fetch("SELECT * FROM test LIMIT 1");
ASSERT_NOK(res);
ASSERT_STR_CONTAINS(res.status().ToString(), "Sending too long RPC message");
}

TEST_F_EX(
PgMiniTest, CacheRefreshWithDroppedEntries, PgMiniTestSingleNode) {
auto conn = ASSERT_RESULT(Connect());
Expand Down

0 comments on commit a28b3ec

Please sign in to comment.