diff --git a/src/yb/docdb/pgsql_operation.cc b/src/yb/docdb/pgsql_operation.cc index fd530f5fd379..e6c71b5757ae 100644 --- a/src/yb/docdb/pgsql_operation.cc +++ b/src/yb/docdb/pgsql_operation.cc @@ -122,6 +122,9 @@ DEFINE_RUNTIME_AUTO_bool(ysql_skip_row_lock_for_update, kExternal, true, false, "take finer column-level locks instead of locking the whole row. This may cause issues with " "data integrity for operations with implicit dependencies between columns."); + +DECLARE_uint64(rpc_max_message_size); + namespace yb::docdb { using dockv::DocKey; @@ -1709,11 +1712,15 @@ Result PgsqlReadOperation::ExecuteScalar( row_count_limit = request_.limit(); } - // We also limit the response's size. - auto response_size_limit = std::numeric_limits::max(); + // We also limit the response's size. Responses that exceed rpc_max_message_size will error + // anyways, so we use that as an upper bound for the limit. This limit only applies on the data + // in the response, and excludes headers, etc., but since we add rows until we *exceed* + // the limit, this already won't avoid hitting rpc max size and is just an effort to limit the + // damage. + auto response_size_limit = GetAtomicFlag(&FLAGS_rpc_max_message_size); if (request_.has_size_limit() && request_.size_limit() > 0) { - response_size_limit = request_.size_limit(); + response_size_limit = std::min(response_size_limit, request_.size_limit()); } VLOG(4) << "Row count limit: " << row_count_limit << ", size limit: " << response_size_limit; diff --git a/src/yb/rpc/lightweight_message.cc b/src/yb/rpc/lightweight_message.cc index 954347990f75..b306935aed73 100644 --- a/src/yb/rpc/lightweight_message.cc +++ b/src/yb/rpc/lightweight_message.cc @@ -31,7 +31,7 @@ DEFINE_UNKNOWN_uint64(rpc_max_message_size, 255_MB, "The maximum size of a message of any RPC that the server will accept. The sum of " "consensus_max_batch_size_bytes and 1KB should be less than rpc_max_message_size"); -DECLARE_int32(protobuf_message_total_bytes_limit); +DECLARE_uint32(protobuf_message_total_bytes_limit); using google::protobuf::internal::WireFormatLite; using google::protobuf::io::CodedOutputStream; diff --git a/src/yb/rpc/lwproto-test.cc b/src/yb/rpc/lwproto-test.cc index a814d920007d..f3f66c053f85 100644 --- a/src/yb/rpc/lwproto-test.cc +++ b/src/yb/rpc/lwproto-test.cc @@ -23,7 +23,7 @@ #include "yb/util/size_literals.h" #include "yb/util/test_macros.h" -DECLARE_int32(protobuf_message_total_bytes_limit); +DECLARE_uint32(protobuf_message_total_bytes_limit); DECLARE_uint64(rpc_max_message_size); namespace yb { diff --git a/src/yb/rpc/outbound_call.cc b/src/yb/rpc/outbound_call.cc index f5e35f5b639f..55870a9c7739 100644 --- a/src/yb/rpc/outbound_call.cc +++ b/src/yb/rpc/outbound_call.cc @@ -296,6 +296,14 @@ void OutboundCall::Serialize(ByteBlocks* output) { buffer_consumption_ = ScopedTrackedConsumption(); } +size_t OutboundCall::HeaderTotalLength(size_t header_pb_len) { + return + kMsgLengthPrefixLength // Int prefix for the total length. + + CodedOutputStream::VarintSize32( + narrow_cast(header_pb_len)) // Varint delimiter for header PB. + + header_pb_len; // Length for the header PB itself. +} + Status OutboundCall::SetRequestParam( AnyMessageConstPtr req, std::unique_ptr sidecars, const MemTrackerPtr& mem_tracker) { auto req_size = req.SerializedSize(); @@ -320,11 +328,7 @@ Status OutboundCall::SetRequestParam( encoded_sidecars_len = sidecar_offsets->size() * sizeof(uint32_t); header_pb_len += 1 + Output::VarintSize64(encoded_sidecars_len) + encoded_sidecars_len; } - size_t header_size = - kMsgLengthPrefixLength // Int prefix for the total length. - + CodedOutputStream::VarintSize32( - narrow_cast(header_pb_len)) // Varint delimiter for header PB. - + header_pb_len; // Length for the header PB itself. + size_t header_size = HeaderTotalLength(header_pb_len); size_t buffer_size = header_size + message_size; buffer_ = RefCntBuffer(buffer_size); diff --git a/src/yb/rpc/outbound_call.h b/src/yb/rpc/outbound_call.h index 47196885832c..3a8acfd5fe02 100644 --- a/src/yb/rpc/outbound_call.h +++ b/src/yb/rpc/outbound_call.h @@ -360,6 +360,8 @@ class OutboundCall : public RpcCall { expires_at_.store(expires_at, std::memory_order_release); } + static size_t HeaderTotalLength(size_t header_pb_len); + // ---------------------------------------------------------------------------------------------- // Getters // ---------------------------------------------------------------------------------------------- diff --git a/src/yb/rpc/rpc-test-base.cc b/src/yb/rpc/rpc-test-base.cc index 4c2e78facd7c..8a0a664379a7 100644 --- a/src/yb/rpc/rpc-test-base.cc +++ b/src/yb/rpc/rpc-test-base.cc @@ -48,6 +48,8 @@ using yb::rpc_test::AddRequestPB; using yb::rpc_test::AddResponsePB; using yb::rpc_test::EchoRequestPB; using yb::rpc_test::EchoResponsePB; +using yb::rpc_test::RepeatedEchoRequestPB; +using yb::rpc_test::RepeatedEchoResponsePB; using yb::rpc_test::ForwardRequestPB; using yb::rpc_test::ForwardResponsePB; using yb::rpc_test::PanicRequestPB; @@ -205,6 +207,21 @@ void GenericCalculatorService::DoEcho(InboundCall* incoming) { down_cast(incoming)->RespondSuccess(AnyMessageConstPtr(&resp)); } +void GenericCalculatorService::DoRepeatedEcho(InboundCall* incoming) { + Slice param(incoming->serialized_request()); + RepeatedEchoRequestPB req; + if (!req.ParseFromArray(param.data(), narrow_cast(param.size()))) { + incoming->RespondFailure( + ErrorStatusPB::ERROR_INVALID_REQUEST, + STATUS(InvalidArgument, "Couldn't parse pb", req.InitializationErrorString())); + return; + } + + RepeatedEchoResponsePB resp; + resp.set_data(std::string(req.count(), static_cast(req.character()))); + down_cast(incoming)->RespondSuccess(AnyMessageConstPtr(&resp)); +} + namespace { class CalculatorService: public CalculatorServiceIf { @@ -262,6 +279,12 @@ class CalculatorService: public CalculatorServiceIf { context.RespondSuccess(); } + void RepeatedEcho(const RepeatedEchoRequestPB* req, RepeatedEchoResponsePB* resp, + RpcContext context) override { + resp->set_data(std::string(req->count(), static_cast(req->character()))); + context.RespondSuccess(); + } + void WhoAmI(const WhoAmIRequestPB* req, WhoAmIResponsePB* resp, RpcContext context) override { LOG(INFO) << "Remote address: " << context.remote_address(); resp->set_address(yb::ToString(context.remote_address())); diff --git a/src/yb/rpc/rpc-test-base.h b/src/yb/rpc/rpc-test-base.h index ef91a7939509..dca2b455d196 100644 --- a/src/yb/rpc/rpc-test-base.h +++ b/src/yb/rpc/rpc-test-base.h @@ -64,6 +64,7 @@ class CalculatorServiceMethods { static const constexpr auto kAddMethodName = "Add"; static const constexpr auto kDisconnectMethodName = "Disconnect"; static const constexpr auto kEchoMethodName = "Echo"; + static const constexpr auto kRepeatedEchoMethodName = "RepeatedEcho"; static const constexpr auto kSendStringsMethodName = "SendStrings"; static const constexpr auto kSleepMethodName = "Sleep"; @@ -85,6 +86,12 @@ class CalculatorServiceMethods { return &method; } + static RemoteMethod* RepeatedEchoMethod() { + static RemoteMethod method( + rpc_test::CalculatorServiceIf::static_service_name(), kRepeatedEchoMethodName); + return &method; + } + static RemoteMethod* SendStringsMethod() { static RemoteMethod method( rpc_test::CalculatorServiceIf::static_service_name(), kSendStringsMethodName); @@ -124,6 +131,7 @@ class GenericCalculatorService : public ServiceIf { void DoSendStrings(InboundCall* incoming); void DoSleep(InboundCall *incoming); void DoEcho(InboundCall *incoming); + void DoRepeatedEcho(InboundCall *incoming); void AddMethodToMap( const RpcServicePtr& service, RpcEndpointMap* map, const char* method_name, Method method); diff --git a/src/yb/rpc/rpc-test.cc b/src/yb/rpc/rpc-test.cc index c35fd1d1d483..4b15e96c7960 100644 --- a/src/yb/rpc/rpc-test.cc +++ b/src/yb/rpc/rpc-test.cc @@ -90,6 +90,7 @@ DECLARE_int64(memory_limit_hard_bytes); DECLARE_string(vmodule); DECLARE_uint64(rpc_connection_timeout_ms); DECLARE_uint64(rpc_read_buffer_size); +DECLARE_uint64(rpc_max_message_size); using namespace std::chrono_literals; using std::string; @@ -1141,6 +1142,54 @@ TEST_F(TestRpc, CantAllocateReadBuffer) { RunPlainTest(&TestCantAllocateReadBuffer, SetupServerForTestCantAllocateReadBuffer()); } +namespace { + +void TestMaxSizeRpcResponse(CalculatorServiceProxy* proxy) { + using google::protobuf::io::CodedOutputStream; + + const size_t rpc_max_size = FLAGS_rpc_max_message_size; + const size_t rpc_max_size_varint_size = + CodedOutputStream::VarintSize32(narrow_cast(rpc_max_size)); + + ResponseHeader resp_header; + resp_header.set_call_id(1); + resp_header.set_is_error(false); + + const size_t header_pb_len = resp_header.ByteSize(); + const size_t header_tot_len = OutboundCall::HeaderTotalLength(header_pb_len); + + // We assume that length of data field/entire message is close enough to rpc_max_size for + // simplicity; this should hold for the values we use for rpc_max_size. + const size_t msg_len_without_data = 1 // Field tag. + + rpc_max_size_varint_size // Length of data field. + + rpc_max_size_varint_size; // Length of entire message. + + const size_t data_length = rpc_max_size - header_tot_len - msg_len_without_data; + + RpcController controller; + controller.set_timeout(5s * kTimeMultiplier); + + rpc_test::RepeatedEchoRequestPB req; + req.set_character('0'); + req.set_count(data_length); + + rpc_test::RepeatedEchoResponsePB resp; + ASSERT_OK(proxy->RepeatedEcho(req, &resp, &controller)); + + controller.Reset(); + + req.set_character('0'); + req.set_count(data_length + 1); + ASSERT_NOK(proxy->RepeatedEcho(req, &resp, &controller)); +} + +} // namespace + +TEST_F(TestRpc, MaxSizeResponse) { + ANNOTATE_UNPROTECTED_WRITE(FLAGS_rpc_max_message_size) = 10_MB; + RunPlainTest(&TestMaxSizeRpcResponse); +} + class TestRpcSecure : public RpcTestBase { public: void SetUp() override { diff --git a/src/yb/rpc/rtest.proto b/src/yb/rpc/rtest.proto index c327d93bb988..393943c1e02e 100644 --- a/src/yb/rpc/rtest.proto +++ b/src/yb/rpc/rtest.proto @@ -90,6 +90,15 @@ message EchoResponsePB { required string data = 1; } +message RepeatedEchoRequestPB { + required int32 character = 1; + required uint64 count = 2; +} + +message RepeatedEchoResponsePB { + required string data = 1; +} + message WhoAmIRequestPB { } @@ -136,6 +145,7 @@ service CalculatorService { rpc Add(AddRequestPB) returns(AddResponsePB); rpc Sleep(SleepRequestPB) returns(SleepResponsePB); rpc Echo(EchoRequestPB) returns(EchoResponsePB); + rpc RepeatedEcho(RepeatedEchoRequestPB) returns(RepeatedEchoResponsePB); rpc WhoAmI(WhoAmIRequestPB) returns (WhoAmIResponsePB); rpc TestArgumentsInDiffPackage(yb.rpc_test_diff_package.ReqDiffPackagePB) returns(yb.rpc_test_diff_package.RespDiffPackagePB); diff --git a/src/yb/rpc/serialization.cc b/src/yb/rpc/serialization.cc index 3703b3d82af2..8eb39102895c 100644 --- a/src/yb/rpc/serialization.cc +++ b/src/yb/rpc/serialization.cc @@ -61,7 +61,10 @@ namespace rpc { size_t SerializedMessageSize(size_t body_size, size_t additional_size) { auto full_size = body_size + additional_size; - return body_size + CodedOutputStream::VarintSize32(narrow_cast(full_size)); + // VarintSize64 used to avoid casting errors. There is a separate constraint later enforced where + // size <= rpc_max_message_size <= protobuf_message_total_bytes_limit < 512 MB, so we never end up + // in a case where we actually send RPCs with size that doesn't fit in 32 bits. + return body_size + CodedOutputStream::VarintSize64(full_size); } Status SerializeMessage( @@ -104,6 +107,10 @@ Status SerializeHeader(const MessageLite& header, + header_pb_len; // Length for the header PB itself. size_t total_size = header_tot_len + param_len; + if (total_size > FLAGS_rpc_max_message_size) { + return STATUS_FORMAT(InvalidArgument, "Sending too long RPC message ($0 bytes)", total_size); + } + *header_buf = RefCntBuffer(header_tot_len + reserve_for_param); if (header_size != nullptr) { *header_size = header_tot_len; diff --git a/src/yb/rpc/yb_rpc.cc b/src/yb/rpc/yb_rpc.cc index b3f64967ed61..f1832c8305bb 100644 --- a/src/yb/rpc/yb_rpc.cc +++ b/src/yb/rpc/yb_rpc.cc @@ -395,6 +395,10 @@ void YBInboundCall::RespondSuccess(AnyMessageConstPtr response) { void YBInboundCall::RespondFailure(ErrorStatusPB::RpcErrorCodePB error_code, const Status& status) { TRACE_EVENT0("rpc", "InboundCall::RespondFailure"); + + // Release memory early and prevent building an oversized error response. + sidecars_.Reset(); + ErrorStatusPB err; err.set_message(status.ToString()); err.set_code(error_code); diff --git a/src/yb/tablet/preparer.cc b/src/yb/tablet/preparer.cc index fff6d5ed0a8d..5fd930688fad 100644 --- a/src/yb/tablet/preparer.cc +++ b/src/yb/tablet/preparer.cc @@ -56,7 +56,7 @@ DEFINE_test_flag(double, simulate_skip_process_batch, 0.0, "Probability that the preparer will skip invoking ProcessAndClearLeaderSideBatch " "after processing an item."); -DECLARE_int32(protobuf_message_total_bytes_limit); +DECLARE_uint32(protobuf_message_total_bytes_limit); DECLARE_uint64(rpc_max_message_size); using namespace std::literals; diff --git a/src/yb/tablet/tablet_peer-test.cc b/src/yb/tablet/tablet_peer-test.cc index b64bbf4aff0e..faaf1be03fd9 100644 --- a/src/yb/tablet/tablet_peer-test.cc +++ b/src/yb/tablet/tablet_peer-test.cc @@ -84,7 +84,7 @@ DECLARE_uint64(initial_log_segment_size_bytes); DECLARE_int32(log_min_seconds_to_retain); DECLARE_uint64(log_segment_size_bytes); DECLARE_uint64(max_group_replicate_batch_size); -DECLARE_int32(protobuf_message_total_bytes_limit); +DECLARE_uint32(protobuf_message_total_bytes_limit); DECLARE_uint64(rpc_max_message_size); DECLARE_bool(enable_flush_retryable_requests); diff --git a/src/yb/tserver/pg_client_session.cc b/src/yb/tserver/pg_client_session.cc index 13e873beba6d..3a5ed7e040cd 100644 --- a/src/yb/tserver/pg_client_session.cc +++ b/src/yb/tserver/pg_client_session.cc @@ -82,6 +82,8 @@ DEFINE_RUNTIME_string(ysql_sequence_cache_method, "connection", DECLARE_bool(ysql_serializable_isolation_for_ddl_txn); DECLARE_bool(ysql_ddl_rollback_enabled); +DECLARE_uint64(rpc_max_message_size); + namespace yb { namespace tserver { @@ -392,8 +394,16 @@ struct PerformData { if (status.ok()) { status = ProcessResponse(); } + + size_t max_size = GetAtomicFlag(&FLAGS_rpc_max_message_size); + if (status.ok() && sidecars.size() > max_size) { + status = STATUS_FORMAT( + InvalidArgument, "Sending too long RPC message ($0 bytes of data)", sidecars.size()); + } + if (!status.ok()) { StatusToPB(status, resp.mutable_status()); + sidecars.Reset(); } if (cache_setter) { cache_setter({resp, ExtractRowsSidecar(resp, sidecars)}); diff --git a/src/yb/util/pb_util.cc b/src/yb/util/pb_util.cc index d8c44c87b310..3aa42763b095 100644 --- a/src/yb/util/pb_util.cc +++ b/src/yb/util/pb_util.cc @@ -121,7 +121,7 @@ COMPILE_ASSERT((arraysize(kPBContainerMagic) - 1) == kPBContainerMagicLen, // To permit parsing of very large PB messages, we must use parse through a CodedInputStream and // bump the byte limit. The SetTotalBytesLimit() docs say that 512MB is the shortest theoretical // message length that may produce integer overflow warnings, so that's what we'll use. -DEFINE_UNKNOWN_int32( +DEFINE_UNKNOWN_uint32( protobuf_message_total_bytes_limit, 511_MB, "Limits single protobuf message size for deserialization."); TAG_FLAG(protobuf_message_total_bytes_limit, advanced); diff --git a/src/yb/yql/pgwrapper/pg_mini-test.cc b/src/yb/yql/pgwrapper/pg_mini-test.cc index 97bad5f2f9b7..39da9827e693 100644 --- a/src/yb/yql/pgwrapper/pg_mini-test.cc +++ b/src/yb/yql/pgwrapper/pg_mini-test.cc @@ -1931,6 +1931,31 @@ TEST_F(PgMiniTest, NoWaitForRPCOnTermination) { ASSERT_LT(termination_duration, RegularBuildVsDebugVsSanitizers(3000, 5000, 5000)); } +TEST_F(PgMiniTest, ReadHugeRow) { + constexpr size_t kNumColumns = 2; + constexpr size_t kColumnSize = 254000000; + + std::string create_query = "CREATE TABLE test(pk INT PRIMARY KEY"; + for (size_t i = 0; i < kNumColumns; ++i) { + create_query += Format(", text$0 TEXT", i); + } + create_query += ")"; + + auto conn = ASSERT_RESULT(Connect()); + ASSERT_OK(conn.Execute(create_query)); + ASSERT_OK(conn.Execute("INSERT INTO test(pk) VALUES(0)")); + + for (size_t i = 0; i < kNumColumns; ++i) { + ASSERT_OK(conn.ExecuteFormat( + "UPDATE test SET text$0 = repeat('0', $1) WHERE pk = 0", + i, kColumnSize)); + } + + const auto res = conn.Fetch("SELECT * FROM test LIMIT 1"); + ASSERT_NOK(res); + ASSERT_STR_CONTAINS(res.status().ToString(), "Sending too long RPC message"); +} + TEST_F_EX( PgMiniTest, CacheRefreshWithDroppedEntries, PgMiniTestSingleNode) { auto conn = ASSERT_RESULT(Connect());