diff --git a/README.md b/README.md index e6e0bc2..410dac2 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ Supports: * WebSockets * HTTP fetcher * Modbus [TCP] -* MQTT [TCP] +* MQTT [TCP, WebSockets] SimpleSocket supports TLS via openssl. diff --git a/include/simple_socket/mqtt/MQTTBroker.hpp b/include/simple_socket/mqtt/MQTTBroker.hpp index f165929..409da7d 100644 --- a/include/simple_socket/mqtt/MQTTBroker.hpp +++ b/include/simple_socket/mqtt/MQTTBroker.hpp @@ -8,8 +8,11 @@ namespace simple_socket { class MQTTBroker { public: - explicit MQTTBroker(int port); + explicit MQTTBroker(int tcpPort); +#ifdef SIMPLE_SOCKET_WITH_WEBSOCKETS + MQTTBroker(int tcpPort, int wsPort); +#endif void start(); void stop(); diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8c87bb4..e91b5ab 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -117,6 +117,9 @@ if (SIMPLE_SOCKET_WITH_TLS) target_compile_definitions(simple_socket PRIVATE SIMPLE_SOCKET_WITH_TLS=1) target_link_libraries(simple_socket PRIVATE OpenSSL::SSL OpenSSL::Crypto) endif () +if (SIMPLE_SOCKET_WITH_WEBSOCKETS) + target_compile_definitions(simple_socket PUBLIC SIMPLE_SOCKET_WITH_WEBSOCKETS=1) +endif () target_include_directories(simple_socket PUBLIC diff --git a/src/simple_socket/mqtt/MQTTBroker.cpp b/src/simple_socket/mqtt/MQTTBroker.cpp index 4801489..3e08716 100644 --- a/src/simple_socket/mqtt/MQTTBroker.cpp +++ b/src/simple_socket/mqtt/MQTTBroker.cpp @@ -4,8 +4,14 @@ #include "simple_socket/TCPSocket.hpp" #include "simple_socket/mqtt/mqtt_common.hpp" +#ifdef SIMPLE_SOCKET_WITH_WEBSOCKETS +#include "simple_socket/ws/WebSocket.hpp" +#include "simple_socket/mqtt/WsMqttWrapper.hpp" +#endif + #include #include +#include #include #include #include @@ -20,14 +26,26 @@ struct MQTTBroker::Impl { explicit Impl(int port) : server_(port), stop_(false) {} +#ifdef SIMPLE_SOCKET_WITH_WEBSOCKETS + explicit Impl(int port, int wsPort) + : server_(port), ws_(wsPort), stop_(false) {} +#endif + void start() { listener_ = std::thread([this] { acceptLoop(); }); +#ifdef SIMPLE_SOCKET_WITH_WEBSOCKETS + if (ws_) wsListener_ = std::thread([this] { wsAcceptLoop(); }); +#endif } void stop() { stop_ = true; server_.close(); if (listener_.joinable()) listener_.join(); +#ifdef SIMPLE_SOCKET_WITH_WEBSOCKETS + if (ws_) ws_->stop(); + if (wsListener_.joinable()) wsListener_.join(); +#endif } private: @@ -37,14 +55,58 @@ struct MQTTBroker::Impl { std::unordered_set topics; }; +#ifdef SIMPLE_SOCKET_WITH_WEBSOCKETS + std::optional ws_; +#endif + TCPServer server_; + std::atomic stop_; std::thread listener_; + std::thread wsListener_; std::mutex subsMutex_; std::unordered_map> subscribers_; std::vector clients_; +#ifdef SIMPLE_SOCKET_WITH_WEBSOCKETS + void wsAcceptLoop() { + + std::unordered_map connections; + + ws_->onOpen = [this, &connections](WebSocketConnection* conn) { + auto client = std::make_unique(); + Client* clientPtr = client.get(); + clients_.push_back(clientPtr); + auto wrapper = std::make_unique(conn); + connections[conn] = wrapper.get(); + client->conn = std::move(wrapper); + + std::thread(&Impl::handleClient, this, std::move(client)).detach(); + }; + ws_->onMessage = [&connections](WebSocketConnection* conn, const std::string& msg) { + connections[conn]->push_back(msg); + }; + ws_->onClose = [&connections](WebSocketConnection* conn) { + connections[conn]->close(); + }; + ws_->start(); + + while (!stop_) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // remove closed connections + for (auto it = connections.begin(); it != connections.end();) { + if (it->second->closed()) { + it = connections.erase(it); + } else { + ++it; + } + } + } + ws_->stop(); + } +#endif void acceptLoop() { @@ -92,144 +154,10 @@ struct MQTTBroker::Impl { // CONNACK const std::vector connack = {CONNACK, 0x02, 0x00, 0x00}; - c->conn->write(connack); + if (!c->conn->write(connack)) return; // Main loop - bool running = true; - while (running && !stop_) { - uint8_t hdr = 0; - if (!c->conn->readExact(&hdr, 1)) break; - size_t rem = decodeRemainingLength(c->conn.get()); - std::vector buf(rem); - if (rem > 0 && !c->conn->readExact(buf.data(), rem)) break; - - const uint8_t typeNibble = static_cast(hdr & 0xF0); - const uint8_t flagsNibble = static_cast(hdr & 0x0F); - - switch (typeNibble) { - case (PUBLISH & 0xF0): { - size_t p = 0; - - // Topic length - if (buf.size() < 2) return; - uint16_t tlen = (static_cast(buf[p]) << 8) | buf[p + 1]; - p += 2; - if (p + tlen > buf.size()) return; - std::string topic(reinterpret_cast(&buf[p]), tlen); - p += tlen; - - // QoS from fixed header - const uint8_t qos = static_cast((hdr >> 1) & 0x03); - if (qos > 0) { - // Skip Packet Identifier for QoS1/2 - if (p + 2 > buf.size()) return; - p += 2; - // Optional: send PUBACK for QoS1, PUBREC/PUBREL/PUBCOMP for QoS2 - // (not implemented here) - } - - if (p > buf.size()) return; - std::string message(reinterpret_cast(&buf[p]), buf.size() - p); - - // Snapshot subscribers to avoid holding the lock during writes - std::vector targets; - { - std::lock_guard lock(subsMutex_); - auto it = subscribers_.find(topic); - if (it == subscribers_.end()) return; - targets = it->second; - } - - // Build QoS 0 PUBLISH to subscribers - auto packetTopic = encodeShortString(topic); - std::vector pl; - pl.reserve(packetTopic.size() + message.size()); - pl.insert(pl.end(), packetTopic.begin(), packetTopic.end()); - pl.insert(pl.end(), message.begin(), message.end()); - - std::vector packet; - packet.reserve(1 + 4 + pl.size()); - packet.push_back(PUBLISH);// 0x30 (QoS0) - auto len = encodeRemainingLength(pl.size()); - packet.insert(packet.end(), len.begin(), len.end()); - packet.insert(packet.end(), pl.begin(), pl.end()); - - for (auto* sub : targets) { - // std::cout << "MQTTBroker: delivering message on topic '" << topic << "' to client '" << sub->clientId << "'" << std::endl; - sub->conn->write(packet); - } - } break; - - case (SUBSCRIBE & 0xF0): { - // Must have flags 0x02 per spec - if (flagsNibble != 0x02) break; - size_t p = 0; - if (buf.size() < 5) break;// pid(2) + topic len(2) + qos(1) - uint16_t pid = (static_cast(buf[p]) << 8) | buf[p + 1]; - p += 2; - - uint16_t tlen = (static_cast(buf[p]) << 8) | buf[p + 1]; - p += 2; - if (p + tlen + 1 > buf.size()) break; - std::string topic(reinterpret_cast(&buf[p]), tlen); - p += tlen; - /*uint8_t reqQos =*/(void) buf[p++];// ignore, grant QoS 0 - - { - std::lock_guard lock(subsMutex_); - subscribers_[topic].push_back(c.get()); - c->topics.insert(topic); - } - - // SUBACK echoing Packet Identifier, grant QoS 0 - const std::vector suback = { - SUBACK, 0x03, - static_cast(pid >> 8), static_cast(pid & 0xFF), - 0x00}; - c->conn->write(suback); - } break; - - case (UNSUBSCRIBE & 0xF0): { - if (flagsNibble != 0x02) break; - size_t p = 0; - if (buf.size() < 4) break;// pid(2) + topic len(2) - uint16_t pid = (static_cast(buf[p]) << 8) | buf[p + 1]; - p += 2; - - uint16_t tlen = (static_cast(buf[p]) << 8) | buf[p + 1]; - p += 2; - if (p + tlen > buf.size()) break; - std::string topic(reinterpret_cast(&buf[p]), tlen); - - { - std::lock_guard lock(subsMutex_); - c->topics.erase(topic); - auto& vec = subscribers_[topic]; - std::erase(vec, c.get()); - } - - // UNSUBACK echoing Packet Identifier - const std::vector unsuback = { - UNSUBACK, 0x02, - static_cast(pid >> 8), static_cast(pid & 0xFF)}; - c->conn->write(unsuback); - } break; - - case PINGREQ: { - if (flagsNibble != 0x00) break; - const std::vector pingresp = {PINGRESP, 0x00}; - c->conn->write(pingresp); - } break; - - case DISCONNECT: - running = false; - break; - - default: - // ignore unsupported packets - break; - } - } + clientLoop(c.get()); cleanupClient(*c); } catch (const std::exception& e) { @@ -238,6 +166,169 @@ struct MQTTBroker::Impl { } } + void handlePublish(Client* c, uint8_t hdr, std::vector& buf) { + size_t p = 0; + + // Topic length + if (buf.size() < 2) return; + uint16_t tlen = (static_cast(buf[p]) << 8) | buf[p + 1]; + p += 2; + if (p + tlen > buf.size()) return; + std::string topic(reinterpret_cast(&buf[p]), tlen); + p += tlen; + + // QoS from fixed header + const auto qos = static_cast((hdr >> 1) & 0x03); + if (qos > 0) { + // Skip Packet Identifier for QoS1/2 + if (p + 2 > buf.size()) return; + p += 2; + // Optional: send PUBACK for QoS1, PUBREC/PUBREL/PUBCOMP for QoS2 + // (not implemented here) + } + + if (p > buf.size()) return; + std::string message(reinterpret_cast(&buf[p]), buf.size() - p); + + { + std::lock_guard lock(subsMutex_); + if (subscribers_.empty() || !subscribers_.contains(topic)) return; + } + + // Build QoS 0 PUBLISH to subscribers + auto packetTopic = encodeShortString(topic); + std::vector pl; + pl.reserve(packetTopic.size() + message.size()); + pl.insert(pl.end(), packetTopic.begin(), packetTopic.end()); + pl.insert(pl.end(), message.begin(), message.end()); + + std::vector packet; + packet.reserve(1 + 4 + pl.size()); + packet.push_back(PUBLISH);// 0x30 (QoS0) + auto len = encodeRemainingLength(pl.size()); + packet.insert(packet.end(), len.begin(), len.end()); + packet.insert(packet.end(), pl.begin(), pl.end()); + + std::lock_guard lock(subsMutex_); + for (auto it = subscribers_.begin(); it != subscribers_.end();) { + if (it->first == topic) { + auto& subs = it->second; + bool erased = false; + for (auto* sub : subs) { + if (!sub->conn->write(packet)) { + it = subscribers_.erase(it); + erased = true; + break;// exit inner loop + } + } + if (!erased) ++it; + } else { + ++it; + } + } + } + + void handleSubscriber(Client* c, uint8_t hdr, std::vector& buf, bool& running) { + + + size_t p = 0; + if (buf.size() < 5) return;// pid(2) + topic len(2) + qos(1) + uint16_t pid = (static_cast(buf[p]) << 8) | buf[p + 1]; + p += 2; + + uint16_t tlen = (static_cast(buf[p]) << 8) | buf[p + 1]; + p += 2; + if (p + tlen + 1 > buf.size()) return; + std::string topic(reinterpret_cast(&buf[p]), tlen); + p += tlen; + /*uint8_t reqQos =*/(void) buf[p++];// ignore, grant QoS 0 + + { + std::lock_guard lock(subsMutex_); + subscribers_[topic].push_back(c); + c->topics.insert(topic); + } + + // SUBACK echoing Packet Identifier, grant QoS 0 + const std::vector suback = { + SUBACK, 0x03, + static_cast(pid >> 8), static_cast(pid & 0xFF), + 0x00}; + if (!c->conn->write(suback)) { + running = false; + } + } + + void clientLoop(Client* c) { + bool running = true; + while (running && !stop_) { + uint8_t hdr = 0; + if (!c->conn->readExact(&hdr, 1)) break; + size_t rem = decodeRemainingLength(c->conn.get()); + std::vector buf(rem); + if (rem > 0 && !c->conn->readExact(buf.data(), rem)) break; + + const auto typeNibble = static_cast(hdr & 0xF0); + const auto flagsNibble = static_cast(hdr & 0x0F); + + switch (typeNibble) { + case (PUBLISH & 0xF0): { + handlePublish(c, hdr, buf); + } break; + + case (SUBSCRIBE & 0xF0): { + // Must have flags 0x02 per spec + if (flagsNibble != 0x02) break; + handleSubscriber(c, hdr, buf, running); + } break; + + case (UNSUBSCRIBE & 0xF0): { + if (flagsNibble != 0x02) break; + size_t p = 0; + if (buf.size() < 4) break;// pid(2) + topic len(2) + uint16_t pid = (static_cast(buf[p]) << 8) | buf[p + 1]; + p += 2; + + uint16_t tlen = (static_cast(buf[p]) << 8) | buf[p + 1]; + p += 2; + if (p + tlen > buf.size()) break; + std::string topic(reinterpret_cast(&buf[p]), tlen); + + { + std::lock_guard lock(subsMutex_); + c->topics.erase(topic); + auto& vec = subscribers_[topic]; + std::erase(vec, c); + } + + // UNSUBACK echoing Packet Identifier + const std::vector unsuback = { + UNSUBACK, 0x02, + static_cast(pid >> 8), static_cast(pid & 0xFF)}; + if (!c->conn->write(unsuback)) { + running = false; + } + } break; + + case PINGREQ: { + if (flagsNibble != 0x00) break; + const std::vector pingresp = {PINGRESP, 0x00}; + if (!c->conn->write(pingresp)) { + running = false; + } + } break; + + case DISCONNECT: + running = false; + break; + + default: + // ignore unsupported packets + break; + } + } + } + void cleanupClient(Client& c) { std::lock_guard lock(subsMutex_); for (auto& topic : c.topics) { @@ -249,8 +340,13 @@ struct MQTTBroker::Impl { }; -MQTTBroker::MQTTBroker(int port) - : pimpl_(std::make_unique(port)) {} +MQTTBroker::MQTTBroker(int tcpPort) + : pimpl_(std::make_unique(tcpPort)) {} + +#ifdef SIMPLE_SOCKET_WITH_WEBSOCKETS +MQTTBroker::MQTTBroker(int tcpPort, int wsPort) + : pimpl_(std::make_unique(tcpPort, wsPort)) {} +#endif void MQTTBroker::start() { pimpl_->start(); diff --git a/src/simple_socket/mqtt/WsMqttWrapper.hpp b/src/simple_socket/mqtt/WsMqttWrapper.hpp new file mode 100644 index 0000000..95fc6ea --- /dev/null +++ b/src/simple_socket/mqtt/WsMqttWrapper.hpp @@ -0,0 +1,78 @@ + +#ifndef SIMPLE_SOCKET_WSMQTTWRAPPER_HPP +#define SIMPLE_SOCKET_WSMQTTWRAPPER_HPP + + +#include "simple_socket/SimpleConnection.hpp" + +#include "simple_socket/ws/WebSocket.hpp" + +#include +#include +#include +#include +#include + +namespace simple_socket { + class WebSocketConnection; + + struct WsWrapper: SimpleConnection { + + explicit WsWrapper(WebSocketConnection* c): connection(c) {} + + int read(uint8_t* buffer, size_t size) override { + std::unique_lock lock(m_); + cv_.wait(lock, [&] { return closed_ || !queue_.empty(); }); + if (queue_.empty()) return -1;// closed and no data + + std::string msg = std::move(queue_.front()); + queue_.pop_front(); + + size_t toCopy = std::min(size, msg.size()); + std::memcpy(buffer, msg.data(), toCopy); + + if (toCopy < msg.size()) { + // put the remainder back to the front so next read continues it + queue_.push_front(msg.substr(toCopy)); + } + + return static_cast(toCopy); + } + + bool write(const uint8_t* data, size_t size) override { + if (closed_) return false; + return connection->send(data, size); + } + + void close() override { + { + std::lock_guard lock(m_); + closed_ = true; + } + cv_.notify_all(); + } + + void push_back(const std::string& msg) { + { + std::lock_guard lock(m_); + queue_.push_back(msg); + } + cv_.notify_one(); + } + + [[nodiscard]] bool closed() const { + return closed_; + } + + private: + std::atomic_bool closed_{false}; + std::deque queue_; + std::mutex m_; + std::condition_variable cv_; + WebSocketConnection* connection; + }; + +}// namespace simple_socket + + +#endif//SIMPLE_SOCKET_WSMQTTWRAPPER_HPP diff --git a/src/simple_socket/ws/WebSocket.cpp b/src/simple_socket/ws/WebSocket.cpp index e75caf6..4db8a4d 100644 --- a/src/simple_socket/ws/WebSocket.cpp +++ b/src/simple_socket/ws/WebSocket.cpp @@ -2,7 +2,6 @@ #include "simple_socket/ws/WebSocket.hpp" #include "simple_socket/TCPSocket.hpp" -#include "simple_socket/socket_common.hpp" #include "simple_socket/ws/WebSocketHandshakeKeyGen.hpp" #include "simple_socket/util/uuid.hpp" @@ -36,8 +35,13 @@ namespace { std::ostringstream response; response << "HTTP/1.1 101 Switching Protocols\r\n" << "Upgrade: websocket\r\n" - << "Connection: Upgrade\r\n" - << "Sec-WebSocket-Accept: " << secWebSocketAccept << "\r\n\r\n"; + << "Connection: Upgrade\r\n"; + if (auto* val = http.get("sec-websocket-protocol")) { + if (val != nullptr && toLower(*val).find("mqtt") != std::string::npos) { + response << "Sec-WebSocket-Protocol: mqtt\r\n"; + } + } + response << "Sec-WebSocket-Accept: " << secWebSocketAccept << "\r\n\r\n"; const std::string responseStr = response.str(); if (!conn.write(responseStr)) { @@ -59,8 +63,11 @@ struct WebSocket::Impl { try { WebSocketCallbacks callbacks{scope->onOpen, scope->onClose, scope->onMessage}; - auto ws = std::make_unique(callbacks, socket.accept(), WebSocketConnectionImpl::Role::Server); - ws->run(handshake); + auto conn = socket.accept(); + handshake(*conn); + auto ws = std::make_unique(callbacks, std::move(conn), WebSocketConnectionImpl::Role::Server); + + ws->run(); connections.emplace_back(std::move(ws)); } catch (std::exception&) { // std::cerr << ex.what() << std::endl; diff --git a/src/simple_socket/ws/WebSocketClient.cpp b/src/simple_socket/ws/WebSocketClient.cpp index 9565e5c..69eea24 100644 --- a/src/simple_socket/ws/WebSocketClient.cpp +++ b/src/simple_socket/ws/WebSocketClient.cpp @@ -148,13 +148,13 @@ struct WebSocketClient::Impl { const auto [host, port] = parseWebSocketURL(url); auto c = ctx_.connect(host, port, useTLS); + performHandshake(*c, url, host, port); WebSocketCallbacks callbacks{scope_->onOpen, scope_->onClose, scope_->onMessage}; conn = std::make_unique(callbacks, std::move(c), WebSocketConnectionImpl::Role::Client); conn->setBufferSize(bufferSize); - conn->run([url, host, port](SimpleConnection& conn) { - performHandshake(conn, url, host, port); - }); + + conn->run(); } bool send(const std::string& message) { diff --git a/src/simple_socket/ws/WebSocketConnection.hpp b/src/simple_socket/ws/WebSocketConnection.hpp index d70410d..729425a 100644 --- a/src/simple_socket/ws/WebSocketConnection.hpp +++ b/src/simple_socket/ws/WebSocketConnection.hpp @@ -54,9 +54,8 @@ namespace simple_socket { buffer.resize(size); } - void run(const std::function& handshake) { + void run() { - handshake(*conn_); if (callbacks_.onOpen) { callbacks_.onOpen(this); } @@ -65,7 +64,6 @@ namespace simple_socket { listen(); }); } - bool send(const std::string& message) override { const auto frame = buildText(message, role_); std::lock_guard lg(tx_mtx_); @@ -115,7 +113,6 @@ namespace simple_socket { std::mutex tx_mtx_;// serialize writes only std::atomic_bool closed_{false}; - WebSocket* socket_{}; std::unique_ptr conn_; WebSocketCallbacks callbacks_; std::thread thread_; @@ -179,6 +176,7 @@ namespace simple_socket { } void listen() { + std::vector rx; // accumulated bytes from socket std::vector message;// assembling fragmented messages bool continued = false; diff --git a/tests/integration/CMakeLists.txt b/tests/integration/CMakeLists.txt index 5d71ba3..e342a5b 100644 --- a/tests/integration/CMakeLists.txt +++ b/tests/integration/CMakeLists.txt @@ -37,5 +37,7 @@ if (SIMPLE_SOCKET_WITH_MQTT) add_executable(run_mqtt_client_broker run_mqtt_client_broker.cpp) target_link_libraries(run_mqtt_client_broker PRIVATE simple_socket) + + file(COPY mqtt_client.html DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) endif () diff --git a/tests/integration/mqtt_client.html b/tests/integration/mqtt_client.html new file mode 100644 index 0000000..2721e2d --- /dev/null +++ b/tests/integration/mqtt_client.html @@ -0,0 +1,54 @@ + + + + + MQTT client + + + +
+ +

+
+ + + + + + + + diff --git a/tests/integration/run_mqtt_client_broker.cpp b/tests/integration/run_mqtt_client_broker.cpp index bc88ed8..baf84e7 100644 --- a/tests/integration/run_mqtt_client_broker.cpp +++ b/tests/integration/run_mqtt_client_broker.cpp @@ -10,13 +10,14 @@ using namespace simple_socket; int main() { - int port = 1883; + int tcpPort = 1883; + int wsPort = tcpPort+1; - MQTTBroker broker(port); + MQTTBroker broker(tcpPort, wsPort); broker.start(); try { - MQTTClient client("127.0.0.1", port, "SimpleSocketClient"); + MQTTClient client("127.0.0.1", tcpPort, "SimpleSocketClient"); client.connect(false); std::string topic1 = "simple_socket/topic1"; @@ -30,14 +31,18 @@ int main() { std::cout << "[" << topic2 << "] Got: " << msg << std::endl; }); - std::atomic_bool stop = false; - std::thread([&client, &stop, topic1, topic2] { + client.subscribe("simple_socket/slider", [](const auto& msg) { + std::cout << "[simple_socket/slider] Got: " << msg << std::endl; + }); + + std::atomic_bool stop; + auto clientThread = std::thread([&client, topic1, topic2, &stop] { while (!stop) { std::this_thread::sleep_for(std::chrono::seconds(1)); client.publish(topic1, "Hello from SimpleSocket MQTT!"); client.publish(topic2, "Another hello from SimpleSocket MQTT!"); } - }).detach(); + }); client.run(); @@ -46,9 +51,16 @@ int main() { client.unsubscribe(topic2); }).detach(); +#ifdef _WIN32 + system("start mqtt_client.html"); +#endif + std::cout << "Press any key to exit..." << std::endl; std::cin.get(); stop = true; + if (clientThread.joinable()) { + clientThread.join(); + } client.close(); broker.stop();