Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom agent strings #63

Merged
merged 17 commits into from
Jul 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions lib/malloy/client/controller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@ auto controller::run() -> bool
return true;
}

auto controller::init(config cfg) -> bool {
if (!malloy::controller::init(cfg)) {
return false;
}
m_cfg = std::move(cfg);
return true;
}
auto controller::start() -> bool {
return root_start(m_cfg);
}

#if MALLOY_FEATURE_TLS
bool controller::init_tls()
{
Expand Down
98 changes: 57 additions & 41 deletions lib/malloy/client/controller.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "../core/http/response.hpp"
#include "../core/http/type_traits.hpp"
#include "../core/error.hpp"
#include "malloy/core/http/utils.hpp"


#if MALLOY_FEATURE_TLS
#include "http/connection_tls.hpp"
Expand Down Expand Up @@ -66,11 +68,20 @@ namespace malloy::client
public:
struct config :
malloy::controller::config {

/**
* @brief Agent string used for connections
* @details Set as the User-Agent in http headers
*/
std::string user_agent{"malloy-client"};
};

controller() = default;
~controller() override = default;

[[nodiscard("init may fail")]]
auto init(config cfg) -> bool;

#if MALLOY_FEATURE_TLS
/**
* Initialize the TLS context.
Expand All @@ -97,25 +108,7 @@ namespace malloy::client
requires concepts::http_callback<Callback, Filter>
[[nodiscard]] auto http_request(malloy::http::request<ReqBody> req, Callback&& done, Filter filter = {}) -> std::future<malloy::error_code>
{

// Create connection
auto conn = std::make_shared<http::connection_plain<ReqBody, Filter, std::decay_t<Callback>>>(
m_cfg.logger->clone(m_cfg.logger->name() + " | HTTP connection"),
io_ctx()
);

// Run
std::promise<malloy::error_code> prom;
auto err_channel = prom.get_future();
conn->run(
std::to_string(req.port()).c_str(),
req,
std::move(prom),
std::move(done),
std::move(filter)
);

return err_channel;
return make_http_connection<false>(std::move(req), std::forward<Callback>(done), std::move(filter));
}

#if MALLOY_FEATURE_TLS
Expand All @@ -128,27 +121,7 @@ namespace malloy::client
requires concepts::http_callback<Callback, Filter>
[[nodiscard]] auto https_request(malloy::http::request<ReqBody> req, Callback&& done, Filter filter = {}) -> std::future<malloy::error_code>
{
check_tls();

// Create connection
auto conn = std::make_shared<http::connection_tls<ReqBody, Filter, std::decay_t<Callback>>>(
m_cfg.logger->clone(m_cfg.logger->name() + " | HTTP connection"),
io_ctx(),
*m_tls_ctx
);

// Run
std::promise<malloy::error_code> prom;
auto err_channel = prom.get_future();
conn->run(
std::to_string(req.port()).c_str(),
req,
std::move(prom),
std::move(done),
std::move(filter)
);

return err_channel;
return make_http_connection<true>(std::move(req), std::forward<Callback>(done), std::move(filter));
}

/**
Expand Down Expand Up @@ -219,8 +192,13 @@ namespace malloy::client
*/
auto run() -> bool;

auto start() -> bool;

protected:

private:
std::shared_ptr<boost::asio::ssl::context> m_tls_ctx;
config m_cfg;

/**
* Checks whether the TLS context was initialized.
Expand All @@ -233,6 +211,44 @@ namespace malloy::client
throw std::logic_error("TLS context not initialized.");
}

template<bool isHttps, malloy::http::concepts::body Body, typename Callback, typename Filter>
auto make_http_connection(malloy::http::request<Body>&& req, Callback&& cb, Filter&& filter) -> std::future<malloy::error_code>
{

std::promise<malloy::error_code> prom;
auto err_channel = prom.get_future();
[this](auto&& cb) {
#if MALLOY_FEATURE_TLS
if constexpr (isHttps) {
init_tls();
cb(std::make_shared<http::connection_tls<Body, Filter, std::decay_t<Callback>>>(
m_cfg.logger->clone(m_cfg.logger->name() + " | HTTP connection"),
io_ctx(),
*m_tls_ctx));
return;
}
#endif
cb(std::make_shared<http::connection_plain<Body, Filter, std::decay_t<Callback>>>(
m_cfg.logger->clone(m_cfg.logger->name() + " | HTTP connection"),
io_ctx()));
}([this, prom = std::move(prom), req = std::move(req), filter = std::forward<Filter>(filter), cb = std::forward<Callback>(cb)](auto&& conn) mutable {
if (!malloy::http::has_field(req, malloy::http::field::user_agent)) {
req.set(malloy::http::field::user_agent, m_cfg.user_agent);
}

// Run
conn->run(
std::to_string(req.port()).c_str(),
req,
std::move(prom),
std::forward<Callback>(cb),
std::forward<Filter>(filter));
});

return err_channel;

}

template<bool isSecure>
void make_ws_connection(
const std::string& host,
Expand All @@ -258,7 +274,7 @@ namespace malloy::client
} else
#endif
return malloy::websocket::stream{boost::beast::tcp_stream{boost::asio::make_strand(io_ctx())}};
}());
}(), m_cfg.user_agent);

conn->connect(results, resource, [conn, done = std::forward<decltype(done)>(done)](auto ec) mutable {
if (ec) {
Expand Down
16 changes: 6 additions & 10 deletions lib/malloy/core/controller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ controller::~controller()
stop().wait();
}

bool controller::init(config cfg)
bool controller::init(const config& cfg)
{
// Don't initialize if not stopped
if (m_state != state::stopped)
Expand All @@ -32,8 +32,6 @@ bool controller::init(config cfg)
return false;
}

// Grab the config
m_cfg = std::move(cfg);

// Create the I/O context
m_io_ctx = std::make_shared<boost::asio::io_context>();
Expand All @@ -45,17 +43,17 @@ bool controller::init(config cfg)
return true;
}

bool controller::start()
bool controller::root_start(const config& cfg)
{
// Sanity check
if (!m_io_ctx) {
m_cfg.logger->critical("no I/O context present. Make sure that init() was called and succeeded.");
cfg.logger->critical("no I/O context present. Make sure that init() was called and succeeded.");
return false;
}

// Create the I/O context threads
m_io_threads.reserve(m_cfg.num_threads - 1);
for (std::size_t i = 0; i < m_cfg.num_threads; i++) {
m_io_threads.reserve(cfg.num_threads - 1);
for (std::size_t i = 0; i < cfg.num_threads; i++) {
m_io_threads.emplace_back(
[this]
{
Expand All @@ -65,7 +63,7 @@ bool controller::start()
}

// Log
m_cfg.logger->debug("starting i/o context.");
cfg.logger->debug("starting i/o context.");

// Update state
m_state = state::running;
Expand Down Expand Up @@ -97,14 +95,12 @@ std::future<void> controller::stop()
// Tell the workguard that we no longer need it's service
m_workguard->reset();

m_cfg.logger->debug("waiting for I/O threads to stop...");

for (auto& thread : m_io_threads)
thread.join();

m_state = state::stopped;

m_cfg.logger->debug("all I/O threads stopped.");
}
);
}
17 changes: 5 additions & 12 deletions lib/malloy/core/controller.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,7 @@ namespace malloy
controller() = default;
virtual ~controller();

[[nodiscard("init may fail")]]
virtual
bool init(config cfg);

/**
* Start the server. This function will not return until the server is stopped.
*
* @return Whether starting the server was successful.
*/
[[nodiscard("start may fail")]]
virtual
bool start();

/**
* Stop the server.
Expand All @@ -64,14 +53,18 @@ namespace malloy
std::future<void> stop();

protected:
config m_cfg;
[[nodiscard("init may fail")]]
bool init(const config& cfg);

[[nodiscard]]
boost::asio::io_context&
io_ctx() const noexcept
{
return *m_io_ctx;
}
[[nodiscard("start may fail")]]
auto root_start(const config& cfg) -> bool;


void remove_workguard() const;

Expand Down
6 changes: 6 additions & 0 deletions lib/malloy/core/http/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ namespace malloy::http
{
head.target(head.target().substr(resource.size()));
}

template<bool isReq, typename Fields>
auto has_field(const boost::beast::http::header<isReq, Fields>& head, const malloy::http::field check) -> bool
{
return head.find(check) != head.end();
}
}


50 changes: 12 additions & 38 deletions lib/malloy/core/websocket/connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,6 @@

namespace malloy::websocket
{
namespace detail
{
constexpr std::string_view beast_version = BOOST_BEAST_VERSION_STRING;
template<bool isClient>
constexpr auto ws_agent_string() -> std::string_view
{
if (isClient) {
return {BOOST_BEAST_VERSION_STRING " websocket-client-async"};
} else {
return {BOOST_BEAST_VERSION_STRING " malloy"};
}
}
} // namespace detail

/**
* @class connection
* @tparam isClient: Whether it is the client end of a websocket connection
Expand All @@ -54,11 +40,6 @@ namespace malloy::websocket
closed
};

/**
* The agent string.
*/
constexpr static std::string_view agent_string = detail::ws_agent_string<isClient>();

/**
* See stream::set_binary(bool)
*/
Expand All @@ -76,12 +57,12 @@ namespace malloy::websocket
* `connect` must be called before this connection can be used
*/
static auto
make(const std::shared_ptr<spdlog::logger> logger, stream&& ws) -> std::shared_ptr<connection>
make(const std::shared_ptr<spdlog::logger> logger, stream&& ws, const std::string& agent_string) -> std::shared_ptr<connection>
{
// We have to emulate make_shared here because the ctor is private
connection* me = nullptr;
try {
me = new connection{logger, std::move(ws)};
me = new connection{logger, std::move(ws), agent_string};
return std::shared_ptr<connection>{me};
} catch (...) {
delete me;
Expand Down Expand Up @@ -254,13 +235,15 @@ namespace malloy::websocket
std::vector<std::function<void()>> msg_queue_;
std::shared_ptr<spdlog::logger> m_logger;
stream m_ws;
std::string m_agent_string;

enum state m_state = state::closed;

connection(
std::shared_ptr<spdlog::logger> logger, stream&& ws) :
std::shared_ptr<spdlog::logger> logger, stream&& ws, std::string agent_str) :
m_logger(std::move(logger)),
m_ws{std::move(ws)}
m_ws{std::move(ws)},
m_agent_string{std::move(agent_str)}
{
// Sanity check logger
if (!m_logger)
Expand All @@ -275,21 +258,12 @@ namespace malloy::websocket
boost::beast::websocket::stream_base::timeout::suggested(
isClient ? boost::beast::role_type::client : boost::beast::role_type::server));

if constexpr (isClient) {
// Set a decorator to change the User-Agent of the handshake
m_ws.set_option(
boost::beast::websocket::stream_base::decorator(
[](boost::beast::websocket::request_type& req) {
req.set(boost::beast::http::field::user_agent, agent_string);
}));
} else {
// Set a decorator to change the Server of the handshake
m_ws.set_option(
boost::beast::websocket::stream_base::decorator(
[](boost::beast::websocket::response_type& res) {
res.set(boost::beast::http::field::server, agent_string);
}));
}
const auto agent_field = isClient ? malloy::http::field::user_agent : malloy::http::field::server;
m_ws.set_option(
boost::beast::websocket::stream_base::decorator(
[this, agent_field](boost::beast::websocket::request_type& req) {
req.set(agent_field, m_agent_string);
}));
}

void
Expand Down
Loading