Skip to content

Commit

Permalink
Block any BWU frames before the NC connection accepted by both side
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 624066422
  • Loading branch information
realp2us authored and copybara-github committed Apr 16, 2024
1 parent 198243f commit e7a3f6b
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 22 deletions.
22 changes: 22 additions & 0 deletions connections/implementation/bwu_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "connections/implementation/bwu_handler.h"
#include "connections/implementation/client_proxy.h"
#include "connections/implementation/endpoint_channel_manager.h"
#include "connections/implementation/flags/nearby_connections_feature_flags.h"
#include "connections/implementation/offline_frames.h"
#include "connections/implementation/service_id_constants.h"
#ifdef NO_WEBRTC
Expand All @@ -36,6 +37,7 @@
#include "connections/implementation/wifi_direct_bwu_handler.h"
#include "connections/implementation/wifi_hotspot_bwu_handler.h"
#include "connections/implementation/wifi_lan_bwu_handler.h"
#include "internal/flags/nearby_flags.h"
#include "internal/platform/byte_array.h"
#include "internal/platform/count_down_latch.h"
#include "internal/platform/feature_flags.h"
Expand Down Expand Up @@ -471,6 +473,26 @@ void BwuManager::OnBwuNegotiationFrame(ClientProxy* client,
NEARBY_LOGS(INFO) << "OnBwuNegotiationFrame: processing incoming "
<< BwuNegotiationFrame::EventType_Name(frame.event_type())
<< " frame for endpoint " << endpoint_id;

if (NearbyFlags::GetInstance().GetBoolFlag(
config_package_nearby::nearby_connections_feature::
kProcessBwuFrameAfterPcpConnected) &&
!client->IsConnectedToEndpoint(endpoint_id)) {
NEARBY_LOGS(WARNING)
<< "BwuManager skips the process BANDWIDTH_UPGRADE_NEGOTIATION before "
"PCP connected, "
<< frame.event_type();

// For the case discover side not yet get the local accept from client, but
// advertise side already get, discover side should inform advertise side
// the upgrade failed, so the advertise side could have chance to initialize
// another upgrade flow again.
if (frame.event_type() == BwuNegotiationFrame::UPGRADE_PATH_AVAILABLE) {
RunUpgradeFailedProtocol(client, endpoint_id, frame.upgrade_path_info());
}
return;
}

switch (frame.event_type()) {
case BwuNegotiationFrame::UPGRADE_PATH_AVAILABLE:
ProcessBwuPathAvailableEvent(client, endpoint_id,
Expand Down
62 changes: 62 additions & 0 deletions connections/implementation/bwu_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@
#include "connections/implementation/endpoint_manager.h"
#include "connections/implementation/fake_bwu_handler.h"
#include "connections/implementation/fake_endpoint_channel.h"
#include "connections/implementation/flags/nearby_connections_feature_flags.h"
#include "connections/implementation/mediums/mediums.h"
#include "connections/implementation/offline_frames.h"
#include "connections/implementation/service_id_constants.h"
#include "internal/flags/nearby_flags.h"
#include "internal/platform/exception.h"
#include "internal/platform/feature_flags.h"
#include "internal/proto/analytics/connections_log.pb.h"
#include "proto/connections_enums.pb.h"

Expand Down Expand Up @@ -928,6 +931,65 @@ TEST_F(BwuManagerTest, OnProcessBwuEvent) {
// TODO(b/235109434): Add more unit tests coverage for BWU module
}

TEST_F(BwuManagerTest, BlockBwuFrameBeforeAccept) {
NearbyFlags::GetInstance().OverrideBoolFlagValue(
config_package_nearby::nearby_connections_feature::
kProcessBwuFrameAfterPcpConnected,
false);
CreateInitialEndpoint(kServiceIdA, kEndpointId1, Medium::BLUETOOTH);

ExceptionOr<OfflineFrame> hotspot_path_available_frame =
parser::FromBytes(parser::ForBwuWifiHotspotPathAvailable(
/*ssid=*/"Direct-357a2d8c", /*password=*/"b592f7d3",
/*port=*/1234, /*frequency=*/2412, /*gateway=*/"123.234.23.1",
true));
OfflineFrame frame = hotspot_path_available_frame.result();
frame.set_version(OfflineFrame::V1);
auto* v1_frame = frame.mutable_v1();
auto* sub_frame = v1_frame->mutable_bandwidth_upgrade_negotiation();
sub_frame->set_event_type(
BandwidthUpgradeNegotiationFrame::UPGRADE_PATH_AVAILABLE);
auto* upgrade_path_info = sub_frame->mutable_upgrade_path_info();

upgrade_path_info->set_supports_client_introduction_ack(false);
upgrade_path_info->set_supports_disabling_encryption(true);
bwu_manager_->OnIncomingFrame(frame, std::string(kEndpointId1), &client_,
Medium::BLUETOOTH, packet_meta_data_);
CountDownLatch latch(1);
// The BWU frame should not be drop, so the inProgressUpgrades should not be
// empty.
ASSERT_EQ(bwu_manager_->IsUpgradeOngoing(std::string(kEndpointId1)), true);
UnRegisterChannelForEndpoint(kEndpointId1);

NearbyFlags::GetInstance().OverrideBoolFlagValue(
config_package_nearby::nearby_connections_feature::
kProcessBwuFrameAfterPcpConnected,
true);
CreateInitialEndpoint(kServiceIdA, kEndpointId2, Medium::BLUETOOTH);

ExceptionOr<OfflineFrame> hotspot_path_available_frame2 =
parser::FromBytes(parser::ForBwuWifiHotspotPathAvailable(
/*ssid=*/"Direct-357a2d8c", /*password=*/"b592f7d3",
/*port=*/1234, /*frequency=*/2412, /*gateway=*/"123.234.23.1",
true));
OfflineFrame frame2 = hotspot_path_available_frame2.result();
frame2.set_version(OfflineFrame::V1);
auto* v1_frame2 = frame2.mutable_v1();
auto* sub_frame2 = v1_frame2->mutable_bandwidth_upgrade_negotiation();
sub_frame->set_event_type(
BandwidthUpgradeNegotiationFrame::UPGRADE_PATH_AVAILABLE);
auto* upgrade_path_info2 = sub_frame2->mutable_upgrade_path_info();

upgrade_path_info2->set_supports_client_introduction_ack(false);
upgrade_path_info2->set_supports_disabling_encryption(true);
bwu_manager_->OnIncomingFrame(frame2, std::string(kEndpointId2), &client_,
Medium::BLUETOOTH, packet_meta_data_);
CountDownLatch latch2(1);
// The BWU frame should be drop, so the inProgressUpgrades should be empty.
ASSERT_EQ(bwu_manager_->IsUpgradeOngoing(std::string(kEndpointId2)), false);
UnRegisterChannelForEndpoint(kEndpointId2);
}

INSTANTIATE_TEST_SUITE_P(BwuManagerTestParam, BwuManagerTestParam,
testing::Bool());

Expand Down
69 changes: 47 additions & 22 deletions connections/implementation/payload_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,30 @@ void PayloadManager::OnIncomingFrame(
PayloadTransferFrame& frame =
*offline_frame.mutable_v1()->mutable_payload_transfer();

// Block any payload before the connection been accepted by both sides
// to prevent vulnerability.
if (NearbyFlags::GetInstance().GetBoolFlag(
config_package_nearby::nearby_connections_feature::
kProcessBwuFrameAfterPcpConnected) &&
!to_client->IsConnectedToEndpoint(from_endpoint_id)) {
if (frame.packet_type() == PayloadTransferFrame::DATA) {
PendingPayloadHandle pending_payload =
pending_payloads_.GetPayload(frame.payload_header().id());
bool is_last = IsLastChunk(frame.payload_chunk());
// If payload need to be ack'd receiving, then send back the ACK frame.
if (pending_payload && is_last &&
IsPayloadReceivedAckEnabled(to_client, from_endpoint_id,
*pending_payload)) {
SendPayloadReceivedAck(to_client, *pending_payload, from_endpoint_id,
is_last);
}
}
NEARBY_LOGS(INFO)
<< "PayloadManager skipped process payloads before PCP connected, "
<< frame.payload_header().id();
return;
}

switch (frame.packet_type()) {
case PayloadTransferFrame::CONTROL:
NEARBY_LOGS(INFO) << "PayloadManager::OnIncomingFrame [CONTROL]: self="
Expand All @@ -534,7 +558,8 @@ void PayloadManager::OnIncomingFrame(
break;
case PayloadTransferFrame::PAYLOAD_ACK:
NEARBY_LOGS(INFO) << "[safe-to-disconnect][PAYLOAD_RECEIVED_ACK] sender "
"received payload ack from " << from_endpoint_id;
"received payload ack from "
<< from_endpoint_id;
ProcessPayloadAckPacket(from_endpoint_id, frame);
break;
default:
Expand All @@ -556,8 +581,8 @@ void PayloadManager::OnEndpointDisconnect(ClientProxy* client,
}
RunOnStatusUpdateThread(
"payload-manager-on-disconnect",
[this, client, endpoint_id,
barrier, reason]() RUN_ON_PAYLOAD_STATUS_UPDATE_THREAD() mutable {
[this, client, endpoint_id, barrier,
reason]() RUN_ON_PAYLOAD_STATUS_UPDATE_THREAD() mutable {
// Iterate through all our payloads and look for payloads associated
// with this endpoint.
MutexLock lock(&mutex_);
Expand Down Expand Up @@ -586,20 +611,19 @@ void PayloadManager::OnEndpointDisconnect(ClientProxy* client,
// Send a client notification of a payload transfer failure.
client->OnPayloadProgress(endpoint_id, update);

PayloadStatus payload_status;
switch (reason) {
case DisconnectionReason::LOCAL_DISCONNECTION:
payload_status = PayloadStatus::LOCAL_CLIENT_DISCONNECTION;
break;
case DisconnectionReason::REMOTE_DISCONNECTION:
payload_status = PayloadStatus::REMOTE_CLIENT_DISCONNECTION;
break;
case DisconnectionReason::IO_ERROR:
default:
payload_status = PayloadStatus::ENDPOINT_IO_ERROR;
break;
}

PayloadStatus payload_status;
switch (reason) {
case DisconnectionReason::LOCAL_DISCONNECTION:
payload_status = PayloadStatus::LOCAL_CLIENT_DISCONNECTION;
break;
case DisconnectionReason::REMOTE_DISCONNECTION:
payload_status = PayloadStatus::REMOTE_CLIENT_DISCONNECTION;
break;
case DisconnectionReason::IO_ERROR:
default:
payload_status = PayloadStatus::ENDPOINT_IO_ERROR;
break;
}

if (pending_payload->IsIncoming()) {
client->GetAnalyticsRecorder().OnIncomingPayloadDone(
Expand Down Expand Up @@ -843,9 +867,10 @@ void PayloadManager::SendControlMessage(
endpoint_ids);
}

void PayloadManager::SendPayloadReceivedAck(
ClientProxy* client, PendingPayload& pending_payload,
const std::string& endpoint_id, bool is_last_chunk) {
void PayloadManager::SendPayloadReceivedAck(ClientProxy* client,
PendingPayload& pending_payload,
const std::string& endpoint_id,
bool is_last_chunk) {
if (!is_last_chunk ||
!IsPayloadReceivedAckEnabled(client, endpoint_id, pending_payload)) {
return;
Expand Down Expand Up @@ -1296,8 +1321,8 @@ void PayloadManager::ProcessDataPacket(
packet_meta_data.StopFileIo();
bool is_last_chunk = (payload_chunk.flags() &
PayloadTransferFrame::PayloadChunk::LAST_CHUNK) != 0;
SendPayloadReceivedAck(
to_client, *pending_payload, from_endpoint_id, is_last_chunk);
SendPayloadReceivedAck(to_client, *pending_payload, from_endpoint_id,
is_last_chunk);

HandleSuccessfulIncomingChunk(to_client, from_endpoint_id, payload_header,
payload_chunk.flags(), payload_chunk.offset(),
Expand Down
48 changes: 48 additions & 0 deletions connections/implementation/payload_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,28 @@
#include "gtest/gtest.h"
#include "absl/strings/string_view.h"
#include "absl/time/time.h"
#include "connections/implementation/analytics/packet_meta_data.h"
#include "connections/implementation/flags/nearby_connections_feature_flags.h"
#include "connections/implementation/offline_frames.h"
#include "connections/implementation/simulation_user.h"
#include "connections/listeners.h"
#include "connections/medium_selector.h"
#include "connections/payload.h"
#include "connections/status.h"
#include "internal/flags/nearby_flags.h"
#include "internal/platform/byte_array.h"
#include "internal/platform/count_down_latch.h"
#include "internal/platform/exception.h"
#include "internal/platform/logging.h"
#include "internal/platform/medium_environment.h"
#include "internal/platform/pipe.h"

namespace nearby {
namespace connections {
namespace {
using ::location::nearby::connections::OfflineFrame;
using ::nearby::analytics::PacketMetaData;
using ::location::nearby::proto::connections::Medium;

constexpr size_t kChunkSize = 64 * 1024;
constexpr absl::string_view kServiceId = "service-id";
Expand Down Expand Up @@ -90,6 +98,29 @@ class PayloadSimulationUser : public SimulationUser {
pm_.SendPayload(&client_, {discovered_.endpoint_id}, std::move(payload));
}

void ReceivePayload(Payload payload, std::string from_payload_id) {
PayloadTransferFrame::PayloadHeader header;
header.set_id(payload.GetId());
header.set_type(PayloadTransferFrame::PayloadHeader::FILE);
header.set_total_size(payload.AsBytes().size());
header.set_file_name("test_file.txt");
header.set_parent_folder("");
PayloadTransferFrame::PayloadChunk chunk;
chunk.set_body(payload.AsBytes().data());
chunk.set_offset(payload.GetOffset());
chunk.set_flags(1);

OfflineFrame offline_frame;

ByteArray bytes = parser::ForDataPayloadTransfer(header, chunk);
offline_frame.ParseFromString(std::string(bytes));

PacketMetaData packet_meta_data;

pm_.OnIncomingFrame(offline_frame, from_payload_id, &client_,
Medium::WIFI_HOTSPOT, packet_meta_data);
}

Status CancelPayload() {
if (sender_payload_id_) {
return pm_.CancelPayload(&client_, sender_payload_id_);
Expand Down Expand Up @@ -369,6 +400,23 @@ TEST_P(PayloadManagerTest, SendPayloadWithSkip_StreamPayload) {
env_.Stop();
}

TEST_P(PayloadManagerTest, OfflineFrame_BeforeConnected_ShouldDrop) {
NearbyFlags::GetInstance().OverrideBoolFlagValue(
config_package_nearby::nearby_connections_feature::
kProcessBwuFrameAfterPcpConnected,
false);
env_.Start();
PayloadSimulationUser user(kDeviceB, GetParam());
auto [input, tx] = CreatePipe();
const ByteArray message{std::string(kMessage)};
tx->Write(message);
Payload payload(std::move(input));
user.ReceivePayload(std::move(payload), "1234");
ASSERT_EQ(user.GetPayload().AsStream(), nullptr);
user.Stop();
env_.Stop();
}

INSTANTIATE_TEST_SUITE_P(ParametrisedPayloadManagerTest, PayloadManagerTest,
::testing::ValuesIn(kTestCases));

Expand Down

0 comments on commit e7a3f6b

Please sign in to comment.