From fd4529d0c17be681a3c0bba37bbcbb8a8dea7748 Mon Sep 17 00:00:00 2001 From: Eugene Ostroukhov Date: Fri, 10 Nov 2017 16:01:00 -0800 Subject: [PATCH] inspector: Fix crash for WS connection Attaching WS session will now include a roundtrip onto the main thread to make sure there is no other session (e.g. JS bindings) This change also required refactoring WS socket implementation to better support scenarios like this. Fixes: https://github.com/nodejs/node/issues/16852 PR-URL: https://github.com/nodejs/node/pull/17085 Reviewed-By: James M Snell Reviewed-By: Timothy Gu --- src/inspector_agent.cc | 4 - src/inspector_agent.h | 1 - src/inspector_io.cc | 73 +- src/inspector_io.h | 8 +- src/inspector_socket.cc | 825 ++++++++++-------- src/inspector_socket.h | 96 +- src/inspector_socket_server.cc | 306 +++---- src/inspector_socket_server.h | 28 +- src/node.cc | 2 +- test/cctest/test_inspector_socket.cc | 601 ++++++------- test/cctest/test_inspector_socket_server.cc | 17 +- test/common/inspector-helper.js | 55 +- ...st-inspector-no-crash-ws-after-bindings.js | 30 + 13 files changed, 1022 insertions(+), 1024 deletions(-) create mode 100644 test/parallel/test-inspector-no-crash-ws-after-bindings.js diff --git a/src/inspector_agent.cc b/src/inspector_agent.cc index 9677dca37b5939..216b43ca6c3379 100644 --- a/src/inspector_agent.cc +++ b/src/inspector_agent.cc @@ -548,10 +548,6 @@ void Agent::Connect(InspectorSessionDelegate* delegate) { client_->connectFrontend(delegate); } -bool Agent::IsConnected() { - return io_ && io_->IsConnected(); -} - void Agent::WaitForDisconnect() { CHECK_NE(client_, nullptr); client_->contextDestroyed(parent_env_->context()); diff --git a/src/inspector_agent.h b/src/inspector_agent.h index 29b9546b514aea..a2d61d0c8db28e 100644 --- a/src/inspector_agent.h +++ b/src/inspector_agent.h @@ -48,7 +48,6 @@ class Agent { bool IsStarted() { return !!client_; } // IO thread started, and client connected - bool IsConnected(); bool IsWaitingForConnect(); void WaitForDisconnect(); diff --git a/src/inspector_io.cc b/src/inspector_io.cc index 538cbab3c9fe84..9af4458c6b20f1 100644 --- a/src/inspector_io.cc +++ b/src/inspector_io.cc @@ -136,7 +136,7 @@ class InspectorIoDelegate: public node::inspector::SocketServerDelegate { const std::string& script_name, bool wait); // Calls PostIncomingMessage() with appropriate InspectorAction: // kStartSession - bool StartSession(int session_id, const std::string& target_id) override; + void StartSession(int session_id, const std::string& target_id) override; // kSendMessage void MessageReceived(int session_id, const std::string& message) override; // kEndSession @@ -145,19 +145,22 @@ class InspectorIoDelegate: public node::inspector::SocketServerDelegate { std::vector GetTargetIds() override; std::string GetTargetTitle(const std::string& id) override; std::string GetTargetUrl(const std::string& id) override; - bool IsConnected() { return connected_; } void ServerDone() override { io_->ServerDone(); } + void AssignTransport(InspectorSocketServer* server) { + server_ = server; + } + private: InspectorIo* io_; - bool connected_; int session_id_; const std::string script_name_; const std::string script_path_; const std::string target_id_; bool waiting_; + InspectorSocketServer* server_; }; void InterruptCallback(v8::Isolate*, void* agent) { @@ -226,10 +229,6 @@ void InspectorIo::Stop() { DispatchMessages(); } -bool InspectorIo::IsConnected() { - return delegate_ != nullptr && delegate_->IsConnected(); -} - bool InspectorIo::IsStarted() { return platform_ != nullptr; } @@ -264,6 +263,7 @@ void InspectorIo::IoThreadAsyncCb(uv_async_t* async) { MessageQueue outgoing_message_queue; io->SwapBehindLock(&io->outgoing_message_queue_, &outgoing_message_queue); for (const auto& outgoing : outgoing_message_queue) { + int session_id = std::get<1>(outgoing); switch (std::get<0>(outgoing)) { case TransportAction::kKill: transport->TerminateConnections(); @@ -272,8 +272,14 @@ void InspectorIo::IoThreadAsyncCb(uv_async_t* async) { transport->Stop(nullptr); break; case TransportAction::kSendMessage: - std::string message = StringViewToUtf8(std::get<2>(outgoing)->string()); - transport->Send(std::get<1>(outgoing), message); + transport->Send(session_id, + StringViewToUtf8(std::get<2>(outgoing)->string())); + break; + case TransportAction::kAcceptSession: + transport->AcceptSession(session_id); + break; + case TransportAction::kDeclineSession: + transport->DeclineSession(session_id); break; } } @@ -293,6 +299,7 @@ void InspectorIo::ThreadMain() { wait_for_connect_); delegate_ = &delegate; Transport server(&delegate, &loop, options_.host_name(), options_.port()); + delegate.AssignTransport(&server); TransportAndIo queue_transport(&server, this); thread_req_.data = &queue_transport; if (!server.Start()) { @@ -308,6 +315,7 @@ void InspectorIo::ThreadMain() { uv_run(&loop, UV_RUN_DEFAULT); thread_req_.data = nullptr; CHECK_EQ(uv_loop_close(&loop), 0); + delegate.AssignTransport(nullptr); delegate_ = nullptr; } @@ -358,6 +366,21 @@ void InspectorIo::NotifyMessageReceived() { incoming_message_cond_.Broadcast(scoped_lock); } +TransportAction InspectorIo::Attach(int session_id) { + Agent* agent = parent_env_->inspector_agent(); + if (agent->delegate() != nullptr) + return TransportAction::kDeclineSession; + + CHECK_EQ(session_delegate_, nullptr); + session_id_ = session_id; + state_ = State::kConnected; + fprintf(stderr, "Debugger attached.\n"); + session_delegate_ = std::unique_ptr( + new IoSessionDelegate(this)); + agent->Connect(session_delegate_.get()); + return TransportAction::kAcceptSession; +} + void InspectorIo::DispatchMessages() { // This function can be reentered if there was an incoming message while // V8 was processing another inspector request (e.g. if the user is @@ -375,16 +398,14 @@ void InspectorIo::DispatchMessages() { MessageQueue::value_type task; std::swap(dispatching_message_queue_.front(), task); dispatching_message_queue_.pop_front(); + int id = std::get<1>(task); StringView message = std::get<2>(task)->string(); switch (std::get<0>(task)) { case InspectorAction::kStartSession: - CHECK_EQ(session_delegate_, nullptr); - session_id_ = std::get<1>(task); - state_ = State::kConnected; - fprintf(stderr, "Debugger attached.\n"); - session_delegate_ = std::unique_ptr( - new IoSessionDelegate(this)); - parent_env_->inspector_agent()->Connect(session_delegate_.get()); + Write(Attach(id), id, StringView()); + break; + case InspectorAction::kStartSessionUnconditionally: + Attach(id); break; case InspectorAction::kEndSession: CHECK_NE(session_delegate_, nullptr); @@ -428,22 +449,23 @@ InspectorIoDelegate::InspectorIoDelegate(InspectorIo* io, const std::string& script_name, bool wait) : io_(io), - connected_(false), session_id_(0), script_name_(script_name), script_path_(script_path), target_id_(GenerateID()), - waiting_(wait) { } + waiting_(wait), + server_(nullptr) { } -bool InspectorIoDelegate::StartSession(int session_id, +void InspectorIoDelegate::StartSession(int session_id, const std::string& target_id) { - if (connected_) - return false; - connected_ = true; - session_id_++; - io_->PostIncomingMessage(InspectorAction::kStartSession, session_id, ""); - return true; + session_id_ = session_id; + InspectorAction action = InspectorAction::kStartSession; + if (waiting_) { + action = InspectorAction::kStartSessionUnconditionally; + server_->AcceptSession(session_id); + } + io_->PostIncomingMessage(action, session_id, ""); } void InspectorIoDelegate::MessageReceived(int session_id, @@ -464,7 +486,6 @@ void InspectorIoDelegate::MessageReceived(int session_id, } void InspectorIoDelegate::EndSession(int session_id) { - connected_ = false; io_->PostIncomingMessage(InspectorAction::kEndSession, session_id, ""); } diff --git a/src/inspector_io.h b/src/inspector_io.h index 7c15466eed91ff..79ccc6095ffec3 100644 --- a/src/inspector_io.h +++ b/src/inspector_io.h @@ -36,6 +36,7 @@ class InspectorIoDelegate; enum class InspectorAction { kStartSession, + kStartSessionUnconditionally, // First attach with --inspect-brk kEndSession, kSendMessage }; @@ -44,7 +45,9 @@ enum class InspectorAction { enum class TransportAction { kKill, kSendMessage, - kStop + kStop, + kAcceptSession, + kDeclineSession }; class InspectorIo { @@ -61,7 +64,6 @@ class InspectorIo { void Stop(); bool IsStarted(); - bool IsConnected(); void WaitForDisconnect(); // Called from thread to queue an incoming message and trigger @@ -124,6 +126,8 @@ class InspectorIo { void WaitForFrontendMessageWhilePaused(); // Broadcast incoming_message_cond_ void NotifyMessageReceived(); + // Attach session to an inspector. Either kAcceptSession or kDeclineSession + TransportAction Attach(int session_id); const DebugOptions options_; diff --git a/src/inspector_socket.cc b/src/inspector_socket.cc index 49d337b70b1198..23b77f6aa5609f 100644 --- a/src/inspector_socket.cc +++ b/src/inspector_socket.cc @@ -1,4 +1,6 @@ #include "inspector_socket.h" + +#include "http_parser.h" #include "util-inl.h" #define NODE_WANT_INTERNALS 1 @@ -18,12 +20,71 @@ namespace node { namespace inspector { -static const char CLOSE_FRAME[] = {'\x88', '\x00'}; +class TcpHolder { + public: + using Pointer = std::unique_ptr; -enum ws_decode_result { - FRAME_OK, FRAME_INCOMPLETE, FRAME_CLOSE, FRAME_ERROR + static Pointer Accept(uv_stream_t* server, + InspectorSocket::DelegatePointer delegate); + void SetHandler(ProtocolHandler* handler); + int WriteRaw(const std::vector& buffer, uv_write_cb write_cb); + uv_tcp_t* tcp() { + return &tcp_; + } + InspectorSocket::Delegate* delegate(); + + private: + static TcpHolder* From(void* handle) { + return node::ContainerOf(&TcpHolder::tcp_, + reinterpret_cast(handle)); + } + static void OnClosed(uv_handle_t* handle); + static void OnDataReceivedCb(uv_stream_t* stream, ssize_t nread, + const uv_buf_t* buf); + static void DisconnectAndDispose(TcpHolder* holder); + explicit TcpHolder(InspectorSocket::DelegatePointer delegate); + ~TcpHolder() = default; + void ReclaimUvBuf(const uv_buf_t* buf, ssize_t read); + + uv_tcp_t tcp_; + const InspectorSocket::DelegatePointer delegate_; + ProtocolHandler* handler_; + std::vector buffer; +}; + + +class ProtocolHandler { + public: + ProtocolHandler(InspectorSocket* inspector, TcpHolder::Pointer tcp); + + virtual void AcceptUpgrade(const std::string& accept_key) = 0; + virtual void OnData(std::vector* data) = 0; + virtual void OnEof() = 0; + virtual void Write(const std::vector data) = 0; + virtual void CancelHandshake() = 0; + + std::string GetHost(); + + InspectorSocket* inspector() { + return inspector_; + } + + static void Shutdown(ProtocolHandler* handler) { + handler->Shutdown(); + } + + protected: + virtual ~ProtocolHandler() = default; + virtual void Shutdown() = 0; + int WriteRaw(const std::vector& buffer, uv_write_cb write_cb); + InspectorSocket::Delegate* delegate(); + + InspectorSocket* const inspector_; + TcpHolder::Pointer tcp_; }; +namespace { + #if DUMP_READS || DUMP_WRITES static void dump_hex(const char* buf, size_t len) { const char* ptr = buf; @@ -50,64 +111,52 @@ static void dump_hex(const char* buf, size_t len) { } #endif -static void remove_from_beginning(std::vector* buffer, size_t count) { - buffer->erase(buffer->begin(), buffer->begin() + count); -} - -static void dispose_inspector(uv_handle_t* handle) { - InspectorSocket* inspector = inspector_from_stream(handle); - inspector_cb close = - inspector->ws_mode ? inspector->ws_state->close_cb : nullptr; - inspector->buffer.clear(); - delete inspector->ws_state; - inspector->ws_state = nullptr; - if (close) { - close(inspector, 0); - } -} - -static void close_connection(InspectorSocket* inspector) { - uv_handle_t* socket = reinterpret_cast(&inspector->tcp); - if (!uv_is_closing(socket)) { - uv_read_stop(reinterpret_cast(socket)); - uv_close(socket, dispose_inspector); - } -} - -struct WriteRequest { - WriteRequest(InspectorSocket* inspector, const char* data, size_t size) - : inspector(inspector) - , storage(data, data + size) - , buf(uv_buf_init(&storage[0], storage.size())) {} +class WriteRequest { + public: + WriteRequest(ProtocolHandler* handler, const std::vector& buffer) + : handler(handler) + , storage(buffer) + , buf(uv_buf_init(storage.data(), storage.size())) {} static WriteRequest* from_write_req(uv_write_t* req) { return node::ContainerOf(&WriteRequest::req, req); } - InspectorSocket* const inspector; + static void Cleanup(uv_write_t* req, int status) { + delete WriteRequest::from_write_req(req); + } + + ProtocolHandler* const handler; std::vector storage; uv_write_t req; uv_buf_t buf; }; -// Cleanup -static void write_request_cleanup(uv_write_t* req, int status) { - delete WriteRequest::from_write_req(req); +void allocate_buffer(uv_handle_t* stream, size_t len, uv_buf_t* buf) { + *buf = uv_buf_init(new char[len], len); } -static int write_to_client(InspectorSocket* inspector, - const char* msg, - size_t len, - uv_write_cb write_cb = write_request_cleanup) { -#if DUMP_WRITES - printf("%s (%ld bytes):\n", __FUNCTION__, len); - dump_hex(msg, len); -#endif +static void remove_from_beginning(std::vector* buffer, size_t count) { + buffer->erase(buffer->begin(), buffer->begin() + count); +} - // Freed in write_request_cleanup - WriteRequest* wr = new WriteRequest(inspector, msg, len); - uv_stream_t* stream = reinterpret_cast(&inspector->tcp); - return uv_write(&wr->req, stream, &wr->buf, 1, write_cb) < 0; +// Cleanup + +static const char CLOSE_FRAME[] = {'\x88', '\x00'}; + +enum ws_decode_result { + FRAME_OK, FRAME_INCOMPLETE, FRAME_CLOSE, FRAME_ERROR +}; + +static void generate_accept_string(const std::string& client_key, + char (*buffer)[ACCEPT_KEY_LENGTH]) { + // Magic string from websockets spec. + static const char ws_magic[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + std::string input(client_key + ws_magic); + char hash[SHA_DIGEST_LENGTH]; + SHA1(reinterpret_cast(&input[0]), input.size(), + reinterpret_cast(hash)); + node::base64_encode(hash, sizeof(hash), *buffer, sizeof(*buffer)); } // Constants for hybi-10 frame format. @@ -134,11 +183,11 @@ const size_t kTwoBytePayloadLengthField = 126; const size_t kEightBytePayloadLengthField = 127; const size_t kMaskingKeyWidthInBytes = 4; -static std::vector encode_frame_hybi17(const char* message, - size_t data_length) { +static std::vector encode_frame_hybi17(const std::vector& message) { std::vector frame; OpCode op_code = kOpCodeText; frame.push_back(kFinalBit | op_code); + const size_t data_length = message.size(); if (data_length <= kMaxSingleBytePayloadLength) { frame.push_back(static_cast(data_length)); } else if (data_length <= 0xFFFF) { @@ -158,7 +207,7 @@ static std::vector encode_frame_hybi17(const char* message, extended_payload_length + 8); CHECK_EQ(0, remaining); } - frame.insert(frame.end(), message, message + data_length); + frame.insert(frame.end(), message.begin(), message.end()); return frame; } @@ -248,272 +297,368 @@ static ws_decode_result decode_frame_hybi17(const std::vector& buffer, return closed ? FRAME_CLOSE : FRAME_OK; } -static void invoke_read_callback(InspectorSocket* inspector, - int status, const uv_buf_t* buf) { - if (inspector->ws_state->read_cb) { - inspector->ws_state->read_cb( - reinterpret_cast(&inspector->tcp), status, buf); + +// WS protocol +class WsHandler : public ProtocolHandler { + public: + WsHandler(InspectorSocket* inspector, TcpHolder::Pointer tcp) + : ProtocolHandler(inspector, std::move(tcp)), + OnCloseSent(&WsHandler::WaitForCloseReply), + OnCloseRecieved(&WsHandler::CloseFrameReceived), + dispose_(false) { } + + void AcceptUpgrade(const std::string& accept_key) override { } + void CancelHandshake() override {} + + void OnEof() override { + tcp_.reset(); + if (dispose_) + delete this; } -} -static void shutdown_complete(InspectorSocket* inspector) { - close_connection(inspector); -} + void OnData(std::vector* data) override { + // 1. Parse. + int processed = 0; + do { + processed = ParseWsFrames(*data); + // 2. Fix the data size & length + if (processed > 0) { + remove_from_beginning(data, processed); + } + } while (processed > 0 && !data->empty()); + } -static void on_close_frame_written(uv_write_t* req, int status) { - WriteRequest* wr = WriteRequest::from_write_req(req); - InspectorSocket* inspector = wr->inspector; - delete wr; - inspector->ws_state->close_sent = true; - if (inspector->ws_state->received_close) { - shutdown_complete(inspector); + void Write(const std::vector data) { + std::vector output = encode_frame_hybi17(data); + WriteRaw(output, WriteRequest::Cleanup); } -} -static void close_frame_received(InspectorSocket* inspector) { - inspector->ws_state->received_close = true; - if (!inspector->ws_state->close_sent) { - invoke_read_callback(inspector, 0, 0); - write_to_client(inspector, CLOSE_FRAME, sizeof(CLOSE_FRAME), - on_close_frame_written); - } else { - shutdown_complete(inspector); + protected: + void Shutdown() override { + if (tcp_) { + dispose_ = true; + SendClose(); + } else { + delete this; + } } -} -static int parse_ws_frames(InspectorSocket* inspector) { - int bytes_consumed = 0; - std::vector output; - bool compressed = false; - - ws_decode_result r = decode_frame_hybi17(inspector->buffer, - true /* client_frame */, - &bytes_consumed, &output, - &compressed); - // Compressed frame means client is ignoring the headers and misbehaves - if (compressed || r == FRAME_ERROR) { - invoke_read_callback(inspector, UV_EPROTO, nullptr); - close_connection(inspector); - bytes_consumed = 0; - } else if (r == FRAME_CLOSE) { - close_frame_received(inspector); - bytes_consumed = 0; - } else if (r == FRAME_OK && inspector->ws_state->alloc_cb - && inspector->ws_state->read_cb) { - uv_buf_t buffer; - size_t len = output.size(); - inspector->ws_state->alloc_cb( - reinterpret_cast(&inspector->tcp), - len, &buffer); - CHECK_GE(buffer.len, len); - memcpy(buffer.base, &output[0], len); - invoke_read_callback(inspector, len, &buffer); - } - return bytes_consumed; -} + private: + using Callback = void (WsHandler::*)(void); -static void prepare_buffer(uv_handle_t* stream, size_t len, uv_buf_t* buf) { - *buf = uv_buf_init(new char[len], len); -} + static void OnCloseFrameWritten(uv_write_t* req, int status) { + WriteRequest* wr = WriteRequest::from_write_req(req); + WsHandler* handler = static_cast(wr->handler); + delete wr; + Callback cb = handler->OnCloseSent; + (handler->*cb)(); + } -static void reclaim_uv_buf(InspectorSocket* inspector, const uv_buf_t* buf, - ssize_t read) { - if (read > 0) { - std::vector& buffer = inspector->buffer; - buffer.insert(buffer.end(), buf->base, buf->base + read); + void WaitForCloseReply() { + OnCloseRecieved = &WsHandler::OnEof; } - delete[] buf->base; -} -static void websockets_data_cb(uv_stream_t* stream, ssize_t nread, - const uv_buf_t* buf) { - InspectorSocket* inspector = inspector_from_stream(stream); - reclaim_uv_buf(inspector, buf, nread); - if (nread < 0 || nread == UV_EOF) { - inspector->connection_eof = true; - if (!inspector->shutting_down && inspector->ws_state->read_cb) { - inspector->ws_state->read_cb(stream, nread, nullptr); + void SendClose() { + WriteRaw(std::vector(CLOSE_FRAME, CLOSE_FRAME + sizeof(CLOSE_FRAME)), + OnCloseFrameWritten); + } + + void CloseFrameReceived() { + OnCloseSent = &WsHandler::OnEof; + SendClose(); + } + + int ParseWsFrames(const std::vector& buffer) { + int bytes_consumed = 0; + std::vector output; + bool compressed = false; + + ws_decode_result r = decode_frame_hybi17(buffer, + true /* client_frame */, + &bytes_consumed, &output, + &compressed); + // Compressed frame means client is ignoring the headers and misbehaves + if (compressed || r == FRAME_ERROR) { + OnEof(); + bytes_consumed = 0; + } else if (r == FRAME_CLOSE) { + (this->*OnCloseRecieved)(); + bytes_consumed = 0; + } else if (r == FRAME_OK) { + delegate()->OnWsFrame(output); } - if (inspector->ws_state->close_sent && - !inspector->ws_state->received_close) { - shutdown_complete(inspector); // invoke callback + return bytes_consumed; + } + + + Callback OnCloseSent; + Callback OnCloseRecieved; + bool dispose_; +}; + +// HTTP protocol +class HttpEvent { + public: + HttpEvent(const std::string& path, bool upgrade, + bool isGET, const std::string& ws_key) : path(path), + upgrade(upgrade), + isGET(isGET), + ws_key(ws_key) { } + + std::string path; + bool upgrade; + bool isGET; + std::string ws_key; + std::string current_header_; +}; + +class HttpHandler : public ProtocolHandler { + public: + explicit HttpHandler(InspectorSocket* inspector, TcpHolder::Pointer tcp) + : ProtocolHandler(inspector, std::move(tcp)), + parsing_value_(false) { + http_parser_init(&parser_, HTTP_REQUEST); + http_parser_settings_init(&parser_settings); + parser_settings.on_header_field = OnHeaderField; + parser_settings.on_header_value = OnHeaderValue; + parser_settings.on_message_complete = OnMessageComplete; + parser_settings.on_url = OnPath; + } + + void AcceptUpgrade(const std::string& accept_key) override { + char accept_string[ACCEPT_KEY_LENGTH]; + generate_accept_string(accept_key, &accept_string); + const char accept_ws_prefix[] = "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: "; + const char accept_ws_suffix[] = "\r\n\r\n"; + std::vector reply(accept_ws_prefix, + accept_ws_prefix + sizeof(accept_ws_prefix) - 1); + reply.insert(reply.end(), accept_string, + accept_string + sizeof(accept_string)); + reply.insert(reply.end(), accept_ws_suffix, + accept_ws_suffix + sizeof(accept_ws_suffix) - 1); + if (WriteRaw(reply, WriteRequest::Cleanup) >= 0) { + inspector_->SwitchProtocol(new WsHandler(inspector_, std::move(tcp_))); + } else { + tcp_.reset(); } - } else { - #if DUMP_READS - printf("%s read %ld bytes\n", __FUNCTION__, nread); - if (nread > 0) { - dump_hex(inspector->buffer.data() + inspector->buffer.size() - nread, - nread); - } - #endif - // 2. Parse. - int processed = 0; - do { - processed = parse_ws_frames(inspector); - // 3. Fix the buffer size & length - if (processed > 0) { - remove_from_beginning(&inspector->buffer, processed); + } + + void CancelHandshake() { + const char HANDSHAKE_FAILED_RESPONSE[] = + "HTTP/1.0 400 Bad Request\r\n" + "Content-Type: text/html; charset=UTF-8\r\n\r\n" + "WebSockets request was expected\r\n"; + WriteRaw(std::vector(HANDSHAKE_FAILED_RESPONSE, + HANDSHAKE_FAILED_RESPONSE + sizeof(HANDSHAKE_FAILED_RESPONSE) - 1), + ThenCloseAndReportFailure); + } + + + void OnEof() override { + tcp_.reset(); + } + + void OnData(std::vector* data) override { + http_parser_execute(&parser_, &parser_settings, data->data(), data->size()); + data->clear(); + if (parser_.http_errno != HPE_OK) { + CancelHandshake(); + } + // Event handling may delete *this + std::vector events; + std::swap(events, events_); + for (const HttpEvent& event : events) { + bool shouldContinue = event.isGET && !event.upgrade; + if (!event.isGET) { + CancelHandshake(); + } else if (!event.upgrade) { + delegate()->OnHttpGet(event.path); + } else if (event.ws_key.empty()) { + CancelHandshake(); + } else { + delegate()->OnSocketUpgrade(event.path, event.ws_key); } - } while (processed > 0 && !inspector->buffer.empty()); + if (!shouldContinue) + return; + } } -} -int inspector_read_start(InspectorSocket* inspector, - uv_alloc_cb alloc_cb, uv_read_cb read_cb) { - CHECK(inspector->ws_mode); - CHECK(!inspector->shutting_down || read_cb == nullptr); - inspector->ws_state->close_sent = false; - inspector->ws_state->alloc_cb = alloc_cb; - inspector->ws_state->read_cb = read_cb; - int err = - uv_read_start(reinterpret_cast(&inspector->tcp), - prepare_buffer, - websockets_data_cb); - if (err < 0) { - close_connection(inspector); - } - return err; -} + void Write(const std::vector data) override { + WriteRaw(data, WriteRequest::Cleanup); + } -void inspector_read_stop(InspectorSocket* inspector) { - uv_read_stop(reinterpret_cast(&inspector->tcp)); - inspector->ws_state->alloc_cb = nullptr; - inspector->ws_state->read_cb = nullptr; -} + protected: + void Shutdown() override { + delete this; + } -static void generate_accept_string(const std::string& client_key, - char (*buffer)[ACCEPT_KEY_LENGTH]) { - // Magic string from websockets spec. - static const char ws_magic[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; - std::string input(client_key + ws_magic); - char hash[SHA_DIGEST_LENGTH]; - SHA1(reinterpret_cast(&input[0]), input.size(), - reinterpret_cast(hash)); - node::base64_encode(hash, sizeof(hash), *buffer, sizeof(*buffer)); -} + private: + static void ThenCloseAndReportFailure(uv_write_t* req, int status) { + ProtocolHandler* handler = WriteRequest::from_write_req(req)->handler; + WriteRequest::Cleanup(req, status); + handler->inspector()->SwitchProtocol(nullptr); + } -static int header_value_cb(http_parser* parser, const char* at, size_t length) { - static const char SEC_WEBSOCKET_KEY_HEADER[] = "Sec-WebSocket-Key"; - auto inspector = static_cast(parser->data); - auto state = inspector->http_parsing_state; - state->parsing_value = true; - if (state->current_header.size() == sizeof(SEC_WEBSOCKET_KEY_HEADER) - 1 && - node::StringEqualNoCaseN(state->current_header.data(), - SEC_WEBSOCKET_KEY_HEADER, - sizeof(SEC_WEBSOCKET_KEY_HEADER) - 1)) { - state->ws_key.append(at, length); - } - return 0; -} + static int OnHeaderValue(http_parser* parser, const char* at, size_t length) { + static const char SEC_WEBSOCKET_KEY_HEADER[] = "Sec-WebSocket-Key"; + HttpHandler* handler = From(parser); + handler->parsing_value_ = true; + if (handler->current_header_.size() == + sizeof(SEC_WEBSOCKET_KEY_HEADER) - 1 && + node::StringEqualNoCaseN(handler->current_header_.data(), + SEC_WEBSOCKET_KEY_HEADER, + sizeof(SEC_WEBSOCKET_KEY_HEADER) - 1)) { + handler->ws_key_.append(at, length); + } + return 0; + } -static int header_field_cb(http_parser* parser, const char* at, size_t length) { - auto inspector = static_cast(parser->data); - auto state = inspector->http_parsing_state; - if (state->parsing_value) { - state->parsing_value = false; - state->current_header.clear(); + static int OnHeaderField(http_parser* parser, const char* at, size_t length) { + HttpHandler* handler = From(parser); + if (handler->parsing_value_) { + handler->parsing_value_ = false; + handler->current_header_.clear(); + } + handler->current_header_.append(at, length); + return 0; } - state->current_header.append(at, length); - return 0; -} -static int path_cb(http_parser* parser, const char* at, size_t length) { - auto inspector = static_cast(parser->data); - auto state = inspector->http_parsing_state; - state->path.append(at, length); - return 0; -} + static int OnPath(http_parser* parser, const char* at, size_t length) { + HttpHandler* handler = From(parser); + handler->path_.append(at, length); + return 0; + } + + static HttpHandler* From(http_parser* parser) { + return node::ContainerOf(&HttpHandler::parser_, parser); + } + + static int OnMessageComplete(http_parser* parser) { + // Event needs to be fired after the parser is done. + HttpHandler* handler = From(parser); + handler->events_.push_back(HttpEvent(handler->path_, parser->upgrade, + parser->method == HTTP_GET, + handler->ws_key_)); + handler->path_ = ""; + handler->ws_key_ = ""; + handler->parsing_value_ = false; + handler->current_header_ = ""; + + return 0; + } + + bool parsing_value_; + http_parser parser_; + http_parser_settings parser_settings; + std::vector events_; + std::string current_header_; + std::string ws_key_; + std::string path_; +}; + +} // namespace -static void handshake_complete(InspectorSocket* inspector) { - uv_read_stop(reinterpret_cast(&inspector->tcp)); - handshake_cb callback = inspector->http_parsing_state->callback; - inspector->ws_state = new ws_state_s(); - inspector->ws_mode = true; - callback(inspector, kInspectorHandshakeUpgraded, - inspector->http_parsing_state->path); +// Any protocol +ProtocolHandler::ProtocolHandler(InspectorSocket* inspector, + TcpHolder::Pointer tcp) + : inspector_(inspector), tcp_(std::move(tcp)) { + CHECK_NE(nullptr, tcp_); + tcp_->SetHandler(this); } -static void cleanup_http_parsing_state(InspectorSocket* inspector) { - delete inspector->http_parsing_state; - inspector->http_parsing_state = nullptr; +int ProtocolHandler::WriteRaw(const std::vector& buffer, + uv_write_cb write_cb) { + return tcp_->WriteRaw(buffer, write_cb); } -static void report_handshake_failure_cb(uv_handle_t* handle) { - dispose_inspector(handle); - InspectorSocket* inspector = inspector_from_stream(handle); - handshake_cb cb = inspector->http_parsing_state->callback; - cleanup_http_parsing_state(inspector); - cb(inspector, kInspectorHandshakeFailed, std::string()); +InspectorSocket::Delegate* ProtocolHandler::delegate() { + return tcp_->delegate(); } -static void close_and_report_handshake_failure(InspectorSocket* inspector) { - uv_handle_t* socket = reinterpret_cast(&inspector->tcp); - if (uv_is_closing(socket)) { - report_handshake_failure_cb(socket); +std::string ProtocolHandler::GetHost() { + char ip[INET6_ADDRSTRLEN]; + sockaddr_storage addr; + int len = sizeof(addr); + int err = uv_tcp_getsockname(tcp_->tcp(), + reinterpret_cast(&addr), + &len); + if (err != 0) + return ""; + if (addr.ss_family == AF_INET6) { + const sockaddr_in6* v6 = reinterpret_cast(&addr); + err = uv_ip6_name(v6, ip, sizeof(ip)); } else { - uv_read_stop(reinterpret_cast(socket)); - uv_close(socket, report_handshake_failure_cb); + const sockaddr_in* v4 = reinterpret_cast(&addr); + err = uv_ip4_name(v4, ip, sizeof(ip)); + } + if (err != 0) + return ""; + return ip; +} + +// RAII uv_tcp_t wrapper +TcpHolder::TcpHolder(InspectorSocket::DelegatePointer delegate) + : tcp_(), + delegate_(std::move(delegate)), + handler_(nullptr) { } + +// static +TcpHolder::Pointer TcpHolder::Accept( + uv_stream_t* server, + InspectorSocket::DelegatePointer delegate) { + TcpHolder* result = new TcpHolder(std::move(delegate)); + uv_stream_t* tcp = reinterpret_cast(&result->tcp_); + int err = uv_tcp_init(server->loop, &result->tcp_); + if (err == 0) { + err = uv_accept(server, tcp); + } + if (err == 0) { + err = uv_read_start(tcp, allocate_buffer, OnDataReceivedCb); + } + if (err == 0) { + return { result, DisconnectAndDispose }; + } else { + fprintf(stderr, "[%s:%d@%s]\n", __FILE__, __LINE__, __FUNCTION__); + + delete result; + return { nullptr, nullptr }; } } -static void then_close_and_report_failure(uv_write_t* req, int status) { - InspectorSocket* inspector = WriteRequest::from_write_req(req)->inspector; - write_request_cleanup(req, status); - close_and_report_handshake_failure(inspector); +void TcpHolder::SetHandler(ProtocolHandler* handler) { + handler_ = handler; } -static void handshake_failed(InspectorSocket* inspector) { - const char HANDSHAKE_FAILED_RESPONSE[] = - "HTTP/1.0 400 Bad Request\r\n" - "Content-Type: text/html; charset=UTF-8\r\n\r\n" - "WebSockets request was expected\r\n"; - write_to_client(inspector, HANDSHAKE_FAILED_RESPONSE, - sizeof(HANDSHAKE_FAILED_RESPONSE) - 1, - then_close_and_report_failure); +int TcpHolder::WriteRaw(const std::vector& buffer, uv_write_cb write_cb) { +#if DUMP_WRITES + printf("%s (%ld bytes):\n", __FUNCTION__, buffer.size()); + dump_hex(buffer.data(), buffer.size()); + printf("\n"); +#endif + + // Freed in write_request_cleanup + WriteRequest* wr = new WriteRequest(handler_, buffer); + uv_stream_t* stream = reinterpret_cast(&tcp_); + int err = uv_write(&wr->req, stream, &wr->buf, 1, write_cb); + if (err < 0) + delete wr; + return err < 0; } -// init_handshake references message_complete_cb -static void init_handshake(InspectorSocket* socket); - -static int message_complete_cb(http_parser* parser) { - InspectorSocket* inspector = static_cast(parser->data); - struct http_parsing_state_s* state = inspector->http_parsing_state; - if (parser->method != HTTP_GET) { - handshake_failed(inspector); - } else if (!parser->upgrade) { - if (state->callback(inspector, kInspectorHandshakeHttpGet, state->path)) { - init_handshake(inspector); - } else { - handshake_failed(inspector); - } - } else if (state->ws_key.empty()) { - handshake_failed(inspector); - } else if (state->callback(inspector, kInspectorHandshakeUpgrading, - state->path)) { - char accept_string[ACCEPT_KEY_LENGTH]; - generate_accept_string(state->ws_key, &accept_string); - const char accept_ws_prefix[] = "HTTP/1.1 101 Switching Protocols\r\n" - "Upgrade: websocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Accept: "; - const char accept_ws_suffix[] = "\r\n\r\n"; - std::string reply(accept_ws_prefix, sizeof(accept_ws_prefix) - 1); - reply.append(accept_string, sizeof(accept_string)); - reply.append(accept_ws_suffix, sizeof(accept_ws_suffix) - 1); - if (write_to_client(inspector, &reply[0], reply.size()) >= 0) { - handshake_complete(inspector); - inspector->http_parsing_state->done = true; - } else { - close_and_report_handshake_failure(inspector); - } - } else { - handshake_failed(inspector); - } - return 0; +InspectorSocket::Delegate* TcpHolder::delegate() { + return delegate_.get(); +} + +// static +void TcpHolder::OnClosed(uv_handle_t* handle) { + delete From(handle); } -static void data_received_cb(uv_stream_t* tcp, ssize_t nread, - const uv_buf_t* buf) { +void TcpHolder::OnDataReceivedCb(uv_stream_t* tcp, ssize_t nread, + const uv_buf_t* buf) { #if DUMP_READS if (nread >= 0) { printf("%s (%ld bytes)\n", __FUNCTION__, nread); @@ -522,107 +667,65 @@ static void data_received_cb(uv_stream_t* tcp, ssize_t nread, printf("[%s:%d] %s\n", __FUNCTION__, __LINE__, uv_err_name(nread)); } #endif - InspectorSocket* inspector = inspector_from_stream(tcp); - reclaim_uv_buf(inspector, buf, nread); + TcpHolder* holder = From(tcp); + holder->ReclaimUvBuf(buf, nread); if (nread < 0 || nread == UV_EOF) { - close_and_report_handshake_failure(inspector); + holder->handler_->OnEof(); } else { - http_parsing_state_s* state = inspector->http_parsing_state; - http_parser* parser = &state->parser; - http_parser_execute(parser, &state->parser_settings, - inspector->buffer.data(), nread); - remove_from_beginning(&inspector->buffer, nread); - if (parser->http_errno != HPE_OK) { - handshake_failed(inspector); - } - if (inspector->http_parsing_state->done) { - cleanup_http_parsing_state(inspector); - } + holder->handler_->OnData(&holder->buffer); } } -static void init_handshake(InspectorSocket* socket) { - http_parsing_state_s* state = socket->http_parsing_state; - CHECK_NE(state, nullptr); - state->current_header.clear(); - state->ws_key.clear(); - state->path.clear(); - state->done = false; - http_parser_init(&state->parser, HTTP_REQUEST); - state->parser.data = socket; - http_parser_settings* settings = &state->parser_settings; - http_parser_settings_init(settings); - settings->on_header_field = header_field_cb; - settings->on_header_value = header_value_cb; - settings->on_message_complete = message_complete_cb; - settings->on_url = path_cb; +// static +void TcpHolder::DisconnectAndDispose(TcpHolder* holder) { + uv_handle_t* handle = reinterpret_cast(&holder->tcp_); + uv_close(handle, OnClosed); } -int inspector_accept(uv_stream_t* server, InspectorSocket* socket, - handshake_cb callback) { - CHECK_NE(callback, nullptr); - CHECK_EQ(socket->http_parsing_state, nullptr); - - socket->http_parsing_state = new http_parsing_state_s(); - uv_stream_t* tcp = reinterpret_cast(&socket->tcp); - int err = uv_tcp_init(server->loop, &socket->tcp); - - if (err == 0) { - err = uv_accept(server, tcp); - } - if (err == 0) { - init_handshake(socket); - socket->http_parsing_state->callback = callback; - err = uv_read_start(tcp, prepare_buffer, - data_received_cb); - } - if (err != 0) { - uv_close(reinterpret_cast(tcp), nullptr); +void TcpHolder::ReclaimUvBuf(const uv_buf_t* buf, ssize_t read) { + if (read > 0) { + buffer.insert(buffer.end(), buf->base, buf->base + read); } - return err; + delete[] buf->base; } -void inspector_write(InspectorSocket* inspector, const char* data, - size_t len) { - if (inspector->ws_mode) { - std::vector output = encode_frame_hybi17(data, len); - write_to_client(inspector, &output[0], output.size()); +// Public interface +InspectorSocket::InspectorSocket() + : protocol_handler_(nullptr, ProtocolHandler::Shutdown) { } + +InspectorSocket::~InspectorSocket() = default; + +// static +InspectorSocket::Pointer InspectorSocket::Accept(uv_stream_t* server, + DelegatePointer delegate) { + auto tcp = TcpHolder::Accept(server, std::move(delegate)); + if (tcp) { + InspectorSocket* inspector = new InspectorSocket(); + inspector->SwitchProtocol(new HttpHandler(inspector, std::move(tcp))); + return InspectorSocket::Pointer(inspector); } else { - write_to_client(inspector, data, len); + return InspectorSocket::Pointer(nullptr); } } -void inspector_close(InspectorSocket* inspector, - inspector_cb callback) { - // libuv throws assertions when closing stream that's already closed - we - // need to do the same. - CHECK(!uv_is_closing(reinterpret_cast(&inspector->tcp))); - CHECK(!inspector->shutting_down); - inspector->shutting_down = true; - inspector->ws_state->close_cb = callback; - if (inspector->connection_eof) { - close_connection(inspector); - } else { - inspector_read_stop(inspector); - write_to_client(inspector, CLOSE_FRAME, sizeof(CLOSE_FRAME), - on_close_frame_written); - inspector_read_start(inspector, nullptr, nullptr); - } +void InspectorSocket::AcceptUpgrade(const std::string& ws_key) { + protocol_handler_->AcceptUpgrade(ws_key); +} + +void InspectorSocket::CancelHandshake() { + protocol_handler_->CancelHandshake(); +} + +std::string InspectorSocket::GetHost() { + return protocol_handler_->GetHost(); } -bool inspector_is_active(const InspectorSocket* inspector) { - const uv_handle_t* tcp = - reinterpret_cast(&inspector->tcp); - return !inspector->shutting_down && !uv_is_closing(tcp); +void InspectorSocket::SwitchProtocol(ProtocolHandler* handler) { + protocol_handler_.reset(std::move(handler)); } -void InspectorSocket::reinit() { - http_parsing_state = nullptr; - ws_state = nullptr; - buffer.clear(); - ws_mode = false; - shutting_down = false; - connection_eof = false; +void InspectorSocket::Write(const char* data, size_t len) { + protocol_handler_->Write(std::vector(data, data + len)); } } // namespace inspector diff --git a/src/inspector_socket.h b/src/inspector_socket.h index f93150d6f9a1cf..1a3411435ee2c5 100644 --- a/src/inspector_socket.h +++ b/src/inspector_socket.h @@ -1,7 +1,6 @@ #ifndef SRC_INSPECTOR_SOCKET_H_ #define SRC_INSPECTOR_SOCKET_H_ -#include "http_parser.h" #include "util-inl.h" #include "uv.h" @@ -11,88 +10,41 @@ namespace node { namespace inspector { -enum inspector_handshake_event { - kInspectorHandshakeUpgrading, - kInspectorHandshakeUpgraded, - kInspectorHandshakeHttpGet, - kInspectorHandshakeFailed -}; - -class InspectorSocket; - -typedef void (*inspector_cb)(InspectorSocket*, int); -// Notifies as handshake is progressing. Returning false as a response to -// kInspectorHandshakeUpgrading or kInspectorHandshakeHttpGet event will abort -// the connection. inspector_write can be used from the callback. -typedef bool (*handshake_cb)(InspectorSocket*, - enum inspector_handshake_event state, - const std::string& path); - -struct http_parsing_state_s { - http_parser parser; - http_parser_settings parser_settings; - handshake_cb callback; - bool done; - bool parsing_value; - std::string ws_key; - std::string path; - std::string current_header; -}; - -struct ws_state_s { - uv_alloc_cb alloc_cb; - uv_read_cb read_cb; - inspector_cb close_cb; - bool close_sent; - bool received_close; -}; +class ProtocolHandler; // HTTP Wrapper around a uv_tcp_t class InspectorSocket { public: - InspectorSocket() : data(nullptr), http_parsing_state(nullptr), - ws_state(nullptr), buffer(0), ws_mode(false), - shutting_down(false), connection_eof(false) { } - void reinit(); - void* data; - struct http_parsing_state_s* http_parsing_state; - struct ws_state_s* ws_state; - std::vector buffer; - uv_tcp_t tcp; - bool ws_mode; - bool shutting_down; - bool connection_eof; + class Delegate { + public: + virtual void OnHttpGet(const std::string& path) = 0; + virtual void OnSocketUpgrade(const std::string& path, + const std::string& accept_key) = 0; + virtual void OnWsFrame(const std::vector& frame) = 0; + virtual ~Delegate() {} + }; - private: - DISALLOW_COPY_AND_ASSIGN(InspectorSocket); -}; + using DelegatePointer = std::unique_ptr; + using Pointer = std::unique_ptr; + + static Pointer Accept(uv_stream_t* server, DelegatePointer delegate); -int inspector_accept(uv_stream_t* server, InspectorSocket* inspector, - handshake_cb callback); + ~InspectorSocket(); -void inspector_close(InspectorSocket* inspector, - inspector_cb callback); + void AcceptUpgrade(const std::string& accept_key); + void CancelHandshake(); + void Write(const char* data, size_t len); + void SwitchProtocol(ProtocolHandler* handler); + std::string GetHost(); -// Callbacks will receive stream handles. Use inspector_from_stream to get -// InspectorSocket* from the stream handle. -int inspector_read_start(InspectorSocket* inspector, uv_alloc_cb, - uv_read_cb); -void inspector_read_stop(InspectorSocket* inspector); -void inspector_write(InspectorSocket* inspector, - const char* data, size_t len); -bool inspector_is_active(const InspectorSocket* inspector); + private: + InspectorSocket(); -inline InspectorSocket* inspector_from_stream(uv_tcp_t* stream) { - return node::ContainerOf(&InspectorSocket::tcp, stream); -} + std::unique_ptr protocol_handler_; -inline InspectorSocket* inspector_from_stream(uv_stream_t* stream) { - return inspector_from_stream(reinterpret_cast(stream)); -} + DISALLOW_COPY_AND_ASSIGN(InspectorSocket); +}; -inline InspectorSocket* inspector_from_stream(uv_handle_t* stream) { - return inspector_from_stream(reinterpret_cast(stream)); -} } // namespace inspector } // namespace node diff --git a/src/inspector_socket_server.cc b/src/inspector_socket_server.cc index 958c41a654adff..bf114251cb134b 100644 --- a/src/inspector_socket_server.cc +++ b/src/inspector_socket_server.cc @@ -33,7 +33,6 @@ std::string FormatWsAddress(const std::string& host, int port, return url.str(); } - namespace { static const uint8_t PROTOCOL_JSON[] = { @@ -114,8 +113,8 @@ void SendHttpResponse(InspectorSocket* socket, const std::string& response) { "\r\n"; char header[sizeof(HEADERS) + 20]; int header_len = snprintf(header, sizeof(header), HEADERS, response.size()); - inspector_write(socket, header, header_len); - inspector_write(socket, response.data(), response.size()); + socket->Write(header, header_len); + socket->Write(response.data(), response.size()); } void SendVersionResponse(InspectorSocket* socket) { @@ -145,28 +144,6 @@ void SendProtocolJson(InspectorSocket* socket) { CHECK_EQ(Z_OK, inflateEnd(&strm)); SendHttpResponse(socket, data); } - -int GetSocketHost(uv_tcp_t* socket, std::string* out_host) { - char ip[INET6_ADDRSTRLEN]; - sockaddr_storage addr; - int len = sizeof(addr); - int err = uv_tcp_getsockname(socket, - reinterpret_cast(&addr), - &len); - if (err != 0) - return err; - if (addr.ss_family == AF_INET6) { - const sockaddr_in6* v6 = reinterpret_cast(&addr); - err = uv_ip6_name(v6, ip, sizeof(ip)); - } else { - const sockaddr_in* v4 = reinterpret_cast(&addr); - err = uv_ip4_name(v4, ip, sizeof(ip)); - } - if (err != 0) - return err; - *out_host = ip; - return err; -} } // namespace @@ -209,46 +186,58 @@ class Closer { class SocketSession { public: - static int Accept(InspectorSocketServer* server, int server_port, - uv_stream_t* server_socket); + SocketSession(InspectorSocketServer* server, int id, int server_port); + void Close() { + ws_socket_.reset(); + } void Send(const std::string& message); - void Close(); - + void Own(InspectorSocket::Pointer ws_socket) { + ws_socket_ = std::move(ws_socket); + } int id() const { return id_; } - bool IsForTarget(const std::string& target_id) const { - return target_id_ == target_id; + int server_port() { + return server_port_; } - static int ServerPortForClient(InspectorSocket* client) { - return From(client)->server_port_; + InspectorSocket* ws_socket() { + return ws_socket_.get(); } - - private: - SocketSession(InspectorSocketServer* server, int server_port); - static SocketSession* From(InspectorSocket* socket) { - return node::ContainerOf(&SocketSession::socket_, socket); + void set_ws_key(const std::string& ws_key) { + ws_key_ = ws_key; + } + void Accept() { + ws_socket_->AcceptUpgrade(ws_key_); + } + void Decline() { + ws_socket_->CancelHandshake(); } - enum class State { kHttp, kWebSocket, kClosing, kEOF, kDeclined }; - static bool HandshakeCallback(InspectorSocket* socket, - enum inspector_handshake_event state, - const std::string& path); - static void ReadCallback(uv_stream_t* stream, ssize_t read, - const uv_buf_t* buf); - static void CloseCallback(InspectorSocket* socket, int code); + class Delegate : public InspectorSocket::Delegate { + public: + Delegate(InspectorSocketServer* server, int session_id) + : server_(server), session_id_(session_id) { } + ~Delegate() { + server_->SessionTerminated(session_id_); + } + void OnHttpGet(const std::string& path) override; + void OnSocketUpgrade(const std::string& path, + const std::string& ws_key) override; + void OnWsFrame(const std::vector& data) override; + + private: + SocketSession* Session() { + return server_->Session(session_id_); + } - void FrontendConnected(); - void SetDeclined() { state_ = State::kDeclined; } - void SetTargetId(const std::string& target_id) { - CHECK(target_id_.empty()); - target_id_ = target_id; - } + InspectorSocketServer* server_; + int session_id_; + }; + private: const int id_; - InspectorSocket socket_; + InspectorSocket::Pointer ws_socket_; InspectorSocketServer* server_; - std::string target_id_; - State state_; const int server_port_; + std::string ws_key_; }; class ServerSocket { @@ -269,7 +258,6 @@ class ServerSocket { return node::ContainerOf(&ServerSocket::tcp_socket_, reinterpret_cast(socket)); } - static void SocketConnectedCallback(uv_stream_t* tcp_socket, int status); static void SocketClosedCallback(uv_handle_t* tcp_socket); static void FreeOnCloseCallback(uv_handle_t* tcp_socket_) { @@ -296,41 +284,57 @@ InspectorSocketServer::InspectorSocketServer(SocketServerDelegate* delegate, state_ = ServerState::kNew; } -bool InspectorSocketServer::SessionStarted(SocketSession* session, - const std::string& id) { - if (TargetExists(id) && delegate_->StartSession(session->id(), id)) { - connected_sessions_[session->id()] = session; - return true; - } else { - return false; +InspectorSocketServer::~InspectorSocketServer() = default; + +SocketSession* InspectorSocketServer::Session(int session_id) { + auto it = connected_sessions_.find(session_id); + return it == connected_sessions_.end() ? nullptr : it->second.second.get(); +} + +void InspectorSocketServer::SessionStarted(int session_id, + const std::string& id, + const std::string& ws_key) { + SocketSession* session = Session(session_id); + if (!TargetExists(id)) { + Session(session_id)->Decline(); + return; } + connected_sessions_[session_id].first = id; + session->set_ws_key(ws_key); + delegate_->StartSession(session_id, id); } -void InspectorSocketServer::SessionTerminated(SocketSession* session) { - int id = session->id(); - if (connected_sessions_.erase(id) != 0) { - delegate_->EndSession(id); - if (connected_sessions_.empty()) { - if (state_ == ServerState::kRunning && !server_sockets_.empty()) { - PrintDebuggerReadyMessage(host_, server_sockets_[0]->port(), - delegate_->GetTargetIds(), out_); - } - if (state_ == ServerState::kStopped) { - delegate_->ServerDone(); - } +void InspectorSocketServer::SessionTerminated(int session_id) { + if (Session(session_id) == nullptr) { + return; + } + bool was_attached = connected_sessions_[session_id].first != ""; + if (was_attached) { + delegate_->EndSession(session_id); + } + connected_sessions_.erase(session_id); + if (connected_sessions_.empty()) { + if (was_attached && state_ == ServerState::kRunning + && !server_sockets_.empty()) { + PrintDebuggerReadyMessage(host_, server_sockets_[0]->port(), + delegate_->GetTargetIds(), out_); + } + if (state_ == ServerState::kStopped) { + delegate_->ServerDone(); } } - delete session; } -bool InspectorSocketServer::HandleGetRequest(InspectorSocket* socket, +bool InspectorSocketServer::HandleGetRequest(int session_id, const std::string& path) { + SocketSession* session = Session(session_id); + InspectorSocket* socket = session->ws_socket(); const char* command = MatchPathSegment(path.c_str(), "/json"); if (command == nullptr) return false; if (MatchPathSegment(command, "list") || command[0] == '\0') { - SendListResponse(socket); + SendListResponse(socket, session); return true; } else if (MatchPathSegment(command, "protocol")) { SendProtocolJson(socket); @@ -348,7 +352,8 @@ bool InspectorSocketServer::HandleGetRequest(InspectorSocket* socket, return false; } -void InspectorSocketServer::SendListResponse(InspectorSocket* socket) { +void InspectorSocketServer::SendListResponse(InspectorSocket* socket, + SocketSession* session) { std::vector> response; for (const std::string& id : delegate_->GetTargetIds()) { response.push_back(std::map()); @@ -366,15 +371,14 @@ void InspectorSocketServer::SendListResponse(InspectorSocket* socket) { bool connected = false; for (const auto& session : connected_sessions_) { - if (session.second->IsForTarget(id)) { + if (session.second.first == id) { connected = true; break; } } if (!connected) { - std::string host; - int port = SocketSession::ServerPortForClient(socket); - GetSocketHost(&socket->tcp, &host); + std::string host = socket->GetHost(); + int port = session->server_port(); std::ostringstream frontend_url; frontend_url << "chrome-devtools://devtools/bundled"; frontend_url << "/inspector.html?experiments=true&v8only=true&ws="; @@ -444,9 +448,8 @@ void InspectorSocketServer::Stop(ServerCallback cb) { } void InspectorSocketServer::TerminateConnections() { - for (const auto& session : connected_sessions_) { - session.second->Close(); - } + for (const auto& key_value : connected_sessions_) + key_value.second.second->Close(); } bool InspectorSocketServer::TargetExists(const std::string& id) { @@ -455,13 +458,6 @@ bool InspectorSocketServer::TargetExists(const std::string& id) { return found != target_ids.end(); } -void InspectorSocketServer::Send(int session_id, const std::string& message) { - auto session_iterator = connected_sessions_.find(session_id); - if (session_iterator != connected_sessions_.end()) { - session_iterator->second->Send(message); - } -} - void InspectorSocketServer::ServerSocketListening(ServerSocket* server_socket) { server_sockets_.push_back(server_socket); } @@ -491,92 +487,73 @@ int InspectorSocketServer::Port() const { return port_; } -// InspectorSession tracking -SocketSession::SocketSession(InspectorSocketServer* server, int server_port) - : id_(server->GenerateSessionId()), - server_(server), - state_(State::kHttp), - server_port_(server_port) { } +void InspectorSocketServer::Accept(int server_port, + uv_stream_t* server_socket) { + std::unique_ptr session( + new SocketSession(this, next_session_id_++, server_port)); + + InspectorSocket::DelegatePointer delegate = + InspectorSocket::DelegatePointer( + new SocketSession::Delegate(this, session->id())); -void SocketSession::Close() { - CHECK_NE(state_, State::kClosing); - state_ = State::kClosing; - inspector_close(&socket_, CloseCallback); + InspectorSocket::Pointer inspector = + InspectorSocket::Accept(server_socket, std::move(delegate)); + if (inspector) { + session->Own(std::move(inspector)); + connected_sessions_[session->id()].second = std::move(session); + } } -// static -int SocketSession::Accept(InspectorSocketServer* server, int server_port, - uv_stream_t* server_socket) { - // Memory is freed when the socket closes. - SocketSession* session = new SocketSession(server, server_port); - int err = inspector_accept(server_socket, &session->socket_, - HandshakeCallback); - if (err != 0) { - delete session; +void InspectorSocketServer::AcceptSession(int session_id) { + SocketSession* session = Session(session_id); + if (session == nullptr) { + delegate_->EndSession(session_id); + } else { + session->Accept(); } - return err; } -// static -bool SocketSession::HandshakeCallback(InspectorSocket* socket, - inspector_handshake_event event, - const std::string& path) { - SocketSession* session = SocketSession::From(socket); - InspectorSocketServer* server = session->server_; - const std::string& id = path.empty() ? path : path.substr(1); - switch (event) { - case kInspectorHandshakeHttpGet: - return server->HandleGetRequest(socket, path); - case kInspectorHandshakeUpgrading: - if (server->SessionStarted(session, id)) { - session->SetTargetId(id); - return true; - } else { - session->SetDeclined(); - return false; - } - case kInspectorHandshakeUpgraded: - session->FrontendConnected(); - return true; - case kInspectorHandshakeFailed: - server->SessionTerminated(session); - return false; - default: - UNREACHABLE(); - return false; +void InspectorSocketServer::DeclineSession(int session_id) { + auto it = connected_sessions_.find(session_id); + if (it != connected_sessions_.end()) { + it->second.first.clear(); + it->second.second->Decline(); } } -// static -void SocketSession::CloseCallback(InspectorSocket* socket, int code) { - SocketSession* session = SocketSession::From(socket); - CHECK_EQ(State::kClosing, session->state_); - session->server_->SessionTerminated(session); +void InspectorSocketServer::Send(int session_id, const std::string& message) { + SocketSession* session = Session(session_id); + if (session != nullptr) { + session->Send(message); + } +} + +// InspectorSession tracking +SocketSession::SocketSession(InspectorSocketServer* server, int id, + int server_port) + : id_(id), + server_(server), + server_port_(server_port) { } + + +void SocketSession::Send(const std::string& message) { + ws_socket_->Write(message.data(), message.length()); } -void SocketSession::FrontendConnected() { - CHECK_EQ(State::kHttp, state_); - state_ = State::kWebSocket; - inspector_read_start(&socket_, OnBufferAlloc, ReadCallback); +void SocketSession::Delegate::OnHttpGet(const std::string& path) { + if (!server_->HandleGetRequest(session_id_, path)) + Session()->ws_socket()->CancelHandshake(); } -// static -void SocketSession::ReadCallback(uv_stream_t* stream, ssize_t read, - const uv_buf_t* buf) { - InspectorSocket* socket = inspector_from_stream(stream); - SocketSession* session = SocketSession::From(socket); - if (read > 0) { - session->server_->MessageReceived(session->id_, - std::string(buf->base, read)); - } else { - session->Close(); - } - if (buf != nullptr && buf->base != nullptr) - delete[] buf->base; +void SocketSession::Delegate::OnSocketUpgrade(const std::string& path, + const std::string& ws_key) { + std::string id = path.empty() ? path : path.substr(1); + server_->SessionStarted(session_id_, id, ws_key); } -void SocketSession::Send(const std::string& message) { - inspector_write(&socket_, message.data(), message.length()); +void SocketSession::Delegate::OnWsFrame(const std::vector& data) { + server_->MessageReceived(session_id_, + std::string(data.data(), data.size())); } // ServerSocket implementation @@ -624,8 +601,7 @@ void ServerSocket::SocketConnectedCallback(uv_stream_t* tcp_socket, if (status == 0) { ServerSocket* server_socket = ServerSocket::FromTcpSocket(tcp_socket); // Memory is freed when the socket closes. - SocketSession::Accept(server_socket->server_, server_socket->port_, - tcp_socket); + server_socket->server_->Accept(server_socket->port_, tcp_socket); } } diff --git a/src/inspector_socket_server.h b/src/inspector_socket_server.h index 16b047da333f68..b193e33a46d6d3 100644 --- a/src/inspector_socket_server.h +++ b/src/inspector_socket_server.h @@ -22,7 +22,7 @@ class ServerSocket; class SocketServerDelegate { public: - virtual bool StartSession(int session_id, const std::string& target_id) = 0; + virtual void StartSession(int session_id, const std::string& target_id) = 0; virtual void EndSession(int session_id) = 0; virtual void MessageReceived(int session_id, const std::string& message) = 0; virtual std::vector GetTargetIds() = 0; @@ -34,8 +34,6 @@ class SocketServerDelegate { // HTTP Server, writes messages requested as TransportActions, and responds // to HTTP requests and WS upgrades. - - class InspectorSocketServer { public: using ServerCallback = void (*)(InspectorSocketServer*); @@ -44,6 +42,8 @@ class InspectorSocketServer { const std::string& host, int port, FILE* out = stderr); + ~InspectorSocketServer(); + // Start listening on host/port bool Start(); @@ -54,6 +54,10 @@ class InspectorSocketServer { void Send(int session_id, const std::string& message); // kKill void TerminateConnections(); + // kAcceptSession + void AcceptSession(int session_id); + // kDeclineSession + void DeclineSession(int session_id); int Port() const; @@ -62,19 +66,18 @@ class InspectorSocketServer { void ServerSocketClosed(ServerSocket* server_socket); // Session connection lifecycle - bool HandleGetRequest(InspectorSocket* socket, const std::string& path); - bool SessionStarted(SocketSession* session, const std::string& id); - void SessionTerminated(SocketSession* session); + void Accept(int server_port, uv_stream_t* server_socket); + bool HandleGetRequest(int session_id, const std::string& path); + void SessionStarted(int session_id, const std::string& target_id, + const std::string& ws_id); + void SessionTerminated(int session_id); void MessageReceived(int session_id, const std::string& message) { delegate_->MessageReceived(session_id, message); } - - int GenerateSessionId() { - return next_session_id_++; - } + SocketSession* Session(int session_id); private: - void SendListResponse(InspectorSocket* socket); + void SendListResponse(InspectorSocket* socket, SocketSession* session); bool TargetExists(const std::string& id); enum class ServerState {kNew, kRunning, kStopping, kStopped}; @@ -85,7 +88,8 @@ class InspectorSocketServer { std::string path_; std::vector server_sockets_; Closer* closer_; - std::map connected_sessions_; + std::map>> + connected_sessions_; int next_session_id_; FILE* out_; ServerState state_; diff --git a/src/node.cc b/src/node.cc index f9b67dd675dc13..fed5fe8681db5b 100644 --- a/src/node.cc +++ b/src/node.cc @@ -2341,7 +2341,7 @@ static void InitGroups(const FunctionCallbackInfo& args) { static void WaitForInspectorDisconnect(Environment* env) { #if HAVE_INSPECTOR - if (env->inspector_agent()->IsConnected()) { + if (env->inspector_agent()->delegate() != nullptr) { // Restore signal dispositions, the app is done and is no longer // capable of handling signals. #if defined(__POSIX__) && !defined(NODE_SHARED_MODE) diff --git a/test/cctest/test_inspector_socket.cc b/test/cctest/test_inspector_socket.cc index d1f5b4a98ac37a..943109b8a594d2 100644 --- a/test/cctest/test_inspector_socket.cc +++ b/test/cctest/test_inspector_socket.cc @@ -1,57 +1,17 @@ #include "inspector_socket.h" #include "gtest/gtest.h" +#include + #define PORT 9444 namespace { using node::inspector::InspectorSocket; -using node::inspector::inspector_from_stream; -using node::inspector::inspector_handshake_event; -using node::inspector::kInspectorHandshakeFailed; -using node::inspector::kInspectorHandshakeHttpGet; -using node::inspector::kInspectorHandshakeUpgraded; -using node::inspector::kInspectorHandshakeUpgrading; static const int MAX_LOOP_ITERATIONS = 10000; -#define SPIN_WHILE(condition) \ - { \ - Timeout timeout(&loop); \ - while ((condition) && !timeout.timed_out) { \ - uv_run(&loop, UV_RUN_NOWAIT); \ - } \ - ASSERT_FALSE((condition)); \ - } - -static bool connected = false; -static bool inspector_ready = false; -static int handshake_events = 0; -static enum inspector_handshake_event last_event = kInspectorHandshakeHttpGet; static uv_loop_t loop; -static uv_tcp_t server, client_socket; -static InspectorSocket inspector; -static std::string last_path; // NOLINT(runtime/string) -static void (*handshake_delegate)(enum inspector_handshake_event state, - const std::string& path, - bool* should_continue); -static const char SERVER_CLOSE_FRAME[] = {'\x88', '\x00'}; - - -struct read_expects { - const char* expected; - size_t expected_len; - size_t pos; - bool read_expected; - bool callback_called; -}; - -static const char HANDSHAKE_REQ[] = "GET /ws/path HTTP/1.1\r\n" - "Host: localhost:9222\r\n" - "Upgrade: websocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Key: aaa==\r\n" - "Sec-WebSocket-Version: 13\r\n\r\n"; class Timeout { public: @@ -86,16 +46,176 @@ class Timeout { uv_timer_t timer_; }; -static void stop_if_stop_path(enum inspector_handshake_event state, - const std::string& path, bool* cont) { - *cont = path.empty() || path != "/close"; +#define SPIN_WHILE(condition) \ + { \ + Timeout timeout(&loop); \ + while ((condition) && !timeout.timed_out) { \ + uv_run(&loop, UV_RUN_NOWAIT); \ + } \ + ASSERT_FALSE((condition)); \ + } + +enum inspector_handshake_event { + kInspectorHandshakeHttpGet, + kInspectorHandshakeUpgraded, + kInspectorHandshakeNoEvents +}; + +struct expectations { + std::string actual_data; + size_t actual_offset; + size_t actual_end; + int err_code; +}; + +static bool waiting_to_close = true; + +void handle_closed(uv_handle_t* handle) { + waiting_to_close = false; } -static bool connected_cb(InspectorSocket* socket, - enum inspector_handshake_event state, - const std::string& path) { - inspector_ready = state == kInspectorHandshakeUpgraded; - last_event = state; +static void really_close(uv_handle_t* handle) { + waiting_to_close = true; + if (!uv_is_closing(handle)) { + uv_close(handle, handle_closed); + SPIN_WHILE(waiting_to_close); + } +} + +static void buffer_alloc_cb(uv_handle_t* stream, size_t len, uv_buf_t* buf) { + buf->base = new char[len]; + buf->len = len; +} + +class TestInspectorDelegate; + +static TestInspectorDelegate* delegate = nullptr; + +// Gtest asserts can't be used in dtor directly. +static void assert_is_delegate(TestInspectorDelegate* d) { + GTEST_ASSERT_EQ(delegate, d); +} + +class TestInspectorDelegate : public InspectorSocket::Delegate { + public: + using delegate_fn = void(*)(inspector_handshake_event, const std::string&, + bool* should_continue); + + TestInspectorDelegate() : inspector_ready(false), + last_event(kInspectorHandshakeNoEvents), + handshake_events(0), + handshake_delegate_(stop_if_stop_path), + fail_on_ws_frame_(false) { } + + ~TestInspectorDelegate() { + assert_is_delegate(this); + delegate = nullptr; + } + + void OnHttpGet(const std::string& path) override { + process(kInspectorHandshakeHttpGet, path); + } + + void OnSocketUpgrade(const std::string& path, + const std::string& ws_key) override { + ws_key_ = ws_key; + process(kInspectorHandshakeUpgraded, path); + } + + void OnWsFrame(const std::vector& buffer) override { + ASSERT_FALSE(fail_on_ws_frame_); + frames.push(buffer); + } + + void SetDelegate(delegate_fn d) { + handshake_delegate_ = d; + } + + void SetInspector(InspectorSocket::Pointer inspector) { + socket_ = std::move(inspector); + } + + void Write(const char* buf, size_t len) { + socket_->Write(buf, len); + } + + void ExpectReadError() { + SPIN_WHILE(frames.empty() || !frames.back().empty()); + } + + void ExpectData(const char* data, size_t len) { + const char* cur = data; + const char* end = data + len; + while (cur < end) { + SPIN_WHILE(frames.empty()); + const std::vector& frame = frames.front(); + EXPECT_FALSE(frame.empty()); + auto c = frame.begin(); + for (; c < frame.end() && cur < end; c++) { + GTEST_ASSERT_EQ(*cur, *c) << "Character #" << cur - data; + cur = cur + 1; + } + EXPECT_EQ(c, frame.end()); + frames.pop(); + } + } + + void FailOnWsFrame() { + fail_on_ws_frame_ = true; + } + + void WaitForDispose() { + SPIN_WHILE(delegate != nullptr); + } + + void Close() { + socket_.reset(); + } + + bool inspector_ready; + std::string last_path; // NOLINT(runtime/string) + inspector_handshake_event last_event; + int handshake_events; + std::queue> frames; + + private: + static void stop_if_stop_path(enum inspector_handshake_event state, + const std::string& path, bool* cont) { + *cont = path.empty() || path != "/close"; + } + + void process(inspector_handshake_event event, const std::string& path); + + bool disposed_ = false; + delegate_fn handshake_delegate_; + InspectorSocket::Pointer socket_; + std::string ws_key_; + bool fail_on_ws_frame_; +}; + +static bool connected = false; +static uv_tcp_t server, client_socket; +static const char SERVER_CLOSE_FRAME[] = {'\x88', '\x00'}; + +struct read_expects { + const char* expected; + size_t expected_len; + size_t pos; + bool read_expected; + bool callback_called; +}; + +static const char HANDSHAKE_REQ[] = "GET /ws/path HTTP/1.1\r\n" + "Host: localhost:9222\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Key: aaa==\r\n" + "Sec-WebSocket-Version: 13\r\n\r\n"; + +void TestInspectorDelegate::process(inspector_handshake_event event, + const std::string& path) { + inspector_ready = event == kInspectorHandshakeUpgraded; + last_event = event; if (path.empty()) { last_path = "@@@ Nothing received @@@"; } else { @@ -103,15 +223,23 @@ static bool connected_cb(InspectorSocket* socket, } handshake_events++; bool should_continue = true; - handshake_delegate(state, path, &should_continue); - return should_continue; + handshake_delegate_(event, path, &should_continue); + if (should_continue) { + if (inspector_ready) + socket_->AcceptUpgrade(ws_key_); + } else { + socket_->CancelHandshake(); + } } static void on_new_connection(uv_stream_t* server, int status) { GTEST_ASSERT_EQ(0, status); connected = true; - inspector_accept(server, static_cast(server->data), - connected_cb); + delegate = new TestInspectorDelegate(); + delegate->SetInspector( + InspectorSocket::Accept(server, + InspectorSocket::DelegatePointer(delegate))); + GTEST_ASSERT_NE(nullptr, delegate); } void write_done(uv_write_t* req, int status) { req->data = nullptr; } @@ -129,11 +257,6 @@ static void do_write(const char* data, int len) { SPIN_WHILE(req.data); } -static void buffer_alloc_cb(uv_handle_t* stream, size_t len, uv_buf_t* buf) { - buf->base = new char[len]; - buf->len = len; -} - static void check_data_cb(read_expects* expectation, ssize_t nread, const uv_buf_t* buf, bool* retval) { *retval = false; @@ -207,102 +330,6 @@ static void expect_on_client(const char* data, size_t len) { SPIN_WHILE(!expectation.read_expected); } -struct expectations { - std::string actual_data; - size_t actual_offset; - size_t actual_end; - int err_code; -}; - -static void grow_expects_buffer(uv_handle_t* stream, size_t size, uv_buf_t* b) { - expectations* expects = static_cast( - inspector_from_stream(stream)->data); - size_t end = expects->actual_end; - // Grow the buffer in chunks of 64k. - size_t new_length = (end + size + 65535) & ~((size_t) 0xFFFF); - expects->actual_data.resize(new_length); - *b = uv_buf_init(&expects->actual_data[end], new_length - end); -} - -// static void dump_hex(const char* buf, size_t len) { -// const char* ptr = buf; -// const char* end = ptr + len; -// const char* cptr; -// char c; -// int i; - -// while (ptr < end) { -// cptr = ptr; -// for (i = 0; i < 16 && ptr < end; i++) { -// printf("%2.2X ", *(ptr++)); -// } -// for (i = 72 - (i * 4); i > 0; i--) { -// printf(" "); -// } -// for (i = 0; i < 16 && cptr < end; i++) { -// c = *(cptr++); -// printf("%c", (c > 0x19) ? c : '.'); -// } -// printf("\n"); -// } -// printf("\n\n"); -// } - -static void save_read_data(uv_stream_t* stream, ssize_t nread, - const uv_buf_t* buf) { - expectations* expects = static_cast( - inspector_from_stream(stream)->data); - expects->err_code = nread < 0 ? nread : 0; - if (nread > 0) { - expects->actual_end += nread; - } -} - -static void setup_inspector_expecting() { - if (inspector.data) { - return; - } - expectations* expects = new expectations(); - inspector.data = expects; - inspector_read_start(&inspector, grow_expects_buffer, save_read_data); -} - -static void expect_on_server(const char* data, size_t len) { - setup_inspector_expecting(); - expectations* expects = static_cast(inspector.data); - for (size_t i = 0; i < len;) { - SPIN_WHILE(expects->actual_offset == expects->actual_end); - for (; i < len && expects->actual_offset < expects->actual_end; i++) { - char actual = expects->actual_data[expects->actual_offset++]; - char expected = data[i]; - if (expected != actual) { - fprintf(stderr, "Character %zu:\n", i); - GTEST_ASSERT_EQ(expected, actual); - } - } - } - expects->actual_end -= expects->actual_offset; - if (!expects->actual_end) { - memmove(&expects->actual_data[0], - &expects->actual_data[expects->actual_offset], - expects->actual_end); - } - expects->actual_offset = 0; -} - -static void inspector_record_error_code(uv_stream_t* stream, ssize_t nread, - const uv_buf_t* buf) { - InspectorSocket *inspector = inspector_from_stream(stream); - // Increment instead of assign is to ensure the function is only called once - *(static_cast(inspector->data)) += nread; -} - -static void expect_server_read_error() { - setup_inspector_expecting(); - expectations* expects = static_cast(inspector.data); - SPIN_WHILE(expects->err_code != UV_EPROTO); -} - static void expect_handshake() { const char UPGRADE_RESPONSE[] = "HTTP/1.1 101 Switching Protocols\r\n" @@ -320,35 +347,6 @@ static void expect_handshake_failure() { expect_on_client(UPGRADE_RESPONSE, sizeof(UPGRADE_RESPONSE) - 1); } -static bool waiting_to_close = true; - -void handle_closed(uv_handle_t* handle) { waiting_to_close = false; } - -static void really_close(uv_handle_t* handle) { - waiting_to_close = true; - if (!uv_is_closing(handle)) { - uv_close(handle, handle_closed); - SPIN_WHILE(waiting_to_close); - } -} - -// Called when the test leaves inspector socket in active state -static void manual_inspector_socket_cleanup() { - EXPECT_EQ(0, uv_is_active( - reinterpret_cast(&inspector.tcp))); - really_close(reinterpret_cast(&inspector.tcp)); - delete inspector.ws_state; - inspector.ws_state = nullptr; - delete inspector.http_parsing_state; - inspector.http_parsing_state = nullptr; - inspector.buffer.clear(); -} - -static void assert_both_sockets_closed() { - SPIN_WHILE(uv_is_active(reinterpret_cast(&client_socket))); - SPIN_WHILE(uv_is_active(reinterpret_cast(&inspector.tcp))); -} - static void on_connection(uv_connect_t* connect, int status) { GTEST_ASSERT_EQ(0, status); connect->data = connect; @@ -357,16 +355,10 @@ static void on_connection(uv_connect_t* connect, int status) { class InspectorSocketTest : public ::testing::Test { protected: virtual void SetUp() { - inspector.reinit(); - handshake_delegate = stop_if_stop_path; - handshake_events = 0; connected = false; - inspector_ready = false; - last_event = kInspectorHandshakeHttpGet; GTEST_ASSERT_EQ(0, uv_loop_init(&loop)); server = uv_tcp_t(); client_socket = uv_tcp_t(); - server.data = &inspector; sockaddr_in addr; uv_tcp_init(&loop, &server); uv_tcp_init(&loop, &client_socket); @@ -386,13 +378,7 @@ class InspectorSocketTest : public ::testing::Test { virtual void TearDown() { really_close(reinterpret_cast(&client_socket)); - EXPECT_TRUE(inspector.buffer.empty()); - expectations* expects = static_cast(inspector.data); - if (expects != nullptr) { - GTEST_ASSERT_EQ(expects->actual_end, expects->actual_offset); - delete expects; - inspector.data = nullptr; - } + SPIN_WHILE(delegate != nullptr); const int err = uv_loop_close(&loop); if (err != 0) { uv_print_all_handles(&loop, stderr); @@ -403,22 +389,22 @@ class InspectorSocketTest : public ::testing::Test { TEST_F(InspectorSocketTest, ReadsAndWritesInspectorMessage) { ASSERT_TRUE(connected); - ASSERT_FALSE(inspector_ready); + ASSERT_FALSE(delegate->inspector_ready); do_write(const_cast(HANDSHAKE_REQ), sizeof(HANDSHAKE_REQ) - 1); - SPIN_WHILE(!inspector_ready); + SPIN_WHILE(!delegate->inspector_ready); expect_handshake(); // 2. Brief exchange const char SERVER_MESSAGE[] = "abcd"; const char CLIENT_FRAME[] = {'\x81', '\x04', 'a', 'b', 'c', 'd'}; - inspector_write(&inspector, SERVER_MESSAGE, sizeof(SERVER_MESSAGE) - 1); + delegate->Write(SERVER_MESSAGE, sizeof(SERVER_MESSAGE) - 1); expect_on_client(CLIENT_FRAME, sizeof(CLIENT_FRAME)); const char SERVER_FRAME[] = {'\x81', '\x84', '\x7F', '\xC2', '\x66', '\x31', '\x4E', '\xF0', '\x55', '\x05'}; const char CLIENT_MESSAGE[] = "1234"; do_write(SERVER_FRAME, sizeof(SERVER_FRAME)); - expect_on_server(CLIENT_MESSAGE, sizeof(CLIENT_MESSAGE) - 1); + delegate->ExpectData(CLIENT_MESSAGE, sizeof(CLIENT_MESSAGE) - 1); // 3. Close const char CLIENT_CLOSE_FRAME[] = {'\x88', '\x80', '\x2D', @@ -487,53 +473,34 @@ TEST_F(InspectorSocketTest, BufferEdgeCases) { "{\"id\":17,\"method\":\"Network.canEmulateNetworkConditions\"}"}; do_write(MULTIPLE_REQUESTS, sizeof(MULTIPLE_REQUESTS)); - expect_on_server(EXPECT, sizeof(EXPECT) - 1); - inspector_read_stop(&inspector); - manual_inspector_socket_cleanup(); + delegate->ExpectData(EXPECT, sizeof(EXPECT) - 1); } TEST_F(InspectorSocketTest, AcceptsRequestInSeveralWrites) { ASSERT_TRUE(connected); - ASSERT_FALSE(inspector_ready); + ASSERT_FALSE(delegate->inspector_ready); // Specifically, break up the request in the "Sec-WebSocket-Key" header name // and value const int write1 = 95; const int write2 = 5; const int write3 = sizeof(HANDSHAKE_REQ) - write1 - write2 - 1; do_write(const_cast(HANDSHAKE_REQ), write1); - ASSERT_FALSE(inspector_ready); + ASSERT_FALSE(delegate->inspector_ready); do_write(const_cast(HANDSHAKE_REQ) + write1, write2); - ASSERT_FALSE(inspector_ready); + ASSERT_FALSE(delegate->inspector_ready); do_write(const_cast(HANDSHAKE_REQ) + write1 + write2, write3); - SPIN_WHILE(!inspector_ready); + SPIN_WHILE(!delegate->inspector_ready); expect_handshake(); - inspector_read_stop(&inspector); GTEST_ASSERT_EQ(uv_is_active(reinterpret_cast(&client_socket)), 0); - manual_inspector_socket_cleanup(); } TEST_F(InspectorSocketTest, ExtraTextBeforeRequest) { - last_event = kInspectorHandshakeUpgraded; - char UNCOOL_BRO[] = "Uncool, bro: Text before the first req\r\n"; - do_write(const_cast(UNCOOL_BRO), sizeof(UNCOOL_BRO) - 1); - - ASSERT_FALSE(inspector_ready); - do_write(const_cast(HANDSHAKE_REQ), sizeof(HANDSHAKE_REQ) - 1); - SPIN_WHILE(last_event != kInspectorHandshakeFailed); - expect_handshake_failure(); - assert_both_sockets_closed(); -} - -TEST_F(InspectorSocketTest, ExtraLettersBeforeRequest) { - char UNCOOL_BRO[] = "Uncool!!"; + delegate->last_event = kInspectorHandshakeUpgraded; + char UNCOOL_BRO[] = "Text before the first req, shouldn't be her\r\n"; do_write(const_cast(UNCOOL_BRO), sizeof(UNCOOL_BRO) - 1); - - ASSERT_FALSE(inspector_ready); - do_write(const_cast(HANDSHAKE_REQ), sizeof(HANDSHAKE_REQ) - 1); - SPIN_WHILE(last_event != kInspectorHandshakeFailed); expect_handshake_failure(); - assert_both_sockets_closed(); + GTEST_ASSERT_EQ(nullptr, delegate); } TEST_F(InspectorSocketTest, RequestWithoutKey) { @@ -544,87 +511,65 @@ TEST_F(InspectorSocketTest, RequestWithoutKey) { "Sec-WebSocket-Version: 13\r\n\r\n"; do_write(const_cast(BROKEN_REQUEST), sizeof(BROKEN_REQUEST) - 1); - SPIN_WHILE(last_event != kInspectorHandshakeFailed); expect_handshake_failure(); - assert_both_sockets_closed(); } TEST_F(InspectorSocketTest, KillsConnectionOnProtocolViolation) { ASSERT_TRUE(connected); - ASSERT_FALSE(inspector_ready); + ASSERT_FALSE(delegate->inspector_ready); do_write(const_cast(HANDSHAKE_REQ), sizeof(HANDSHAKE_REQ) - 1); - SPIN_WHILE(!inspector_ready); - ASSERT_TRUE(inspector_ready); + SPIN_WHILE(!delegate->inspector_ready); + ASSERT_TRUE(delegate->inspector_ready); expect_handshake(); const char SERVER_FRAME[] = "I'm not a good WS frame. Nope!"; do_write(SERVER_FRAME, sizeof(SERVER_FRAME)); - expect_server_read_error(); + SPIN_WHILE(delegate != nullptr); GTEST_ASSERT_EQ(uv_is_active(reinterpret_cast(&client_socket)), 0); } TEST_F(InspectorSocketTest, CanStopReadingFromInspector) { ASSERT_TRUE(connected); - ASSERT_FALSE(inspector_ready); + ASSERT_FALSE(delegate->inspector_ready); do_write(const_cast(HANDSHAKE_REQ), sizeof(HANDSHAKE_REQ) - 1); expect_handshake(); - ASSERT_TRUE(inspector_ready); + ASSERT_TRUE(delegate->inspector_ready); // 2. Brief exchange const char SERVER_FRAME[] = {'\x81', '\x84', '\x7F', '\xC2', '\x66', '\x31', '\x4E', '\xF0', '\x55', '\x05'}; const char CLIENT_MESSAGE[] = "1234"; do_write(SERVER_FRAME, sizeof(SERVER_FRAME)); - expect_on_server(CLIENT_MESSAGE, sizeof(CLIENT_MESSAGE) - 1); + delegate->ExpectData(CLIENT_MESSAGE, sizeof(CLIENT_MESSAGE) - 1); - inspector_read_stop(&inspector); do_write(SERVER_FRAME, sizeof(SERVER_FRAME)); GTEST_ASSERT_EQ(uv_is_active( reinterpret_cast(&client_socket)), 0); - manual_inspector_socket_cleanup(); -} - -static int inspector_closed = 0; - -void inspector_closed_cb(InspectorSocket *inspector, int code) { - inspector_closed++; } TEST_F(InspectorSocketTest, CloseDoesNotNotifyReadCallback) { - inspector_closed = 0; do_write(const_cast(HANDSHAKE_REQ), sizeof(HANDSHAKE_REQ) - 1); expect_handshake(); - int error_code = 0; - inspector.data = &error_code; - inspector_read_start(&inspector, buffer_alloc_cb, - inspector_record_error_code); - inspector_close(&inspector, inspector_closed_cb); + delegate->Close(); char CLOSE_FRAME[] = {'\x88', '\x00'}; expect_on_client(CLOSE_FRAME, sizeof(CLOSE_FRAME)); - EXPECT_EQ(0, inspector_closed); const char CLIENT_CLOSE_FRAME[] = {'\x88', '\x80', '\x2D', '\x0E', '\x1E', '\xFA'}; + delegate->FailOnWsFrame(); do_write(CLIENT_CLOSE_FRAME, sizeof(CLIENT_CLOSE_FRAME)); - EXPECT_NE(UV_EOF, error_code); - SPIN_WHILE(inspector_closed == 0); - EXPECT_EQ(1, inspector_closed); - inspector.data = nullptr; + SPIN_WHILE(delegate != nullptr); } TEST_F(InspectorSocketTest, CloseWorksWithoutReadEnabled) { - inspector_closed = 0; do_write(const_cast(HANDSHAKE_REQ), sizeof(HANDSHAKE_REQ) - 1); expect_handshake(); - inspector_close(&inspector, inspector_closed_cb); + delegate->Close(); char CLOSE_FRAME[] = {'\x88', '\x00'}; expect_on_client(CLOSE_FRAME, sizeof(CLOSE_FRAME)); - EXPECT_EQ(0, inspector_closed); const char CLIENT_CLOSE_FRAME[] = {'\x88', '\x80', '\x2D', '\x0E', '\x1E', '\xFA'}; do_write(CLIENT_CLOSE_FRAME, sizeof(CLIENT_CLOSE_FRAME)); - SPIN_WHILE(inspector_closed == 0); - EXPECT_EQ(1, inspector_closed); } // Make sure buffering works @@ -641,26 +586,24 @@ static void send_in_chunks(const char* data, size_t len) { } static const char TEST_SUCCESS[] = "Test Success\n\n"; +static int ReportsHttpGet_eventsCount = 0; static void ReportsHttpGet_handshake(enum inspector_handshake_event state, const std::string& path, bool* cont) { *cont = true; enum inspector_handshake_event expected_state = kInspectorHandshakeHttpGet; std::string expected_path; - switch (handshake_events) { + switch (delegate->handshake_events) { case 1: expected_path = "/some/path"; break; case 2: expected_path = "/respond/withtext"; - inspector_write(&inspector, TEST_SUCCESS, sizeof(TEST_SUCCESS) - 1); + delegate->Write(TEST_SUCCESS, sizeof(TEST_SUCCESS) - 1); break; case 3: expected_path = "/some/path2"; break; - case 5: - expected_state = kInspectorHandshakeFailed; - break; case 4: expected_path = "/close"; *cont = false; @@ -670,10 +613,11 @@ static void ReportsHttpGet_handshake(enum inspector_handshake_event state, } EXPECT_EQ(expected_state, state); EXPECT_EQ(expected_path, path); + ReportsHttpGet_eventsCount = delegate->handshake_events; } TEST_F(InspectorSocketTest, ReportsHttpGet) { - handshake_delegate = ReportsHttpGet_handshake; + delegate->SetDelegate(ReportsHttpGet_handshake); const char GET_REQ[] = "GET /some/path HTTP/1.1\r\n" "Host: localhost:9222\r\n" @@ -688,7 +632,6 @@ TEST_F(InspectorSocketTest, ReportsHttpGet) { send_in_chunks(WRITE_REQUEST, sizeof(WRITE_REQUEST) - 1); expect_on_client(TEST_SUCCESS, sizeof(TEST_SUCCESS) - 1); - const char GET_REQS[] = "GET /some/path2 HTTP/1.1\r\n" "Host: localhost:9222\r\n" "Sec-WebSocket-Key: aaa==\r\n" @@ -698,53 +641,50 @@ TEST_F(InspectorSocketTest, ReportsHttpGet) { "Sec-WebSocket-Key: aaa==\r\n" "Sec-WebSocket-Version: 13\r\n\r\n"; send_in_chunks(GET_REQS, sizeof(GET_REQS) - 1); - expect_handshake_failure(); - EXPECT_EQ(5, handshake_events); + EXPECT_EQ(4, ReportsHttpGet_eventsCount); + EXPECT_EQ(nullptr, delegate); } -static void -HandshakeCanBeCanceled_handshake(enum inspector_handshake_event state, - const std::string& path, bool* cont) { - switch (handshake_events - 1) { +static int HandshakeCanBeCanceled_eventCount = 0; + +static +void HandshakeCanBeCanceled_handshake(enum inspector_handshake_event state, + const std::string& path, bool* cont) { + switch (delegate->handshake_events - 1) { case 0: - EXPECT_EQ(kInspectorHandshakeUpgrading, state); + EXPECT_EQ(kInspectorHandshakeUpgraded, state); EXPECT_EQ("/ws/path", path); break; - case 1: - EXPECT_EQ(kInspectorHandshakeFailed, state); - EXPECT_TRUE(path.empty()); - break; default: EXPECT_TRUE(false); break; } *cont = false; + HandshakeCanBeCanceled_eventCount = delegate->handshake_events; } TEST_F(InspectorSocketTest, HandshakeCanBeCanceled) { - handshake_delegate = HandshakeCanBeCanceled_handshake; + delegate->SetDelegate(HandshakeCanBeCanceled_handshake); do_write(const_cast(HANDSHAKE_REQ), sizeof(HANDSHAKE_REQ) - 1); expect_handshake_failure(); - EXPECT_EQ(2, handshake_events); + EXPECT_EQ(1, HandshakeCanBeCanceled_eventCount); + EXPECT_EQ(nullptr, delegate); } static void GetThenHandshake_handshake(enum inspector_handshake_event state, const std::string& path, bool* cont) { *cont = true; std::string expected_path = "/ws/path"; - switch (handshake_events - 1) { + switch (delegate->handshake_events - 1) { case 0: EXPECT_EQ(kInspectorHandshakeHttpGet, state); expected_path = "/respond/withtext"; - inspector_write(&inspector, TEST_SUCCESS, sizeof(TEST_SUCCESS) - 1); + delegate->Write(TEST_SUCCESS, sizeof(TEST_SUCCESS) - 1); break; case 1: - EXPECT_EQ(kInspectorHandshakeUpgrading, state); - break; - case 2: EXPECT_EQ(kInspectorHandshakeUpgraded, state); break; default: @@ -755,7 +695,7 @@ static void GetThenHandshake_handshake(enum inspector_handshake_event state, } TEST_F(InspectorSocketTest, GetThenHandshake) { - handshake_delegate = GetThenHandshake_handshake; + delegate->SetDelegate(GetThenHandshake_handshake); const char WRITE_REQUEST[] = "GET /respond/withtext HTTP/1.1\r\n" "Host: localhost:9222\r\n\r\n"; send_in_chunks(WRITE_REQUEST, sizeof(WRITE_REQUEST) - 1); @@ -764,15 +704,7 @@ TEST_F(InspectorSocketTest, GetThenHandshake) { do_write(const_cast(HANDSHAKE_REQ), sizeof(HANDSHAKE_REQ) - 1); expect_handshake(); - EXPECT_EQ(3, handshake_events); - manual_inspector_socket_cleanup(); -} - -static void WriteBeforeHandshake_inspector_delegate(inspector_handshake_event e, - const std::string& path, - bool* cont) { - if (e == kInspectorHandshakeFailed) - inspector_closed = 1; + EXPECT_EQ(2, delegate->handshake_events); } TEST_F(InspectorSocketTest, WriteBeforeHandshake) { @@ -780,51 +712,31 @@ TEST_F(InspectorSocketTest, WriteBeforeHandshake) { const char MESSAGE2[] = "Message 2"; const char EXPECTED[] = "Message 1Message 2"; - inspector_write(&inspector, MESSAGE1, sizeof(MESSAGE1) - 1); - inspector_write(&inspector, MESSAGE2, sizeof(MESSAGE2) - 1); + delegate->Write(MESSAGE1, sizeof(MESSAGE1) - 1); + delegate->Write(MESSAGE2, sizeof(MESSAGE2) - 1); expect_on_client(EXPECTED, sizeof(EXPECTED) - 1); - inspector_closed = 0; - handshake_delegate = WriteBeforeHandshake_inspector_delegate; really_close(reinterpret_cast(&client_socket)); - SPIN_WHILE(inspector_closed == 0); -} - -static void CleanupSocketAfterEOF_close_cb(InspectorSocket* inspector, - int status) { - *(static_cast(inspector->data)) = true; -} - -static void CleanupSocketAfterEOF_read_cb(uv_stream_t* stream, ssize_t nread, - const uv_buf_t* buf) { - EXPECT_EQ(UV_EOF, nread); - InspectorSocket* insp = inspector_from_stream(stream); - inspector_close(insp, CleanupSocketAfterEOF_close_cb); + SPIN_WHILE(delegate != nullptr); } TEST_F(InspectorSocketTest, CleanupSocketAfterEOF) { do_write(const_cast(HANDSHAKE_REQ), sizeof(HANDSHAKE_REQ) - 1); expect_handshake(); - inspector_read_start(&inspector, buffer_alloc_cb, - CleanupSocketAfterEOF_read_cb); - for (int i = 0; i < MAX_LOOP_ITERATIONS; ++i) { uv_run(&loop, UV_RUN_NOWAIT); } uv_close(reinterpret_cast(&client_socket), nullptr); - bool flag = false; - inspector.data = &flag; - SPIN_WHILE(!flag); - inspector.data = nullptr; + SPIN_WHILE(delegate != nullptr); } TEST_F(InspectorSocketTest, EOFBeforeHandshake) { const char MESSAGE[] = "We'll send EOF afterwards"; - inspector_write(&inspector, MESSAGE, sizeof(MESSAGE) - 1); + delegate->Write(MESSAGE, sizeof(MESSAGE) - 1); expect_on_client(MESSAGE, sizeof(MESSAGE) - 1); uv_close(reinterpret_cast(&client_socket), nullptr); - SPIN_WHILE(last_event != kInspectorHandshakeFailed); + SPIN_WHILE(delegate != nullptr); } static void fill_message(std::string* buffer) { @@ -843,9 +755,9 @@ static void mask_message(const std::string& message, TEST_F(InspectorSocketTest, Send1Mb) { ASSERT_TRUE(connected); - ASSERT_FALSE(inspector_ready); + ASSERT_FALSE(delegate->inspector_ready); do_write(const_cast(HANDSHAKE_REQ), sizeof(HANDSHAKE_REQ) - 1); - SPIN_WHILE(!inspector_ready); + SPIN_WHILE(!delegate->inspector_ready); expect_handshake(); // 2. Brief exchange @@ -860,7 +772,7 @@ TEST_F(InspectorSocketTest, Send1Mb) { std::string expected(EXPECTED_FRAME_HEADER, sizeof(EXPECTED_FRAME_HEADER)); expected.append(message); - inspector_write(&inspector, &message[0], message.size()); + delegate->Write(&message[0], message.size()); expect_on_client(&expected[0], expected.size()); char MASK[4] = {'W', 'h', 'O', 'a'}; @@ -874,9 +786,8 @@ TEST_F(InspectorSocketTest, Send1Mb) { outgoing.resize(outgoing.size() + message.size()); mask_message(message, &outgoing[sizeof(FRAME_TO_SERVER_HEADER)], MASK); - setup_inspector_expecting(); // Buffer on the client side. do_write(&outgoing[0], outgoing.size()); - expect_on_server(&message[0], message.size()); + delegate->ExpectData(&message[0], message.size()); // 3. Close const char CLIENT_CLOSE_FRAME[] = {'\x88', '\x80', '\x2D', @@ -887,53 +798,33 @@ TEST_F(InspectorSocketTest, Send1Mb) { reinterpret_cast(&client_socket))); } -static ssize_t err; - -void ErrorCleansUpTheSocket_cb(uv_stream_t* stream, ssize_t read, - const uv_buf_t* buf) { - err = read; -} - TEST_F(InspectorSocketTest, ErrorCleansUpTheSocket) { - inspector_closed = 0; do_write(const_cast(HANDSHAKE_REQ), sizeof(HANDSHAKE_REQ) - 1); expect_handshake(); const char NOT_A_GOOD_FRAME[] = {'H', 'e', 'l', 'l', 'o'}; - err = 42; - inspector_read_start(&inspector, buffer_alloc_cb, - ErrorCleansUpTheSocket_cb); do_write(NOT_A_GOOD_FRAME, sizeof(NOT_A_GOOD_FRAME)); - SPIN_WHILE(err > 0); - EXPECT_EQ(UV_EPROTO, err); -} - -static void ServerClosedByClient_cb(InspectorSocket* socket, int code) { - *static_cast(socket->data) = true; + SPIN_WHILE(delegate != nullptr); } -TEST_F(InspectorSocketTest, NoCloseResponseFromClinet) { +TEST_F(InspectorSocketTest, NoCloseResponseFromClient) { ASSERT_TRUE(connected); - ASSERT_FALSE(inspector_ready); + ASSERT_FALSE(delegate->inspector_ready); do_write(const_cast(HANDSHAKE_REQ), sizeof(HANDSHAKE_REQ) - 1); - SPIN_WHILE(!inspector_ready); + SPIN_WHILE(!delegate->inspector_ready); expect_handshake(); // 2. Brief exchange const char SERVER_MESSAGE[] = "abcd"; const char CLIENT_FRAME[] = {'\x81', '\x04', 'a', 'b', 'c', 'd'}; - inspector_write(&inspector, SERVER_MESSAGE, sizeof(SERVER_MESSAGE) - 1); + delegate->Write(SERVER_MESSAGE, sizeof(SERVER_MESSAGE) - 1); expect_on_client(CLIENT_FRAME, sizeof(CLIENT_FRAME)); - bool closed = false; - - inspector.data = &closed; - inspector_close(&inspector, ServerClosedByClient_cb); + delegate->Close(); expect_on_client(SERVER_CLOSE_FRAME, sizeof(SERVER_CLOSE_FRAME)); uv_close(reinterpret_cast(&client_socket), nullptr); - SPIN_WHILE(!closed); - inspector.data = nullptr; GTEST_ASSERT_EQ(0, uv_is_active( - reinterpret_cast(&client_socket))); + reinterpret_cast(&client_socket))); + delegate->WaitForDispose(); } } // anonymous namespace diff --git a/test/cctest/test_inspector_socket_server.cc b/test/cctest/test_inspector_socket_server.cc index ab74917234eefb..49a3ca9f95857d 100644 --- a/test/cctest/test_inspector_socket_server.cc +++ b/test/cctest/test_inspector_socket_server.cc @@ -95,16 +95,17 @@ class TestInspectorServerDelegate : public SocketServerDelegate { server_ = server; } - bool StartSession(int session_id, const std::string& target_id) override { + void StartSession(int session_id, const std::string& target_id) override { buffer_.clear(); CHECK_NE(targets_.end(), std::find(targets_.begin(), targets_.end(), target_id)); if (target_id == UNCONNECTABLE_TARGET_ID) { - return false; + server_->DeclineSession(session_id); + return; } connected++; session_id_ = session_id; - return true; + server_->AcceptSession(session_id); } void MessageReceived(int session_id, const std::string& message) override { @@ -350,12 +351,13 @@ class ServerHolder { class ServerDelegateNoTargets : public SocketServerDelegate { public: + ServerDelegateNoTargets() : server_(nullptr) { } void Connect(InspectorSocketServer* server) { } void MessageReceived(int session_id, const std::string& message) override { } void EndSession(int session_id) override { } - bool StartSession(int session_id, const std::string& target_id) override { - return false; + void StartSession(int session_id, const std::string& target_id) override { + server_->DeclineSession(session_id); } std::vector GetTargetIds() override { @@ -375,6 +377,9 @@ class ServerDelegateNoTargets : public SocketServerDelegate { } bool done = false; + + private: + InspectorSocketServer* server_; }; static void TestHttpRequest(int port, const std::string& path, @@ -407,7 +412,6 @@ TEST_F(InspectorSocketServerTest, InspectorSessions) { well_behaved_socket.Write(WsHandshakeRequest(MAIN_TARGET_ID)); well_behaved_socket.Expect(WS_HANDSHAKE_RESPONSE); - EXPECT_EQ(1, delegate.connected); well_behaved_socket.Write("\x81\x84\x7F\xC2\x66\x31\x4E\xF0\x55\x05"); @@ -416,7 +420,6 @@ TEST_F(InspectorSocketServerTest, InspectorSessions) { delegate.Write("5678"); well_behaved_socket.Expect("\x81\x4" "5678"); - well_behaved_socket.Write(CLIENT_CLOSE_FRAME); well_behaved_socket.Expect(SERVER_CLOSE_FRAME); diff --git a/test/common/inspector-helper.js b/test/common/inspector-helper.js index 454eef4c5e26da..0d010a8ca70617 100644 --- a/test/common/inspector-helper.js +++ b/test/common/inspector-helper.js @@ -369,28 +369,43 @@ class NodeInstance { }); } - wsHandshake(devtoolsUrl) { - return this.portPromise.then((port) => new Promise((resolve) => { - http.get({ - port, - path: url.parse(devtoolsUrl).path, - headers: { - 'Connection': 'Upgrade', - 'Upgrade': 'websocket', - 'Sec-WebSocket-Version': 13, - 'Sec-WebSocket-Key': 'key==' - } - }).on('upgrade', (message, socket) => { - resolve(new InspectorSession(socket, this)); - }).on('response', common.mustNotCall('Upgrade was not received')); - })); + async sendUpgradeRequest() { + const response = await this.httpGet(null, '/json/list'); + const devtoolsUrl = response[0]['webSocketDebuggerUrl']; + const port = await this.portPromise; + return http.get({ + port, + path: url.parse(devtoolsUrl).path, + headers: { + 'Connection': 'Upgrade', + 'Upgrade': 'websocket', + 'Sec-WebSocket-Version': 13, + 'Sec-WebSocket-Key': 'key==' + } + }); } async connectInspectorSession() { console.log('[test]', 'Connecting to a child Node process'); - const response = await this.httpGet(null, '/json/list'); - const url = response[0]['webSocketDebuggerUrl']; - return this.wsHandshake(url); + const upgradeRequest = await this.sendUpgradeRequest(); + return new Promise((resolve, reject) => { + upgradeRequest + .on('upgrade', + (message, socket) => resolve(new InspectorSession(socket, this))) + .on('response', common.mustNotCall('Upgrade was not received')); + }); + } + + async expectConnectionDeclined() { + console.log('[test]', 'Checking upgrade is not possible'); + const upgradeRequest = await this.sendUpgradeRequest(); + return new Promise((resolve, reject) => { + upgradeRequest + .on('upgrade', common.mustNotCall('Upgrade was received')) + .on('response', (response) => + response.on('data', () => {}) + .on('end', () => resolve(response.statusCode))); + }); } expectShutdown() { @@ -403,6 +418,10 @@ class NodeInstance { return new Promise((resolve) => this._stderrLineCallback = resolve); } + write(message) { + this._process.stdin.write(message); + } + kill() { this._process.kill(); } diff --git a/test/parallel/test-inspector-no-crash-ws-after-bindings.js b/test/parallel/test-inspector-no-crash-ws-after-bindings.js new file mode 100644 index 00000000000000..286373068e8e9b --- /dev/null +++ b/test/parallel/test-inspector-no-crash-ws-after-bindings.js @@ -0,0 +1,30 @@ +'use strict'; +const common = require('../common'); +common.skipIfInspectorDisabled(); +common.crashOnUnhandledRejection(); +const { NodeInstance } = require('../common/inspector-helper.js'); +const assert = require('assert'); + +const expected = 'Can connect now!'; + +const script = ` + 'use strict'; + const { Session } = require('inspector'); + + const s = new Session(); + s.connect(); + console.error('${expected}'); + process.stdin.on('data', () => process.exit(0)); +`; + +async function runTests() { + const instance = new NodeInstance(['--inspect=0', '--expose-internals'], + script); + while (await instance.nextStderrString() !== expected); + assert.strictEqual(400, await instance.expectConnectionDeclined()); + instance.write('Stop!\n'); + assert.deepStrictEqual({ exitCode: 0, signal: null }, + await instance.expectShutdown()); +} + +runTests();