diff --git a/CMakeLists.txt b/CMakeLists.txt index 5f2f7b6ce31..85aa5996879 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -777,6 +777,8 @@ add_library (seastar src/util/read_first_line.cc src/util/tmp_file.cc src/util/short_streams.cc + src/websocket/client.cc + src/websocket/common.cc src/websocket/server.cc ) diff --git a/demos/CMakeLists.txt b/demos/CMakeLists.txt index 0f26401b1b8..11a82f35537 100644 --- a/demos/CMakeLists.txt +++ b/demos/CMakeLists.txt @@ -62,8 +62,11 @@ endif () seastar_add_demo (hello-world SOURCES hello-world.cc) -seastar_add_demo (websocket - SOURCES websocket_demo.cc) +seastar_add_demo (websocket_server + SOURCES websocket_server_demo.cc) + +seastar_add_demo (websocket_client + SOURCES websocket_client_demo.cc) seastar_add_demo (echo SOURCES echo_demo.cc) diff --git a/demos/websocket_client_demo.cc b/demos/websocket_client_demo.cc new file mode 100644 index 00000000000..24576dd1aa3 --- /dev/null +++ b/demos/websocket_client_demo.cc @@ -0,0 +1,96 @@ +/* + * This file is open source software, licensed to you under the terms + * of the Apache License, Version 2.0 (the "License"). See the NOTICE file + * distributed with this work for additional information regarding copyright + * ownership. You may not use this file except in compliance with the License. + * + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace seastar; +using namespace seastar::experimental; + +namespace bpo = boost::program_options; + +int main(int argc, char** argv) { + seastar::app_template app; + app.add_options() + ("host", bpo::value(), "Host to connect") + ("port", bpo::value(), "Port to connect") + ("path", bpo::value(), "Path to query upon") + ("subprotocol", bpo::value()->default_value(""), "Sub-protocol") + ; + app.run(argc, argv, [&app]() -> seastar::future<> { + auto&& config = app.configuration(); + auto host = config["host"].as(); + auto port = config["port"].as(); + auto path = config["path"].as(); + auto subprotocol = config["subprotocol"].as(); + + return async([=] { + net::hostent e = net::dns::get_host_by_name(host, net::inet_address::family::INET).get(); + auto ws = std::make_unique(socket_address(e.addr_list.front(), port)); + + if (!subprotocol.empty()) { + ws->set_subprotocol(subprotocol); + } + + auto req = http::request::make("GET", host, path); + + auto handler = [](input_stream& in, + output_stream& out) { + return repeat([&in, &out]() { + return in.read().then([&out](temporary_buffer f) { + auto value = std::stol(std::string(f.get(), f.size())); + std::cout << "got " << value << "\n"; + auto new_str = std::to_string(value + 1); + return out.write(temporary_buffer(new_str.data(), new_str.size())) + .then([&out] { return out.flush(); }) + .then([] { + return make_ready_future(stop_iteration::no); + }); + }); + }); + }; + + std::cout << "Sending messages to " << host << ":" << port + << " for 1 hour (interruptible, hit Ctrl-C to stop)..." << std::endl; + + seastar::shared_ptr client_con; + + ws->make_request(std::move(req), handler).then( + [&ws, &client_con](auto con) -> future<> { + client_con = con; + return when_all_succeed( + [con]{ return con->process(); }, + [con]{ return con->send_message(temporary_buffer("1", 1), true); }, + [&ws]{ + return sleep_abortable(std::chrono::hours(1)) + .handle_exception([&ws](auto ignored) { + std::cout << "Stopping the client" << std::endl; + return ws->stop(); + }); + } + ).discard_result(); + }).get(); + }); + }); +} diff --git a/demos/websocket_demo.cc b/demos/websocket_server_demo.cc similarity index 82% rename from demos/websocket_demo.cc rename to demos/websocket_server_demo.cc index c816d97bbf9..4d3af3086c1 100644 --- a/demos/websocket_demo.cc +++ b/demos/websocket_server_demo.cc @@ -32,10 +32,17 @@ using namespace seastar; using namespace seastar::experimental; +namespace bpo = boost::program_options; + int main(int argc, char** argv) { seastar::app_template app; - app.run(argc, argv, [] () -> seastar::future<> { - return async([] { + app.add_options() + ("port", bpo::value()->default_value(10000), "WebSocket server port") ; + app.run(argc, argv, [&app]() -> seastar::future<> { + auto&& config = app.configuration(); + uint16_t port = config["port"].as(); + + return async([port] { websocket::server ws; ws.register_handler("echo", [] (input_stream& in, output_stream& out) { @@ -57,8 +64,8 @@ int main(int argc, char** argv) { auto d = defer([&ws] () noexcept { ws.stop().get(); }); - ws.listen(socket_address(ipv4_addr("127.0.0.1", 8123))); - std::cout << "Listening on 127.0.0.1:8123 for 1 hour (interruptible, hit Ctrl-C to stop)..." << std::endl; + ws.listen(socket_address(ipv4_addr("127.0.0.1", port))); + std::cout << "Listening on 127.0.0.1:" << port << " for 1 hour (interruptible, hit Ctrl-C to stop)..." << std::endl; seastar::sleep_abortable(std::chrono::hours(1)).handle_exception([](auto ignored) {}).get(); std::cout << "Stopping the server, deepest thanks to all clients, hope we meet again" << std::endl; }); diff --git a/include/seastar/http/client.hh b/include/seastar/http/client.hh index 853d466ab9e..660dd95d14c 100644 --- a/include/seastar/http/client.hh +++ b/include/seastar/http/client.hh @@ -25,6 +25,7 @@ #include #endif #include +#include #include #include #include @@ -136,26 +137,6 @@ private: void shutdown() noexcept; }; -/** - * \brief Factory that provides transport for \ref client - * - * This customization point allows callers provide its own transport for client. The - * client code calls factory when it needs more connections to the server and maintains - * the pool of re-usable sockets internally - */ - -class connection_factory { -public: - /** - * \brief Make a \ref connected_socket - * - * The implementations of this method should return ready-to-use socket that will - * be used by \ref client as transport for its http connections - */ - virtual future make(abort_source*) = 0; - virtual ~connection_factory() {} -}; - /** * \brief Class client wraps communications using HTTP protocol * diff --git a/include/seastar/http/connection_factory.hh b/include/seastar/http/connection_factory.hh new file mode 100644 index 00000000000..7b1b1a2af03 --- /dev/null +++ b/include/seastar/http/connection_factory.hh @@ -0,0 +1,74 @@ +/* + * This file is open source software, licensed to you under the terms + * of the Apache License, Version 2.0 (the "License"). See the NOTICE file + * distributed with this work for additional information regarding copyright + * ownership. You may not use this file except in compliance with the License. + * + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include +#include +#include + +namespace seastar::http::experimental { + +/** + * \brief Factory that provides transport for \ref client + * + * This customization point allows callers provide its own transport for client. The + * client code calls factory when it needs more connections to the server. + */ + +class connection_factory { +public: + /** + * \brief Make a \ref connected_socket + * + * The implementations of this method should return ready-to-use socket that will + * be used by \ref client as transport for its http connections + */ + virtual future make(abort_source*) = 0; + virtual ~connection_factory() {} +}; + +class basic_connection_factory : public connection_factory { + socket_address _addr; +public: + explicit basic_connection_factory(socket_address addr) + : _addr(std::move(addr)) + { + } + virtual future make(abort_source* as) override { + return seastar::connect(_addr, {}, transport::TCP); + } +}; + +class tls_connection_factory : public connection_factory { + socket_address _addr; + shared_ptr _creds; + sstring _host; +public: + tls_connection_factory(socket_address addr, shared_ptr creds, sstring host) + : _addr(std::move(addr)) + , _creds(std::move(creds)) + , _host(std::move(host)) + { + } + virtual future make(abort_source* as) override { + return tls::connect(_creds, _addr, tls::tls_options{.server_name = _host}); + } +}; + +} diff --git a/include/seastar/http/request.hh b/include/seastar/http/request.hh index d222fb74726..d27604b6a06 100644 --- a/include/seastar/http/request.hh +++ b/include/seastar/http/request.hh @@ -313,10 +313,10 @@ struct request { */ static request make(httpd::operation_type type, sstring host, sstring path); -private: - void add_query_param(std::string_view param); sstring request_line() const; future<> write_request_headers(output_stream& out) const; +private: + void add_query_param(std::string_view param); friend class experimental::connection; }; diff --git a/include/seastar/websocket/client.hh b/include/seastar/websocket/client.hh new file mode 100644 index 00000000000..f1905f4cdf3 --- /dev/null +++ b/include/seastar/websocket/client.hh @@ -0,0 +1,140 @@ +/* + * This file is open source software, licensed to you under the terms + * of the Apache License, Version 2.0 (the "License"). See the NOTICE file + * distributed with this work for additional information regarding copyright + * ownership. You may not use this file except in compliance with the License. + * + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +namespace seastar::experimental::websocket { + +class client; + +/// \addtogroup websocket +/// @{ + +/*! + * \brief a client WebSocket connection + */ +class client_connection : public connection { + client& _client; + sstring _ws_key; +public: + /*! + * \param server owning \ref server + * \param fd established socket used for communication + */ + client_connection(client& client, connected_socket&& fd, std::string_view ws_key, + handler_t handler); + ~client_connection(); + + /*! + * \brief serve WebSocket protocol on a client_connection + */ + future<> process(); + + /** + * @brief Send a websocket message to the server + */ + future<> send_message(temporary_buffer buf, bool flush); + +protected: + friend class client; + future<> perform_handshake(const http::request& req); + future<> send_request_head(const http::request& req); + future<> read_reply(); +}; + +/*! + * \brief a WebSocket client + * + * A client capable of establishing and processing a single concurrent connection + * on a WebSocket protocol. + */ +class client { + boost::intrusive::list _connections; + std::string _subprotocol; + std::unique_ptr _new_connections; + + std::random_device _rd_device; + std::mt19937_64 _random_gen; + + using connection_ptr = seastar::shared_ptr; + +public: + /** + * \brief Construct a plaintext client + * + * This creates a plaintext client that connects to provided address via non-TLS socket. + * + * \param addr -- host address to connect to + */ + explicit client(socket_address addr); + + /** + * \brief Construct a secure client + * + * This creates a secure client that connects to provided address via TLS socket with + * given credentials. + * + * \param addr -- host address to connect to + * \param creds -- credentials + * \param host -- optional host name + */ + client(socket_address addr, shared_ptr creds, sstring host = {}); + + /** + * \brief Construct a client with connection factory + * + * This creates a client that uses factory to get \ref connected_socket that is then + * used as transport. + * + * \param f -- the factory pointer + */ + explicit client(std::unique_ptr f); + + /** + * Starts the process of establishing a Websocket connection + */ + future> + make_request(http::request rq, const handler_t& handler); + + /*! + * Stops the client and shuts down active connection, if any + */ + future<> stop(); + + void set_subprotocol(std::string const& subprotocol) { _subprotocol = subprotocol; } + + /** + * Sets the seed for WebSocket key generation + */ + void set_seed(std::size_t seed); + + friend class client_connection; +}; + +/// }@ + +} diff --git a/include/seastar/websocket/common.hh b/include/seastar/websocket/common.hh new file mode 100644 index 00000000000..34cdd32fe53 --- /dev/null +++ b/include/seastar/websocket/common.hh @@ -0,0 +1,292 @@ +/* + * This file is open source software, licensed to you under the terms + * of the Apache License, Version 2.0 (the "License"). See the NOTICE file + * distributed with this work for additional information regarding copyright + * ownership. You may not use this file except in compliance with the License. + * + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace seastar::experimental::websocket { + +extern sstring magic_key_suffix; + +using handler_t = std::function(input_stream&, output_stream&)>; + +class server; + + +/// \defgroup websocket WebSocket +/// \addtogroup websocket +/// @{ + +/*! + * \brief an error in handling a WebSocket connection + */ +class exception : public std::exception { + std::string _msg; +public: + exception(std::string_view msg) : _msg(msg) {} + virtual const char* what() const noexcept { + return _msg.c_str(); + } +}; + +/*! + * \brief Possible type of a websocket frame. + */ +enum opcodes { + CONTINUATION = 0x0, + TEXT = 0x1, + BINARY = 0x2, + CLOSE = 0x8, + PING = 0x9, + PONG = 0xA, + INVALID = 0xFF, +}; + +struct frame_header { + static constexpr uint8_t FIN = 7; + static constexpr uint8_t RSV1 = 6; + static constexpr uint8_t RSV2 = 5; + static constexpr uint8_t RSV3 = 4; + static constexpr uint8_t MASKED = 7; + + uint8_t fin : 1; + uint8_t rsv1 : 1; + uint8_t rsv2 : 1; + uint8_t rsv3 : 1; + uint8_t opcode : 4; + uint8_t masked : 1; + uint8_t length : 7; + frame_header(const char* input) { + this->fin = (input[0] >> FIN) & 1; + this->rsv1 = (input[0] >> RSV1) & 1; + this->rsv2 = (input[0] >> RSV2) & 1; + this->rsv3 = (input[0] >> RSV3) & 1; + this->opcode = input[0] & 0b1111; + this->masked = (input[1] >> MASKED) & 1; + this->length = (input[1] & 0b1111111); + } + // Returns length of the rest of the header. + uint64_t get_rest_of_header_length() { + size_t next_read_length = get_masked() ? sizeof(uint32_t) : 0; + if (length == 126) { + next_read_length += sizeof(uint16_t); + } else if (length == 127) { + next_read_length += sizeof(uint64_t); + } + return next_read_length; + } + uint8_t get_fin() {return fin;} + uint8_t get_rsv1() {return rsv1;} + uint8_t get_rsv2() {return rsv2;} + uint8_t get_rsv3() {return rsv3;} + uint8_t get_opcode() {return opcode;} + uint8_t get_masked() {return masked;} + uint8_t get_length() {return length;} + + bool is_opcode_known() { + //https://datatracker.ietf.org/doc/html/rfc6455#section-5.1 + return opcode < 0xA && !(opcode < 0x8 && opcode > 0x2); + } +}; + +class websocket_parser { + enum class parsing_state : uint8_t { + flags_and_payload_data, + payload_length_and_mask, + payload + }; + enum class connection_state : uint8_t { + valid, + closed, + error + }; + using consumption_result_t = consumption_result; + using buff_t = temporary_buffer; + // What parser is currently doing. + parsing_state _state; + // Whether parser is parsing messages coming to the client. + bool _is_client = false; + // State of connection - can be valid, closed or should be closed + // due to error. + connection_state _cstate; + sstring _buffer; + std::unique_ptr _header; + uint64_t _payload_length = 0; + uint64_t _consumed_payload_length = 0; + uint32_t _masking_key; + buff_t _result; + + static future dont_stop() { + return make_ready_future(continue_consuming{}); + } + static future stop(buff_t data) { + return make_ready_future(stop_consuming(std::move(data))); + } + uint64_t remaining_payload_length() const { + return _payload_length - _consumed_payload_length; + } + + // Removes mask from payload given in p. + void remove_mask(buff_t& p, size_t n) { + char *payload = p.get_write(); + for (uint64_t i = 0, j = 0; i < n; ++i, j = (j + 1) % 4) { + payload[i] ^= static_cast(((_masking_key << (j * 8)) >> 24)); + } + } +public: + websocket_parser(bool is_client) : + _state(parsing_state::flags_and_payload_data), + _is_client{is_client}, + _cstate(connection_state::valid), + _masking_key(0) {} + future operator()(temporary_buffer data); + bool is_valid() { return _cstate == connection_state::valid; } + bool eof() { return _cstate == connection_state::closed; } + opcodes opcode() const; + buff_t result(); +}; + + +/*! + * \brief a server WebSocket connection + */ +class connection : public boost::intrusive::list_base_hook<> { +protected: + using buff_t = temporary_buffer; + + /*! + * \brief Implementation of connection's data source. + */ + class connection_source_impl final : public data_source_impl { + queue* data; + + public: + connection_source_impl(queue* data) : data(data) {} + + virtual future get() override { + return data->pop_eventually().then_wrapped([](future f){ + try { + return make_ready_future(std::move(f.get())); + } catch(...) { + return current_exception_as_future(); + } + }); + } + + virtual future<> close() override { + data->push(buff_t(0)); + return make_ready_future<>(); + } + }; + + /*! + * \brief Implementation of connection's data sink. + */ + class connection_sink_impl final : public data_sink_impl { + queue* data; + public: + connection_sink_impl(queue* data) : data(data) {} + + virtual future<> put(net::packet d) override { + net::fragment f = d.frag(0); + return data->push_eventually(temporary_buffer{std::move(f.base), f.size}); + } + + size_t buffer_size() const noexcept override { + return data->max_size(); + } + + virtual future<> close() override { + data->push(buff_t(0)); + return make_ready_future<>(); + } + }; + + /*! + * \brief This function processess received PING frame. + * https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 + */ + future<> handle_ping(); + /*! + * \brief This function processess received PONG frame. + * https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 + */ + future<> handle_pong(); + + static const size_t PIPE_SIZE = 512; + connected_socket _fd; + input_stream _read_buf; + output_stream _write_buf; + bool _done = false; + + websocket_parser _websocket_parser; + queue > _input_buffer; + input_stream _input; + queue > _output_buffer; + output_stream _output; + + sstring _subprotocol; + handler_t _handler; + bool _is_client; +public: + /*! + * \param fd established socket used for communication + * \param is_client Whether the connection is for the client. + */ + connection(connected_socket&& fd, bool is_client) + : _fd(std::move(fd)) + , _read_buf(_fd.input()) + , _write_buf(_fd.output()) + , _websocket_parser{is_client} + , _input_buffer{PIPE_SIZE} + , _output_buffer{PIPE_SIZE} + , _is_client{is_client} + { + _input = input_stream{data_source{ + std::make_unique(&_input_buffer)}}; + _output = output_stream{data_sink{ + std::make_unique(&_output_buffer)}}; + } + + /*! + * \brief close the socket + */ + void shutdown_input(); + future<> close(bool send_close = true); + +protected: + future<> read_one(); + future<> response_loop(); + /*! + * \brief Packs buff in websocket frame and sends it to the client. + */ + future<> send_data(opcodes opcode, temporary_buffer&& buff); +}; + +std::string sha1_base64(std::string_view source); +std::string encode_base64(std::string_view source); + +extern logger websocket_logger; + +/// @} +} diff --git a/include/seastar/websocket/server.hh b/include/seastar/websocket/server.hh index e8bfef99ce4..c8963c35890 100644 --- a/include/seastar/websocket/server.hh +++ b/include/seastar/websocket/server.hh @@ -25,272 +25,45 @@ #include #include -#include #include -#include #include -#include #include +#include namespace seastar::experimental::websocket { -using handler_t = std::function(input_stream&, output_stream&)>; - -class server; - -/// \defgroup websocket WebSocket /// \addtogroup websocket /// @{ /*! - * \brief an error in handling a WebSocket connection + * \brief a server WebSocket connection */ -class exception : public std::exception { - std::string _msg; -public: - exception(std::string_view msg) : _msg(msg) {} - virtual const char* what() const noexcept { - return _msg.c_str(); - } -}; - -/*! - * \brief Possible type of a websocket frame. - */ -enum opcodes { - CONTINUATION = 0x0, - TEXT = 0x1, - BINARY = 0x2, - CLOSE = 0x8, - PING = 0x9, - PONG = 0xA, - INVALID = 0xFF, -}; - -struct frame_header { - static constexpr uint8_t FIN = 7; - static constexpr uint8_t RSV1 = 6; - static constexpr uint8_t RSV2 = 5; - static constexpr uint8_t RSV3 = 4; - static constexpr uint8_t MASKED = 7; - - uint8_t fin : 1; - uint8_t rsv1 : 1; - uint8_t rsv2 : 1; - uint8_t rsv3 : 1; - uint8_t opcode : 4; - uint8_t masked : 1; - uint8_t length : 7; - frame_header(const char* input) { - this->fin = (input[0] >> FIN) & 1; - this->rsv1 = (input[0] >> RSV1) & 1; - this->rsv2 = (input[0] >> RSV2) & 1; - this->rsv3 = (input[0] >> RSV3) & 1; - this->opcode = input[0] & 0b1111; - this->masked = (input[1] >> MASKED) & 1; - this->length = (input[1] & 0b1111111); - } - // Returns length of the rest of the header. - uint64_t get_rest_of_header_length() { - size_t next_read_length = sizeof(uint32_t); // Masking key - if (length == 126) { - next_read_length += sizeof(uint16_t); - } else if (length == 127) { - next_read_length += sizeof(uint64_t); - } - return next_read_length; - } - uint8_t get_fin() {return fin;} - uint8_t get_rsv1() {return rsv1;} - uint8_t get_rsv2() {return rsv2;} - uint8_t get_rsv3() {return rsv3;} - uint8_t get_opcode() {return opcode;} - uint8_t get_masked() {return masked;} - uint8_t get_length() {return length;} +class server_connection : public connection { - bool is_opcode_known() { - //https://datatracker.ietf.org/doc/html/rfc6455#section-5.1 - return opcode < 0xA && !(opcode < 0x8 && opcode > 0x2); - } -}; - -class websocket_parser { - enum class parsing_state : uint8_t { - flags_and_payload_data, - payload_length_and_mask, - payload - }; - enum class connection_state : uint8_t { - valid, - closed, - error - }; - using consumption_result_t = consumption_result; - using buff_t = temporary_buffer; - // What parser is currently doing. - parsing_state _state; - // State of connection - can be valid, closed or should be closed - // due to error. - connection_state _cstate; - sstring _buffer; - std::unique_ptr _header; - uint64_t _payload_length; - uint64_t _consumed_payload_length = 0; - uint32_t _masking_key; - buff_t _result; - - static future dont_stop() { - return make_ready_future(continue_consuming{}); - } - static future stop(buff_t data) { - return make_ready_future(stop_consuming(std::move(data))); - } - uint64_t remaining_payload_length() const { - return _payload_length - _consumed_payload_length; - } - - // Removes mask from payload given in p. - void remove_mask(buff_t& p, size_t n) { - char *payload = p.get_write(); - for (uint64_t i = 0, j = 0; i < n; ++i, j = (j + 1) % 4) { - payload[i] ^= static_cast(((_masking_key << (j * 8)) >> 24)); - } - } -public: - websocket_parser() : _state(parsing_state::flags_and_payload_data), - _cstate(connection_state::valid), - _payload_length(0), - _masking_key(0) {} - future operator()(temporary_buffer data); - bool is_valid() { return _cstate == connection_state::valid; } - bool eof() { return _cstate == connection_state::closed; } - opcodes opcode() const; - buff_t result(); -}; - -/*! - * \brief a WebSocket connection - */ -class connection : public boost::intrusive::list_base_hook<> { - using buff_t = temporary_buffer; - - /*! - * \brief Implementation of connection's data source. - */ - class connection_source_impl final : public data_source_impl { - queue* data; - - public: - connection_source_impl(queue* data) : data(data) {} - - virtual future get() override { - return data->pop_eventually().then_wrapped([](future f){ - try { - return make_ready_future(std::move(f.get())); - } catch(...) { - return current_exception_as_future(); - } - }); - } - - virtual future<> close() override { - data->push(buff_t(0)); - return make_ready_future<>(); - } - }; - - /*! - * \brief Implementation of connection's data sink. - */ - class connection_sink_impl final : public data_sink_impl { - queue* data; - public: - connection_sink_impl(queue* data) : data(data) {} - - virtual future<> put(net::packet d) override { - net::fragment f = d.frag(0); - return data->push_eventually(temporary_buffer{std::move(f.base), f.size}); - } - - size_t buffer_size() const noexcept override { - return data->max_size(); - } - - virtual future<> close() override { - data->push(buff_t(0)); - return make_ready_future<>(); - } - }; - - /*! - * \brief This function processess received PING frame. - * https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 - */ - future<> handle_ping(); - /*! - * \brief This function processess received PONG frame. - * https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 - */ - future<> handle_pong(); - - static const size_t PIPE_SIZE = 512; server& _server; - connected_socket _fd; - input_stream _read_buf; - output_stream _write_buf; http_request_parser _http_parser; - bool _done = false; - websocket_parser _websocket_parser; - queue > _input_buffer; - input_stream _input; - queue > _output_buffer; - output_stream _output; - - sstring _subprotocol; - handler_t _handler; public: /*! * \param server owning \ref server * \param fd established socket used for communication */ - connection(server& server, connected_socket&& fd) - : _server(server) - , _fd(std::move(fd)) - , _read_buf(_fd.input()) - , _write_buf(_fd.output()) - , _input_buffer{PIPE_SIZE} - , _output_buffer{PIPE_SIZE} - { - _input = input_stream{data_source{ - std::make_unique(&_input_buffer)}}; - _output = output_stream{data_sink{ - std::make_unique(&_output_buffer)}}; + server_connection(server& server, connected_socket&& fd) + : connection(std::move(fd), false) + , _server(server) { on_new_connection(); } - ~connection(); + ~server_connection(); /*! - * \brief serve WebSocket protocol on a connection + * \brief serve WebSocket protocol on a server_connection */ future<> process(); - /*! - * \brief close the socket - */ - void shutdown_input(); - future<> close(bool send_close = true); protected: future<> read_loop(); - future<> read_one(); future<> read_http_upgrade_request(); - future<> response_loop(); void on_new_connection(); - /*! - * \brief Packs buff in websocket frame and sends it to the client. - */ - future<> send_data(opcodes opcode, temporary_buffer&& buff); - }; /*! @@ -301,7 +74,7 @@ protected: */ class server { std::vector _listeners; - boost::intrusive::list _connections; + boost::intrusive::list _connections; std::map _handlers; gate _task_gate; public: @@ -326,7 +99,7 @@ public: void register_handler(std::string&& name, handler_t handler); - friend class connection; + friend class server_connection; protected: void accept(server_socket &listener); future accept_one(server_socket &listener); diff --git a/src/http/client.cc b/src/http/client.cc index 499e45bd625..4053f77c31f 100644 --- a/src/http/client.cc +++ b/src/http/client.cc @@ -35,7 +35,6 @@ module seastar; #include #include #include -#include #include #include #include @@ -203,39 +202,11 @@ future<> connection::close() { }); } -class basic_connection_factory : public connection_factory { - socket_address _addr; -public: - explicit basic_connection_factory(socket_address addr) - : _addr(std::move(addr)) - { - } - virtual future make(abort_source* as) override { - return seastar::connect(_addr, {}, transport::TCP); - } -}; - client::client(socket_address addr) : client(std::make_unique(std::move(addr))) { } -class tls_connection_factory : public connection_factory { - socket_address _addr; - shared_ptr _creds; - sstring _host; -public: - tls_connection_factory(socket_address addr, shared_ptr creds, sstring host) - : _addr(std::move(addr)) - , _creds(std::move(creds)) - , _host(std::move(host)) - { - } - virtual future make(abort_source* as) override { - return tls::connect(_creds, _addr, tls::tls_options{.server_name = _host}); - } -}; - client::client(socket_address addr, shared_ptr creds, sstring host) : client(std::make_unique(std::move(addr), std::move(creds), std::move(host))) { diff --git a/src/websocket/client.cc b/src/websocket/client.cc new file mode 100644 index 00000000000..4e1daf38a93 --- /dev/null +++ b/src/websocket/client.cc @@ -0,0 +1,195 @@ +/* + * This file is open source software, licensed to you under the terms + * of the Apache License, Version 2.0 (the "License"). See the NOTICE file + * distributed with this work for additional information regarding copyright + * ownership. You may not use this file except in compliance with the License. + * + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace seastar::experimental::websocket { + +client_connection::client_connection(client& client, connected_socket&& fd, std::string_view ws_key, + const handler_t handler) + : connection(std::move(fd), true) + , _client{client} + , _ws_key{ws_key} { + _handler = std::move(handler); + _client._connections.push_back(*this); +} + +client_connection::~client_connection() { + _client._connections.erase(_client._connections.iterator_to(*this)); +} + +future<> client_connection::perform_handshake(const http::request& req) { + return send_request_head(req).then( + [this]{ return read_reply(); } + ).handle_exception([this](auto ep) { + websocket_logger.error("Got error during handshake {}", ep); + return _read_buf.close(); + }); +} + +future<> client_connection::send_request_head(const http::request& req) { + return _write_buf.write(req.request_line()).then([this, &req] { + return req.write_request_headers(_write_buf).then([this] { + return _write_buf.write("\r\n", 2); + }).then([this] { + return _write_buf.flush(); + }); + }); +} + +future<> client_connection::process() { + return when_all_succeed( + _handler(_input, _output).handle_exception([this] (std::exception_ptr e) mutable { + return _read_buf.close().then([e = std::move(e)] () mutable { + return make_exception_future<>(std::move(e)); + }); + }), + response_loop(), + do_until([this] {return _done;}, [this] {return read_one();}) + ).discard_result().finally([this] { + return _read_buf.close(); + }); +} + +future<> client_connection::send_message(temporary_buffer buf, bool flush) { + auto f = _output.write(std::move(buf)); + if (flush) { + f = f.then([this](){ return _output.flush(); }); + } + return f; +} + +future<> client_connection::read_reply() { + http_response_parser parser; + return do_with(std::move(parser), [this] (auto& parser) { + parser.init(); + return _read_buf.consume(parser).then([this, &parser] { + if (parser.eof()) { + websocket_logger.trace("Parsing response EOFed"); + throw std::system_error(ECONNABORTED, std::system_category()); + } + if (parser.failed()) { + websocket_logger.trace("Parsing response failed"); + throw std::runtime_error("Invalid http server response"); + } + + std::unique_ptr resp = parser.get_parsed_response(); + if (resp->_status != http::reply::status_type::switching_protocols) { + websocket_logger.trace("Didn't receive 101 switching protocols response"); + throw std::runtime_error("Invalid http server response"); + } + + if (resp->get_header("Upgrade").find("websocket") == sstring::npos) { + websocket_logger.trace("Bad or non-existing Upgrade header"); + throw std::runtime_error("Bad or non-existing Upgrade header"); + } + if (resp->get_header("Connection").find("Upgrade") == sstring::npos) { + websocket_logger.trace("Bad or non-existing Connection header"); + throw std::runtime_error("Bad or non-existing Connection header"); + } + auto accept = resp->get_header("Sec-WebSocket-Accept"); + if (accept.empty()) { + websocket_logger.trace("Did not receive Sec-WebSocket-Accept header"); + throw std::runtime_error("Did not receive Sec-WebSocket-Accept header"); + } + if (accept != sha1_base64(_ws_key + magic_key_suffix)) { + websocket_logger.trace("Received mismatching Sec-WebSocket-Accept header"); + throw std::runtime_error("Received mismatching Sec-WebSocket-Accept header"); + } + + return make_ready_future<>(); + }); + }); +} + +client::client(socket_address addr) + : client(std::make_unique( + std::move(addr))) +{ +} + +client::client(socket_address addr, shared_ptr creds, sstring host) + : client(std::make_unique( + std::move(addr), std::move(creds), std::move(host))) +{ +} +client::client(std::unique_ptr f) + : _new_connections(std::move(f)) + , _random_gen{_rd_device()} +{ +} + +future> + client::make_request(http::request rq, const handler_t& handler) { + if (rq._version.empty()) { + rq._version = "1.1"; + } + rq._headers["Upgrade"] = "websocket"; + rq._headers["Connection"] = "Upgrade"; + rq._headers["Sec-WebSocket-Version"] = "13"; + if (!_subprotocol.empty()) { + rq._headers["Sec-WebSocket-Protocol"] = _subprotocol; + } + + uint8_t key[16] = {}; + std::uniform_int_distribution dist(0, 255); + for (auto& key_char : key) { + key_char = dist(_random_gen); + } + + std::string ws_key = encode_base64(std::string_view(reinterpret_cast(key), sizeof(key))); + rq._headers["Sec-WebSocket-Key"] = ws_key; + + abort_source* as = nullptr; // TODO + + return do_with(std::move(rq), [this, as, handler, ws_key](auto& rq) { + return _new_connections->make(as).then([this, &rq, as, handler, ws_key] (connected_socket cs) { + websocket_logger.trace("created new http connection {}", cs.local_address()); + + auto con = seastar::make_shared(*this, std::move(cs), ws_key, handler); + + auto sub = as ? as->subscribe([con] () noexcept { con->shutdown_input(); }) : std::nullopt; + return con->perform_handshake(rq).then([con](){ return con; }); + }); + }); +} + +future<> client::stop() { + for (auto&& c : _connections) { + c.shutdown_input(); + } + + return parallel_for_each(_connections, [] (client_connection& conn) { + return conn.close(true).handle_exception([] (auto ignored) {}); + }); +} + +void client::set_seed(std::size_t seed) { + _random_gen.seed(seed); +} + +} diff --git a/src/websocket/common.cc b/src/websocket/common.cc new file mode 100644 index 00000000000..a37ca3497db --- /dev/null +++ b/src/websocket/common.cc @@ -0,0 +1,294 @@ +/* + * This file is open source software, licensed to you under the terms + * of the Apache License, Version 2.0 (the "License"). See the NOTICE file + * distributed with this work for additional information regarding copyright + * ownership. You may not use this file except in compliance with the License. + * + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include + +namespace seastar::experimental::websocket { + +sstring magic_key_suffix = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; +logger websocket_logger("websocket"); + +opcodes websocket_parser::opcode() const { + if (_header) { + return opcodes(_header->opcode); + } else { + return opcodes::INVALID; + } +} + +websocket_parser::buff_t websocket_parser::result() { + return std::move(_result); +} + +future websocket_parser::operator()( + temporary_buffer data) { + if (data.size() == 0) { + // EOF + _cstate = connection_state::closed; + return websocket_parser::stop(std::move(data)); + } + if (_state == parsing_state::flags_and_payload_data) { + if (_buffer.length() + data.size() >= 2) { + // _buffer.length() is less than 2 when entering this if body due to how + // the rest of code is structured. The else branch will never increase + // _buffer.length() to >=2 and other paths to this condition will always + // have buffer cleared. + assert(_buffer.length() < 2); + + size_t hlen = _buffer.length(); + _buffer.append(data.get(), 2 - hlen); + data.trim_front(2 - hlen); + _header = std::make_unique(_buffer.data()); + _buffer = {}; + + // https://datatracker.ietf.org/doc/html/rfc6455#section-5.1 + // We must close the connection if data isn't masked. + if (((!_header->masked) && (!_is_client)) || + // RSVX must be 0 + (_header->rsv1 | _header->rsv2 | _header->rsv3) || + // Opcode must be known. + (!_header->is_opcode_known())) { + _cstate = connection_state::error; + return websocket_parser::stop(std::move(data)); + } + _state = parsing_state::payload_length_and_mask; + } else { + _buffer.append(data.get(), data.size()); + return websocket_parser::dont_stop(); + } + } + if (_state == parsing_state::payload_length_and_mask) { + size_t const required_bytes = _header->get_rest_of_header_length(); + if (_buffer.length() + data.size() >= required_bytes) { + if (_buffer.length() < required_bytes) { + size_t hlen = _buffer.length(); + _buffer.append(data.get(), required_bytes - hlen); + data.trim_front(required_bytes - hlen); + } + _payload_length = _header->length; + char const *input = _buffer.data(); + if (_header->length == 126) { + _payload_length = consume_be(input); + } else if (_header->length == 127) { + _payload_length = consume_be(input); + } + + if (_header->get_masked()) { + _masking_key = consume_be(input); + } + _buffer = {}; + _state = parsing_state::payload; + } else { + _buffer.append(data.get(), data.size()); + return websocket_parser::dont_stop(); + } + } + if (_state == parsing_state::payload) { + if (data.size() < remaining_payload_length()) { + // data has insufficient data to complete the frame - consume data.size() bytes + if (_result.empty()) { + _result = temporary_buffer(remaining_payload_length()); + _consumed_payload_length = 0; + } + std::copy(data.begin(), data.end(), _result.get_write() + _consumed_payload_length); + _consumed_payload_length += data.size(); + return websocket_parser::dont_stop(); + } else { + // data has sufficient data to complete the frame - consume remaining_payload_length() + auto consumed_bytes = remaining_payload_length(); + if (_result.empty()) { + // Try to avoid memory copies in case when network packets contain one or more full + // websocket frames. + if (consumed_bytes == data.size()) { + _result = std::move(data); + data = temporary_buffer(0); + } else { + _result = data.share(); + _result.trim(consumed_bytes); + data.trim_front(consumed_bytes); + } + } else { + std::copy(data.begin(), data.begin() + consumed_bytes, + _result.get_write() + _consumed_payload_length); + data.trim_front(consumed_bytes); + } + if (!_is_client) { + remove_mask(_result, _payload_length); + } + _consumed_payload_length = 0; + _state = parsing_state::flags_and_payload_data; + return websocket_parser::stop(std::move(data)); + } + } + _cstate = connection_state::error; + return websocket_parser::stop(std::move(data)); +} + +future<> connection::handle_ping() { + // TODO + return make_ready_future<>(); +} + +future<> connection::handle_pong() { + // TODO + return make_ready_future<>(); +} + +future<> connection::send_data(opcodes opcode, temporary_buffer&& buff) { + // Maximum length of header is 14: + // 2 for static part of the header + // 8 for payload length field at maximum size + // 4 for optional mask + char header[14] = {'\x80', 0}; + size_t header_size = 2; + + header[0] += opcode; + + if ((126 <= buff.size()) && (buff.size() <= std::numeric_limits::max())) { + header[1] = 0x7E; + write_be(header + 2, buff.size()); + header_size += sizeof(uint16_t); + } else if (std::numeric_limits::max() < buff.size()) { + header[1] = 0x7F; + write_be(header + 2, buff.size()); + header_size += sizeof(uint64_t); + } else { + header[1] = uint8_t(buff.size()); + } + + temporary_buffer write_buf; + if (_is_client) { + header[1] |= 0x80; + // https://datatracker.ietf.org/doc/html/rfc6455#section-5.3 requires that the masking key + // must be unpredictable and derived from a strong source of entropy. This requirement + // arose due to usage of WebSocket protocol in the browsers, where both server and client + // payload generation code may be malicious and the only trusted piece is the browser + // itself. Consequently there was a need to make the bytes on the wire unpredictable for the + // client code, so that it cannot run attacks against intermediate proxies that do not + // understand WebSocket. + // + // In the case of Seastar there is no security boundary between payload generator and + // payload serializer in this class, accordingly in terms of security impact it is + // sufficient to simply use predictable masking key. Zero is chosen because it does not + // change the payload. + uint32_t masking_key = 0; + write_be(header + header_size, masking_key); + header_size += sizeof(uint32_t); + } + + scattered_message msg; + msg.append(sstring(header, header_size)); + msg.append(std::move(buff)); + return _write_buf.write(std::move(msg)).then([this] { + return _write_buf.flush(); + }); +} + +future<> connection::response_loop() { + return do_until([this] {return _done;}, [this] { + // FIXME: implement error handling + return _output_buffer.pop_eventually().then([this] ( + temporary_buffer buf) { + return send_data(opcodes::BINARY, std::move(buf)); + }); + }).finally([this]() { + return _write_buf.close(); + }); +} + +void connection::shutdown_input() { + _fd.shutdown_input(); +} + +future<> connection::close(bool send_close) { + return [this, send_close]() { + if (send_close) { + return send_data(opcodes::CLOSE, temporary_buffer(0)); + } else { + return make_ready_future<>(); + } + }().finally([this] { + _done = true; + return when_all_succeed(_input.close(), _output.close()).discard_result().finally([this] { + _fd.shutdown_output(); + }); + }); +} + +future<> connection::read_one() { + return _read_buf.consume(_websocket_parser).then([this] () mutable { + if (_websocket_parser.is_valid()) { + // FIXME: implement error handling + switch(_websocket_parser.opcode()) { + // We do not distinguish between these 3 types. + case opcodes::CONTINUATION: + case opcodes::TEXT: + case opcodes::BINARY: + return _input_buffer.push_eventually(_websocket_parser.result()); + case opcodes::CLOSE: + websocket_logger.debug("Received close frame."); + // datatracker.ietf.org/doc/html/rfc6455#section-5.5.1 + return close(true); + case opcodes::PING: + websocket_logger.debug("Received ping frame."); + return handle_ping(); + case opcodes::PONG: + websocket_logger.debug("Received pong frame."); + return handle_pong(); + default: + // Invalid - do nothing. + ; + } + } else if (_websocket_parser.eof()) { + return close(false); + } + websocket_logger.debug("Reading from socket has failed."); + return close(true); + }); +} + +std::string sha1_base64(std::string_view source) { + unsigned char hash[20]; + assert(sizeof(hash) == gnutls_hash_get_len(GNUTLS_DIG_SHA1)); + if (int ret = gnutls_hash_fast(GNUTLS_DIG_SHA1, source.data(), source.size(), hash); + ret != GNUTLS_E_SUCCESS) { + throw websocket::exception(fmt::format("gnutls_hash_fast: {}", gnutls_strerror(ret))); + } + return encode_base64(std::string_view(reinterpret_cast(hash), sizeof(hash))); +} + +std::string encode_base64(std::string_view source) { + gnutls_datum_t src_data{ + .data = reinterpret_cast(const_cast(source.data())), + .size = static_cast(source.size()) + }; + gnutls_datum_t encoded_data; + if (int ret = gnutls_base64_encode2(&src_data, &encoded_data); ret != GNUTLS_E_SUCCESS) { + throw websocket::exception(fmt::format("gnutls_base64_encode2: {}", gnutls_strerror(ret))); + } + auto free_encoded_data = defer([&] () noexcept { gnutls_free(encoded_data.data); }); + // base64_encoded.data is "unsigned char *" + return std::string(reinterpret_cast(encoded_data.data), encoded_data.size); +} + +} diff --git a/src/websocket/server.cc b/src/websocket/server.cc index f9c26bfe293..45b5beb477a 100644 --- a/src/websocket/server.cc +++ b/src/websocket/server.cc @@ -25,14 +25,10 @@ #include #include #include -#include #include -#include -#include namespace seastar::experimental::websocket { -static sstring magic_key_suffix = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; static sstring http_upgrade_reply_template = "HTTP/1.1 101 Switching Protocols\r\n" "Upgrade: websocket\r\n" @@ -40,20 +36,6 @@ static sstring http_upgrade_reply_template = "Sec-WebSocket-Version: 13\r\n" "Sec-WebSocket-Accept: "; -static logger wlogger("websocket"); - -opcodes websocket_parser::opcode() const { - if (_header) { - return opcodes(_header->opcode); - } else { - return opcodes::INVALID; - } -} - -websocket_parser::buff_t websocket_parser::result() { - return std::move(_result); -} - void server::listen(socket_address addr, listen_options lo) { _listeners.push_back(seastar::listen(addr, lo)); accept(_listeners.back()); @@ -74,10 +56,10 @@ void server::accept(server_socket &listener) { future server::accept_one(server_socket &listener) { return listener.accept().then([this](accept_result ar) { - auto conn = std::make_unique(*this, std::move(ar.connection)); + auto conn = std::make_unique(*this, std::move(ar.connection)); (void)try_with_gate(_task_gate, [conn = std::move(conn)]() mutable { return conn->process().finally([conn = std::move(conn)] { - wlogger.debug("Connection is finished"); + websocket_logger.debug("Connection is finished"); }); }).handle_exception_type([](const gate_closed_exception &e) {}); return make_ready_future(stop_iteration::no); @@ -85,11 +67,11 @@ future server::accept_one(server_socket &listener) { // We expect a ECONNABORTED when server::stop is called, // no point in warning about that. if (e.code().value() != ECONNABORTED) { - wlogger.error("accept failed: {}", e); + websocket_logger.error("accept failed: {}", e); } return make_ready_future(stop_iteration::yes); }).handle_exception([](std::exception_ptr ex) { - wlogger.info("accept failed: {}", ex); + websocket_logger.info("accept failed: {}", ex); return make_ready_future(stop_iteration::yes); }); } @@ -104,48 +86,27 @@ future<> server::stop() { } return _task_gate.close().finally([this] { - return parallel_for_each(_connections, [] (connection& conn) { + return parallel_for_each(_connections, [] (server_connection& conn) { return conn.close(true).handle_exception([] (auto ignored) {}); }); }); } -connection::~connection() { +server_connection::~server_connection() { _server._connections.erase(_server._connections.iterator_to(*this)); } -void connection::on_new_connection() { +void server_connection::on_new_connection() { _server._connections.push_back(*this); } -future<> connection::process() { +future<> server_connection::process() { return when_all_succeed(read_loop(), response_loop()).discard_result().handle_exception([] (const std::exception_ptr& e) { - wlogger.debug("Processing failed: {}", e); + websocket_logger.debug("Processing failed: {}", e); }); } -static std::string sha1_base64(std::string_view source) { - unsigned char hash[20]; - assert(sizeof(hash) == gnutls_hash_get_len(GNUTLS_DIG_SHA1)); - if (int ret = gnutls_hash_fast(GNUTLS_DIG_SHA1, source.data(), source.size(), hash); - ret != GNUTLS_E_SUCCESS) { - throw websocket::exception(fmt::format("gnutls_hash_fast: {}", gnutls_strerror(ret))); - } - gnutls_datum_t hash_data{ - .data = hash, - .size = sizeof(hash), - }; - gnutls_datum_t base64_encoded; - if (int ret = gnutls_base64_encode2(&hash_data, &base64_encoded); - ret != GNUTLS_E_SUCCESS) { - throw websocket::exception(fmt::format("gnutls_base64_encode2: {}", gnutls_strerror(ret))); - } - auto free_base64_encoded = defer([&] () noexcept { gnutls_free(base64_encoded.data); }); - // base64_encoded.data is "unsigned char *" - return std::string(reinterpret_cast(base64_encoded.data), base64_encoded.size); -} - -future<> connection::read_http_upgrade_request() { +future<> server_connection::read_http_upgrade_request() { _http_parser.init(); return _read_buf.consume(_http_parser).then([this] () mutable { if (_http_parser.eof()) { @@ -172,17 +133,17 @@ future<> connection::read_http_upgrade_request() { } this->_handler = this->_server._handlers[subprotocol]; this->_subprotocol = subprotocol; - wlogger.debug("Sec-WebSocket-Protocol: {}", subprotocol); + websocket_logger.debug("Sec-WebSocket-Protocol: {}", subprotocol); sstring sec_key = req->get_header("Sec-Websocket-Key"); sstring sec_version = req->get_header("Sec-Websocket-Version"); sstring sha1_input = sec_key + magic_key_suffix; - wlogger.debug("Sec-Websocket-Key: {}, Sec-Websocket-Version: {}", sec_key, sec_version); + websocket_logger.debug("Sec-Websocket-Key: {}, Sec-Websocket-Version: {}", sec_key, sec_version); std::string sha1_output = sha1_base64(sha1_input); - wlogger.debug("SHA1 output: {} of size {}", sha1_output, sha1_output.size()); + websocket_logger.debug("SHA1 output: {} of size {}", sha1_output, sha1_output.size()); return _write_buf.write(http_upgrade_reply_template).then([this, sha1_output = std::move(sha1_output)] { return _write_buf.write(sha1_output); @@ -198,152 +159,7 @@ future<> connection::read_http_upgrade_request() { }); } -future websocket_parser::operator()( - temporary_buffer data) { - if (data.size() == 0) { - // EOF - _cstate = connection_state::closed; - return websocket_parser::stop(std::move(data)); - } - if (_state == parsing_state::flags_and_payload_data) { - if (_buffer.length() + data.size() >= 2) { - // _buffer.length() is less than 2 when entering this if body due to how - // the rest of code is structured. The else branch will never increase - // _buffer.length() to >=2 and other paths to this condition will always - // have buffer cleared. - assert(_buffer.length() < 2); - - size_t hlen = _buffer.length(); - _buffer.append(data.get(), 2 - hlen); - data.trim_front(2 - hlen); - _header = std::make_unique(_buffer.data()); - _buffer = {}; - - // https://datatracker.ietf.org/doc/html/rfc6455#section-5.1 - // We must close the connection if data isn't masked. - if ((!_header->masked) || - // RSVX must be 0 - (_header->rsv1 | _header->rsv2 | _header->rsv3) || - // Opcode must be known. - (!_header->is_opcode_known())) { - _cstate = connection_state::error; - return websocket_parser::stop(std::move(data)); - } - _state = parsing_state::payload_length_and_mask; - } else { - _buffer.append(data.get(), data.size()); - return websocket_parser::dont_stop(); - } - } - if (_state == parsing_state::payload_length_and_mask) { - size_t const required_bytes = _header->get_rest_of_header_length(); - if (_buffer.length() + data.size() >= required_bytes) { - if (_buffer.length() < required_bytes) { - size_t hlen = _buffer.length(); - _buffer.append(data.get(), required_bytes - hlen); - data.trim_front(required_bytes - hlen); - } - _payload_length = _header->length; - char const *input = _buffer.data(); - if (_header->length == 126) { - _payload_length = consume_be(input); - } else if (_header->length == 127) { - _payload_length = consume_be(input); - } - - _masking_key = consume_be(input); - _buffer = {}; - _state = parsing_state::payload; - } else { - _buffer.append(data.get(), data.size()); - return websocket_parser::dont_stop(); - } - } - if (_state == parsing_state::payload) { - if (data.size() < remaining_payload_length()) { - // data has insufficient data to complete the frame - consume data.size() bytes - if (_result.empty()) { - _result = temporary_buffer(remaining_payload_length()); - _consumed_payload_length = 0; - } - std::copy(data.begin(), data.end(), _result.get_write() + _consumed_payload_length); - _consumed_payload_length += data.size(); - return websocket_parser::dont_stop(); - } else { - // data has sufficient data to complete the frame - consume remaining_payload_length() - auto consumed_bytes = remaining_payload_length(); - if (_result.empty()) { - // Try to avoid memory copies in case when network packets contain one or more full - // websocket frames. - if (consumed_bytes == data.size()) { - _result = std::move(data); - data = temporary_buffer(0); - } else { - _result = data.share(); - _result.trim(consumed_bytes); - data.trim_front(consumed_bytes); - } - } else { - std::copy(data.begin(), data.begin() + consumed_bytes, - _result.get_write() + _consumed_payload_length); - data.trim_front(consumed_bytes); - } - remove_mask(_result, _payload_length); - _consumed_payload_length = 0; - _state = parsing_state::flags_and_payload_data; - return websocket_parser::stop(std::move(data)); - } - } - _cstate = connection_state::error; - return websocket_parser::stop(std::move(data)); -} - -future<> connection::handle_ping() { - // TODO - return make_ready_future<>(); -} - -future<> connection::handle_pong() { - // TODO - return make_ready_future<>(); -} - - -future<> connection::read_one() { - return _read_buf.consume(_websocket_parser).then([this] () mutable { - if (_websocket_parser.is_valid()) { - // FIXME: implement error handling - switch(_websocket_parser.opcode()) { - // We do not distinguish between these 3 types. - case opcodes::CONTINUATION: - case opcodes::TEXT: - case opcodes::BINARY: - return _input_buffer.push_eventually(_websocket_parser.result()); - case opcodes::CLOSE: - wlogger.debug("Received close frame."); - /* - * datatracker.ietf.org/doc/html/rfc6455#section-5.5.1 - */ - return close(true); - case opcodes::PING: - wlogger.debug("Received ping frame."); - return handle_ping(); - case opcodes::PONG: - wlogger.debug("Received pong frame."); - return handle_pong(); - default: - // Invalid - do nothing. - ; - } - } else if (_websocket_parser.eof()) { - return close(false); - } - wlogger.debug("Reading from socket has failed."); - return close(true); - }); -} - -future<> connection::read_loop() { +future<> server_connection::read_loop() { return read_http_upgrade_request().then([this] { return when_all_succeed( _handler(_input, _output).handle_exception([this] (std::exception_ptr e) mutable { @@ -358,63 +174,6 @@ future<> connection::read_loop() { }); } -void connection::shutdown_input() { - _fd.shutdown_input(); -} - -future<> connection::close(bool send_close) { - return [this, send_close]() { - if (send_close) { - return send_data(opcodes::CLOSE, temporary_buffer(0)); - } else { - return make_ready_future<>(); - } - }().finally([this] { - _done = true; - return when_all_succeed(_input.close(), _output.close()).discard_result().finally([this] { - _fd.shutdown_output(); - }); - }); -} - -future<> connection::send_data(opcodes opcode, temporary_buffer&& buff) { - char header[10] = {'\x80', 0}; - size_t header_size = 2; - - header[0] += opcode; - - if ((126 <= buff.size()) && (buff.size() <= std::numeric_limits::max())) { - header[1] = 0x7E; - write_be(header + 2, buff.size()); - header_size += sizeof(uint16_t); - } else if (std::numeric_limits::max() < buff.size()) { - header[1] = 0x7F; - write_be(header + 2, buff.size()); - header_size += sizeof(uint64_t); - } else { - header[1] = uint8_t(buff.size()); - } - - scattered_message msg; - msg.append(sstring(header, header_size)); - msg.append(std::move(buff)); - return _write_buf.write(std::move(msg)).then([this] { - return _write_buf.flush(); - }); -} - -future<> connection::response_loop() { - return do_until([this] {return _done;}, [this] { - // FIXME: implement error handling - return _output_buffer.pop_eventually().then([this] ( - temporary_buffer buf) { - return send_data(opcodes::BINARY, std::move(buf)); - }); - }).finally([this]() { - return _write_buf.close(); - }); -} - bool server::is_handler_registered(std::string const& name) { return _handlers.find(name) != _handlers.end(); } diff --git a/tests/unit/websocket_test.cc b/tests/unit/websocket_test.cc index e6d0a446de6..eade98ab9c2 100644 --- a/tests/unit/websocket_test.cc +++ b/tests/unit/websocket_test.cc @@ -49,7 +49,7 @@ SEASTAR_TEST_CASE(test_websocket_handshake) { }); }); }); - websocket::connection conn(dummy, acceptor.get().connection); + websocket::server_connection conn(dummy, acceptor.get().connection); future<> serve = conn.process(); auto close = defer([&conn, &input, &output, &serve] () noexcept { conn.close().get(); @@ -113,7 +113,7 @@ SEASTAR_TEST_CASE(test_websocket_handler_registration) { }); }); }); - websocket::connection conn(ws, acceptor.get().connection); + websocket::server_connection conn(ws, acceptor.get().connection); future<> serve = conn.process(); auto close = defer([&conn, &input, &output, &serve] () noexcept {