Skip to content
This repository has been archived by the owner on Jun 23, 2022. It is now read-only.

Commit

Permalink
feat: throw exception if reads corrupt data (#863)
Browse files Browse the repository at this point in the history
  • Loading branch information
levy5307 authored Jul 25, 2021
1 parent 3858fdc commit 7e3eab7
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 11 deletions.
4 changes: 4 additions & 0 deletions include/dsn/cpp/rpc_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ class rpc_read_stream : public binary_reader
}
}

int read(char *buffer, int sz) { return inner_read(buffer, sz); }

int read(blob &blob, int len) { return inner_read(blob, len); }

~rpc_read_stream()
{
if (_msg) {
Expand Down
4 changes: 2 additions & 2 deletions include/dsn/cpp/serialization_helper/thrift_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class binary_reader_transport : public TVirtualTransport<binary_reader_transport
uint32_t read(uint8_t *buf, uint32_t len)
{
int l = _reader.read((char *)buf, static_cast<int>(len));
if (l == 0) {
if (dsn_unlikely(l <= 0)) {
throw TTransportException(TTransportException::END_OF_FILE,
"no more data to read after end-of-buffer");
}
Expand Down Expand Up @@ -712,4 +712,4 @@ inline void unmarshall_thrift_json(binary_reader &reader, T &val)
::apache::thrift::protocol::TJSONProtocol proto(transport);
unmarshall_thrift_internal(val, &proto);
}
}
} // namespace dsn
13 changes: 10 additions & 3 deletions include/dsn/utility/binary_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <cstring>
#include <dsn/utility/blob.h>
#include <gtest/gtest_prod.h>

namespace dsn {
class binary_reader
Expand Down Expand Up @@ -39,21 +40,27 @@ class binary_reader
int read(/*out*/ bool &val) { return read_pod(val); }

int read(/*out*/ std::string &s);
int read(char *buffer, int sz);
virtual int read(char *buffer, int sz);
int read(blob &blob);
int read(blob &blob, int len);
virtual int read(blob &blob, int len);

blob get_buffer() const { return _blob; }
blob get_remaining_buffer() const { return _blob.range(static_cast<int>(_ptr - _blob.data())); }
bool is_eof() const { return _ptr >= _blob.data() + _size; }
int total_size() const { return _size; }
int get_remaining_size() const { return _remaining_size; }

protected:
int inner_read(blob &blob, int len);
int inner_read(char *buffer, int sz);

private:
blob _blob;
int _size;
const char *_ptr;
int _remaining_size;

FRIEND_TEST(binary_reader_test, inner_read);
};

template <typename T>
Expand All @@ -70,4 +77,4 @@ inline int binary_reader::read_pod(/*out*/ T &val)
return 0;
}
}
}
} // namespace dsn
9 changes: 9 additions & 0 deletions src/replica/replica_2pc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ void replica::on_client_write(dsn::message_ex *request, bool ignore_throttling)
}

task_spec *spec = task_spec::get(request->rpc_code());
if (dsn_unlikely(nullptr == spec || request->rpc_code() == TASK_CODE_INVALID)) {
derror_f("recv message with unhandled rpc name {} from {}, trace_id = {}",
request->rpc_code().to_string(),
request->header->from_address.to_string(),
request->header->trace_id);
response_client_write(request, ERR_HANDLER_NOT_FOUND);
return;
}

if (is_duplicating() && !spec->rpc_request_is_write_idempotent) {
// Ignore non-idempotent write, because duplication provides no guarantee of atomicity to
// make this write produce the same result on multiple clusters.
Expand Down
27 changes: 21 additions & 6 deletions src/utils/binary_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,24 @@ int binary_reader::read(blob &blob)
}

int binary_reader::read(blob &blob, int len)
{
auto res = inner_read(blob, len);
if (dsn_unlikely(res < 0)) {
assert(false);
}
return res;
}

int binary_reader::read(char *buffer, int sz)
{
auto res = inner_read(buffer, sz);
if (dsn_unlikely(res < 0)) {
assert(false);
}
return res;
}

int binary_reader::inner_read(blob &blob, int len)
{
if (len <= get_remaining_size()) {
blob = _blob.range(static_cast<int>(_ptr - _blob.data()), len);
Expand All @@ -64,22 +82,19 @@ int binary_reader::read(blob &blob, int len)
_remaining_size -= len;
return len + sizeof(len);
} else {
assert(false);
return 0;
return -1;
}
}

int binary_reader::read(char *buffer, int sz)
int binary_reader::inner_read(char *buffer, int sz)
{
if (sz <= get_remaining_size()) {
memcpy((void *)buffer, _ptr, sz);
_ptr += sz;
_remaining_size -= sz;
return sz;
} else {
assert(false);
return 0;
return -1;
}
}

} // namespace dsn
73 changes: 73 additions & 0 deletions src/utils/test/binary_reader_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include <dsn/utility/binary_reader.h>
#include <gtest/gtest.h>
#include <dsn/utility/defer.h>

namespace dsn {

TEST(binary_reader_test, inner_read)
{

{
blob input = blob::create_from_bytes(std::string("test10086"));
binary_reader reader(input);

blob output;
int size = 4;
auto res = reader.inner_read(output, size);
ASSERT_EQ(res, size + sizeof(size));
ASSERT_EQ(output.to_string(), "test");
}

{
blob input = blob::create_from_bytes(std::string("test10086"));
binary_reader reader(input);

blob output;
int size = 10;
auto res = reader.inner_read(output, size);
ASSERT_EQ(res, -1);
}

{

blob input = blob::create_from_bytes(std::string("test10086"));
binary_reader reader(input);

int size = 4;
char *output_str = new char[size + 1];
auto cleanup = dsn::defer([&output_str]() { delete[] output_str; });
auto res = reader.inner_read(output_str, size);
output_str[size] = '\0';
ASSERT_EQ(res, size);
ASSERT_EQ(std::string(output_str), "test");
}

{
blob input = blob::create_from_bytes(std::string("test10086"));
binary_reader reader(input);

int size = 10;
char *output_str = new char[size];
auto cleanup = dsn::defer([&output_str]() { delete[] output_str; });
auto res = reader.inner_read(output_str, size);
ASSERT_EQ(res, -1);
}
}
} // namespace dsn

0 comments on commit 7e3eab7

Please sign in to comment.