diff --git a/library.cpp b/library.cpp index 4f95939..8626a31 100644 --- a/library.cpp +++ b/library.cpp @@ -5,6 +5,7 @@ #include #include #include +#include namespace net = boost::asio; // from namespace beast = boost::beast; // from @@ -14,46 +15,75 @@ using tcp = boost::asio::ip::tcp; // from using namespace std::chrono_literals; #define TRACE(stream, msg) do { stream << L" " << msg << std::endl; } while(0) -#define VERBOSE(msg) do { if (s_enable_verbose) { TRACE(std::wcout, msg); } } while(0) +#define VERBOSE(msg) do { if (g_enable_verbose) { TRACE(std::wcout, msg); } } while(0) #define COUT(msg) TRACE(std::wcout, msg) #define CERR(msg) TRACE(std::wcerr, msg) namespace /*anon*/ { - static on_connect_t s_on_connect_cb{nullptr}; - static on_fail_t s_on_fail_cb{nullptr}; - static on_disconnect_t s_on_disconnect_cb{nullptr}; - static on_data_t s_on_data_cb{nullptr}; - // Global variables - static std::atomic_bool s_enable_verbose{false}; + static std::atomic_bool g_enable_verbose{false}; + static net::thread_pool g_thread_pool; class Session; using SessionPtr = std::shared_ptr; + using Handle = size_t; struct Manager { - // TODO maybe allow multiple sessions and use weak pointers instead? - static inline SessionPtr Install(SessionPtr sess) + + inline Handle Register(SessionPtr const& sess) { - std::shared_ptr no_session{}; - if (!std::atomic_compare_exchange_strong(&s_instance_unsafe, &no_session, sess)) { - return nullptr; - } - return sess; + std::lock_guard lk(mx_); + + garbage_collect(); + assert(sess); + + Handle handle = next_handle_++; + auto [it, ok] = sessions_.emplace(handle, sess); + assert(ok); + + return handle; } - static inline SessionPtr Active() { - return std::atomic_load(&s_instance_unsafe); + inline bool Forget(Handle h) + { + std::lock_guard lk(mx_); + + garbage_collect(); + + if (auto it = sessions_.find(h); it != end(sessions_)) { + sessions_.erase(it); + return true; + } + + return false; } - static inline bool Clear(SessionPtr sess) + inline SessionPtr Active(Handle h) { - return std::atomic_compare_exchange_strong(&s_instance_unsafe, &sess, {}); + std::lock_guard lk(mx_); + + if (auto it = sessions_.find(h); it != end(sessions_)) + return it->second.lock(); + + return nullptr; } private: - static SessionPtr s_instance_unsafe; // use atomic_ operations to safely access - }; - /*static*/ SessionPtr Manager::s_instance_unsafe; + using Registry = std::map>; + std::mutex mx_; + size_t next_handle_ = 1; + Registry sessions_; + + void garbage_collect() // lock must be held + { + for (auto it = begin(sessions_); it != end(sessions_);) { + if (it->second.expired()) + it = sessions_.erase(it); + else + ++it; + } + } + } g_session_manager; std::string utf8_encode(std::wstring const& wstr) { @@ -111,12 +141,11 @@ namespace /*anon*/ { } class Session : public std::enable_shared_from_this { - net::thread_pool ioc_{1}; - websocket::stream ws_{make_strand(ioc_.get_executor())}; + websocket::stream ws_{make_strand(g_thread_pool)}; tcp::resolver resolver_{ws_.get_executor()}; beast::flat_buffer buffer_; - std::wstring host_, path_; // path part in url. For example: /v2/ws + std::wstring host_, port_, path_; // path part in url. For example: /v2/ws /// Print error related information in stderr /// \param ec instance that contains error related information @@ -131,150 +160,170 @@ namespace /*anon*/ { VERBOSE(msg); } else { - if (s_on_fail_cb) - s_on_fail_cb(msg.c_str()); + if (on_fail_cb) + on_fail_cb(handle_, msg.c_str()); CERR(msg); } } - public: - Session() = default; +#include + struct connect_op { + using EC = beast::error_code; + using EP = tcp::endpoint; + using EPS = tcp::resolver::results_type; + template void operator()(Self& self, EC ec, EPS r) { return call(self, ec, r); } + template void operator()(Self& self, EC ec, EP) { return call(self, ec); } + template void operator()(Self& self, EC ec) { return call(self, ec); } + template void operator()(Self& self) { return call(self); } + + SessionPtr s; + net::coroutine coro; + + private: + template + void call(Self& self, beast::error_code ec = {}, tcp::resolver::results_type results = {}) + { + auto& ws_ = s->ws_; + reenter(coro) + { + yield s->resolver_.async_resolve(utf8_encode(s->host_), utf8_encode(s->port_), + std::move(self)); + if (ec) goto complete; + + beast::get_lowest_layer(ws_).expires_after(30s); + yield beast::get_lowest_layer(ws_).async_connect(results, std::move(self)); + if (ec) goto complete; + + beast::get_lowest_layer(ws_).expires_never(); + ws_.set_option(websocket::stream_base::timeout::suggested(beast::role_type::client)); + ws_.set_option(websocket::stream_base::decorator([](websocket::request_type& req) { + req.set(http::field::user_agent, std::string(BOOST_BEAST_VERSION_STRING) + " WsDll"); + })); + + // Host HTTP header includes the port. See https://tools.ietf.org/html/rfc7230#section-5.4 + yield ws_.async_handshake(utf8_encode(s->host_) + ":" + utf8_encode(s->port_), + utf8_encode(s->path_), std::move(self)); + +complete: + s.reset(); // deallocate before completion + return self.complete(ec); + } + } + }; +#include - /// Send message to remote websocket server - /// \param data to be sent - void send_message(std::wstring const& data) - { - post(ws_.get_executor(), - std::bind(&Session::do_send_message, shared_from_this(), utf8_encode(data))); + template + auto async_connect(Token&& token) { + return net::async_compose(connect_op{shared_from_this(), {}}, + token); } - /// Close the connect between websocket client and server. It call - /// async_close to call a callback function which also calls user - /// registered callback function to deal with close event. - void disconnect() + public: + Handle handle_ = 0; // TODO make friend/private? + on_fail_t on_fail_cb = nullptr; + on_disconnect_t on_disconnect_cb = nullptr; + on_data_t on_data_cb = nullptr; + + Session() = default; + ~Session() { - post(ws_.get_executor(), std::bind(&Session::do_disconnect, shared_from_this())); + try { + if (on_disconnect_cb) + std::exchange(on_disconnect_cb, nullptr)(handle_); + } + catch (std::exception const& e) { + // swallow + } } - /// Start the asynchronous operation - /// \param host host to be connected - /// \param port tcp port to be connected - void run(std::wstring host, std::wstring port, std::wstring path) + /// Start asynchronous operation + /// + /// Only returns when connection (attempt) completed + int run(std::wstring host, std::wstring port, std::wstring path) { // Save these for later host_ = std::move(host); + port_ = std::move(port); path_ = std::move(path); - VERBOSE(L"Run host_: " << host_ << L", port: " << port << L", path_: " << path_); - - // Look up the domain name - resolver_.async_resolve(utf8_encode(host_), utf8_encode(port), - beast::bind_front_handler(&Session::on_resolve, shared_from_this())); - } - - private: // all private (do_*/on_*) assumed on strand - std::deque _outbox; // NOTE: reference stability of elements - - void do_send_message(std::string data) - { - VERBOSE(L"Queueing message: " << quoted(utf8_decode(data))); - _outbox.push_back(std::move(data)); // extend lifetime to completion of async write - - if (_outbox.size()==1) // need to start write chain? - do_write_loop(); - } - - void do_disconnect() - { - VERBOSE(L"Disconnecting"); - ws_.async_close(websocket::close_code::normal, - beast::bind_front_handler(&Session::on_close, shared_from_this())); - } - - /// Callback function registered by async_resolve method. It is - /// called after resolve operation is done. It will call - /// async_connect to issue async connecting operation with - /// callback function - /// \param ec - /// \param results - void on_resolve(beast::error_code ec, tcp::resolver::results_type const& results) - { - VERBOSE(L"In on_resolve"); - if (ec) - return fail(ec, L"resolve"); - - // Set the timeout for the operation - beast::get_lowest_layer(ws_).expires_after(30s); - - // Make the connection on the IP address we get from a lookup - beast::get_lowest_layer(ws_).async_connect( - results, beast::bind_front_handler(&Session::on_connect, shared_from_this())); - } - - void on_connect(beast::error_code ec, tcp::resolver::results_type::endpoint_type ep) - { - VERBOSE(L"In on_connect"); - if (ec) - return fail(ec, L"connect"); - - // Turn off the timeout on the tcp_stream, because - // the websocket stream has its own timeout system. - beast::get_lowest_layer(ws_).expires_never(); + VERBOSE(L"Run host_: " << host_ << L", port: " << port_ << L", path_: " << path_); + try { + assert(shared_from_this()); + async_connect(net::use_future).get(); - // Set suggested timeout settings for the websocket - ws_.set_option(websocket::stream_base::timeout::suggested(beast::role_type::client)); - - // Set a decorator to change the User-Agent of the handshake - ws_.set_option(websocket::stream_base::decorator([](websocket::request_type& req) { - req.set(http::field::user_agent, - std::string(BOOST_BEAST_VERSION_STRING) + " WsDll"); - })); - - // Perform the websocket handshake - - // Host HTTP header includes the port. See https://tools.ietf.org/html/rfc7230#section-5.4 - ws_.async_handshake(utf8_encode(host_) + ":" + std::to_string(ep.port()), utf8_encode(path_), - beast::bind_front_handler(&Session::on_handshake, shared_from_this())); + VERBOSE(L"Issue async_read after connect"); + ws_.async_read(buffer_, beast::bind_front_handler(&Session::on_read, shared_from_this())); + return 1; + } + catch (beast::system_error const& se) { + fail(se.code(), L"Connection operation"); + return 0; + } } - void on_handshake(beast::error_code ec) + /// Send message to remote websocket server + /// \param data to be sent + bool send_message(std::wstring const& data) { - VERBOSE(L"In on_handshake"); - if (ec) - return fail(ec, L"handshake"); - - if (s_on_connect_cb) - s_on_connect_cb(); - - // Send the message - VERBOSE(L"Issue async_read in on_handshake"); - ws_.async_read(buffer_, beast::bind_front_handler(&Session::on_read, shared_from_this())); + std::promise p; + auto fut = p.get_future(); + + try { + VERBOSE(L"Writing message: " << data); + post(ws_.get_executor(), [this, &p, udata = utf8_encode(data)]() mutable { + ws_.async_write( // + net::buffer(udata), [&p](beast::error_code ec, size_t) mutable { + if (!ec) + p.set_value(); + else { + p.set_exception(std::make_exception_ptr(beast::system_error(ec))); + } + }); + }); + fut.get(); + return true; + } + catch (beast::system_error const& se) { + fail(se.code(), L"send_message"); + return false; + } } - void do_write_loop() + /// Close the connect between websocket client and server. It call + /// async_close to call a callback function which also calls user + /// registered callback function to deal with close event. + bool disconnect() { - if (_outbox.empty()) { - VERBOSE(L"Output queue empty"); - return; + std::promise p; + auto fut = p.get_future(); + + try { + post(ws_.get_executor(), [this, &p]() mutable { + VERBOSE(L"Disconnecting"); + ws_.async_close( // + websocket::close_code::normal, [&p](beast::error_code ec) mutable { + if (!ec) + p.set_value(); + else { + p.set_exception(std::make_exception_ptr(beast::system_error(ec))); + } + }); + }); + fut.get(); + + if (on_disconnect_cb) + std::exchange(on_disconnect_cb, nullptr)(handle_); + + return g_session_manager.Forget(handle_); + } + catch (beast::system_error const& se) { + fail(se.code(), L"disconnect"); + return false; } - - VERBOSE(L"Writing message: " << quoted(utf8_decode(_outbox.front()))); - ws_.async_write(net::buffer(_outbox.front()), - beast::bind_front_handler(&Session::on_write, shared_from_this())); } - void on_write(beast::error_code ec, std::size_t bytes_transferred) - { - VERBOSE(L"In on_write"); - boost::ignore_unused(bytes_transferred); - - if (ec) - return fail(ec, L"write"); - - _outbox.pop_front(); - do_write_loop(); // drain _outbox - } + private: // all private (do_*/on_*) assumed on strand void on_read(beast::error_code ec, std::size_t bytes_transferred) { VERBOSE(L"In on_read"); @@ -286,50 +335,37 @@ namespace /*anon*/ { const std::wstring wdata = utf8_decode(beast::buffers_to_string(buffer_.data())); VERBOSE(L"Received[" << bytes_transferred << L"] " << std::quoted(wdata)); - if (s_on_data_cb) - s_on_data_cb(wdata.c_str(), wdata.length()); + if (on_data_cb) + on_data_cb(handle_, wdata.c_str(), wdata.length()); buffer_.consume(bytes_transferred); // some forms of async_read can read extra data VERBOSE(L"Issue new async_read in on_read"); ws_.async_read(buffer_, beast::bind_front_handler(&Session::on_read, shared_from_this())); } - - /// Only called when client proactively closes connection by calling - /// websocket_disconnect. - /// \param ec instance of error code - void on_close(beast::error_code ec) - { - VERBOSE(L"In on_close"); - if (ec) - fail(ec, L"close"); - - if (s_on_disconnect_cb) - s_on_disconnect_cb(); - - get_lowest_layer(ws_).cancel(); // cause all async operations to abort - - if (!Manager::Clear(shared_from_this())) { - // CERR(L"Could not remove active session"); // redundant message when Sessions::Install fails - } - } }; } WSDLLAPI void enable_verbose(intptr_t enabled) { COUT(L"Verbose output " << (enabled ? L"enabled" : L"disabled")); - s_enable_verbose = enabled; + g_enable_verbose = enabled; } -WSDLLAPI size_t websocket_connect(wchar_t const* szServer) + +WSDLLAPI websocket_handle_t websocket_connect(wchar_t const* szServer, size_t dwOnFail, size_t dwOnDisconnect, size_t dwOnData) { - auto new_session = Manager::Install(std::make_shared()); - if (!new_session) { - COUT(L"A session is already active."); + auto session = std::make_shared(); + session->handle_ = g_session_manager.Register(session); + + session->on_fail_cb = reinterpret_cast(dwOnFail); + session->on_disconnect_cb = reinterpret_cast(dwOnDisconnect); + session->on_data_cb = reinterpret_cast(dwOnData); + + if (!g_session_manager.Active(session->handle_)) { + COUT(L"Session rejected"); // shouldn't happen currently return 0; } - assert(new_session == Manager::Active()); VERBOSE(L"Connecting to the server: " << szServer); @@ -345,64 +381,28 @@ WSDLLAPI size_t websocket_connect(wchar_t const* szServer) if (path.empty()) path = L"/"; - new_session->run(matches[1], matches[2], std::move(path)); - - return 1; + return session->run(matches[1], matches[2], std::move(path)) ? session->handle_ : Handle{}; } -WSDLLAPI size_t websocket_disconnect() +WSDLLAPI size_t websocket_disconnect(websocket_handle_t h) { - if (SessionPtr sess = Manager::Active()) { - sess->disconnect(); - return 1; + if (SessionPtr sess = g_session_manager.Active(h)) { + return sess->disconnect(); } CERR(L"Session not active. Can't disconnect."); return 0; } -WSDLLAPI size_t websocket_send(wchar_t const* szMessage, size_t dwLen, bool /*TODO: isBinary*/) +WSDLLAPI size_t websocket_send(websocket_handle_t h, wchar_t const* szMessage, size_t dwLen) { - if (SessionPtr sess = Manager::Active()) { - sess->send_message(std::wstring(szMessage, dwLen)); - return 1; + if (SessionPtr sess = g_session_manager.Active(h)) { + return sess->send_message(std::wstring(szMessage, dwLen)); } CERR(L"Session not active. Can't send data."); return 0; } -WSDLLAPI size_t websocket_isconnected() -{ - return Manager::Active() != nullptr; -} - -WSDLLAPI size_t websocket_register_on_connect_cb(size_t dwAddress) +WSDLLAPI size_t websocket_isconnected(websocket_handle_t h) { - VERBOSE(L"Registering on_connect callback"); - s_on_connect_cb = reinterpret_cast(dwAddress); - - return 1; -} - -WSDLLAPI size_t websocket_register_on_fail_cb(size_t dwAddress) -{ - VERBOSE(L"Registering on_fail callback"); - s_on_fail_cb = reinterpret_cast(dwAddress); - - return 1; -} - -WSDLLAPI size_t websocket_register_on_disconnect_cb(size_t dwAddress) -{ - VERBOSE(L"Registering on_disconnect callback"); - s_on_disconnect_cb = reinterpret_cast(dwAddress); - - return 1; -} - -WSDLLAPI size_t websocket_register_on_data_cb(size_t dwAddress) -{ - VERBOSE(L"Registering on_data callback"); - s_on_data_cb = reinterpret_cast(dwAddress); - - return 1; + return g_session_manager.Active(h) != nullptr; } diff --git a/library.h b/library.h index b4da75d..3fe1c43 100644 --- a/library.h +++ b/library.h @@ -24,21 +24,18 @@ #include extern "C" { - typedef void (*on_connect_t)(); - typedef void (*on_fail_t)(wchar_t const* from); - typedef void (*on_disconnect_t)(); - typedef void (*on_data_t)(wchar_t const*, size_t); + typedef intptr_t websocket_handle_t; + typedef void (*on_fail_t)(websocket_handle_t, wchar_t const* from); + typedef void (*on_disconnect_t)(websocket_handle_t); + typedef void (*on_data_t)(websocket_handle_t, wchar_t const*, size_t); - WSDLLAPI void enable_verbose(intptr_t enabled); - WSDLLAPI size_t websocket_connect(wchar_t const* szServer); - WSDLLAPI size_t websocket_disconnect(); - WSDLLAPI size_t websocket_send(wchar_t const* szMessage, size_t dwLen, bool isBinary); - WSDLLAPI size_t websocket_isconnected(); + WSDLLAPI websocket_handle_t websocket_connect(wchar_t const* szServer, size_t dwOnFail, + size_t dwOnDisconnect, size_t dwOnData); - WSDLLAPI size_t websocket_register_on_connect_cb(size_t dwAddress); - WSDLLAPI size_t websocket_register_on_fail_cb(size_t dwAddress); - WSDLLAPI size_t websocket_register_on_disconnect_cb(size_t dwAddress); - WSDLLAPI size_t websocket_register_on_data_cb(size_t dwAddress); + WSDLLAPI void enable_verbose(intptr_t enabled); + WSDLLAPI size_t websocket_disconnect(websocket_handle_t); + WSDLLAPI size_t websocket_send(websocket_handle_t, wchar_t const* szMessage, size_t dwLen); + WSDLLAPI size_t websocket_isconnected(websocket_handle_t); } #endif // WebSocketAsio_LIBRARY_H diff --git a/test/console.cpp b/test/console.cpp index 29a940e..14313fc 100644 --- a/test/console.cpp +++ b/test/console.cpp @@ -13,13 +13,6 @@ using namespace std::chrono_literals; namespace { // diagnostics tracing helpers using std::this_thread::sleep_for; - static auto timestamp() { - auto now = std::chrono::high_resolution_clock::now; - static auto start = now(); - auto t = now(); - - return (t - start) / 1.ms; - } static std::atomic_int tid_gen{0}; thread_local int tid = tid_gen++; @@ -28,7 +21,7 @@ namespace { // diagnostics tracing helpers template void trace(Args const&... args) { std::lock_guard lk(console_mx); - std::wcout << L"\tThread:" << std::setw(2) << tid << std::right << std::setw(10) << timestamp() << L"ms "; + std::wcout << L"\tThread:" << std::setw(2) << tid << L" "; (std::wcout << ... << args) << std::endl; } @@ -51,52 +44,41 @@ int main() { std::wcout << std::fixed << std::setprecision(2); - struct { - on_fail_t on_fail; // allow type safe assignments - on_connect_t on_connect; - on_disconnect_t on_disconnect; - on_data_t on_data; - - void do_register() { - ::websocket_register_on_fail_cb(unsafe_cb(on_fail)); - ::websocket_register_on_data_cb(unsafe_cb(on_data)); - ::websocket_register_on_connect_cb(unsafe_cb(on_connect)); - ::websocket_register_on_disconnect_cb(unsafe_cb(on_disconnect)); - }; - - } callbacks{ - [](wchar_t const* wsz) { trace(L"ON_FAIL: ", std::quoted(wsz)); } , - []() { trace(L"ON_CONNECT"); } , - []() { trace(L"ON_DISCONNECT"); } , - [](wchar_t const* wsz, size_t n) { - trace(L"ON_DATA: ", std::quoted(std::wstring_view(wsz).substr(0, n))); - }, - }; + on_fail_t on_fail = [](websocket_handle_t h, wchar_t const* wsz) { trace(L"ON_FAIL handle#", h, ": ", std::quoted(wsz)); }; + on_disconnect_t on_disconnect = [](websocket_handle_t h) { trace(L"ON_DISCONNECT handle#", h); }; + on_data_t on_data = [](websocket_handle_t h, wchar_t const* wsz, size_t n) { trace(L"ON_DATA handle#", h, ": ", std::quoted(std::wstring_view(wsz).substr(0, n))); }; ::enable_verbose(1); - callbacks.do_register(); - for (auto delay : {0ms, 200ms}) { - std::wcout << L"\n=========================================== Start (with " << delay / 1.s - << L"s delay) ======\n" << std::endl; + auto url = L"ws://localhost:8080/something"; + + { + std::wcout << L"\n======================= First ==============\n" << std::endl; - TRACED(::websocket_isconnected()); + websocket_handle_t h {}; + TRACED(::websocket_isconnected(h)); + h = TRACED(::websocket_connect(url, unsafe_cb(on_fail), unsafe_cb(on_disconnect), unsafe_cb(on_data))); + TRACED(::websocket_isconnected(h)); - TRACED(::websocket_connect(L"ws://localhost:8080/something")); - TRACED(::websocket_isconnected()); + TRACED(::websocket_send(h, L"First message\n", 14)); - if (delay > 0s) - sleep_for(delay); + TRACED(::websocket_disconnect(h)); + TRACED(::websocket_isconnected(h)); + } + + { + std::wcout << L"\n======================= Second ==============\n" << std::endl; - TRACED(::websocket_send(L"First message\n", 14, false)); - if (delay > 0s) - sleep_for(delay); + websocket_handle_t h {}; + TRACED(::websocket_isconnected(h)); + h = TRACED(::websocket_connect(url, unsafe_cb(on_fail), unsafe_cb(on_disconnect), unsafe_cb(on_data))); + TRACED(::websocket_isconnected(h)); - TRACED(::websocket_disconnect()); - TRACED(::websocket_isconnected()); - sleep_for(100ms); - TRACED(::websocket_isconnected()); + sleep_for(2s); + TRACED(::websocket_send(h, L"Second message\n", 15)); - sleep_for(1s); + sleep_for(2s); + TRACED(::websocket_disconnect(h)); + TRACED(::websocket_isconnected(h)); } }