Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions lib/inc/drogon/WebSocketController.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
#include <memory>
#include <string>
#include <vector>
#ifdef __cpp_impl_coroutine
#include <drogon/utils/coroutine.h>
#endif

#define WS_PATH_LIST_BEGIN \
static void initPathRouting() \
Expand All @@ -31,6 +34,11 @@
#define WS_ADD_PATH_VIA_REGEX(regExp, ...) \
registerSelfRegex__(regExp, {__VA_ARGS__})
#define WS_PATH_LIST_END }
#define WS_CORO_PATH_LIST_BEGIN WS_PATH_LIST_BEGIN
#define WS_CORO_PATH_ADD(path, ...) WS_PATH_ADD(path, __VA_ARGS__)
#define WS_CORO_ADD_PATH_VIA_REGEX(regExp, ...) \
WS_ADD_PATH_VIA_REGEX(regExp, __VA_ARGS__)
#define WS_CORO_PATH_LIST_END WS_PATH_LIST_END

namespace drogon
{
Expand Down Expand Up @@ -135,4 +143,129 @@ template <typename T, bool AutoCreation>
typename WebSocketController<T, AutoCreation>::pathRegistrator
WebSocketController<T, AutoCreation>::registrator_;

#ifdef __cpp_impl_coroutine
/**
* @brief The abstract base class for coroutine WebSocket controllers.
*/
class WebSocketCoroControllerBase : public WebSocketControllerBase
{
public:
~WebSocketCoroControllerBase() override = default;

virtual Task<> handleNewMessageCoro(const WebSocketConnectionPtr &,
std::string &&,
const WebSocketMessageType &) = 0;

virtual Task<> handleNewConnectionCoro(const HttpRequestPtr &,
const WebSocketConnectionPtr &) = 0;

virtual Task<> handleConnectionClosedCoro(
const WebSocketConnectionPtr &) = 0;
};

/**
* @brief Reflection base class template for coroutine WebSocket
* controllers.
*/
template <typename T, bool AutoCreation = true>
class WebSocketCoroController : public DrObject<T>,
public WebSocketCoroControllerBase
{
public:
static const bool isAutoCreation = AutoCreation;

virtual ~WebSocketCoroController()
{
}

void handleNewMessage(const WebSocketConnectionPtr &conn,
std::string &&message,
const WebSocketMessageType &type) final
{
auto objPtr = DrClassMap::getSingleInstance<T>();
drogon::async_run([objPtr,
conn,
message = std::move(message),
type]() mutable -> Task<> {
co_await objPtr->handleNewMessageCoro(conn,
std::move(message),
type);
});
}

void handleNewConnection(const HttpRequestPtr &req,
const WebSocketConnectionPtr &conn) final
{
auto objPtr = DrClassMap::getSingleInstance<T>();
drogon::async_run([objPtr, req, conn]() mutable -> Task<> {
co_await objPtr->handleNewConnectionCoro(req, conn);
});
}

void handleConnectionClosed(const WebSocketConnectionPtr &conn) final
{
auto objPtr = DrClassMap::getSingleInstance<T>();
drogon::async_run([objPtr, conn]() mutable -> Task<> {
co_await objPtr->handleConnectionClosedCoro(conn);
});
}

protected:
WebSocketCoroController()
{
}

static void registerSelf__(
const std::string &path,
const std::vector<internal::HttpConstraint> &constraints)
{
LOG_TRACE << "register websocket coro controller("
<< WebSocketCoroController<T, AutoCreation>::classTypeName()
<< ") on path:" << path;
app().registerWebSocketController(
path,
WebSocketCoroController<T, AutoCreation>::classTypeName(),
constraints);
}

static void registerSelfRegex__(
const std::string &regExp,
const std::vector<internal::HttpConstraint> &constraints)
{
LOG_TRACE << "register websocket coro controller("
<< WebSocketCoroController<T, AutoCreation>::classTypeName()
<< ") on regExp:" << regExp;
app().registerWebSocketControllerRegex(
regExp,
WebSocketCoroController<T, AutoCreation>::classTypeName(),
constraints);
}

private:
class pathRegistrator
{
public:
pathRegistrator()
{
if (AutoCreation)
{
T::initPathRouting();
}
}
};

friend pathRegistrator;
static pathRegistrator registrator_;

virtual void *touch()
{
return &registrator_;
}
};

template <typename T, bool AutoCreation>
typename WebSocketCoroController<T, AutoCreation>::pathRegistrator
WebSocketCoroController<T, AutoCreation>::registrator_;
#endif

} // namespace drogon
16 changes: 10 additions & 6 deletions lib/src/HttpServer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -758,16 +758,20 @@ void HttpServer::websocketRequestHandling(
std::function<void(const HttpResponsePtr &)> &&callback,
WebSocketConnectionImplPtr &&wsConnPtr)
{
binderPtr->handleRequest(
req,
[req, callback = std::move(callback)](const HttpResponsePtr &resp) {
AopAdvice::instance().passPostHandlingAdvices(req, resp);
auto request = req;
auto binder = std::move(binderPtr);
auto wsConn = std::move(wsConnPtr);

binder->handleRequest(
request,
[request, callback = std::move(callback)](const HttpResponsePtr &resp) {
AopAdvice::instance().passPostHandlingAdvices(request, resp);
callback(resp);
});

// TODO: more elegant?
static_cast<WebsocketControllerBinder *>(binderPtr.get())
->handleNewConnection(req, wsConnPtr);
static_cast<WebsocketControllerBinder *>(binder.get())
->handleNewConnection(request, wsConn);
}

void HttpServer::handleResponse(
Expand Down
6 changes: 5 additions & 1 deletion lib/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@ if (BUILD_CTL)
set(INTEGRATION_TEST_SERVER_SOURCES
${INTEGRATION_TEST_SERVER_SOURCES}
integration_test/server/CoroFilter.cpp
integration_test/server/api_v1_CoroTest.cc)
integration_test/server/api_v1_CoroTest.cc
integration_test/server/WebSocketCoroTest.cc)
set(INTEGRATION_TEST_CLIENT_SOURCES
${INTEGRATION_TEST_CLIENT_SOURCES}
integration_test/client/WebSocketCoroTest.cc)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED TRUE)
endif(DROGON_CXX_STANDARD GREATER_EQUAL 20 AND HAS_COROUTINE)
Expand Down
93 changes: 93 additions & 0 deletions lib/tests/integration_test/client/WebSocketCoroTest.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#include <drogon/WebSocketClient.h>
#include <drogon/drogon_test.h>

#include <chrono>
#include <cstdio>
#include <future>

using namespace drogon;
using namespace std::chrono_literals;

DROGON_TEST(WebSocketCoroControllerTest)
{
#if defined(__cpp_impl_coroutine)
auto promise = std::make_shared<std::promise<void>>();
auto future = promise->get_future();

auto first = WebSocketClient::newWebSocketClient("127.0.0.1", 8848);
auto second = WebSocketClient::newWebSocketClient("127.0.0.1", 8848);

first->setMessageHandler(
[first, second, promise, TEST_CTX](const std::string &message,
const WebSocketClientPtr &,
const WebSocketMessageType &type) {
CHECK(type == WebSocketMessageType::Text);
if (message == "opened")
{
first->getConnection()->send("hello-coro");
}
else if (message == "coro:hello-coro")
{
first->getConnection()->shutdown();
}
});

first->setConnectionClosedHandler([second, promise, TEST_CTX](
const WebSocketClientPtr &) {
auto req2 = HttpRequest::newHttpRequest();
req2->setPath("/coro-chat");

second->setMessageHandler(
[second, promise, TEST_CTX](const std::string &message,
const WebSocketClientPtr &,
const WebSocketMessageType &type) {
CHECK(type == WebSocketMessageType::Text);
if (message == "opened")
{
second->getConnection()->send("stats");
return;
}

if (message.rfind("stats:", 0) == 0)
{
int opened = 0;
int msg = 0;
int closed = 0;
auto parsed = sscanf(message.c_str(),
"stats:%d:%d:%d",
&opened,
&msg,
&closed);
REQUIRE(parsed == 3);
CHECK(opened >= 2);
CHECK(msg >= 2);
CHECK(closed >= 1);
second->stop();
promise->set_value();
}
});

second->connectToServer(req2,
[second, TEST_CTX](ReqResult r,
const HttpResponsePtr &resp,
const WebSocketClientPtr &) {
REQUIRE(r == ReqResult::Ok);
REQUIRE(resp != nullptr);
CHECK(second->getConnection()->connected());
});
});

auto req = HttpRequest::newHttpRequest();
req->setPath("/coro-chat");
first->connectToServer(req,
[first, TEST_CTX](ReqResult r,
const HttpResponsePtr &resp,
const WebSocketClientPtr &) {
REQUIRE(r == ReqResult::Ok);
REQUIRE(resp != nullptr);
CHECK(first->getConnection()->connected());
});

REQUIRE(future.wait_for(5s) == std::future_status::ready);
#endif
}
53 changes: 53 additions & 0 deletions lib/tests/integration_test/server/WebSocketCoroTest.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include "WebSocketCoroTest.h"

#include <drogon/drogon.h>

using namespace example;

std::atomic<int> WebSocketCoroTest::openedCount_{0};
std::atomic<int> WebSocketCoroTest::messageCount_{0};
std::atomic<int> WebSocketCoroTest::closedCount_{0};

drogon::Task<> WebSocketCoroTest::handleNewConnectionCoro(
const drogon::HttpRequestPtr &,
const drogon::WebSocketConnectionPtr &conn)
{
++openedCount_;
co_await drogon::sleepCoro(drogon::app().getLoop(), 0.001);
conn->send("opened");
co_return;
}

drogon::Task<> WebSocketCoroTest::handleNewMessageCoro(
const drogon::WebSocketConnectionPtr &conn,
std::string &&message,
const drogon::WebSocketMessageType &type)
{
if (type != drogon::WebSocketMessageType::Text)
{
co_return;
}

++messageCount_;
co_await drogon::sleepCoro(drogon::app().getLoop(), 0.001);

if (message == "stats")
{
conn->send("stats:" + std::to_string(openedCount_.load()) + ":" +
std::to_string(messageCount_.load()) + ":" +
std::to_string(closedCount_.load()));
}
else
{
conn->send("coro:" + message);
}
co_return;
}

drogon::Task<> WebSocketCoroTest::handleConnectionClosedCoro(
const drogon::WebSocketConnectionPtr &)
{
++closedCount_;
co_await drogon::sleepCoro(drogon::app().getLoop(), 0.001);
co_return;
}
31 changes: 31 additions & 0 deletions lib/tests/integration_test/server/WebSocketCoroTest.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#pragma once

#include <drogon/WebSocketController.h>
#include <atomic>

namespace example
{
class WebSocketCoroTest
: public drogon::WebSocketCoroController<WebSocketCoroTest>
{
public:
drogon::Task<> handleNewMessageCoro(
const drogon::WebSocketConnectionPtr &,
std::string &&,
const drogon::WebSocketMessageType &) override;
drogon::Task<> handleConnectionClosedCoro(
const drogon::WebSocketConnectionPtr &) override;
drogon::Task<> handleNewConnectionCoro(
const drogon::HttpRequestPtr &,
const drogon::WebSocketConnectionPtr &) override;

WS_CORO_PATH_LIST_BEGIN
WS_CORO_PATH_ADD("/coro-chat", drogon::Get);
WS_CORO_PATH_LIST_END

private:
static std::atomic<int> openedCount_;
static std::atomic<int> messageCount_;
static std::atomic<int> closedCount_;
};
} // namespace example
Loading