diff --git a/packages/react-native/ReactCommon/jsinspector-modern/InspectorInterfaces.h b/packages/react-native/ReactCommon/jsinspector-modern/InspectorInterfaces.h index e4f1c221033120..576fd9a662c529 100644 --- a/packages/react-native/ReactCommon/jsinspector-modern/InspectorInterfaces.h +++ b/packages/react-native/ReactCommon/jsinspector-modern/InspectorInterfaces.h @@ -93,7 +93,15 @@ class JSINSPECTOR_EXPORT IInspector : public IDestructible { virtual ~IInspector() = 0; - /// addPage is called by the VM to add a page to the list of debuggable pages. + /** + * Add a page to the list of inspectable pages. + * Callers are responsible for calling removePage when the page is no longer + * expecting connections. + * \param connectFunc a function that will be called to establish a + * connection. \c connectFunc may return nullptr to reject the connection + * (e.g. if the page is in the process of shutting down). + * \returns the ID assigned to the new page. + */ virtual int addPage( const std::string& title, const std::string& vm, @@ -107,8 +115,12 @@ class JSINSPECTOR_EXPORT IInspector : public IDestructible { /// getPages is called by the client to list all debuggable pages. virtual std::vector getPages() const = 0; - /// connect is called by the client to initiate a debugging session on the - /// given page. + /** + * Called by InspectorPackagerConnection to initiate a debugging session with + * the given page. + * \returns an ILocalConnection that can be used to send messages to the + * page, or nullptr if the connection has been rejected. + */ virtual std::unique_ptr connect( int pageId, std::unique_ptr remote) = 0; diff --git a/packages/react-native/ReactCommon/jsinspector-modern/InspectorPackagerConnection.cpp b/packages/react-native/ReactCommon/jsinspector-modern/InspectorPackagerConnection.cpp index 2ef3c40df978d1..b98c01a6586157 100644 --- a/packages/react-native/ReactCommon/jsinspector-modern/InspectorPackagerConnection.cpp +++ b/packages/react-native/ReactCommon/jsinspector-modern/InspectorPackagerConnection.cpp @@ -101,6 +101,16 @@ void InspectorPackagerConnection::Impl::handleConnect( auto& inspector = getInspectorInstance(); auto inspectorConnection = inspector.connect(pageIdInt, std::move(remoteConnection)); + if (!inspectorConnection) { + LOG(INFO) << "Connection to page " << pageId << " rejected"; + + // RemoteConnection::onDisconnect(), if the connection even calls it, will + // be a no op (because the session is not added to `inspectorSessions_`), so + // let's always notify the remote client of the disconnection ourselves. + sendToPackager(folly::dynamic::object("event", "disconnect")( + "payload", folly::dynamic::object("pageId", pageId))); + return; + } inspectorSessions_.emplace( pageId, Session{ diff --git a/packages/react-native/ReactCommon/jsinspector-modern/tests/InspectorPackagerConnectionTest.cpp b/packages/react-native/ReactCommon/jsinspector-modern/tests/InspectorPackagerConnectionTest.cpp index d0b62cfabbcc2f..562f0310f89939 100644 --- a/packages/react-native/ReactCommon/jsinspector-modern/tests/InspectorPackagerConnectionTest.cpp +++ b/packages/react-native/ReactCommon/jsinspector-modern/tests/InspectorPackagerConnectionTest.cpp @@ -1196,4 +1196,144 @@ TEST_F( EXPECT_CALL(*localConnections_[1], disconnect()).RetiresOnSaturation(); getInspectorInstance().removePage(pageId); } + +TEST_F(InspectorPackagerConnectionTest, TestRejectedPageConnection) { + // Configure gmock to expect calls in a specific order. + InSequence mockCallsMustBeInSequence; + + enum { + Accept, + RejectSilently, + RejectWithDisconnect + } mockNextConnectionBehavior; + + auto pageId = getInspectorInstance().addPage( + "mock-title", + "mock-vm", + [&mockNextConnectionBehavior, + this](auto remoteConnection) -> std::unique_ptr { + switch (mockNextConnectionBehavior) { + case Accept: + return localConnections_.make_unique(std::move(remoteConnection)); + case RejectSilently: + return nullptr; + case RejectWithDisconnect: + remoteConnection->onDisconnect(); + return nullptr; + } + }); + + packagerConnection_->connect(); + + ASSERT_TRUE(webSockets_[0]); + + // Reject the connection by returning nullptr. + mockNextConnectionBehavior = RejectSilently; + + EXPECT_CALL( + *webSockets_[0], + send(JsonParsed(AllOf( + AtJsonPtr("/event", Eq("disconnect")), + AtJsonPtr("/payload/pageId", Eq(std::to_string(pageId))))))) + .RetiresOnSaturation(); + + webSockets_[0]->getDelegate().didReceiveMessage(sformat( + R"({{ + "event": "connect", + "payload": {{ + "pageId": {0} + }} + }})", + toJson(std::to_string(pageId)))); + + webSockets_[0]->getDelegate().didReceiveMessage(sformat( + R"({{ + "event": "wrappedEvent", + "payload": {{ + "pageId": {0}, + "wrappedEvent": {1} + }} + }})", + toJson(std::to_string(pageId)), + toJson(R"({ + "method": "FakeDomain.fakeMethod", + "id": 1, + "params": ["arg1", "arg2"] + })"))); + + // Reject the connection by explicitly calling onDisconnect(), then returning + // nullptr. + mockNextConnectionBehavior = RejectWithDisconnect; + + EXPECT_CALL( + *webSockets_[0], + send(JsonParsed(AllOf( + AtJsonPtr("/event", Eq("disconnect")), + AtJsonPtr("/payload/pageId", Eq(std::to_string(pageId))))))) + .RetiresOnSaturation(); + + webSockets_[0]->getDelegate().didReceiveMessage(sformat( + R"({{ + "event": "connect", + "payload": {{ + "pageId": {0} + }} + }})", + toJson(std::to_string(pageId)))); + + webSockets_[0]->getDelegate().didReceiveMessage(sformat( + R"({{ + "event": "wrappedEvent", + "payload": {{ + "pageId": {0}, + "wrappedEvent": {1} + }} + }})", + toJson(std::to_string(pageId)), + toJson(R"({ + "method": "FakeDomain.fakeMethod", + "id": 2, + "params": ["arg1", "arg2"] + })"))); + + // Accept a connection after previously rejecting connections to the same + // page. + mockNextConnectionBehavior = Accept; + + webSockets_[0]->getDelegate().didReceiveMessage(sformat( + R"({{ + "event": "connect", + "payload": {{ + "pageId": {0} + }} + }})", + toJson(std::to_string(pageId)))); + + EXPECT_CALL( + *localConnections_[0], + sendMessage(JsonParsed(AllOf( + AtJsonPtr("/method", Eq("FakeDomain.fakeMethod")), + AtJsonPtr("/id", Eq(3)), + AtJsonPtr("/params", ElementsAre("arg1", "arg2")))))) + .RetiresOnSaturation(); + + webSockets_[0]->getDelegate().didReceiveMessage(sformat( + R"({{ + "event": "wrappedEvent", + "payload": {{ + "pageId": {0}, + "wrappedEvent": {1} + }} + }})", + toJson(std::to_string(pageId)), + toJson(R"({ + "method": "FakeDomain.fakeMethod", + "id": 3, + "params": ["arg1", "arg2"] + })"))); + + EXPECT_CALL(*localConnections_[0], disconnect()).RetiresOnSaturation(); + getInspectorInstance().removePage(pageId); +} + } // namespace facebook::react::jsinspector_modern