diff --git a/lib/inc/drogon/WebSocketController.h b/lib/inc/drogon/WebSocketController.h index 70fb77f5b4..003d800a5b 100644 --- a/lib/inc/drogon/WebSocketController.h +++ b/lib/inc/drogon/WebSocketController.h @@ -23,6 +23,9 @@ #include #include #include +#ifdef __cpp_impl_coroutine +#include +#endif #define WS_PATH_LIST_BEGIN \ static void initPathRouting() \ @@ -31,6 +34,11 @@ #define WS_ADD_PATH_VIA_REGEX(regExp, ...) \ registerSelfRegex__(regExp, {__VA_ARGS__}) #define WS_PATH_LIST_END } +#define WS_CORO_PATH_LIST_BEGIN WS_PATH_LIST_BEGIN +#define WS_CORO_PATH_ADD(path, ...) WS_PATH_ADD(path, __VA_ARGS__) +#define WS_CORO_ADD_PATH_VIA_REGEX(regExp, ...) \ + WS_ADD_PATH_VIA_REGEX(regExp, __VA_ARGS__) +#define WS_CORO_PATH_LIST_END WS_PATH_LIST_END namespace drogon { @@ -135,4 +143,129 @@ template typename WebSocketController::pathRegistrator WebSocketController::registrator_; +#ifdef __cpp_impl_coroutine +/** + * @brief The abstract base class for coroutine WebSocket controllers. + */ +class WebSocketCoroControllerBase : public WebSocketControllerBase +{ + public: + ~WebSocketCoroControllerBase() override = default; + + virtual Task<> handleNewMessageCoro(const WebSocketConnectionPtr &, + std::string &&, + const WebSocketMessageType &) = 0; + + virtual Task<> handleNewConnectionCoro(const HttpRequestPtr &, + const WebSocketConnectionPtr &) = 0; + + virtual Task<> handleConnectionClosedCoro( + const WebSocketConnectionPtr &) = 0; +}; + +/** + * @brief Reflection base class template for coroutine WebSocket + * controllers. + */ +template +class WebSocketCoroController : public DrObject, + public WebSocketCoroControllerBase +{ + public: + static const bool isAutoCreation = AutoCreation; + + virtual ~WebSocketCoroController() + { + } + + void handleNewMessage(const WebSocketConnectionPtr &conn, + std::string &&message, + const WebSocketMessageType &type) final + { + auto objPtr = DrClassMap::getSingleInstance(); + drogon::async_run([objPtr, + conn, + message = std::move(message), + type]() mutable -> Task<> { + co_await objPtr->handleNewMessageCoro(conn, + std::move(message), + type); + }); + } + + void handleNewConnection(const HttpRequestPtr &req, + const WebSocketConnectionPtr &conn) final + { + auto objPtr = DrClassMap::getSingleInstance(); + drogon::async_run([objPtr, req, conn]() mutable -> Task<> { + co_await objPtr->handleNewConnectionCoro(req, conn); + }); + } + + void handleConnectionClosed(const WebSocketConnectionPtr &conn) final + { + auto objPtr = DrClassMap::getSingleInstance(); + drogon::async_run([objPtr, conn]() mutable -> Task<> { + co_await objPtr->handleConnectionClosedCoro(conn); + }); + } + + protected: + WebSocketCoroController() + { + } + + static void registerSelf__( + const std::string &path, + const std::vector &constraints) + { + LOG_TRACE << "register websocket coro controller(" + << WebSocketCoroController::classTypeName() + << ") on path:" << path; + app().registerWebSocketController( + path, + WebSocketCoroController::classTypeName(), + constraints); + } + + static void registerSelfRegex__( + const std::string ®Exp, + const std::vector &constraints) + { + LOG_TRACE << "register websocket coro controller(" + << WebSocketCoroController::classTypeName() + << ") on regExp:" << regExp; + app().registerWebSocketControllerRegex( + regExp, + WebSocketCoroController::classTypeName(), + constraints); + } + + private: + class pathRegistrator + { + public: + pathRegistrator() + { + if (AutoCreation) + { + T::initPathRouting(); + } + } + }; + + friend pathRegistrator; + static pathRegistrator registrator_; + + virtual void *touch() + { + return ®istrator_; + } +}; + +template +typename WebSocketCoroController::pathRegistrator + WebSocketCoroController::registrator_; +#endif + } // namespace drogon diff --git a/lib/src/HttpServer.cc b/lib/src/HttpServer.cc index f6e4bff1be..1daef46809 100644 --- a/lib/src/HttpServer.cc +++ b/lib/src/HttpServer.cc @@ -758,16 +758,20 @@ void HttpServer::websocketRequestHandling( std::function &&callback, WebSocketConnectionImplPtr &&wsConnPtr) { - binderPtr->handleRequest( - req, - [req, callback = std::move(callback)](const HttpResponsePtr &resp) { - AopAdvice::instance().passPostHandlingAdvices(req, resp); + auto request = req; + auto binder = std::move(binderPtr); + auto wsConn = std::move(wsConnPtr); + + binder->handleRequest( + request, + [request, callback = std::move(callback)](const HttpResponsePtr &resp) { + AopAdvice::instance().passPostHandlingAdvices(request, resp); callback(resp); }); // TODO: more elegant? - static_cast(binderPtr.get()) - ->handleNewConnection(req, wsConnPtr); + static_cast(binder.get()) + ->handleNewConnection(request, wsConn); } void HttpServer::handleResponse( diff --git a/lib/tests/CMakeLists.txt b/lib/tests/CMakeLists.txt index c8249b41ab..f8af556a88 100644 --- a/lib/tests/CMakeLists.txt +++ b/lib/tests/CMakeLists.txt @@ -85,7 +85,11 @@ if (BUILD_CTL) set(INTEGRATION_TEST_SERVER_SOURCES ${INTEGRATION_TEST_SERVER_SOURCES} integration_test/server/CoroFilter.cpp - integration_test/server/api_v1_CoroTest.cc) + integration_test/server/api_v1_CoroTest.cc + integration_test/server/WebSocketCoroTest.cc) + set(INTEGRATION_TEST_CLIENT_SOURCES + ${INTEGRATION_TEST_CLIENT_SOURCES} + integration_test/client/WebSocketCoroTest.cc) set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED TRUE) endif(DROGON_CXX_STANDARD GREATER_EQUAL 20 AND HAS_COROUTINE) diff --git a/lib/tests/integration_test/client/WebSocketCoroTest.cc b/lib/tests/integration_test/client/WebSocketCoroTest.cc new file mode 100644 index 0000000000..f0ee9b4237 --- /dev/null +++ b/lib/tests/integration_test/client/WebSocketCoroTest.cc @@ -0,0 +1,93 @@ +#include +#include + +#include +#include +#include + +using namespace drogon; +using namespace std::chrono_literals; + +DROGON_TEST(WebSocketCoroControllerTest) +{ +#if defined(__cpp_impl_coroutine) + auto promise = std::make_shared>(); + auto future = promise->get_future(); + + auto first = WebSocketClient::newWebSocketClient("127.0.0.1", 8848); + auto second = WebSocketClient::newWebSocketClient("127.0.0.1", 8848); + + first->setMessageHandler( + [first, second, promise, TEST_CTX](const std::string &message, + const WebSocketClientPtr &, + const WebSocketMessageType &type) { + CHECK(type == WebSocketMessageType::Text); + if (message == "opened") + { + first->getConnection()->send("hello-coro"); + } + else if (message == "coro:hello-coro") + { + first->getConnection()->shutdown(); + } + }); + + first->setConnectionClosedHandler([second, promise, TEST_CTX]( + const WebSocketClientPtr &) { + auto req2 = HttpRequest::newHttpRequest(); + req2->setPath("/coro-chat"); + + second->setMessageHandler( + [second, promise, TEST_CTX](const std::string &message, + const WebSocketClientPtr &, + const WebSocketMessageType &type) { + CHECK(type == WebSocketMessageType::Text); + if (message == "opened") + { + second->getConnection()->send("stats"); + return; + } + + if (message.rfind("stats:", 0) == 0) + { + int opened = 0; + int msg = 0; + int closed = 0; + auto parsed = sscanf(message.c_str(), + "stats:%d:%d:%d", + &opened, + &msg, + &closed); + REQUIRE(parsed == 3); + CHECK(opened >= 2); + CHECK(msg >= 2); + CHECK(closed >= 1); + second->stop(); + promise->set_value(); + } + }); + + second->connectToServer(req2, + [second, TEST_CTX](ReqResult r, + const HttpResponsePtr &resp, + const WebSocketClientPtr &) { + REQUIRE(r == ReqResult::Ok); + REQUIRE(resp != nullptr); + CHECK(second->getConnection()->connected()); + }); + }); + + auto req = HttpRequest::newHttpRequest(); + req->setPath("/coro-chat"); + first->connectToServer(req, + [first, TEST_CTX](ReqResult r, + const HttpResponsePtr &resp, + const WebSocketClientPtr &) { + REQUIRE(r == ReqResult::Ok); + REQUIRE(resp != nullptr); + CHECK(first->getConnection()->connected()); + }); + + REQUIRE(future.wait_for(5s) == std::future_status::ready); +#endif +} diff --git a/lib/tests/integration_test/server/WebSocketCoroTest.cc b/lib/tests/integration_test/server/WebSocketCoroTest.cc new file mode 100644 index 0000000000..48b9dcf835 --- /dev/null +++ b/lib/tests/integration_test/server/WebSocketCoroTest.cc @@ -0,0 +1,53 @@ +#include "WebSocketCoroTest.h" + +#include + +using namespace example; + +std::atomic WebSocketCoroTest::openedCount_{0}; +std::atomic WebSocketCoroTest::messageCount_{0}; +std::atomic WebSocketCoroTest::closedCount_{0}; + +drogon::Task<> WebSocketCoroTest::handleNewConnectionCoro( + const drogon::HttpRequestPtr &, + const drogon::WebSocketConnectionPtr &conn) +{ + ++openedCount_; + co_await drogon::sleepCoro(drogon::app().getLoop(), 0.001); + conn->send("opened"); + co_return; +} + +drogon::Task<> WebSocketCoroTest::handleNewMessageCoro( + const drogon::WebSocketConnectionPtr &conn, + std::string &&message, + const drogon::WebSocketMessageType &type) +{ + if (type != drogon::WebSocketMessageType::Text) + { + co_return; + } + + ++messageCount_; + co_await drogon::sleepCoro(drogon::app().getLoop(), 0.001); + + if (message == "stats") + { + conn->send("stats:" + std::to_string(openedCount_.load()) + ":" + + std::to_string(messageCount_.load()) + ":" + + std::to_string(closedCount_.load())); + } + else + { + conn->send("coro:" + message); + } + co_return; +} + +drogon::Task<> WebSocketCoroTest::handleConnectionClosedCoro( + const drogon::WebSocketConnectionPtr &) +{ + ++closedCount_; + co_await drogon::sleepCoro(drogon::app().getLoop(), 0.001); + co_return; +} diff --git a/lib/tests/integration_test/server/WebSocketCoroTest.h b/lib/tests/integration_test/server/WebSocketCoroTest.h new file mode 100644 index 0000000000..5b376e3d41 --- /dev/null +++ b/lib/tests/integration_test/server/WebSocketCoroTest.h @@ -0,0 +1,31 @@ +#pragma once + +#include +#include + +namespace example +{ +class WebSocketCoroTest + : public drogon::WebSocketCoroController +{ + public: + drogon::Task<> handleNewMessageCoro( + const drogon::WebSocketConnectionPtr &, + std::string &&, + const drogon::WebSocketMessageType &) override; + drogon::Task<> handleConnectionClosedCoro( + const drogon::WebSocketConnectionPtr &) override; + drogon::Task<> handleNewConnectionCoro( + const drogon::HttpRequestPtr &, + const drogon::WebSocketConnectionPtr &) override; + + WS_CORO_PATH_LIST_BEGIN + WS_CORO_PATH_ADD("/coro-chat", drogon::Get); + WS_CORO_PATH_LIST_END + + private: + static std::atomic openedCount_; + static std::atomic messageCount_; + static std::atomic closedCount_; +}; +} // namespace example