From 8992f3983d804e3a000e39036401392de48cbda1 Mon Sep 17 00:00:00 2001 From: sav-da Date: Thu, 26 Feb 2026 12:45:50 +0300 Subject: [PATCH 01/10] feat: improved RabbitMQ driver - added headers to ConsumedMessage - added headers to Envelop - added heartbeat to queue --- .../userver/urabbitmq/client_settings.hpp | 5 ++ .../include/userver/urabbitmq/typedefs.hpp | 7 +++ rabbitmq/src/urabbitmq/client_settings.cpp | 5 ++ rabbitmq/src/urabbitmq/component.yaml | 5 ++ rabbitmq/src/urabbitmq/connection.cpp | 11 +++- rabbitmq/src/urabbitmq/connection.hpp | 1 + rabbitmq/src/urabbitmq/connection_pool.cpp | 1 + rabbitmq/src/urabbitmq/consumer_base_impl.cpp | 7 +++ rabbitmq/src/urabbitmq/impl/amqp_channel.cpp | 15 ++++- .../impl/amqp_connection_handler.cpp | 58 ++++++++++++++++++- .../impl/amqp_connection_handler.hpp | 10 ++++ 11 files changed, 120 insertions(+), 5 deletions(-) diff --git a/rabbitmq/include/userver/urabbitmq/client_settings.hpp b/rabbitmq/include/userver/urabbitmq/client_settings.hpp index 5a0122bb10c1..93683c5c4904 100644 --- a/rabbitmq/include/userver/urabbitmq/client_settings.hpp +++ b/rabbitmq/include/userver/urabbitmq/client_settings.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -74,6 +75,10 @@ struct PoolSettings final { /// (tcp error/protocol error/write timeout) leads to a errors burst: /// all outstanding request will fails at once size_t max_in_flight_requests = 5; + + /// Requested AMQP heartbeat interval in seconds. + /// Set to 0 to disable heartbeats. + size_t heartbeat_interval_seconds = 30; }; class TestsHelper; diff --git a/rabbitmq/include/userver/urabbitmq/typedefs.hpp b/rabbitmq/include/userver/urabbitmq/typedefs.hpp index 4714339b802d..b8d1efc26c0e 100644 --- a/rabbitmq/include/userver/urabbitmq/typedefs.hpp +++ b/rabbitmq/include/userver/urabbitmq/typedefs.hpp @@ -4,6 +4,7 @@ /// @brief Convenient typedefs for RabbitMQ entities. #include +#include #include @@ -67,6 +68,8 @@ enum class MessageType { /// metadata fields. This struct is used to pass messages to the end user, /// hiding the actual AMQP message object implementation. struct ConsumedMessage { + using Headers = std::unordered_map; + struct Metadata { std::string exchange; std::string routingKey; @@ -75,17 +78,21 @@ struct ConsumedMessage { Metadata metadata; std::optional reply_to{}; std::optional correlation_id{}; + Headers headers{}; }; /// @brief Structure holding an AMQP message body along with some of its /// metadata fields. This struct is used to pass messages from the end user, /// hiding the actual AMQP message object implementation. struct Envelope { + using Headers = std::unordered_map; + std::string message; MessageType type; std::optional reply_to{}; std::optional correlation_id{}; std::optional expiration{}; + std::optional headers{}; }; } // namespace urabbitmq diff --git a/rabbitmq/src/urabbitmq/client_settings.cpp b/rabbitmq/src/urabbitmq/client_settings.cpp index 85d25ba5adf0..ba7682449243 100644 --- a/rabbitmq/src/urabbitmq/client_settings.cpp +++ b/rabbitmq/src/urabbitmq/client_settings.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -100,9 +101,13 @@ PoolSettings Parse(const yaml_config::YamlConfig& config, formats::parse::To(result.min_pool_size); result.max_pool_size = config["max_pool_size"].As(result.max_pool_size); result.max_in_flight_requests = config["max_in_flight_requests"].As(result.max_in_flight_requests); + result.heartbeat_interval_seconds = + config["heartbeat_interval_seconds"].As(result.heartbeat_interval_seconds); UINVARIANT(result.min_pool_size <= result.max_pool_size, "max_pool_size is less than min_pool_size"); UINVARIANT(result.max_pool_size > 0, "max_pool_size is set to zero"); + UINVARIANT(result.heartbeat_interval_seconds <= std::numeric_limits::max(), + "heartbeat_interval_seconds is too large"); return result; } diff --git a/rabbitmq/src/urabbitmq/component.yaml b/rabbitmq/src/urabbitmq/component.yaml index cdc35b89df77..757c5500a6c2 100644 --- a/rabbitmq/src/urabbitmq/component.yaml +++ b/rabbitmq/src/urabbitmq/component.yaml @@ -23,6 +23,11 @@ properties: description: | per-connection limit for requests awaiting response from the broker default: 5 + heartbeat_interval_seconds: + type: integer + description: | + requested AMQP heartbeat interval in seconds; 0 disables heartbeats + default: 30 use_secure_connection: type: boolean description: whether to use TLS for connections diff --git a/rabbitmq/src/urabbitmq/connection.cpp b/rabbitmq/src/urabbitmq/connection.cpp index f54e752b5bf4..4d8fcac026ed 100644 --- a/rabbitmq/src/urabbitmq/connection.cpp +++ b/rabbitmq/src/urabbitmq/connection.cpp @@ -9,11 +9,20 @@ Connection::Connection( const EndpointInfo& endpoint, const AuthSettings& auth_settings, size_t max_in_flight_requests, + size_t heartbeat_interval_seconds, bool secure, statistics::ConnectionStatistics& stats, engine::Deadline deadline ) - : handler_{resolver, endpoint, auth_settings, secure, stats, deadline}, + : handler_{ + resolver, + endpoint, + auth_settings, + heartbeat_interval_seconds, + secure, + stats, + deadline, + }, connection_{handler_, max_in_flight_requests, deadline}, channel_{connection_}, reliable_channel_{connection_} diff --git a/rabbitmq/src/urabbitmq/connection.hpp b/rabbitmq/src/urabbitmq/connection.hpp index 372596ffadb9..a53796b948a6 100644 --- a/rabbitmq/src/urabbitmq/connection.hpp +++ b/rabbitmq/src/urabbitmq/connection.hpp @@ -29,6 +29,7 @@ class Connection final { const EndpointInfo& endpoint, const AuthSettings& auth_settings, size_t max_in_flight_requests, + size_t heartbeat_interval_seconds, bool secure, statistics::ConnectionStatistics& stats, engine::Deadline deadline diff --git a/rabbitmq/src/urabbitmq/connection_pool.cpp b/rabbitmq/src/urabbitmq/connection_pool.cpp index e1a9baf3a245..febb66583cc4 100644 --- a/rabbitmq/src/urabbitmq/connection_pool.cpp +++ b/rabbitmq/src/urabbitmq/connection_pool.cpp @@ -92,6 +92,7 @@ ConnectionPool::ConnectionUniquePtr ConnectionPool::DoCreateConnection(engine::D endpoint_info_, auth_settings_, pool_settings_.max_in_flight_requests, + pool_settings_.heartbeat_interval_seconds, use_secure_connection_, stats_, deadline diff --git a/rabbitmq/src/urabbitmq/consumer_base_impl.cpp b/rabbitmq/src/urabbitmq/consumer_base_impl.cpp index 3bb22d1ea41b..ef348902c4d3 100644 --- a/rabbitmq/src/urabbitmq/consumer_base_impl.cpp +++ b/rabbitmq/src/urabbitmq/consumer_base_impl.cpp @@ -1,5 +1,6 @@ #include "consumer_base_impl.hpp" +#include #include #include @@ -106,6 +107,12 @@ void ConsumerBaseImpl::OnMessage(const AMQP::Message& message, uint64_t delivery if (message.hasCorrelationID()) { consumed.correlation_id = message.correlationID(); } + const auto& headers = message.headers(); + for (const auto& key : headers.keys()) { + std::ostringstream stream; + stream << headers.get(key); + consumed.headers.emplace(key, stream.str()); + } bts_.Detach(engine::AsyncNoSpan( dispatcher_, diff --git a/rabbitmq/src/urabbitmq/impl/amqp_channel.cpp b/rabbitmq/src/urabbitmq/impl/amqp_channel.cpp index c374dc41bd5d..d6aa45d5b2ef 100644 --- a/rabbitmq/src/urabbitmq/impl/amqp_channel.cpp +++ b/rabbitmq/src/urabbitmq/impl/amqp_channel.cpp @@ -94,6 +94,17 @@ AMQP::Table CreateHeaders() { return headers; } +AMQP::Table CreateHeadersForPublish(const Envelope& envelope) { + auto headers = CreateHeaders(); + if (envelope.headers.has_value()) { + for (const auto& [key, value] : envelope.headers.value()) { + headers[key] = value; + } + } + + return headers; +} + } // namespace AmqpChannel::AmqpChannel(AmqpConnection& conn) @@ -196,7 +207,7 @@ void AmqpChannel::Publish( ) { AMQP::Envelope native_envelope{envelope.message.data(), envelope.message.size()}; native_envelope.setPersistent(envelope.type == MessageType::kPersistent); - native_envelope.setHeaders(CreateHeaders()); + native_envelope.setHeaders(CreateHeadersForPublish(envelope)); if (envelope.reply_to.has_value()) { native_envelope.setReplyTo(envelope.reply_to.value().c_str()); } @@ -285,7 +296,7 @@ ResponseAwaiter AmqpReliableChannel::Publish( if (envelope.expiration.has_value()) { native_envelope.setExpiration(std::to_string(envelope.expiration.value().count())); } - native_envelope.setHeaders(CreateHeaders()); + native_envelope.setHeaders(CreateHeadersForPublish(envelope)); auto awaiter = conn_.GetAwaiter(deadline); diff --git a/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp b/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp index e0857e5510da..68ddded55c44 100644 --- a/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp +++ b/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp @@ -1,5 +1,7 @@ #include "amqp_connection_handler.hpp" +#include +#include #include #include @@ -11,6 +13,7 @@ #include #include +#include #include USERVER_NAMESPACE_BEGIN @@ -18,6 +21,7 @@ USERVER_NAMESPACE_BEGIN namespace urabbitmq::impl { namespace { +constexpr std::chrono::milliseconds kHeartbeatSendTimeout{200}; engine::io::Socket CreateSocket(engine::io::Sockaddr& addr, engine::Deadline deadline) { engine::io::Socket socket{addr.Domain(), engine::io::SocketType::kTcp}; @@ -91,6 +95,7 @@ AmqpConnectionHandler::AmqpConnectionHandler( clients::dns::Resolver& resolver, const EndpointInfo& endpoint, const AuthSettings& auth_settings, + size_t heartbeat_interval_seconds, bool secure, statistics::ConnectionStatistics& stats, engine::Deadline deadline @@ -98,10 +103,15 @@ AmqpConnectionHandler::AmqpConnectionHandler( : address_{ToAmqpAddress(endpoint, auth_settings, secure)}, socket_{CreateSocketPtr(resolver, address_, auth_settings, deadline)}, reader_{*this, *socket_}, + configured_heartbeat_seconds_{ + static_cast(std::min(heartbeat_interval_seconds, std::numeric_limits::max()))}, stats_{stats} {} -AmqpConnectionHandler::~AmqpConnectionHandler() { reader_.Stop(); } +AmqpConnectionHandler::~AmqpConnectionHandler() { + heartbeat_task_.Stop(); + reader_.Stop(); +} void AmqpConnectionHandler::onProperties(AMQP::Connection*, const AMQP::Table&, AMQP::Table& client) { client["product"] = "uServer AMQP library"; @@ -109,6 +119,18 @@ void AmqpConnectionHandler::onProperties(AMQP::Connection*, const AMQP::Table&, client["information"] = "https://userver.tech/dd/de2/rabbitmq_driver.html"; } +uint16_t AmqpConnectionHandler::onNegotiate(AMQP::Connection*, uint16_t interval) { + if (interval == 0 || configured_heartbeat_seconds_ == 0) { + negotiated_heartbeat_seconds_.store(0, std::memory_order_relaxed); + return 0; + } + + const auto negotiated = static_cast(std::min(interval, configured_heartbeat_seconds_)); + negotiated_heartbeat_seconds_.store(negotiated, std::memory_order_relaxed); + LOG_INFO() << "RabbitMQ heartbeat negotiated at " << negotiated << "s"; + return negotiated; +} + void AmqpConnectionHandler::onData(AMQP::Connection* connection, const char* buffer, size_t size) { if (IsBroken()) { // No further actions can be done @@ -160,6 +182,7 @@ void AmqpConnectionHandler::onReady(AMQP::Connection*) { } void AmqpConnectionHandler::OnConnectionCreated(AmqpConnection* connection, engine::Deadline deadline) { + connection_ = connection; reader_.Start(connection); if (!connection_ready_event_.WaitForEventUntil(deadline)) { @@ -169,11 +192,22 @@ void AmqpConnectionHandler::OnConnectionCreated(AmqpConnection* connection, engi if (error_.has_value()) { reader_.Stop(); + connection_ = nullptr; throw ConnectionSetupError{"Failed to setup a connection: " + *error_}; } + + const auto heartbeat_seconds = negotiated_heartbeat_seconds_.load(std::memory_order_relaxed); + if (heartbeat_seconds > 0) { + const auto heartbeat_period = std::chrono::seconds{std::max(1, heartbeat_seconds / 2)}; + heartbeat_task_.Start("amqp_heartbeat", {heartbeat_period}, [this] { SendHeartbeat(); }); + } } -void AmqpConnectionHandler::OnConnectionDestruction() { reader_.Stop(); } +void AmqpConnectionHandler::OnConnectionDestruction() { + heartbeat_task_.Stop(); + connection_ = nullptr; + reader_.Stop(); +} void AmqpConnectionHandler::Invalidate() { broken_ = true; } @@ -189,6 +223,26 @@ statistics::ConnectionStatistics& AmqpConnectionHandler::GetStatistics() { retur const AMQP::Address& AmqpConnectionHandler::GetAddress() const { return address_; } +void AmqpConnectionHandler::SendHeartbeat() { + if (IsBroken() || connection_ == nullptr) { + return; + } + + try { + const auto deadline = engine::Deadline::FromDuration(kHeartbeatSendTimeout); + auto lock = AmqpConnectionLocker{*connection_}.Lock(deadline); + connection_->SetOperationDeadline(deadline); + connection_->GetNative().heartbeat(); + } catch (const std::exception& ex) { + LOG_WARNING() << "Failed to send AMQP heartbeat: " << ex.what(); + Invalidate(); + if (connection_ != nullptr) { + auto lock = AmqpConnectionLocker{*connection_}.Lock({}); + connection_->GetNative().fail("Underlying connection broke."); + } + } +} + } // namespace urabbitmq::impl USERVER_NAMESPACE_END diff --git a/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.hpp b/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.hpp index 31b7a8a88806..5f8af1c347a1 100644 --- a/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.hpp +++ b/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -7,6 +8,7 @@ #include #include +#include #include @@ -45,6 +47,7 @@ class AmqpConnectionHandler final : public AMQP::ConnectionHandler { clients::dns::Resolver& resolver, const EndpointInfo& endpoint, const AuthSettings& auth_settings, + size_t heartbeat_interval_seconds, bool secure, statistics::ConnectionStatistics& stats, engine::Deadline deadline @@ -52,6 +55,7 @@ class AmqpConnectionHandler final : public AMQP::ConnectionHandler { ~AmqpConnectionHandler() override; void onProperties(AMQP::Connection* connection, const AMQP::Table& server, AMQP::Table& client) override; + uint16_t onNegotiate(AMQP::Connection* connection, uint16_t interval) override; void onData(AMQP::Connection* connection, const char* buffer, size_t size) override; @@ -77,9 +81,15 @@ class AmqpConnectionHandler final : public AMQP::ConnectionHandler { const AMQP::Address& GetAddress() const; private: + void SendHeartbeat(); + AMQP::Address address_; std::unique_ptr socket_; io::SocketReader reader_; + utils::PeriodicTask heartbeat_task_; + AmqpConnection* connection_{nullptr}; + std::atomic negotiated_heartbeat_seconds_{0}; + uint16_t configured_heartbeat_seconds_{0}; engine::SingleConsumerEvent connection_ready_event_; std::atomic broken_{false}; From 16f894186ef5aae6df14d083430f2f62e6d23807 Mon Sep 17 00:00:00 2001 From: sav-da Date: Fri, 27 Feb 2026 07:38:19 +0000 Subject: [PATCH 02/10] refactor: review points resolved - added tests --- .../src/tests/publish_consume_rmqtest.cpp | 179 ++++++++++++++++++ rabbitmq/src/urabbitmq/consumer_base_impl.cpp | 35 ++-- 2 files changed, 197 insertions(+), 17 deletions(-) diff --git a/rabbitmq/src/tests/publish_consume_rmqtest.cpp b/rabbitmq/src/tests/publish_consume_rmqtest.cpp index 976274cf255c..cac2772257f6 100644 --- a/rabbitmq/src/tests/publish_consume_rmqtest.cpp +++ b/rabbitmq/src/tests/publish_consume_rmqtest.cpp @@ -1,6 +1,9 @@ #include "utils_rmqtest.hpp" #include +#include + +#include #include #include @@ -78,6 +81,41 @@ class ThrowingConsumer final : public urabbitmq::ConsumerBase { engine::ConditionVariable cond_; }; +class MetadataConsumer final : public urabbitmq::ConsumerBase { +public: + using urabbitmq::ConsumerBase::ConsumerBase; + ~MetadataConsumer() override { Stop(); } + + void Process(urabbitmq::ConsumedMessage message) override { + { + auto locked = messages_.Lock(); + locked->emplace_back(std::move(message)); + } + + if (++consumed_ == expected_consumed_) { + event_.Send(); + } + } + + void ExpectConsume(size_t count) { expected_consumed_ = count; } + + std::vector Wait() { + if (expected_consumed_ != 0) { + [[maybe_unused]] auto res = + event_.WaitForEventFor(utest::kMaxTestWaitTime); + } + + auto locked = messages_.Lock(); + return *locked; + } + +private: + concurrent::Variable> messages_; + std::atomic expected_consumed_{0}; + std::atomic consumed_{0}; + engine::SingleConsumerEvent event_; +}; + } // namespace UTEST(Consumer, CreateOnInvalidQueueWorks) { @@ -227,4 +265,145 @@ UTEST(Consumer, ForDifferentQueuesWork) { client->GetAdminChannel(client.GetDeadline()).RemoveQueue(second_queue, client.GetDeadline()); } +UTEST(Consumer, ConsumeMetadataAndHeadersWork) { + ClientWrapper client{}; + client.SetupRmqEntities(); + const urabbitmq::ConsumerSettings settings{client.GetQueue(), 10}; + + struct Case { + std::string name; + std::optional reply_to; + std::optional correlation_id; + urabbitmq::Envelope::Headers headers; + }; + + const std::vector cases{ + {"no-user-headers", std::nullopt, std::nullopt, {}}, + { + "simple-user-headers", + "reply-queue", + "corr-id", + { + {"x-custom-header", "custom-value"}, + {"x-custom-int", "42"}, + }, + }, + { + "many-user-headers", + "reply-many", + "corr-many", + { + {"x-empty", ""}, + {"x-spaces", "a b c"}, + {"x-symbols", R"(!@#$%^&*()[]{}<>/?\\|;:'\",.~-_=+)"}, + {"x-long", std::string(128, 'x')}, + }, + }, + { + "trace-headers-override", + "reply-override", + "corr-override", + { + {"u-trace-id", "trace-from-user"}, + {"u-parent-span-id", "parent-from-user"}, + {"x-another", "value"}, + }, + }, + }; + + for (const auto &case_data : cases) { + urabbitmq::Envelope envelope{ + "payload-" + case_data.name, + urabbitmq::MessageType::kTransient, + }; + envelope.reply_to = case_data.reply_to; + envelope.correlation_id = case_data.correlation_id; + envelope.headers = case_data.headers; + client->PublishReliable(client.GetExchange(), client.GetRoutingKey(), + envelope, client.GetDeadline()); + } + + MetadataConsumer consumer{client.Get(), settings}; + consumer.ExpectConsume(cases.size()); + consumer.Start(); + auto consumed = consumer.Wait(); + + ASSERT_EQ(consumed.size(), cases.size()); + std::unordered_map + consumed_by_payload; + consumed_by_payload.reserve(consumed.size()); + for (const auto &msg : consumed) { + consumed_by_payload.emplace(msg.message, &msg); + } + + for (const auto &case_data : cases) { + const auto payload = "payload-" + case_data.name; + const auto it = consumed_by_payload.find(payload); + ASSERT_NE(it, consumed_by_payload.end()) + << "Missing consumed payload: " << payload; + + const auto &msg = *it->second; + EXPECT_EQ(msg.message, payload); + EXPECT_EQ(msg.metadata.exchange, client.GetExchange().GetUnderlying()); + EXPECT_EQ(msg.metadata.routingKey, client.GetRoutingKey()); + EXPECT_EQ(msg.reply_to, case_data.reply_to); + EXPECT_EQ(msg.correlation_id, case_data.correlation_id); + + for (const auto &[header_key, header_value] : case_data.headers) { + ASSERT_EQ(msg.headers.count(header_key), 1) + << "Missing header '" << header_key << "' in " << payload; + const auto &actual = msg.headers.at(header_key); + EXPECT_NE(actual.find(header_value), std::string::npos) + << "Unexpected value for header '" << header_key << "' in " << payload + << ": " << actual; + } + + ASSERT_EQ(msg.headers.count("u-trace-id"), 1) + << "Missing u-trace-id in " << payload; + ASSERT_EQ(msg.headers.count("u-parent-span-id"), 1) + << "Missing u-parent-span-id in " << payload; + EXPECT_FALSE(msg.headers.at("u-trace-id").empty()); + EXPECT_FALSE(msg.headers.at("u-parent-span-id").empty()); + } +} + +UTEST(Consumer, HeaderFieldStringConversionInvariants) { + // rabbitmq/src/urabbitmq/consumer_base_impl.cpp:112 + AMQP::Array array_field; + array_field.push_back(AMQP::LongString{"arr-string"}); + array_field.push_back(AMQP::Long{123}); + array_field.push_back(AMQP::BooleanSet{true}); + + AMQP::Table nested_table; + nested_table.set("nested-string", "nested-value"); + nested_table.set("nested-int", 7); + + AMQP::Table headers; + headers.set("string", "value"); + headers.set("empty-string", ""); + headers.set("bool-true", true); + headers.set("bool-false", false); + headers.set("uint8", static_cast(255)); + headers.set("int8", static_cast(-100)); + headers.set("uint16", static_cast(65000)); + headers.set("int16", static_cast(-30000)); + headers.set("uint32", static_cast(4000000000U)); + headers.set("int32", static_cast(-2000000000)); + headers.set("uint64", static_cast(9000000000000000000ULL)); + headers.set("int64", static_cast(-9000000000000000000LL)); + headers.set("float", AMQP::Float{3.14f}); + headers.set("double", AMQP::Double{2.718281828}); + headers.set("decimal", AMQP::DecimalField{2, 12345}); + headers.set("void", nullptr); + headers.set("array", array_field); + headers.set("table", nested_table); + + for (const auto &key : headers.keys()) { + EXPECT_NO_THROW({ + [[maybe_unused]] const auto value = std::string(headers.get(key)); + }) << "Failed for key: " + << key; + } +} + USERVER_NAMESPACE_END diff --git a/rabbitmq/src/urabbitmq/consumer_base_impl.cpp b/rabbitmq/src/urabbitmq/consumer_base_impl.cpp index ef348902c4d3..1d5723b9b724 100644 --- a/rabbitmq/src/urabbitmq/consumer_base_impl.cpp +++ b/rabbitmq/src/urabbitmq/consumer_base_impl.cpp @@ -1,6 +1,5 @@ #include "consumer_base_impl.hpp" -#include #include #include @@ -94,24 +93,26 @@ void ConsumerBaseImpl::Stop() { bool ConsumerBaseImpl::IsBroken() const { return broken_ || !connection_ptr_.IsUsable(); } void ConsumerBaseImpl::OnMessage(const AMQP::Message& message, uint64_t delivery_tag) { - std::string span_name{fmt::format("consume_{}_{}", queue_name_, consumer_tag_.value_or("ctag:unknown"))}; - std::string trace_id = message.headers().get("u-trace-id"); - std::string parent_span_id = message.headers().get("u-parent-span-id"); - ConsumedMessage consumed; - consumed.message = std::string(message.body(), message.bodySize()); - consumed.metadata.exchange = message.exchange(); - consumed.metadata.routingKey = message.routingkey(); - if (message.hasReplyTo()) { - consumed.reply_to = message.replyTo(); - } + const auto &headers = message.headers(); + std::string span_name{fmt::format("consume_{}_{}", queue_name_, + consumer_tag_.value_or("ctag:unknown"))}; + std::string trace_id = headers.get("u-trace-id"); + std::string parent_span_id = headers.get("u-parent-span-id"); + ConsumedMessage consumed; + consumed.message = std::string(message.body(), message.bodySize()); + consumed.metadata.exchange = message.exchange(); + consumed.metadata.routingKey = message.routingkey(); + if (message.hasReplyTo()) { + consumed.reply_to = message.replyTo(); + } if (message.hasCorrelationID()) { - consumed.correlation_id = message.correlationID(); + consumed.correlation_id = message.correlationID(); } - const auto& headers = message.headers(); - for (const auto& key : headers.keys()) { - std::ostringstream stream; - stream << headers.get(key); - consumed.headers.emplace(key, stream.str()); + + const auto keys = headers.keys(); + consumed.headers.reserve(keys.size()); + for (const auto &key : keys) { + consumed.headers.emplace(key, std::string(headers.get(key))); } bts_.Detach(engine::AsyncNoSpan( From d67a42ef6ea20e606d7886de1648baf5f32bd342 Mon Sep 17 00:00:00 2001 From: sav-da Date: Fri, 27 Feb 2026 15:51:04 +0000 Subject: [PATCH 03/10] fix: fixes after review --- .../include/userver/urabbitmq/typedefs.hpp | 8 +- .../src/tests/publish_consume_rmqtest.cpp | 330 ++++++++---------- rabbitmq/src/urabbitmq/consumer_base_impl.cpp | 32 +- .../impl/amqp_connection_handler.cpp | 30 +- 4 files changed, 184 insertions(+), 216 deletions(-) diff --git a/rabbitmq/include/userver/urabbitmq/typedefs.hpp b/rabbitmq/include/userver/urabbitmq/typedefs.hpp index b8d1efc26c0e..ecb6cf588618 100644 --- a/rabbitmq/include/userver/urabbitmq/typedefs.hpp +++ b/rabbitmq/include/userver/urabbitmq/typedefs.hpp @@ -68,8 +68,6 @@ enum class MessageType { /// metadata fields. This struct is used to pass messages to the end user, /// hiding the actual AMQP message object implementation. struct ConsumedMessage { - using Headers = std::unordered_map; - struct Metadata { std::string exchange; std::string routingKey; @@ -78,21 +76,19 @@ struct ConsumedMessage { Metadata metadata; std::optional reply_to{}; std::optional correlation_id{}; - Headers headers{}; + std::unordered_map headers{}; }; /// @brief Structure holding an AMQP message body along with some of its /// metadata fields. This struct is used to pass messages from the end user, /// hiding the actual AMQP message object implementation. struct Envelope { - using Headers = std::unordered_map; - std::string message; MessageType type; std::optional reply_to{}; std::optional correlation_id{}; std::optional expiration{}; - std::optional headers{}; + std::optional> headers{}; }; } // namespace urabbitmq diff --git a/rabbitmq/src/tests/publish_consume_rmqtest.cpp b/rabbitmq/src/tests/publish_consume_rmqtest.cpp index cac2772257f6..a0c8f41837dc 100644 --- a/rabbitmq/src/tests/publish_consume_rmqtest.cpp +++ b/rabbitmq/src/tests/publish_consume_rmqtest.cpp @@ -21,9 +21,14 @@ class Consumer final : public urabbitmq::ConsumerBase { using urabbitmq::ConsumerBase::ConsumerBase; ~Consumer() override { Stop(); } - void Process(std::string message) override { + void Process(urabbitmq::ConsumedMessage message) override { + const auto plain_message = message.message; { auto locked = messages_.Lock(); + locked->emplace_back(plain_message); + } + { + auto locked = messages_with_metadata_.Lock(); locked->emplace_back(std::move(message)); } @@ -34,22 +39,25 @@ class Consumer final : public urabbitmq::ConsumerBase { void ExpectConsume(size_t count) { expected_consumed_ = count; } - std::vector Wait() { + void Wait() { if (expected_consumed_ != 0) { [[maybe_unused]] auto res = event_.WaitForEventFor(utest::kMaxTestWaitTime); } - - return Get(); } - std::vector Get() { + std::vector GetMessages() { auto locked = messages_.Lock(); return *locked; } + std::vector GetMessagesWithMetadata() { + auto locked = messages_with_metadata_.Lock(); + return *locked; + } + private: - std::optional single_expected_message_; concurrent::Variable> messages_; + concurrent::Variable> messages_with_metadata_; std::atomic expected_consumed_{0}; std::atomic consumed_{0}; engine::SingleConsumerEvent event_; @@ -81,41 +89,6 @@ class ThrowingConsumer final : public urabbitmq::ConsumerBase { engine::ConditionVariable cond_; }; -class MetadataConsumer final : public urabbitmq::ConsumerBase { -public: - using urabbitmq::ConsumerBase::ConsumerBase; - ~MetadataConsumer() override { Stop(); } - - void Process(urabbitmq::ConsumedMessage message) override { - { - auto locked = messages_.Lock(); - locked->emplace_back(std::move(message)); - } - - if (++consumed_ == expected_consumed_) { - event_.Send(); - } - } - - void ExpectConsume(size_t count) { expected_consumed_ = count; } - - std::vector Wait() { - if (expected_consumed_ != 0) { - [[maybe_unused]] auto res = - event_.WaitForEventFor(utest::kMaxTestWaitTime); - } - - auto locked = messages_.Lock(); - return *locked; - } - -private: - concurrent::Variable> messages_; - std::atomic expected_consumed_{0}; - std::atomic consumed_{0}; - engine::SingleConsumerEvent event_; -}; - } // namespace UTEST(Consumer, CreateOnInvalidQueueWorks) { @@ -146,7 +119,8 @@ UTEST(Consumer, ConsumeWorks) { consumer.ExpectConsume(1); consumer.Start(); - auto consumed = consumer.Wait(); + consumer.Wait(); + auto consumed = consumer.GetMessages(); ASSERT_EQ(consumed.size(), 1); EXPECT_EQ(consumed[0], envelope.message); @@ -205,12 +179,13 @@ UTEST(Consumer, ThrowsReturnsToQueue) { good_consumer.Start(); engine::InterruptibleSleepFor(std::chrono::milliseconds{200}); - auto consumed = good_consumer.Get(); + auto consumed = good_consumer.GetMessages(); EXPECT_LT(consumed.size(), messages_count); throwing_consumer.Throw(); throwing_consumer.Stop(); - EXPECT_EQ(good_consumer.Wait().size(), messages_count); + good_consumer.Wait(); + EXPECT_EQ(good_consumer.GetMessages().size(), messages_count); } UTEST(Consumer, MultipleConcurrentWork) { @@ -231,8 +206,8 @@ UTEST(Consumer, MultipleConcurrentWork) { second_consumer.Start(); engine::InterruptibleSleepFor(std::chrono::milliseconds{200}); - EXPECT_GT(first_consumer.Get().size(), 0); - EXPECT_GT(second_consumer.Get().size(), 0); + EXPECT_GT(first_consumer.GetMessages().size(), 0); + EXPECT_GT(second_consumer.GetMessages().size(), 0); } UTEST(Consumer, ForDifferentQueuesWork) { @@ -259,151 +234,146 @@ UTEST(Consumer, ForDifferentQueuesWork) { second_consumer.ExpectConsume(messages_count); second_consumer.Start(); - EXPECT_EQ(first_consumer.Wait().size(), messages_count); - EXPECT_EQ(second_consumer.Wait().size(), messages_count); + first_consumer.Wait(); + second_consumer.Wait(); + EXPECT_EQ(first_consumer.GetMessages().size(), messages_count); + EXPECT_EQ(second_consumer.GetMessages().size(), messages_count); client->GetAdminChannel(client.GetDeadline()).RemoveQueue(second_queue, client.GetDeadline()); } UTEST(Consumer, ConsumeMetadataAndHeadersWork) { - ClientWrapper client{}; - client.SetupRmqEntities(); - const urabbitmq::ConsumerSettings settings{client.GetQueue(), 10}; - - struct Case { - std::string name; - std::optional reply_to; - std::optional correlation_id; - urabbitmq::Envelope::Headers headers; - }; - - const std::vector cases{ - {"no-user-headers", std::nullopt, std::nullopt, {}}, - { - "simple-user-headers", - "reply-queue", - "corr-id", - { - {"x-custom-header", "custom-value"}, - {"x-custom-int", "42"}, - }, - }, - { - "many-user-headers", - "reply-many", - "corr-many", - { - {"x-empty", ""}, - {"x-spaces", "a b c"}, - {"x-symbols", R"(!@#$%^&*()[]{}<>/?\\|;:'\",.~-_=+)"}, - {"x-long", std::string(128, 'x')}, - }, - }, - { - "trace-headers-override", - "reply-override", - "corr-override", - { - {"u-trace-id", "trace-from-user"}, - {"u-parent-span-id", "parent-from-user"}, - {"x-another", "value"}, - }, - }, - }; - - for (const auto &case_data : cases) { - urabbitmq::Envelope envelope{ - "payload-" + case_data.name, - urabbitmq::MessageType::kTransient, + ClientWrapper client{}; + client.SetupRmqEntities(); + const urabbitmq::ConsumerSettings settings{client.GetQueue(), 10}; + + struct Case { + std::string name; + std::optional reply_to; + std::optional correlation_id; + std::unordered_map headers; }; - envelope.reply_to = case_data.reply_to; - envelope.correlation_id = case_data.correlation_id; - envelope.headers = case_data.headers; - client->PublishReliable(client.GetExchange(), client.GetRoutingKey(), - envelope, client.GetDeadline()); - } - - MetadataConsumer consumer{client.Get(), settings}; - consumer.ExpectConsume(cases.size()); - consumer.Start(); - auto consumed = consumer.Wait(); - - ASSERT_EQ(consumed.size(), cases.size()); - std::unordered_map - consumed_by_payload; - consumed_by_payload.reserve(consumed.size()); - for (const auto &msg : consumed) { - consumed_by_payload.emplace(msg.message, &msg); - } - - for (const auto &case_data : cases) { - const auto payload = "payload-" + case_data.name; - const auto it = consumed_by_payload.find(payload); - ASSERT_NE(it, consumed_by_payload.end()) - << "Missing consumed payload: " << payload; - - const auto &msg = *it->second; - EXPECT_EQ(msg.message, payload); - EXPECT_EQ(msg.metadata.exchange, client.GetExchange().GetUnderlying()); - EXPECT_EQ(msg.metadata.routingKey, client.GetRoutingKey()); - EXPECT_EQ(msg.reply_to, case_data.reply_to); - EXPECT_EQ(msg.correlation_id, case_data.correlation_id); - - for (const auto &[header_key, header_value] : case_data.headers) { - ASSERT_EQ(msg.headers.count(header_key), 1) - << "Missing header '" << header_key << "' in " << payload; - const auto &actual = msg.headers.at(header_key); - EXPECT_NE(actual.find(header_value), std::string::npos) - << "Unexpected value for header '" << header_key << "' in " << payload - << ": " << actual; + + const std::vector cases{ + {"no-user-headers", std::nullopt, std::nullopt, {}}, + { + "simple-user-headers", + "reply-queue", + "corr-id", + { + {"x-custom-header", "custom-value"}, + {"x-custom-int", "42"}, + }, + }, + { + "many-user-headers", + "reply-many", + "corr-many", + { + {"x-empty", ""}, + {"x-spaces", "a b c"}, + {"x-symbols", R"(!@#$%^&*()[]{}<>/?\\|;:'\",.~-_=+)"}, + {"x-long", std::string(128, 'x')}, + }, + }, + { + "trace-headers-override", + "reply-override", + "corr-override", + { + {"u-trace-id", "trace-from-user"}, + {"u-parent-span-id", "parent-from-user"}, + {"x-another", "value"}, + }, + }, + }; + + for (const auto& case_data : cases) { + urabbitmq::Envelope envelope{ + "payload-" + case_data.name, + urabbitmq::MessageType::kTransient, + }; + envelope.reply_to = case_data.reply_to; + envelope.correlation_id = case_data.correlation_id; + envelope.headers = case_data.headers; + client->PublishReliable(client.GetExchange(), client.GetRoutingKey(), envelope, client.GetDeadline()); } - ASSERT_EQ(msg.headers.count("u-trace-id"), 1) - << "Missing u-trace-id in " << payload; - ASSERT_EQ(msg.headers.count("u-parent-span-id"), 1) - << "Missing u-parent-span-id in " << payload; - EXPECT_FALSE(msg.headers.at("u-trace-id").empty()); - EXPECT_FALSE(msg.headers.at("u-parent-span-id").empty()); - } + Consumer consumer{client.Get(), settings}; + consumer.ExpectConsume(cases.size()); + consumer.Start(); + consumer.Wait(); + auto consumed = consumer.GetMessagesWithMetadata(); + + ASSERT_EQ(consumed.size(), cases.size()); + std::unordered_map consumed_by_payload; + consumed_by_payload.reserve(consumed.size()); + for (const auto& msg : consumed) { + consumed_by_payload.emplace(msg.message, &msg); + } + + for (const auto& case_data : cases) { + const auto payload = "payload-" + case_data.name; + const auto it = consumed_by_payload.find(payload); + ASSERT_NE(it, consumed_by_payload.end()) << "Missing consumed payload: " << payload; + + const auto& msg = *it->second; + EXPECT_EQ(msg.message, payload); + EXPECT_EQ(msg.metadata.exchange, client.GetExchange().GetUnderlying()); + EXPECT_EQ(msg.metadata.routingKey, client.GetRoutingKey()); + EXPECT_EQ(msg.reply_to, case_data.reply_to); + EXPECT_EQ(msg.correlation_id, case_data.correlation_id); + + for (const auto& [header_key, header_value] : case_data.headers) { + ASSERT_EQ(msg.headers.count(header_key), 1) << "Missing header '" << header_key << "' in " << payload; + const auto& actual = msg.headers.at(header_key); + EXPECT_NE(actual.find(header_value), std::string::npos) + << "Unexpected value for header '" << header_key << "' in " << payload << ": " << actual; + } + + ASSERT_EQ(msg.headers.count("u-trace-id"), 1) << "Missing u-trace-id in " << payload; + ASSERT_EQ(msg.headers.count("u-parent-span-id"), 1) << "Missing u-parent-span-id in " << payload; + EXPECT_FALSE(msg.headers.at("u-trace-id").empty()); + EXPECT_FALSE(msg.headers.at("u-parent-span-id").empty()); + } } UTEST(Consumer, HeaderFieldStringConversionInvariants) { - // rabbitmq/src/urabbitmq/consumer_base_impl.cpp:112 - AMQP::Array array_field; - array_field.push_back(AMQP::LongString{"arr-string"}); - array_field.push_back(AMQP::Long{123}); - array_field.push_back(AMQP::BooleanSet{true}); - - AMQP::Table nested_table; - nested_table.set("nested-string", "nested-value"); - nested_table.set("nested-int", 7); - - AMQP::Table headers; - headers.set("string", "value"); - headers.set("empty-string", ""); - headers.set("bool-true", true); - headers.set("bool-false", false); - headers.set("uint8", static_cast(255)); - headers.set("int8", static_cast(-100)); - headers.set("uint16", static_cast(65000)); - headers.set("int16", static_cast(-30000)); - headers.set("uint32", static_cast(4000000000U)); - headers.set("int32", static_cast(-2000000000)); - headers.set("uint64", static_cast(9000000000000000000ULL)); - headers.set("int64", static_cast(-9000000000000000000LL)); - headers.set("float", AMQP::Float{3.14f}); - headers.set("double", AMQP::Double{2.718281828}); - headers.set("decimal", AMQP::DecimalField{2, 12345}); - headers.set("void", nullptr); - headers.set("array", array_field); - headers.set("table", nested_table); - - for (const auto &key : headers.keys()) { - EXPECT_NO_THROW({ - [[maybe_unused]] const auto value = std::string(headers.get(key)); - }) << "Failed for key: " - << key; - } + // rabbitmq/src/urabbitmq/consumer_base_impl.cpp:112 + AMQP::Array array_field; + array_field.push_back(AMQP::LongString{"arr-string"}); + array_field.push_back(AMQP::Long{123}); + array_field.push_back(AMQP::BooleanSet{true}); + + AMQP::Table nested_table; + nested_table.set("nested-string", "nested-value"); + nested_table.set("nested-int", 7); + + AMQP::Table headers; + headers.set("string", "value"); + headers.set("empty-string", ""); + headers.set("bool-true", true); + headers.set("bool-false", false); + headers.set("uint8", static_cast(255)); + headers.set("int8", static_cast(-100)); + headers.set("uint16", static_cast(65000)); + headers.set("int16", static_cast(-30000)); + headers.set("uint32", std::numeric_limits::max()); + headers.set("int32", std::numeric_limits::min()); + headers.set("uint64", std::numeric_limits::max()); + headers.set("int64", std::numeric_limits::min()); + headers.set("float", AMQP::Float{3.14f}); + headers.set("double", AMQP::Double{2.718281828}); + headers.set("decimal", AMQP::DecimalField{2, 12345}); + headers.set("void", nullptr); + headers.set("array", array_field); + headers.set("table", nested_table); + + for (const auto& key : headers.keys()) { + EXPECT_NO_THROW({ [[maybe_unused]] const auto value = std::string(headers.get(key)); } + ) << "Failed for key: " + << key; + } } USERVER_NAMESPACE_END diff --git a/rabbitmq/src/urabbitmq/consumer_base_impl.cpp b/rabbitmq/src/urabbitmq/consumer_base_impl.cpp index 1d5723b9b724..96723823500f 100644 --- a/rabbitmq/src/urabbitmq/consumer_base_impl.cpp +++ b/rabbitmq/src/urabbitmq/consumer_base_impl.cpp @@ -28,8 +28,7 @@ ConsumerBaseImpl::ConsumerBaseImpl(ConnectionPtr&& connection, const ConsumerSet queue_name_{settings.queue.GetUnderlying()}, prefetch_count_{settings.prefetch_count}, connection_ptr_{std::move(connection)}, - channel_{connection_ptr_->GetChannel()} -{ + channel_{connection_ptr_->GetChannel()} { // We take ownership of the connection, because if it remains pooled // things get messy with lifetimes and callbacks connection_ptr_.Adopt(); @@ -93,26 +92,25 @@ void ConsumerBaseImpl::Stop() { bool ConsumerBaseImpl::IsBroken() const { return broken_ || !connection_ptr_.IsUsable(); } void ConsumerBaseImpl::OnMessage(const AMQP::Message& message, uint64_t delivery_tag) { - const auto &headers = message.headers(); - std::string span_name{fmt::format("consume_{}_{}", queue_name_, - consumer_tag_.value_or("ctag:unknown"))}; - std::string trace_id = headers.get("u-trace-id"); - std::string parent_span_id = headers.get("u-parent-span-id"); - ConsumedMessage consumed; - consumed.message = std::string(message.body(), message.bodySize()); - consumed.metadata.exchange = message.exchange(); - consumed.metadata.routingKey = message.routingkey(); - if (message.hasReplyTo()) { - consumed.reply_to = message.replyTo(); - } + const auto& headers = message.headers(); + std::string span_name{fmt::format("consume_{}_{}", queue_name_, consumer_tag_.value_or("ctag:unknown"))}; + std::string trace_id = headers.get("u-trace-id"); + std::string parent_span_id = headers.get("u-parent-span-id"); + ConsumedMessage consumed; + consumed.message = std::string(message.body(), message.bodySize()); + consumed.metadata.exchange = message.exchange(); + consumed.metadata.routingKey = message.routingkey(); + if (message.hasReplyTo()) { + consumed.reply_to = message.replyTo(); + } if (message.hasCorrelationID()) { - consumed.correlation_id = message.correlationID(); + consumed.correlation_id = message.correlationID(); } const auto keys = headers.keys(); consumed.headers.reserve(keys.size()); - for (const auto &key : keys) { - consumed.headers.emplace(key, std::string(headers.get(key))); + for (const auto& key : keys) { + consumed.headers.emplace(key, std::string(headers.get(key))); } bts_.Detach(engine::AsyncNoSpan( diff --git a/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp b/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp index 68ddded55c44..70c9610c70f4 100644 --- a/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp +++ b/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp @@ -1,9 +1,10 @@ #include "amqp_connection_handler.hpp" #include -#include + #include #include +#include #include #include @@ -21,6 +22,7 @@ USERVER_NAMESPACE_BEGIN namespace urabbitmq::impl { namespace { + constexpr std::chrono::milliseconds kHeartbeatSendTimeout{200}; engine::io::Socket CreateSocket(engine::io::Sockaddr& addr, engine::Deadline deadline) { @@ -95,7 +97,7 @@ AmqpConnectionHandler::AmqpConnectionHandler( clients::dns::Resolver& resolver, const EndpointInfo& endpoint, const AuthSettings& auth_settings, - size_t heartbeat_interval_seconds, + std::size_t heartbeat_interval_seconds, bool secure, statistics::ConnectionStatistics& stats, engine::Deadline deadline @@ -103,10 +105,9 @@ AmqpConnectionHandler::AmqpConnectionHandler( : address_{ToAmqpAddress(endpoint, auth_settings, secure)}, socket_{CreateSocketPtr(resolver, address_, auth_settings, deadline)}, reader_{*this, *socket_}, - configured_heartbeat_seconds_{ - static_cast(std::min(heartbeat_interval_seconds, std::numeric_limits::max()))}, - stats_{stats} -{} + configured_heartbeat_seconds_{static_cast< + std::uint16_t>(std::min(heartbeat_interval_seconds, std::numeric_limits::max()))}, + stats_{stats} {} AmqpConnectionHandler::~AmqpConnectionHandler() { heartbeat_task_.Stop(); @@ -120,11 +121,6 @@ void AmqpConnectionHandler::onProperties(AMQP::Connection*, const AMQP::Table&, } uint16_t AmqpConnectionHandler::onNegotiate(AMQP::Connection*, uint16_t interval) { - if (interval == 0 || configured_heartbeat_seconds_ == 0) { - negotiated_heartbeat_seconds_.store(0, std::memory_order_relaxed); - return 0; - } - const auto negotiated = static_cast(std::min(interval, configured_heartbeat_seconds_)); negotiated_heartbeat_seconds_.store(negotiated, std::memory_order_relaxed); LOG_INFO() << "RabbitMQ heartbeat negotiated at " << negotiated << "s"; @@ -187,6 +183,7 @@ void AmqpConnectionHandler::OnConnectionCreated(AmqpConnection* connection, engi if (!connection_ready_event_.WaitForEventUntil(deadline)) { reader_.Stop(); + connection_ = nullptr; throw ConnectionSetupTimeout{"Failed to setup a connection within specified deadline"}; } @@ -196,10 +193,17 @@ void AmqpConnectionHandler::OnConnectionCreated(AmqpConnection* connection, engi throw ConnectionSetupError{"Failed to setup a connection: " + *error_}; } + using namespace std::chrono_literals; + const auto heartbeat_seconds = negotiated_heartbeat_seconds_.load(std::memory_order_relaxed); if (heartbeat_seconds > 0) { - const auto heartbeat_period = std::chrono::seconds{std::max(1, heartbeat_seconds / 2)}; - heartbeat_task_.Start("amqp_heartbeat", {heartbeat_period}, [this] { SendHeartbeat(); }); + const auto half_interval = + std::chrono::duration_cast(std::chrono::seconds{heartbeat_seconds}) / 2; + const auto heartbeat_period = std::max(500ms, half_interval); + + heartbeat_task_.Start("amqp_heartbeat", {heartbeat_period, utils::PeriodicTask::Flags::kNow}, [this] { + SendHeartbeat(); + }); } } From 5a1de2b16a98548ba3f64d612b638d91afbe683f Mon Sep 17 00:00:00 2001 From: sav-da Date: Mon, 2 Mar 2026 11:39:23 +0000 Subject: [PATCH 04/10] fix: fixes after review --- rabbitmq/src/urabbitmq/component.yaml | 2 +- .../src/urabbitmq/impl/amqp_connection_handler.cpp | 11 ++++------- .../src/urabbitmq/impl/amqp_connection_handler.hpp | 2 +- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/rabbitmq/src/urabbitmq/component.yaml b/rabbitmq/src/urabbitmq/component.yaml index 757c5500a6c2..998116d3d320 100644 --- a/rabbitmq/src/urabbitmq/component.yaml +++ b/rabbitmq/src/urabbitmq/component.yaml @@ -27,7 +27,7 @@ properties: type: integer description: | requested AMQP heartbeat interval in seconds; 0 disables heartbeats - default: 30 + default: 60 use_secure_connection: type: boolean description: whether to use TLS for connections diff --git a/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp b/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp index 70c9610c70f4..515ec21fd0c2 100644 --- a/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp +++ b/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp @@ -23,8 +23,6 @@ namespace urabbitmq::impl { namespace { -constexpr std::chrono::milliseconds kHeartbeatSendTimeout{200}; - engine::io::Socket CreateSocket(engine::io::Sockaddr& addr, engine::Deadline deadline) { engine::io::Socket socket{addr.Domain(), engine::io::SocketType::kTcp}; socket.SetOption(IPPROTO_TCP, TCP_NODELAY, 1); @@ -106,7 +104,7 @@ AmqpConnectionHandler::AmqpConnectionHandler( socket_{CreateSocketPtr(resolver, address_, auth_settings, deadline)}, reader_{*this, *socket_}, configured_heartbeat_seconds_{static_cast< - std::uint16_t>(std::min(heartbeat_interval_seconds, std::numeric_limits::max()))}, + std::uint16_t>(std::min(heartbeat_interval_seconds, std::numeric_limits::max()))}, stats_{stats} {} AmqpConnectionHandler::~AmqpConnectionHandler() { @@ -121,7 +119,7 @@ void AmqpConnectionHandler::onProperties(AMQP::Connection*, const AMQP::Table&, } uint16_t AmqpConnectionHandler::onNegotiate(AMQP::Connection*, uint16_t interval) { - const auto negotiated = static_cast(std::min(interval, configured_heartbeat_seconds_)); + const auto negotiated = std::min(interval, configured_heartbeat_seconds_); negotiated_heartbeat_seconds_.store(negotiated, std::memory_order_relaxed); LOG_INFO() << "RabbitMQ heartbeat negotiated at " << negotiated << "s"; return negotiated; @@ -199,9 +197,8 @@ void AmqpConnectionHandler::OnConnectionCreated(AmqpConnection* connection, engi if (heartbeat_seconds > 0) { const auto half_interval = std::chrono::duration_cast(std::chrono::seconds{heartbeat_seconds}) / 2; - const auto heartbeat_period = std::max(500ms, half_interval); - heartbeat_task_.Start("amqp_heartbeat", {heartbeat_period, utils::PeriodicTask::Flags::kNow}, [this] { + heartbeat_task_.Start("amqp_heartbeat", {half_interval, utils::PeriodicTask::Flags::kNow}, [this] { SendHeartbeat(); }); } @@ -233,7 +230,7 @@ void AmqpConnectionHandler::SendHeartbeat() { } try { - const auto deadline = engine::Deadline::FromDuration(kHeartbeatSendTimeout); + const auto deadline = engine::Deadline::FromDuration(std::chrono::seconds{configured_heartbeat_seconds_} / 2); auto lock = AmqpConnectionLocker{*connection_}.Lock(deadline); connection_->SetOperationDeadline(deadline); connection_->GetNative().heartbeat(); diff --git a/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.hpp b/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.hpp index 5f8af1c347a1..30167e9cf16a 100644 --- a/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.hpp +++ b/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.hpp @@ -89,7 +89,7 @@ class AmqpConnectionHandler final : public AMQP::ConnectionHandler { utils::PeriodicTask heartbeat_task_; AmqpConnection* connection_{nullptr}; std::atomic negotiated_heartbeat_seconds_{0}; - uint16_t configured_heartbeat_seconds_{0}; + std::uint16_t configured_heartbeat_seconds_{0}; engine::SingleConsumerEvent connection_ready_event_; std::atomic broken_{false}; From dd78f5fc4d34b783f38cd4c3d32bc1e62b79ea4f Mon Sep 17 00:00:00 2001 From: sav-da Date: Tue, 3 Mar 2026 22:00:41 +0000 Subject: [PATCH 05/10] fix: fixes after review --- rabbitmq/src/tests/publish_consume_rmqtest.cpp | 16 ++++++++-------- rabbitmq/src/urabbitmq/client_settings.cpp | 17 ++++++++++------- .../urabbitmq/impl/amqp_connection_handler.cpp | 6 ++++-- .../urabbitmq/impl/amqp_connection_handler.hpp | 3 ++- 4 files changed, 24 insertions(+), 18 deletions(-) diff --git a/rabbitmq/src/tests/publish_consume_rmqtest.cpp b/rabbitmq/src/tests/publish_consume_rmqtest.cpp index a0c8f41837dc..780badf2a60e 100644 --- a/rabbitmq/src/tests/publish_consume_rmqtest.cpp +++ b/rabbitmq/src/tests/publish_consume_rmqtest.cpp @@ -354,14 +354,14 @@ UTEST(Consumer, HeaderFieldStringConversionInvariants) { headers.set("empty-string", ""); headers.set("bool-true", true); headers.set("bool-false", false); - headers.set("uint8", static_cast(255)); - headers.set("int8", static_cast(-100)); - headers.set("uint16", static_cast(65000)); - headers.set("int16", static_cast(-30000)); - headers.set("uint32", std::numeric_limits::max()); - headers.set("int32", std::numeric_limits::min()); - headers.set("uint64", std::numeric_limits::max()); - headers.set("int64", std::numeric_limits::min()); + headers.set("uint8", static_cast(255)); + headers.set("int8", static_cast(-100)); + headers.set("uint16", static_cast(65000)); + headers.set("int16", static_cast(-30000)); + headers.set("uint32", std::numeric_limits::max()); + headers.set("int32", std::numeric_limits::min()); + headers.set("uint64", std::numeric_limits::max()); + headers.set("int64", std::numeric_limits::min()); headers.set("float", AMQP::Float{3.14f}); headers.set("double", AMQP::Double{2.718281828}); headers.set("decimal", AMQP::DecimalField{2, 12345}); diff --git a/rabbitmq/src/urabbitmq/client_settings.cpp b/rabbitmq/src/urabbitmq/client_settings.cpp index ba7682449243..ede927fb60e4 100644 --- a/rabbitmq/src/urabbitmq/client_settings.cpp +++ b/rabbitmq/src/urabbitmq/client_settings.cpp @@ -1,7 +1,8 @@ #include -#include +#include #include +#include #include #include #include @@ -101,13 +102,16 @@ PoolSettings Parse(const yaml_config::YamlConfig& config, formats::parse::To(result.min_pool_size); result.max_pool_size = config["max_pool_size"].As(result.max_pool_size); result.max_in_flight_requests = config["max_in_flight_requests"].As(result.max_in_flight_requests); - result.heartbeat_interval_seconds = - config["heartbeat_interval_seconds"].As(result.heartbeat_interval_seconds); + result + .heartbeat_interval_seconds = config["heartbeat_interval_seconds"].As(result.heartbeat_interval_seconds + ); UINVARIANT(result.min_pool_size <= result.max_pool_size, "max_pool_size is less than min_pool_size"); UINVARIANT(result.max_pool_size > 0, "max_pool_size is set to zero"); - UINVARIANT(result.heartbeat_interval_seconds <= std::numeric_limits::max(), - "heartbeat_interval_seconds is too large"); + UINVARIANT( + result.heartbeat_interval_seconds <= std::numeric_limits::max(), + "heartbeat_interval_seconds is too large" + ); return result; } @@ -117,8 +121,7 @@ ClientSettings::ClientSettings() = default; ClientSettings::ClientSettings(const components::ComponentConfig& config, const RabbitEndpoints& rabbit_endpoints) : pool_settings{config.As()}, endpoints{rabbit_endpoints}, - use_secure_connection{config["use_secure_connection"].As(true)} -{} + use_secure_connection{config["use_secure_connection"].As(true)} {} RabbitEndpointsMulti::RabbitEndpointsMulti(const formats::json::Value& doc) { const auto rabbitmq_settings = doc["rabbitmq_settings"]; diff --git a/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp b/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp index 515ec21fd0c2..96a7a0db6c2a 100644 --- a/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp +++ b/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp @@ -1,10 +1,10 @@ #include "amqp_connection_handler.hpp" #include +#include #include #include -#include #include #include @@ -230,7 +230,9 @@ void AmqpConnectionHandler::SendHeartbeat() { } try { - const auto deadline = engine::Deadline::FromDuration(std::chrono::seconds{configured_heartbeat_seconds_} / 2); + const auto deadline = engine::Deadline::FromDuration(std::chrono::duration_cast( + std::chrono::seconds{configured_heartbeat_seconds_} / 2.0 + )); auto lock = AmqpConnectionLocker{*connection_}.Lock(deadline); connection_->SetOperationDeadline(deadline); connection_->GetNative().heartbeat(); diff --git a/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.hpp b/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.hpp index 30167e9cf16a..c9f87c541833 100644 --- a/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.hpp +++ b/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -88,7 +89,7 @@ class AmqpConnectionHandler final : public AMQP::ConnectionHandler { io::SocketReader reader_; utils::PeriodicTask heartbeat_task_; AmqpConnection* connection_{nullptr}; - std::atomic negotiated_heartbeat_seconds_{0}; + std::atomic negotiated_heartbeat_seconds_{0}; std::uint16_t configured_heartbeat_seconds_{0}; engine::SingleConsumerEvent connection_ready_event_; From 7c384d8c2aec9d7153f95db7cebe67aeea53305b Mon Sep 17 00:00:00 2001 From: sav-da Date: Wed, 4 Mar 2026 11:41:05 +0000 Subject: [PATCH 06/10] refactor: full rework consume message headers --- .../userver/urabbitmq/client_settings.hpp | 2 +- .../src/tests/publish_consume_rmqtest.cpp | 113 +++++++++++++++--- rabbitmq/src/urabbitmq/consumer_base_impl.cpp | 39 +++++- 3 files changed, 136 insertions(+), 18 deletions(-) diff --git a/rabbitmq/include/userver/urabbitmq/client_settings.hpp b/rabbitmq/include/userver/urabbitmq/client_settings.hpp index 93683c5c4904..5ebee8b3d538 100644 --- a/rabbitmq/include/userver/urabbitmq/client_settings.hpp +++ b/rabbitmq/include/userver/urabbitmq/client_settings.hpp @@ -78,7 +78,7 @@ struct PoolSettings final { /// Requested AMQP heartbeat interval in seconds. /// Set to 0 to disable heartbeats. - size_t heartbeat_interval_seconds = 30; + size_t heartbeat_interval_seconds = 60; }; class TestsHelper; diff --git a/rabbitmq/src/tests/publish_consume_rmqtest.cpp b/rabbitmq/src/tests/publish_consume_rmqtest.cpp index 780badf2a60e..52d01fdfe807 100644 --- a/rabbitmq/src/tests/publish_consume_rmqtest.cpp +++ b/rabbitmq/src/tests/publish_consume_rmqtest.cpp @@ -10,8 +10,13 @@ #include #include #include +#include #include +#include +#include +#include + USERVER_NAMESPACE_BEGIN namespace { @@ -89,6 +94,60 @@ class ThrowingConsumer final : public urabbitmq::ConsumerBase { engine::ConditionVariable cond_; }; +class RawPublisher final { +public: + explicit RawPublisher(engine::Deadline deadline) + : resolver_{engine::current_task::GetTaskProcessor(), {}}, + settings_{urabbitmq::TestsHelper::CreateSettings()}, + handler_{ + resolver_, + settings_.endpoints.endpoints.front(), + settings_.endpoints.auth, + settings_.pool_settings.heartbeat_interval_seconds, + settings_.use_secure_connection, + stats_, + deadline, + }, + connection_{handler_, settings_.pool_settings.max_in_flight_requests, deadline} + {} + + void PublishReliable( + const urabbitmq::Exchange& exchange, + const std::string& routing_key, + std::string_view message, + const AMQP::Table& headers, + engine::Deadline deadline + ) { + AMQP::Envelope envelope{message.data(), message.size()}; + envelope.setHeaders(headers); + + std::optional error; + engine::SingleConsumerEvent published; + + auto reliable = connection_.GetReliableChannel(deadline); + reliable->publish(exchange.GetUnderlying(), routing_key, envelope) + .onAck([&published] { published.Send(); }) + .onError([&published, &error](const char* message) { + error = message; + published.Send(); + }); + + if (!published.WaitForEventFor(utest::kMaxTestWaitTime)) { + throw std::runtime_error{"Timed out waiting for publish ack"}; + } + if (error.has_value()) { + throw std::runtime_error{*error}; + } + } + +private: + clients::dns::Resolver resolver_; + const urabbitmq::ClientSettings settings_; + urabbitmq::statistics::ConnectionStatistics stats_; + urabbitmq::impl::AmqpConnectionHandler handler_; + urabbitmq::impl::AmqpConnection connection_; +}; + } // namespace UTEST(Consumer, CreateOnInvalidQueueWorks) { @@ -339,15 +398,9 @@ UTEST(Consumer, ConsumeMetadataAndHeadersWork) { } UTEST(Consumer, HeaderFieldStringConversionInvariants) { - // rabbitmq/src/urabbitmq/consumer_base_impl.cpp:112 - AMQP::Array array_field; - array_field.push_back(AMQP::LongString{"arr-string"}); - array_field.push_back(AMQP::Long{123}); - array_field.push_back(AMQP::BooleanSet{true}); - - AMQP::Table nested_table; - nested_table.set("nested-string", "nested-value"); - nested_table.set("nested-int", 7); + ClientWrapper client{}; + client.SetupRmqEntities(); + const urabbitmq::ConsumerSettings settings{client.GetQueue(), 10}; AMQP::Table headers; headers.set("string", "value"); @@ -364,15 +417,43 @@ UTEST(Consumer, HeaderFieldStringConversionInvariants) { headers.set("int64", std::numeric_limits::min()); headers.set("float", AMQP::Float{3.14f}); headers.set("double", AMQP::Double{2.718281828}); - headers.set("decimal", AMQP::DecimalField{2, 12345}); headers.set("void", nullptr); - headers.set("array", array_field); - headers.set("table", nested_table); - for (const auto& key : headers.keys()) { - EXPECT_NO_THROW({ [[maybe_unused]] const auto value = std::string(headers.get(key)); } - ) << "Failed for key: " - << key; + const std::unordered_map expected_values{ + {"string", "value"}, + {"empty-string", ""}, + {"bool-true", "true"}, + {"bool-false", "false"}, + {"uint8", "255"}, + {"int8", "-100"}, + {"uint16", "65000"}, + {"int16", "-30000"}, + {"uint32", "4294967295"}, + {"int32", "-2147483648"}, + {"uint64", "18446744073709551615"}, + {"int64", "-9223372036854775808"}, + {"float", "3.14"}, + {"double", "2.718281828"}, + {"void", ""}, + }; + + Consumer consumer{client.Get(), settings}; + consumer.ExpectConsume(1); + consumer.Start(); + + RawPublisher publisher{client.GetDeadline()}; + publisher.PublishReliable(client.GetExchange(), client.GetRoutingKey(), "payload-header-conversion", headers, client.GetDeadline()); + + consumer.Wait(); + const auto consumed = consumer.GetMessagesWithMetadata(); + + ASSERT_EQ(consumed.size(), 1); + EXPECT_EQ(consumed[0].message, "payload-header-conversion"); + ASSERT_EQ(consumed[0].headers.size(), expected_values.size()); + + for (const auto& [key, expected_value] : expected_values) { + ASSERT_EQ(consumed[0].headers.count(key), 1) << "Missing header: " << key; + EXPECT_EQ(consumed[0].headers.at(key), expected_value) << "Unexpected converted value for key: " << key; } } diff --git a/rabbitmq/src/urabbitmq/consumer_base_impl.cpp b/rabbitmq/src/urabbitmq/consumer_base_impl.cpp index 96723823500f..97248d190849 100644 --- a/rabbitmq/src/urabbitmq/consumer_base_impl.cpp +++ b/rabbitmq/src/urabbitmq/consumer_base_impl.cpp @@ -21,6 +21,43 @@ namespace { constexpr std::chrono::milliseconds kStartTimeout{2000}; +std::string ToString(const AMQP::Field& field) { + const auto format = [](const auto value) { return fmt::format("{}", value); }; + + // AMQP-CPP field type codes returned by AMQP::Field::typeID(): + // string: 's'/'S', bool: 't', integers: 'b'/'B'/'U'/'u'/'I'/'i'/'L'/'l', float/double: 'f'/'d'. + switch (field.typeID()) { + case 'S': + case 's': + return static_cast(field); + case 't': + return dynamic_cast(field).value() ? "true" : "false"; + case 'B': + return format(static_cast(static_cast(field))); + case 'b': + return format(static_cast(static_cast(field))); + case 'u': + return format(static_cast(field)); + case 'U': + return format(static_cast(field)); + case 'i': + return format(static_cast(field)); + case 'I': + return format(static_cast(field)); + case 'l': + case 'T': + return format(static_cast(field)); + case 'L': + return format(static_cast(field)); + case 'f': + return format(static_cast(field)); + case 'd': + return format(static_cast(field)); + default: + return std::string(field); + } +} + } // namespace ConsumerBaseImpl::ConsumerBaseImpl(ConnectionPtr&& connection, const ConsumerSettings& settings) @@ -110,7 +147,7 @@ void ConsumerBaseImpl::OnMessage(const AMQP::Message& message, uint64_t delivery const auto keys = headers.keys(); consumed.headers.reserve(keys.size()); for (const auto& key : keys) { - consumed.headers.emplace(key, std::string(headers.get(key))); + consumed.headers.emplace(key, ToString(headers.get(key))); } bts_.Detach(engine::AsyncNoSpan( From 47156b77e7b90973bba3d20c806caac693ea022a Mon Sep 17 00:00:00 2001 From: sav-da Date: Wed, 4 Mar 2026 11:52:54 +0000 Subject: [PATCH 07/10] refactor: extract AMQP field string conversion helper --- .../src/tests/field_to_string_rmqtest.cpp | 58 +++++++++ .../src/tests/publish_consume_rmqtest.cpp | 119 ------------------ rabbitmq/src/urabbitmq/consumer_base_impl.cpp | 40 +----- .../src/urabbitmq/impl/field_to_string.cpp | 51 ++++++++ .../src/urabbitmq/impl/field_to_string.hpp | 15 +++ 5 files changed, 126 insertions(+), 157 deletions(-) create mode 100644 rabbitmq/src/tests/field_to_string_rmqtest.cpp create mode 100644 rabbitmq/src/urabbitmq/impl/field_to_string.cpp create mode 100644 rabbitmq/src/urabbitmq/impl/field_to_string.hpp diff --git a/rabbitmq/src/tests/field_to_string_rmqtest.cpp b/rabbitmq/src/tests/field_to_string_rmqtest.cpp new file mode 100644 index 000000000000..5d0e68b8a669 --- /dev/null +++ b/rabbitmq/src/tests/field_to_string_rmqtest.cpp @@ -0,0 +1,58 @@ +#include + +#include +#include +#include +#include + +#include + +#include + +USERVER_NAMESPACE_BEGIN + +UTEST(FieldToString, BasicTypes) { + AMQP::Table headers; + headers.set("string", "value"); + headers.set("empty-string", ""); + headers.set("bool-true", true); + headers.set("bool-false", false); + headers.set("uint8", static_cast(255)); + headers.set("int8", static_cast(-100)); + headers.set("uint16", static_cast(65000)); + headers.set("int16", static_cast(-30000)); + headers.set("uint32", std::numeric_limits::max()); + headers.set("int32", std::numeric_limits::min()); + headers.set("uint64", std::numeric_limits::max()); + headers.set("int64", std::numeric_limits::min()); + headers.set("float", AMQP::Float{3.14f}); + headers.set("double", AMQP::Double{2.718281828}); + headers.set("void", nullptr); + + const std::unordered_map expected_values{ + {"string", "value"}, + {"empty-string", ""}, + {"bool-true", "true"}, + {"bool-false", "false"}, + {"uint8", "255"}, + {"int8", "-100"}, + {"uint16", "65000"}, + {"int16", "-30000"}, + {"uint32", "4294967295"}, + {"int32", "-2147483648"}, + {"uint64", "18446744073709551615"}, + {"int64", "-9223372036854775808"}, + {"float", "3.14"}, + {"double", "2.718281828"}, + {"void", ""}, + }; + + ASSERT_EQ(headers.keys().size(), expected_values.size()); + for (const auto& [key, expected_value] : expected_values) { + ASSERT_TRUE(headers.contains(key)) << "Missing header: " << key; + EXPECT_EQ(urabbitmq::impl::FieldToString(headers.get(key)), expected_value) + << "Unexpected converted value for key: " << key; + } +} + +USERVER_NAMESPACE_END diff --git a/rabbitmq/src/tests/publish_consume_rmqtest.cpp b/rabbitmq/src/tests/publish_consume_rmqtest.cpp index 52d01fdfe807..2ed810e82e40 100644 --- a/rabbitmq/src/tests/publish_consume_rmqtest.cpp +++ b/rabbitmq/src/tests/publish_consume_rmqtest.cpp @@ -10,13 +10,8 @@ #include #include #include -#include #include -#include -#include -#include - USERVER_NAMESPACE_BEGIN namespace { @@ -94,60 +89,6 @@ class ThrowingConsumer final : public urabbitmq::ConsumerBase { engine::ConditionVariable cond_; }; -class RawPublisher final { -public: - explicit RawPublisher(engine::Deadline deadline) - : resolver_{engine::current_task::GetTaskProcessor(), {}}, - settings_{urabbitmq::TestsHelper::CreateSettings()}, - handler_{ - resolver_, - settings_.endpoints.endpoints.front(), - settings_.endpoints.auth, - settings_.pool_settings.heartbeat_interval_seconds, - settings_.use_secure_connection, - stats_, - deadline, - }, - connection_{handler_, settings_.pool_settings.max_in_flight_requests, deadline} - {} - - void PublishReliable( - const urabbitmq::Exchange& exchange, - const std::string& routing_key, - std::string_view message, - const AMQP::Table& headers, - engine::Deadline deadline - ) { - AMQP::Envelope envelope{message.data(), message.size()}; - envelope.setHeaders(headers); - - std::optional error; - engine::SingleConsumerEvent published; - - auto reliable = connection_.GetReliableChannel(deadline); - reliable->publish(exchange.GetUnderlying(), routing_key, envelope) - .onAck([&published] { published.Send(); }) - .onError([&published, &error](const char* message) { - error = message; - published.Send(); - }); - - if (!published.WaitForEventFor(utest::kMaxTestWaitTime)) { - throw std::runtime_error{"Timed out waiting for publish ack"}; - } - if (error.has_value()) { - throw std::runtime_error{*error}; - } - } - -private: - clients::dns::Resolver resolver_; - const urabbitmq::ClientSettings settings_; - urabbitmq::statistics::ConnectionStatistics stats_; - urabbitmq::impl::AmqpConnectionHandler handler_; - urabbitmq::impl::AmqpConnection connection_; -}; - } // namespace UTEST(Consumer, CreateOnInvalidQueueWorks) { @@ -397,64 +338,4 @@ UTEST(Consumer, ConsumeMetadataAndHeadersWork) { } } -UTEST(Consumer, HeaderFieldStringConversionInvariants) { - ClientWrapper client{}; - client.SetupRmqEntities(); - const urabbitmq::ConsumerSettings settings{client.GetQueue(), 10}; - - AMQP::Table headers; - headers.set("string", "value"); - headers.set("empty-string", ""); - headers.set("bool-true", true); - headers.set("bool-false", false); - headers.set("uint8", static_cast(255)); - headers.set("int8", static_cast(-100)); - headers.set("uint16", static_cast(65000)); - headers.set("int16", static_cast(-30000)); - headers.set("uint32", std::numeric_limits::max()); - headers.set("int32", std::numeric_limits::min()); - headers.set("uint64", std::numeric_limits::max()); - headers.set("int64", std::numeric_limits::min()); - headers.set("float", AMQP::Float{3.14f}); - headers.set("double", AMQP::Double{2.718281828}); - headers.set("void", nullptr); - - const std::unordered_map expected_values{ - {"string", "value"}, - {"empty-string", ""}, - {"bool-true", "true"}, - {"bool-false", "false"}, - {"uint8", "255"}, - {"int8", "-100"}, - {"uint16", "65000"}, - {"int16", "-30000"}, - {"uint32", "4294967295"}, - {"int32", "-2147483648"}, - {"uint64", "18446744073709551615"}, - {"int64", "-9223372036854775808"}, - {"float", "3.14"}, - {"double", "2.718281828"}, - {"void", ""}, - }; - - Consumer consumer{client.Get(), settings}; - consumer.ExpectConsume(1); - consumer.Start(); - - RawPublisher publisher{client.GetDeadline()}; - publisher.PublishReliable(client.GetExchange(), client.GetRoutingKey(), "payload-header-conversion", headers, client.GetDeadline()); - - consumer.Wait(); - const auto consumed = consumer.GetMessagesWithMetadata(); - - ASSERT_EQ(consumed.size(), 1); - EXPECT_EQ(consumed[0].message, "payload-header-conversion"); - ASSERT_EQ(consumed[0].headers.size(), expected_values.size()); - - for (const auto& [key, expected_value] : expected_values) { - ASSERT_EQ(consumed[0].headers.count(key), 1) << "Missing header: " << key; - EXPECT_EQ(consumed[0].headers.at(key), expected_value) << "Unexpected converted value for key: " << key; - } -} - USERVER_NAMESPACE_END diff --git a/rabbitmq/src/urabbitmq/consumer_base_impl.cpp b/rabbitmq/src/urabbitmq/consumer_base_impl.cpp index 97248d190849..f51c46e2bce2 100644 --- a/rabbitmq/src/urabbitmq/consumer_base_impl.cpp +++ b/rabbitmq/src/urabbitmq/consumer_base_impl.cpp @@ -12,6 +12,7 @@ #include #include #include +#include USERVER_NAMESPACE_BEGIN @@ -21,43 +22,6 @@ namespace { constexpr std::chrono::milliseconds kStartTimeout{2000}; -std::string ToString(const AMQP::Field& field) { - const auto format = [](const auto value) { return fmt::format("{}", value); }; - - // AMQP-CPP field type codes returned by AMQP::Field::typeID(): - // string: 's'/'S', bool: 't', integers: 'b'/'B'/'U'/'u'/'I'/'i'/'L'/'l', float/double: 'f'/'d'. - switch (field.typeID()) { - case 'S': - case 's': - return static_cast(field); - case 't': - return dynamic_cast(field).value() ? "true" : "false"; - case 'B': - return format(static_cast(static_cast(field))); - case 'b': - return format(static_cast(static_cast(field))); - case 'u': - return format(static_cast(field)); - case 'U': - return format(static_cast(field)); - case 'i': - return format(static_cast(field)); - case 'I': - return format(static_cast(field)); - case 'l': - case 'T': - return format(static_cast(field)); - case 'L': - return format(static_cast(field)); - case 'f': - return format(static_cast(field)); - case 'd': - return format(static_cast(field)); - default: - return std::string(field); - } -} - } // namespace ConsumerBaseImpl::ConsumerBaseImpl(ConnectionPtr&& connection, const ConsumerSettings& settings) @@ -147,7 +111,7 @@ void ConsumerBaseImpl::OnMessage(const AMQP::Message& message, uint64_t delivery const auto keys = headers.keys(); consumed.headers.reserve(keys.size()); for (const auto& key : keys) { - consumed.headers.emplace(key, ToString(headers.get(key))); + consumed.headers.emplace(key, impl::FieldToString(headers.get(key))); } bts_.Detach(engine::AsyncNoSpan( diff --git a/rabbitmq/src/urabbitmq/impl/field_to_string.cpp b/rabbitmq/src/urabbitmq/impl/field_to_string.cpp new file mode 100644 index 000000000000..3f5ea2216ac2 --- /dev/null +++ b/rabbitmq/src/urabbitmq/impl/field_to_string.cpp @@ -0,0 +1,51 @@ +#include "field_to_string.hpp" + +#include +#include + +#include + +USERVER_NAMESPACE_BEGIN + +namespace urabbitmq::impl { + +std::string FieldToString(const AMQP::Field& field) { + const auto format = [](const auto value) { return fmt::format("{}", value); }; + + // AMQP-CPP field type codes returned by AMQP::Field::typeID(): + // string: 's'/'S', bool: 't', integers: 'b'/'B'/'U'/'u'/'I'/'i'/'L'/'l', float/double: 'f'/'d'. + switch (field.typeID()) { + case 'S': + case 's': + return static_cast(field); + case 't': + return static_cast(field).value() ? "true" : "false"; + case 'B': + return format(static_cast(static_cast(field))); + case 'b': + return format(static_cast(static_cast(field))); + case 'u': + return format(static_cast(field)); + case 'U': + return format(static_cast(field)); + case 'i': + return format(static_cast(field)); + case 'I': + return format(static_cast(field)); + case 'l': + case 'T': + return format(static_cast(field)); + case 'L': + return format(static_cast(field)); + case 'f': + return format(static_cast(field)); + case 'd': + return format(static_cast(field)); + default: + return std::string(field); + } +} + +} // namespace urabbitmq::impl + +USERVER_NAMESPACE_END diff --git a/rabbitmq/src/urabbitmq/impl/field_to_string.hpp b/rabbitmq/src/urabbitmq/impl/field_to_string.hpp new file mode 100644 index 000000000000..0bbb04f58c72 --- /dev/null +++ b/rabbitmq/src/urabbitmq/impl/field_to_string.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include + +#include + +USERVER_NAMESPACE_BEGIN + +namespace urabbitmq::impl { + +std::string FieldToString(const AMQP::Field& field); + +} // namespace urabbitmq::impl + +USERVER_NAMESPACE_END From d93c9e877b375d01f93d149376bdad78e80fe3a0 Mon Sep 17 00:00:00 2001 From: sav-da Date: Wed, 4 Mar 2026 13:23:19 +0000 Subject: [PATCH 08/10] feat: rabbitmq support integral header values in publish --- .../include/userver/urabbitmq/typedefs.hpp | 15 ++++++++- .../src/tests/publish_consume_rmqtest.cpp | 33 +++++++++++++++++-- rabbitmq/src/urabbitmq/impl/amqp_channel.cpp | 26 +++++++++++---- 3 files changed, 63 insertions(+), 11 deletions(-) diff --git a/rabbitmq/include/userver/urabbitmq/typedefs.hpp b/rabbitmq/include/userver/urabbitmq/typedefs.hpp index ecb6cf588618..672f668a446b 100644 --- a/rabbitmq/include/userver/urabbitmq/typedefs.hpp +++ b/rabbitmq/include/userver/urabbitmq/typedefs.hpp @@ -4,7 +4,9 @@ /// @brief Convenient typedefs for RabbitMQ entities. #include +#include #include +#include #include @@ -79,6 +81,17 @@ struct ConsumedMessage { std::unordered_map headers{}; }; +using HeaderValue = std::variant< + std::string, + std::int8_t, + std::uint8_t, + std::int16_t, + std::uint16_t, + std::int32_t, + std::uint32_t, + std::int64_t, + std::uint64_t>; + /// @brief Structure holding an AMQP message body along with some of its /// metadata fields. This struct is used to pass messages from the end user, /// hiding the actual AMQP message object implementation. @@ -88,7 +101,7 @@ struct Envelope { std::optional reply_to{}; std::optional correlation_id{}; std::optional expiration{}; - std::optional> headers{}; + std::optional> headers{}; }; } // namespace urabbitmq diff --git a/rabbitmq/src/tests/publish_consume_rmqtest.cpp b/rabbitmq/src/tests/publish_consume_rmqtest.cpp index 2ed810e82e40..795e809aa026 100644 --- a/rabbitmq/src/tests/publish_consume_rmqtest.cpp +++ b/rabbitmq/src/tests/publish_consume_rmqtest.cpp @@ -1,7 +1,11 @@ #include "utils_rmqtest.hpp" +#include +#include #include +#include #include +#include #include @@ -10,6 +14,7 @@ #include #include #include +#include #include USERVER_NAMESPACE_BEGIN @@ -251,7 +256,21 @@ UTEST(Consumer, ConsumeMetadataAndHeadersWork) { std::string name; std::optional reply_to; std::optional correlation_id; - std::unordered_map headers; + std::unordered_map headers; + }; + + const auto header_value_to_string = [](const urabbitmq::HeaderValue& value) { + return std::visit( + utils::Overloaded{ + [](const std::string& typed_value) { return typed_value; }, + [](const auto typed_value) { + using T = std::decay_t; + static_assert(std::is_integral_v, "Only integral header values are supported"); + return std::to_string(typed_value); + }, + }, + value + ); }; const std::vector cases{ @@ -262,7 +281,14 @@ UTEST(Consumer, ConsumeMetadataAndHeadersWork) { "corr-id", { {"x-custom-header", "custom-value"}, - {"x-custom-int", "42"}, + {"x-int8", std::numeric_limits::min()}, + {"x-uint8", std::numeric_limits::max()}, + {"x-int16", std::numeric_limits::min()}, + {"x-uint16", std::numeric_limits::max()}, + {"x-int32", std::numeric_limits::min()}, + {"x-uint32", std::numeric_limits::max()}, + {"x-int64", std::numeric_limits::min()}, + {"x-uint64", std::numeric_limits::max()}, }, }, { @@ -327,7 +353,8 @@ UTEST(Consumer, ConsumeMetadataAndHeadersWork) { for (const auto& [header_key, header_value] : case_data.headers) { ASSERT_EQ(msg.headers.count(header_key), 1) << "Missing header '" << header_key << "' in " << payload; const auto& actual = msg.headers.at(header_key); - EXPECT_NE(actual.find(header_value), std::string::npos) + const auto expected = header_value_to_string(header_value); + EXPECT_NE(actual.find(expected), std::string::npos) << "Unexpected value for header '" << header_key << "' in " << payload << ": " << actual; } diff --git a/rabbitmq/src/urabbitmq/impl/amqp_channel.cpp b/rabbitmq/src/urabbitmq/impl/amqp_channel.cpp index d6aa45d5b2ef..4cb1dc1a0119 100644 --- a/rabbitmq/src/urabbitmq/impl/amqp_channel.cpp +++ b/rabbitmq/src/urabbitmq/impl/amqp_channel.cpp @@ -1,9 +1,11 @@ #include "amqp_channel.hpp" #include +#include #include #include +#include #include #include @@ -98,7 +100,21 @@ AMQP::Table CreateHeadersForPublish(const Envelope& envelope) { auto headers = CreateHeaders(); if (envelope.headers.has_value()) { for (const auto& [key, value] : envelope.headers.value()) { - headers[key] = value; + std::visit( + utils::Overloaded{ + [&headers, &key](const std::string& typed_value) { headers[key] = typed_value; }, + [&headers, &key](const auto typed_value) { + using T = std::decay_t; + static_assert(std::is_integral_v, "Only integral header values are supported"); + if constexpr (std::is_signed_v) { + headers[key] = static_cast(typed_value); + } else { + headers[key] = static_cast(typed_value); + } + }, + }, + value + ); } } @@ -107,9 +123,7 @@ AMQP::Table CreateHeadersForPublish(const Envelope& envelope) { } // namespace -AmqpChannel::AmqpChannel(AmqpConnection& conn) - : conn_{conn} -{} +AmqpChannel::AmqpChannel(AmqpConnection& conn) : conn_{conn} {} AmqpChannel::~AmqpChannel() = default; @@ -273,9 +287,7 @@ void AmqpChannel::CancelConsumer(const std::optional& consumer_tag) void AmqpChannel::AccountMessageConsumed() { conn_.GetStatistics().AccountMessageConsumed(); } -AmqpReliableChannel::AmqpReliableChannel(AmqpConnection& conn) - : conn_{conn} -{} +AmqpReliableChannel::AmqpReliableChannel(AmqpConnection& conn) : conn_{conn} {} AmqpReliableChannel::~AmqpReliableChannel() = default; From 329cd44ec0e0f3dc8a30852376cee5b9cc6e9261 Mon Sep 17 00:00:00 2001 From: sav-da Date: Thu, 19 Mar 2026 06:28:04 +0000 Subject: [PATCH 09/10] feat: improved rabbitmq headers --- .../include/userver/urabbitmq/typedefs.hpp | 21 +-- .../src/tests/field_to_string_rmqtest.cpp | 58 ------ rabbitmq/src/tests/header_value_rmqtest.cpp | 115 ++++++++++++ .../src/tests/publish_consume_rmqtest.cpp | 97 +++++----- rabbitmq/src/urabbitmq/consumer_base_impl.cpp | 8 +- rabbitmq/src/urabbitmq/impl/amqp_channel.cpp | 21 +-- .../src/urabbitmq/impl/field_to_string.cpp | 51 ------ .../src/urabbitmq/impl/field_to_string.hpp | 15 -- rabbitmq/src/urabbitmq/impl/header_value.cpp | 168 ++++++++++++++++++ rabbitmq/src/urabbitmq/impl/header_value.hpp | 22 +++ 10 files changed, 368 insertions(+), 208 deletions(-) delete mode 100644 rabbitmq/src/tests/field_to_string_rmqtest.cpp create mode 100644 rabbitmq/src/tests/header_value_rmqtest.cpp delete mode 100644 rabbitmq/src/urabbitmq/impl/field_to_string.cpp delete mode 100644 rabbitmq/src/urabbitmq/impl/field_to_string.hpp create mode 100644 rabbitmq/src/urabbitmq/impl/header_value.cpp create mode 100644 rabbitmq/src/urabbitmq/impl/header_value.hpp diff --git a/rabbitmq/include/userver/urabbitmq/typedefs.hpp b/rabbitmq/include/userver/urabbitmq/typedefs.hpp index 672f668a446b..0339868d5a42 100644 --- a/rabbitmq/include/userver/urabbitmq/typedefs.hpp +++ b/rabbitmq/include/userver/urabbitmq/typedefs.hpp @@ -5,9 +5,11 @@ #include #include +#include +#include #include -#include +#include #include USERVER_NAMESPACE_BEGIN @@ -66,6 +68,10 @@ enum class MessageType { kTransient, }; +/// JSON-like representation of an AMQP header value. +/// This is not JSON, but a convenient tree representation for AMQP field values. +using HeaderValue = formats::json::Value; + /// @brief Structure holding an AMQP message body along with some of its /// metadata fields. This struct is used to pass messages to the end user, /// hiding the actual AMQP message object implementation. @@ -78,20 +84,9 @@ struct ConsumedMessage { Metadata metadata; std::optional reply_to{}; std::optional correlation_id{}; - std::unordered_map headers{}; + std::unordered_map headers{}; }; -using HeaderValue = std::variant< - std::string, - std::int8_t, - std::uint8_t, - std::int16_t, - std::uint16_t, - std::int32_t, - std::uint32_t, - std::int64_t, - std::uint64_t>; - /// @brief Structure holding an AMQP message body along with some of its /// metadata fields. This struct is used to pass messages from the end user, /// hiding the actual AMQP message object implementation. diff --git a/rabbitmq/src/tests/field_to_string_rmqtest.cpp b/rabbitmq/src/tests/field_to_string_rmqtest.cpp deleted file mode 100644 index 5d0e68b8a669..000000000000 --- a/rabbitmq/src/tests/field_to_string_rmqtest.cpp +++ /dev/null @@ -1,58 +0,0 @@ -#include - -#include -#include -#include -#include - -#include - -#include - -USERVER_NAMESPACE_BEGIN - -UTEST(FieldToString, BasicTypes) { - AMQP::Table headers; - headers.set("string", "value"); - headers.set("empty-string", ""); - headers.set("bool-true", true); - headers.set("bool-false", false); - headers.set("uint8", static_cast(255)); - headers.set("int8", static_cast(-100)); - headers.set("uint16", static_cast(65000)); - headers.set("int16", static_cast(-30000)); - headers.set("uint32", std::numeric_limits::max()); - headers.set("int32", std::numeric_limits::min()); - headers.set("uint64", std::numeric_limits::max()); - headers.set("int64", std::numeric_limits::min()); - headers.set("float", AMQP::Float{3.14f}); - headers.set("double", AMQP::Double{2.718281828}); - headers.set("void", nullptr); - - const std::unordered_map expected_values{ - {"string", "value"}, - {"empty-string", ""}, - {"bool-true", "true"}, - {"bool-false", "false"}, - {"uint8", "255"}, - {"int8", "-100"}, - {"uint16", "65000"}, - {"int16", "-30000"}, - {"uint32", "4294967295"}, - {"int32", "-2147483648"}, - {"uint64", "18446744073709551615"}, - {"int64", "-9223372036854775808"}, - {"float", "3.14"}, - {"double", "2.718281828"}, - {"void", ""}, - }; - - ASSERT_EQ(headers.keys().size(), expected_values.size()); - for (const auto& [key, expected_value] : expected_values) { - ASSERT_TRUE(headers.contains(key)) << "Missing header: " << key; - EXPECT_EQ(urabbitmq::impl::FieldToString(headers.get(key)), expected_value) - << "Unexpected converted value for key: " << key; - } -} - -USERVER_NAMESPACE_END diff --git a/rabbitmq/src/tests/header_value_rmqtest.cpp b/rabbitmq/src/tests/header_value_rmqtest.cpp new file mode 100644 index 000000000000..3912ad33ba69 --- /dev/null +++ b/rabbitmq/src/tests/header_value_rmqtest.cpp @@ -0,0 +1,115 @@ +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include + +USERVER_NAMESPACE_BEGIN + +namespace { + +template +formats::json::Value MakeHeaderValue(T&& value) { + return formats::json::ValueBuilder{std::forward(value)}.ExtractValue(); +} + +formats::json::Value MakeNestedArrayValue() { + formats::json::ValueBuilder builder{formats::common::Type::kArray}; + builder.PushBack(std::int64_t{-7}); + builder.PushBack("array-value"); + + formats::json::ValueBuilder nested_object{formats::common::Type::kObject}; + nested_object["enabled"] = false; + nested_object["nullable"] = formats::json::ValueBuilder{}; + builder.PushBack(std::move(nested_object)); + + return builder.ExtractValue(); +} + +formats::json::Value MakeNestedObjectValue() { + formats::json::ValueBuilder builder{formats::common::Type::kObject}; + builder["count"] = std::uint64_t{42}; + builder["name"] = "nested-object"; + builder["array"] = formats::json::ValueBuilder{MakeNestedArrayValue()}; + + return builder.ExtractValue(); +} + +void ExpectHeadersEqual( + const std::unordered_map& actual, + const std::unordered_map& expected +) { + ASSERT_EQ(actual.size(), expected.size()); + for (const auto& [key, expected_value] : expected) { + const auto it = actual.find(key); + ASSERT_NE(it, actual.end()) << "Missing key: " << key; + EXPECT_EQ(it->second, expected_value) << "Unexpected value for key: " << key; + } +} + +} // namespace + +UTEST(HeaderValue, ConvertsNestedAmqpTypes) { + AMQP::Table headers; + headers.set("string", "value"); + headers.set("bool", true); + headers.set("signed", static_cast(-10)); + headers.set("unsigned", static_cast(10)); + headers.set("double", AMQP::Double{1.5}); + headers.set("null", nullptr); + + AMQP::Array nested_array; + nested_array.push_back(AMQP::LongLong{-7}); + nested_array.push_back(AMQP::LongString{"array-value"}); + AMQP::Table nested_array_object; + nested_array_object.set("enabled", false); + nested_array_object.set("nullable", nullptr); + nested_array.push_back(nested_array_object); + headers.set("array", nested_array); + + AMQP::Table nested_object; + nested_object.set("count", static_cast(42)); + nested_object.set("name", "nested-object"); + nested_object.set("array", nested_array); + headers.set("object", nested_object); + + const std::unordered_map expected{ + {"string", MakeHeaderValue("value")}, + {"bool", MakeHeaderValue(true)}, + {"signed", MakeHeaderValue(std::int64_t{-10})}, + {"unsigned", MakeHeaderValue(std::uint64_t{10})}, + {"double", MakeHeaderValue(1.5)}, + {"null", formats::json::ValueBuilder{}.ExtractValue()}, + {"array", MakeNestedArrayValue()}, + {"object", MakeNestedObjectValue()}, + }; + + ExpectHeadersEqual(urabbitmq::impl::TableToHeaders(headers), expected); +} + +UTEST(HeaderValue, RoundTripsHeaders) { + const std::unordered_map expected{ + {"string", MakeHeaderValue("value")}, + {"bool", MakeHeaderValue(false)}, + {"signed", MakeHeaderValue(std::int64_t{-123456789})}, + {"unsigned", MakeHeaderValue(std::uint64_t{123456789})}, + {"double", MakeHeaderValue(3.25)}, + {"null", formats::json::ValueBuilder{}.ExtractValue()}, + {"array", MakeNestedArrayValue()}, + {"object", MakeNestedObjectValue()}, + }; + + AMQP::Table table; + urabbitmq::impl::AddHeadersToTable(table, expected); + + ExpectHeadersEqual(urabbitmq::impl::TableToHeaders(table), expected); +} + +USERVER_NAMESPACE_END diff --git a/rabbitmq/src/tests/publish_consume_rmqtest.cpp b/rabbitmq/src/tests/publish_consume_rmqtest.cpp index 795e809aa026..f9d9be91b9ad 100644 --- a/rabbitmq/src/tests/publish_consume_rmqtest.cpp +++ b/rabbitmq/src/tests/publish_consume_rmqtest.cpp @@ -1,26 +1,50 @@ #include "utils_rmqtest.hpp" #include -#include #include -#include #include -#include - -#include +#include #include #include #include #include #include -#include +#include +#include #include USERVER_NAMESPACE_BEGIN namespace { +template +formats::json::Value MakeHeaderValue(T&& value) { + return formats::json::ValueBuilder{std::forward(value)}.ExtractValue(); +} + +formats::json::Value MakeNestedArrayValue() { + formats::json::ValueBuilder builder{formats::common::Type::kArray}; + builder.PushBack(std::int64_t{-7}); + builder.PushBack("array-value"); + + formats::json::ValueBuilder nested_object{formats::common::Type::kObject}; + nested_object["enabled"] = false; + nested_object["nullable"] = formats::json::ValueBuilder{}; + builder.PushBack(std::move(nested_object)); + + return builder.ExtractValue(); +} + +formats::json::Value MakeNestedObjectValue() { + formats::json::ValueBuilder builder{formats::common::Type::kObject}; + builder["count"] = std::uint64_t{42}; + builder["name"] = "nested-object"; + builder["array"] = formats::json::ValueBuilder{MakeNestedArrayValue()}; + + return builder.ExtractValue(); +} + class Consumer final : public urabbitmq::ConsumerBase { public: using urabbitmq::ConsumerBase::ConsumerBase; @@ -259,47 +283,28 @@ UTEST(Consumer, ConsumeMetadataAndHeadersWork) { std::unordered_map headers; }; - const auto header_value_to_string = [](const urabbitmq::HeaderValue& value) { - return std::visit( - utils::Overloaded{ - [](const std::string& typed_value) { return typed_value; }, - [](const auto typed_value) { - using T = std::decay_t; - static_assert(std::is_integral_v, "Only integral header values are supported"); - return std::to_string(typed_value); - }, - }, - value - ); - }; - const std::vector cases{ {"no-user-headers", std::nullopt, std::nullopt, {}}, { - "simple-user-headers", + "scalar-user-headers", "reply-queue", "corr-id", { - {"x-custom-header", "custom-value"}, - {"x-int8", std::numeric_limits::min()}, - {"x-uint8", std::numeric_limits::max()}, - {"x-int16", std::numeric_limits::min()}, - {"x-uint16", std::numeric_limits::max()}, - {"x-int32", std::numeric_limits::min()}, - {"x-uint32", std::numeric_limits::max()}, - {"x-int64", std::numeric_limits::min()}, - {"x-uint64", std::numeric_limits::max()}, + {"x-custom-header", MakeHeaderValue("custom-value")}, + {"x-bool", MakeHeaderValue(true)}, + {"x-int64", MakeHeaderValue(std::int64_t{-10})}, + {"x-uint64", MakeHeaderValue(std::uint64_t{10})}, + {"x-double", MakeHeaderValue(2.5)}, + {"x-null", formats::json::ValueBuilder{}.ExtractValue()}, }, }, { - "many-user-headers", - "reply-many", - "corr-many", + "nested-user-headers", + "reply-nested", + "corr-nested", { - {"x-empty", ""}, - {"x-spaces", "a b c"}, - {"x-symbols", R"(!@#$%^&*()[]{}<>/?\\|;:'\",.~-_=+)"}, - {"x-long", std::string(128, 'x')}, + {"x-array", MakeNestedArrayValue()}, + {"x-object", MakeNestedObjectValue()}, }, }, { @@ -307,9 +312,9 @@ UTEST(Consumer, ConsumeMetadataAndHeadersWork) { "reply-override", "corr-override", { - {"u-trace-id", "trace-from-user"}, - {"u-parent-span-id", "parent-from-user"}, - {"x-another", "value"}, + {"u-trace-id", MakeHeaderValue("trace-from-user")}, + {"u-parent-span-id", MakeHeaderValue("parent-from-user")}, + {"x-another", MakeHeaderValue("value")}, }, }, }; @@ -352,16 +357,16 @@ UTEST(Consumer, ConsumeMetadataAndHeadersWork) { for (const auto& [header_key, header_value] : case_data.headers) { ASSERT_EQ(msg.headers.count(header_key), 1) << "Missing header '" << header_key << "' in " << payload; - const auto& actual = msg.headers.at(header_key); - const auto expected = header_value_to_string(header_value); - EXPECT_NE(actual.find(expected), std::string::npos) - << "Unexpected value for header '" << header_key << "' in " << payload << ": " << actual; + EXPECT_EQ(msg.headers.at(header_key), header_value) + << "Unexpected value for header '" << header_key << "' in " << payload; } ASSERT_EQ(msg.headers.count("u-trace-id"), 1) << "Missing u-trace-id in " << payload; ASSERT_EQ(msg.headers.count("u-parent-span-id"), 1) << "Missing u-parent-span-id in " << payload; - EXPECT_FALSE(msg.headers.at("u-trace-id").empty()); - EXPECT_FALSE(msg.headers.at("u-parent-span-id").empty()); + ASSERT_TRUE(msg.headers.at("u-trace-id").IsString()); + ASSERT_TRUE(msg.headers.at("u-parent-span-id").IsString()); + EXPECT_FALSE(msg.headers.at("u-trace-id").As().empty()); + EXPECT_FALSE(msg.headers.at("u-parent-span-id").As().empty()); } } diff --git a/rabbitmq/src/urabbitmq/consumer_base_impl.cpp b/rabbitmq/src/urabbitmq/consumer_base_impl.cpp index f51c46e2bce2..dc40e20ae969 100644 --- a/rabbitmq/src/urabbitmq/consumer_base_impl.cpp +++ b/rabbitmq/src/urabbitmq/consumer_base_impl.cpp @@ -12,7 +12,7 @@ #include #include #include -#include +#include USERVER_NAMESPACE_BEGIN @@ -108,11 +108,7 @@ void ConsumerBaseImpl::OnMessage(const AMQP::Message& message, uint64_t delivery consumed.correlation_id = message.correlationID(); } - const auto keys = headers.keys(); - consumed.headers.reserve(keys.size()); - for (const auto& key : keys) { - consumed.headers.emplace(key, impl::FieldToString(headers.get(key))); - } + consumed.headers = impl::TableToHeaders(headers); bts_.Detach(engine::AsyncNoSpan( dispatcher_, diff --git a/rabbitmq/src/urabbitmq/impl/amqp_channel.cpp b/rabbitmq/src/urabbitmq/impl/amqp_channel.cpp index 4cb1dc1a0119..69ce974ed119 100644 --- a/rabbitmq/src/urabbitmq/impl/amqp_channel.cpp +++ b/rabbitmq/src/urabbitmq/impl/amqp_channel.cpp @@ -1,14 +1,13 @@ #include "amqp_channel.hpp" #include -#include #include #include -#include #include #include +#include #include USERVER_NAMESPACE_BEGIN @@ -99,23 +98,7 @@ AMQP::Table CreateHeaders() { AMQP::Table CreateHeadersForPublish(const Envelope& envelope) { auto headers = CreateHeaders(); if (envelope.headers.has_value()) { - for (const auto& [key, value] : envelope.headers.value()) { - std::visit( - utils::Overloaded{ - [&headers, &key](const std::string& typed_value) { headers[key] = typed_value; }, - [&headers, &key](const auto typed_value) { - using T = std::decay_t; - static_assert(std::is_integral_v, "Only integral header values are supported"); - if constexpr (std::is_signed_v) { - headers[key] = static_cast(typed_value); - } else { - headers[key] = static_cast(typed_value); - } - }, - }, - value - ); - } + AddHeadersToTable(headers, *envelope.headers); } return headers; diff --git a/rabbitmq/src/urabbitmq/impl/field_to_string.cpp b/rabbitmq/src/urabbitmq/impl/field_to_string.cpp deleted file mode 100644 index 3f5ea2216ac2..000000000000 --- a/rabbitmq/src/urabbitmq/impl/field_to_string.cpp +++ /dev/null @@ -1,51 +0,0 @@ -#include "field_to_string.hpp" - -#include -#include - -#include - -USERVER_NAMESPACE_BEGIN - -namespace urabbitmq::impl { - -std::string FieldToString(const AMQP::Field& field) { - const auto format = [](const auto value) { return fmt::format("{}", value); }; - - // AMQP-CPP field type codes returned by AMQP::Field::typeID(): - // string: 's'/'S', bool: 't', integers: 'b'/'B'/'U'/'u'/'I'/'i'/'L'/'l', float/double: 'f'/'d'. - switch (field.typeID()) { - case 'S': - case 's': - return static_cast(field); - case 't': - return static_cast(field).value() ? "true" : "false"; - case 'B': - return format(static_cast(static_cast(field))); - case 'b': - return format(static_cast(static_cast(field))); - case 'u': - return format(static_cast(field)); - case 'U': - return format(static_cast(field)); - case 'i': - return format(static_cast(field)); - case 'I': - return format(static_cast(field)); - case 'l': - case 'T': - return format(static_cast(field)); - case 'L': - return format(static_cast(field)); - case 'f': - return format(static_cast(field)); - case 'd': - return format(static_cast(field)); - default: - return std::string(field); - } -} - -} // namespace urabbitmq::impl - -USERVER_NAMESPACE_END diff --git a/rabbitmq/src/urabbitmq/impl/field_to_string.hpp b/rabbitmq/src/urabbitmq/impl/field_to_string.hpp deleted file mode 100644 index 0bbb04f58c72..000000000000 --- a/rabbitmq/src/urabbitmq/impl/field_to_string.hpp +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once - -#include - -#include - -USERVER_NAMESPACE_BEGIN - -namespace urabbitmq::impl { - -std::string FieldToString(const AMQP::Field& field); - -} // namespace urabbitmq::impl - -USERVER_NAMESPACE_END diff --git a/rabbitmq/src/urabbitmq/impl/header_value.cpp b/rabbitmq/src/urabbitmq/impl/header_value.cpp new file mode 100644 index 000000000000..440627809466 --- /dev/null +++ b/rabbitmq/src/urabbitmq/impl/header_value.cpp @@ -0,0 +1,168 @@ +#include "header_value.hpp" + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +USERVER_NAMESPACE_BEGIN + +namespace urabbitmq::impl { + +namespace { + +formats::json::Value ToJsonValue(const AMQP::Array& array); +formats::json::Value ToJsonValue(const AMQP::Table& table); +std::unique_ptr ToAmqpField(const HeaderValue& value); + +[[noreturn]] void ThrowUnsupportedAmqpField(char type_id) { + throw std::runtime_error{fmt::format("Unsupported AMQP header field type '{}'", type_id)}; +} + +[[noreturn]] void ThrowUnsupportedHeaderValue(const HeaderValue& value) { + throw std::runtime_error{fmt::format("Unsupported RabbitMQ header value at '{}'", value.GetPath())}; +} + +formats::json::Value ToJsonValue(const AMQP::Array& array) { + formats::json::ValueBuilder builder{formats::common::Type::kArray}; + const auto count = array.count(); + for (std::uint32_t index = 0; index < count; ++index) { + builder.PushBack(formats::json::ValueBuilder{FieldToHeaderValue(array[static_cast(index)])}); + } + + return builder.ExtractValue(); +} + +formats::json::Value ToJsonValue(const AMQP::Table& table) { + formats::json::ValueBuilder builder{formats::common::Type::kObject}; + for (const auto& key : table.keys()) { + builder.EmplaceNocheck(key, formats::json::ValueBuilder{FieldToHeaderValue(table[key])}); + } + + return builder.ExtractValue(); +} + +AMQP::Array ToAmqpArray(const HeaderValue& value) { + AMQP::Array array; + for (const auto& item : value) { + array.push_back(*ToAmqpField(item)); + } + + return array; +} + +AMQP::Table ToAmqpTable(const HeaderValue& value) { + AMQP::Table table; + for (const auto& [key, item] : formats::common::Items(value)) { + table.set(key, *ToAmqpField(item)); + } + + return table; +} + +std::unique_ptr ToAmqpField(const HeaderValue& value) { + if (value.IsNull()) { + return std::make_unique(); + } + if (value.IsBool()) { + return std::make_unique(value.As()); + } + if (value.IsInt() || value.IsInt64()) { + return std::make_unique(value.As()); + } + if (value.IsUInt() || value.IsUInt64()) { + return std::make_unique(value.As()); + } + if (value.IsDouble()) { + return std::make_unique(value.As()); + } + if (value.IsString()) { + return std::make_unique(value.As()); + } + if (value.IsArray()) { + return std::make_unique(ToAmqpArray(value)); + } + if (value.IsObject()) { + return std::make_unique(ToAmqpTable(value)); + } + + ThrowUnsupportedHeaderValue(value); +} + +} // namespace + +HeaderValue FieldToHeaderValue(const AMQP::Field& field) { + switch (field.typeID()) { + case 'S': + case 's': + return formats::json::ValueBuilder{static_cast(field)}.ExtractValue(); + case 't': + return formats::json::ValueBuilder{static_cast(field).value() != 0}.ExtractValue(); + case 'B': + return formats::json::ValueBuilder{static_cast(static_cast(field))} + .ExtractValue(); + case 'b': + return formats::json::ValueBuilder{static_cast(static_cast(field))} + .ExtractValue(); + case 'u': + return formats::json::ValueBuilder{static_cast(static_cast(field))} + .ExtractValue(); + case 'U': + return formats::json::ValueBuilder{static_cast(static_cast(field))} + .ExtractValue(); + case 'i': + return formats::json::ValueBuilder{static_cast(static_cast(field))} + .ExtractValue(); + case 'I': + return formats::json::ValueBuilder{static_cast(static_cast(field))} + .ExtractValue(); + case 'l': + case 'T': + return formats::json::ValueBuilder{static_cast(field)}.ExtractValue(); + case 'L': + return formats::json::ValueBuilder{static_cast(field)}.ExtractValue(); + case 'f': + return formats::json::ValueBuilder{static_cast(static_cast(field))}.ExtractValue(); + case 'd': + case 'D': + return formats::json::ValueBuilder{static_cast(field)}.ExtractValue(); + case 'A': + return ToJsonValue(static_cast(field)); + case 'F': + return ToJsonValue(static_cast(field)); + case 'V': + return formats::json::ValueBuilder{}.ExtractValue(); + } + + ThrowUnsupportedAmqpField(field.typeID()); +} + +std::unordered_map TableToHeaders(const AMQP::Table& table) { + const auto keys = table.keys(); + + std::unordered_map headers; + headers.reserve(keys.size()); + + for (const auto& key : keys) { + headers.emplace(key, FieldToHeaderValue(table[key])); + } + + return headers; +} + +void AddHeadersToTable(AMQP::Table& table, const std::unordered_map& headers) { + for (const auto& [key, value] : headers) { + table.set(key, *ToAmqpField(value)); + } +} + +} // namespace urabbitmq::impl + +USERVER_NAMESPACE_END diff --git a/rabbitmq/src/urabbitmq/impl/header_value.hpp b/rabbitmq/src/urabbitmq/impl/header_value.hpp new file mode 100644 index 000000000000..abea1f2f04f6 --- /dev/null +++ b/rabbitmq/src/urabbitmq/impl/header_value.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include +#include + +#include + +#include + +USERVER_NAMESPACE_BEGIN + +namespace urabbitmq::impl { + +HeaderValue FieldToHeaderValue(const AMQP::Field& field); + +std::unordered_map TableToHeaders(const AMQP::Table& table); + +void AddHeadersToTable(AMQP::Table& table, const std::unordered_map& headers); + +} // namespace urabbitmq::impl + +USERVER_NAMESPACE_END From fa7151cca3489115cd760cb650362d1edb366898 Mon Sep 17 00:00:00 2001 From: sav-da Date: Wed, 25 Mar 2026 06:26:59 +0000 Subject: [PATCH 10/10] fix: fixes after review --- .../basic_chaos/rabbitmq_service.cpp | 27 ++-- .../basic_chaos/static_config.yaml | 1 + .../basic_chaos/tests/test_rabbitmq.py | 88 ++++++++++- rabbitmq/src/tests/header_value_rmqtest.cpp | 44 +++--- .../src/tests/publish_consume_rmqtest.cpp | 22 +-- .../impl/amqp_connection_handler.cpp | 28 ++-- .../impl/amqp_connection_handler.hpp | 10 +- rabbitmq/src/urabbitmq/impl/header_value.cpp | 142 +++++++++--------- rabbitmq/src/urabbitmq/impl/header_value.hpp | 2 +- 9 files changed, 231 insertions(+), 133 deletions(-) diff --git a/rabbitmq/functional_tests/basic_chaos/rabbitmq_service.cpp b/rabbitmq/functional_tests/basic_chaos/rabbitmq_service.cpp index 32d00633ba41..b1faaae3473d 100644 --- a/rabbitmq/functional_tests/basic_chaos/rabbitmq_service.cpp +++ b/rabbitmq/functional_tests/basic_chaos/rabbitmq_service.cpp @@ -7,6 +7,8 @@ #include #include #include +#include +#include #include #include #include @@ -26,8 +28,7 @@ class ChaosProducer final : public components::LoggableComponentBase { ChaosProducer(const components::ComponentConfig& config, const components::ComponentContext& context) : components::LoggableComponentBase{config, context}, - rabbit_client_{context.FindComponent("chaos-rabbit").GetClient()} - { + rabbit_client_{context.FindComponent("chaos-rabbit").GetClient()} { const auto setup_deadline = engine::Deadline::FromDuration(kDefaultOperationTimeout); auto admin_channel = rabbit_client_->GetAdminChannel(setup_deadline); @@ -77,9 +78,7 @@ class ChaosConsumer final : public components::ComponentBase { static constexpr std::string_view kName{"chaos-consumer"}; ChaosConsumer(const components::ComponentConfig& config, const components::ComponentContext& context) - : components::ComponentBase{config, context}, - consumer_{config, context, messages_} - { + : components::ComponentBase{config, context}, consumer_{config, context, messages_} { Start(); } @@ -119,8 +118,7 @@ class ChaosConsumer final : public components::ComponentBase { ) : urabbitmq:: ConsumerBase{context.FindComponent(config["rabbit_name"].As()).GetClient(), ParseSettings(config)}, - messages_{messages} - {} + messages_{messages} {} protected: void Process(urabbitmq::ConsumedMessage msg) override { @@ -150,8 +148,7 @@ class ChaosHandler final : public server::handlers::HttpHandlerBase { ChaosHandler(const components::ComponentConfig& config, const components::ComponentContext& context) : server::handlers::HttpHandlerBase{config, context}, producer_{context.FindComponent()}, - consumer_{context.FindComponent()} - {} + consumer_{context.FindComponent()} {} std::string HandleRequestThrow(const server::http::HttpRequest& request, server::request::RequestContext&) const override { @@ -178,6 +175,13 @@ class ChaosHandler final : public server::handlers::HttpHandlerBase { throw server::handlers::ClientError{server::handlers::ExternalBody{"No 'message' query argument"}}; } urabbitmq::Envelope envelope{message, urabbitmq::MessageType::kTransient, {}, {}, {}}; + if (!request.RequestBody().empty()) { + const auto request_json = formats::json::FromString(request.RequestBody()); + if (request_json.HasMember("headers")) { + envelope + .headers = request_json["headers"].As>(); + } + } const auto& correlation_id = request.GetArg("correlation_id"); if (!correlation_id.empty()) { envelope.correlation_id = correlation_id; @@ -219,9 +223,9 @@ class ChaosHandler final : public server::handlers::HttpHandlerBase { } std::string HandleGet() const { - formats::json::ValueBuilder messages_builder; + urabbitmq::HeaderValue::Builder messages_builder; for (const auto& item : consumer_.GetMessages()) { - formats::json::ValueBuilder item_builder; + urabbitmq::HeaderValue::Builder item_builder; item_builder["message"] = item.message; if (item.correlation_id.has_value()) { item_builder["correlation_id"] = item.correlation_id; @@ -229,6 +233,7 @@ class ChaosHandler final : public server::handlers::HttpHandlerBase { if (item.reply_to.has_value()) { item_builder["reply_to"] = item.reply_to; } + item_builder["headers"] = item.headers; messages_builder.PushBack(std::move(item_builder)); } return formats::json::ToString(messages_builder.ExtractValue()); diff --git a/rabbitmq/functional_tests/basic_chaos/static_config.yaml b/rabbitmq/functional_tests/basic_chaos/static_config.yaml index e509f78aa359..3da63680e54d 100644 --- a/rabbitmq/functional_tests/basic_chaos/static_config.yaml +++ b/rabbitmq/functional_tests/basic_chaos/static_config.yaml @@ -17,6 +17,7 @@ components_manager: min_pool_size: 1 max_pool_size: 1 max_in_flight_requests: 5 + heartbeat_interval_seconds: 1 use_secure_connection: false secdist: {} # Component that stores configuration of hosts and passwords diff --git a/rabbitmq/functional_tests/basic_chaos/tests/test_rabbitmq.py b/rabbitmq/functional_tests/basic_chaos/tests/test_rabbitmq.py index 9718e0d35e39..1db29a2eb3cb 100644 --- a/rabbitmq/functional_tests/basic_chaos/tests/test_rabbitmq.py +++ b/rabbitmq/functional_tests/basic_chaos/tests/test_rabbitmq.py @@ -37,6 +37,10 @@ async def _clear_messages(service_client): assert response.status_code == 200 +def _strip_headers(messages): + return [{key: value for key, value in message.items() if key != 'headers'} for message in messages] + + async def _publish_and_consume(testpoint, client): @testpoint('message_consumed') def message_consumed(data): @@ -51,7 +55,7 @@ def message_consumed(data): response = await client.get('/v1/messages') assert response.status_code == 200 - assert response.json() == MESSAGES + assert _strip_headers(response.json()) == MESSAGES async def test_rabbitmq_happy(testpoint, service_client, gate): @@ -60,6 +64,55 @@ async def test_rabbitmq_happy(testpoint, service_client, gate): await _publish_and_consume(testpoint, service_client) +async def test_rabbitmq_headers(testpoint, service_client, gate): + await _clear_messages(service_client) + + @testpoint('message_consumed') + def message_consumed(data): + pass + + expected_headers = { + 'x-bool': True, + 'x-int': -10, + 'x-uint': 10, + 'x-double': 2.5, + 'x-array': [-7, 'array-value', {'enabled': False, 'nullable': None}], + 'x-object': { + 'count': 42, + 'name': 'nested-object', + 'array': [-7, 'array-value', {'enabled': False, 'nullable': None}], + }, + 'x-null': None, + } + + response = await service_client.post( + '/v1/messages?message=headers&reliable=1&reply_to=reply&correlation_id=corr-id', + json={'headers': expected_headers}, + ) + assert response.status_code == 200 + + await message_consumed.wait_call() + + response = await service_client.get('/v1/messages') + assert response.status_code == 200 + messages = response.json() + assert len(messages) == 1 + + consumed = messages[0] + assert consumed['message'] == 'headers' + assert consumed['reply_to'] == 'reply' + assert consumed['correlation_id'] == 'corr-id' + assert consumed['headers']['x-bool'] is True + assert consumed['headers']['x-int'] == -10 + assert consumed['headers']['x-uint'] == 10 + assert consumed['headers']['x-double'] == 2.5 + assert consumed['headers']['x-array'] == expected_headers['x-array'] + assert consumed['headers']['x-object'] == expected_headers['x-object'] + assert consumed['headers']['x-null'] is None + assert consumed['headers']['u-trace-id'] + assert consumed['headers']['u-parent-span-id'] + + @pytest.mark.skip(reason='std::terminate is called, fix in TAXICOMMON-6995') async def test_consumer_reconnects(testpoint, service_client, gate): await _clear_messages(service_client) @@ -85,4 +138,35 @@ def message_consumed(data): response = await service_client.get('/v1/messages') assert response.status_code == 200 - assert response.json() == MESSAGES + assert _strip_headers(response.json()) == MESSAGES + + +async def test_rabbitmq_heartbeat_reconnects(testpoint, service_client, gate): + await _clear_messages(service_client) + + @testpoint('message_consumed') + def message_consumed(data): + pass + + response = await service_client.post('/v1/messages?message=before-heartbeat') + assert response.status_code == 200 + await message_consumed.wait_call() + + await gate.to_server_noop() + await gate.to_client_noop() + await asyncio.sleep(2.5) + await gate.to_server_pass() + await gate.to_client_pass() + await gate.sockets_close() + + await gate.wait_for_connections(timeout=10.0) + await asyncio.sleep(1.0) + + response = await service_client.post('/v1/messages?message=after-heartbeat') + assert response.status_code == 200 + + await message_consumed.wait_call() + + response = await service_client.get('/v1/messages') + assert response.status_code == 200 + assert any(message['message'] == 'after-heartbeat' for message in response.json()) diff --git a/rabbitmq/src/tests/header_value_rmqtest.cpp b/rabbitmq/src/tests/header_value_rmqtest.cpp index 3912ad33ba69..bddb480443d2 100644 --- a/rabbitmq/src/tests/header_value_rmqtest.cpp +++ b/rabbitmq/src/tests/header_value_rmqtest.cpp @@ -16,28 +16,28 @@ USERVER_NAMESPACE_BEGIN namespace { template -formats::json::Value MakeHeaderValue(T&& value) { - return formats::json::ValueBuilder{std::forward(value)}.ExtractValue(); +urabbitmq::HeaderValue MakeHeaderValue(T&& value) { + return urabbitmq::HeaderValue::Builder{std::forward(value)}.ExtractValue(); } -formats::json::Value MakeNestedArrayValue() { - formats::json::ValueBuilder builder{formats::common::Type::kArray}; +urabbitmq::HeaderValue MakeNestedArrayValue() { + urabbitmq::HeaderValue::Builder builder{formats::common::Type::kArray}; builder.PushBack(std::int64_t{-7}); builder.PushBack("array-value"); - formats::json::ValueBuilder nested_object{formats::common::Type::kObject}; + urabbitmq::HeaderValue::Builder nested_object{formats::common::Type::kObject}; nested_object["enabled"] = false; - nested_object["nullable"] = formats::json::ValueBuilder{}; + nested_object["nullable"] = urabbitmq::HeaderValue::Builder{}; builder.PushBack(std::move(nested_object)); return builder.ExtractValue(); } -formats::json::Value MakeNestedObjectValue() { - formats::json::ValueBuilder builder{formats::common::Type::kObject}; +urabbitmq::HeaderValue MakeNestedObjectValue() { + urabbitmq::HeaderValue::Builder builder{formats::common::Type::kObject}; builder["count"] = std::uint64_t{42}; builder["name"] = "nested-object"; - builder["array"] = formats::json::ValueBuilder{MakeNestedArrayValue()}; + builder["array"] = urabbitmq::HeaderValue::Builder{MakeNestedArrayValue()}; return builder.ExtractValue(); } @@ -83,25 +83,30 @@ UTEST(HeaderValue, ConvertsNestedAmqpTypes) { const std::unordered_map expected{ {"string", MakeHeaderValue("value")}, {"bool", MakeHeaderValue(true)}, - {"signed", MakeHeaderValue(std::int64_t{-10})}, - {"unsigned", MakeHeaderValue(std::uint64_t{10})}, + {"signed", MakeHeaderValue(-10)}, + {"unsigned", MakeHeaderValue(10u)}, {"double", MakeHeaderValue(1.5)}, - {"null", formats::json::ValueBuilder{}.ExtractValue()}, + {"null", urabbitmq::HeaderValue::Builder{}.ExtractValue()}, {"array", MakeNestedArrayValue()}, {"object", MakeNestedObjectValue()}, }; - ExpectHeadersEqual(urabbitmq::impl::TableToHeaders(headers), expected); + const auto actual = urabbitmq::impl::TableToHeaders(headers); + ExpectHeadersEqual(actual, expected); + EXPECT_TRUE(actual.at("signed").IsInt()); + EXPECT_TRUE(actual.at("unsigned").IsUInt()); } UTEST(HeaderValue, RoundTripsHeaders) { const std::unordered_map expected{ {"string", MakeHeaderValue("value")}, {"bool", MakeHeaderValue(false)}, - {"signed", MakeHeaderValue(std::int64_t{-123456789})}, - {"unsigned", MakeHeaderValue(std::uint64_t{123456789})}, + {"signed", MakeHeaderValue(-123456789)}, + {"signed64", MakeHeaderValue(std::int64_t{-1234567890123})}, + {"unsigned", MakeHeaderValue(123456789u)}, + {"unsigned64", MakeHeaderValue(std::uint64_t{1234567890123})}, {"double", MakeHeaderValue(3.25)}, - {"null", formats::json::ValueBuilder{}.ExtractValue()}, + {"null", urabbitmq::HeaderValue::Builder{}.ExtractValue()}, {"array", MakeNestedArrayValue()}, {"object", MakeNestedObjectValue()}, }; @@ -109,7 +114,12 @@ UTEST(HeaderValue, RoundTripsHeaders) { AMQP::Table table; urabbitmq::impl::AddHeadersToTable(table, expected); - ExpectHeadersEqual(urabbitmq::impl::TableToHeaders(table), expected); + const auto actual = urabbitmq::impl::TableToHeaders(table); + ExpectHeadersEqual(actual, expected); + EXPECT_TRUE(actual.at("signed").IsInt()); + EXPECT_TRUE(actual.at("signed64").IsInt64()); + EXPECT_TRUE(actual.at("unsigned").IsUInt()); + EXPECT_TRUE(actual.at("unsigned64").IsUInt64()); } USERVER_NAMESPACE_END diff --git a/rabbitmq/src/tests/publish_consume_rmqtest.cpp b/rabbitmq/src/tests/publish_consume_rmqtest.cpp index f9d9be91b9ad..e40399878898 100644 --- a/rabbitmq/src/tests/publish_consume_rmqtest.cpp +++ b/rabbitmq/src/tests/publish_consume_rmqtest.cpp @@ -19,28 +19,28 @@ USERVER_NAMESPACE_BEGIN namespace { template -formats::json::Value MakeHeaderValue(T&& value) { - return formats::json::ValueBuilder{std::forward(value)}.ExtractValue(); +urabbitmq::HeaderValue MakeHeaderValue(T&& value) { + return urabbitmq::HeaderValue::Builder{std::forward(value)}.ExtractValue(); } -formats::json::Value MakeNestedArrayValue() { - formats::json::ValueBuilder builder{formats::common::Type::kArray}; +urabbitmq::HeaderValue MakeNestedArrayValue() { + urabbitmq::HeaderValue::Builder builder{formats::common::Type::kArray}; builder.PushBack(std::int64_t{-7}); builder.PushBack("array-value"); - formats::json::ValueBuilder nested_object{formats::common::Type::kObject}; + urabbitmq::HeaderValue::Builder nested_object{formats::common::Type::kObject}; nested_object["enabled"] = false; - nested_object["nullable"] = formats::json::ValueBuilder{}; + nested_object["nullable"] = urabbitmq::HeaderValue::Builder{}; builder.PushBack(std::move(nested_object)); return builder.ExtractValue(); } -formats::json::Value MakeNestedObjectValue() { - formats::json::ValueBuilder builder{formats::common::Type::kObject}; +urabbitmq::HeaderValue MakeNestedObjectValue() { + urabbitmq::HeaderValue::Builder builder{formats::common::Type::kObject}; builder["count"] = std::uint64_t{42}; builder["name"] = "nested-object"; - builder["array"] = formats::json::ValueBuilder{MakeNestedArrayValue()}; + builder["array"] = urabbitmq::HeaderValue::Builder{MakeNestedArrayValue()}; return builder.ExtractValue(); } @@ -292,10 +292,12 @@ UTEST(Consumer, ConsumeMetadataAndHeadersWork) { { {"x-custom-header", MakeHeaderValue("custom-value")}, {"x-bool", MakeHeaderValue(true)}, + {"x-int", MakeHeaderValue(-10)}, {"x-int64", MakeHeaderValue(std::int64_t{-10})}, + {"x-uint", MakeHeaderValue(10u)}, {"x-uint64", MakeHeaderValue(std::uint64_t{10})}, {"x-double", MakeHeaderValue(2.5)}, - {"x-null", formats::json::ValueBuilder{}.ExtractValue()}, + {"x-null", urabbitmq::HeaderValue::Builder{}.ExtractValue()}, }, }, { diff --git a/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp b/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp index 96a7a0db6c2a..81975639bd33 100644 --- a/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp +++ b/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp @@ -89,6 +89,10 @@ AMQP::Address ToAmqpAddress(const EndpointInfo& endpoint, const AuthSettings& se return {endpoint.host, endpoint.port, AMQP::Login{settings.login, settings.password}, settings.vhost, secure}; } +std::chrono::milliseconds HalfInterval(std::uint16_t interval_seconds) { + return std::chrono::duration_cast(std::chrono::seconds{interval_seconds} / 2.0); +} + } // namespace AmqpConnectionHandler::AmqpConnectionHandler( @@ -118,14 +122,14 @@ void AmqpConnectionHandler::onProperties(AMQP::Connection*, const AMQP::Table&, client["information"] = "https://userver.tech/dd/de2/rabbitmq_driver.html"; } -uint16_t AmqpConnectionHandler::onNegotiate(AMQP::Connection*, uint16_t interval) { +std::uint16_t AmqpConnectionHandler::onNegotiate(AMQP::Connection*, std::uint16_t interval) { const auto negotiated = std::min(interval, configured_heartbeat_seconds_); negotiated_heartbeat_seconds_.store(negotiated, std::memory_order_relaxed); LOG_INFO() << "RabbitMQ heartbeat negotiated at " << negotiated << "s"; return negotiated; } -void AmqpConnectionHandler::onData(AMQP::Connection* connection, const char* buffer, size_t size) { +void AmqpConnectionHandler::onData(AMQP::Connection* connection, const char* buffer, std::size_t size) { if (IsBroken()) { // No further actions can be done return; @@ -191,16 +195,12 @@ void AmqpConnectionHandler::OnConnectionCreated(AmqpConnection* connection, engi throw ConnectionSetupError{"Failed to setup a connection: " + *error_}; } - using namespace std::chrono_literals; - const auto heartbeat_seconds = negotiated_heartbeat_seconds_.load(std::memory_order_relaxed); if (heartbeat_seconds > 0) { - const auto half_interval = - std::chrono::duration_cast(std::chrono::seconds{heartbeat_seconds}) / 2; - - heartbeat_task_.Start("amqp_heartbeat", {half_interval, utils::PeriodicTask::Flags::kNow}, [this] { - SendHeartbeat(); - }); + heartbeat_task_ + .Start("amqp_heartbeat", {HalfInterval(heartbeat_seconds), utils::PeriodicTask::Flags::kNow}, [this] { + SendHeartbeat(); + }); } } @@ -214,9 +214,9 @@ void AmqpConnectionHandler::Invalidate() { broken_ = true; } bool AmqpConnectionHandler::IsBroken() const { return broken_.load(); } -void AmqpConnectionHandler::AccountRead(size_t size) { stats_.AccountRead(size); } +void AmqpConnectionHandler::AccountRead(std::size_t size) { stats_.AccountRead(size); } -void AmqpConnectionHandler::AccountWrite(size_t size) { stats_.AccountWrite(size); } +void AmqpConnectionHandler::AccountWrite(std::size_t size) { stats_.AccountWrite(size); } void AmqpConnectionHandler::SetOperationDeadline(engine::Deadline deadline) { operation_deadline_ = deadline; } @@ -230,9 +230,7 @@ void AmqpConnectionHandler::SendHeartbeat() { } try { - const auto deadline = engine::Deadline::FromDuration(std::chrono::duration_cast( - std::chrono::seconds{configured_heartbeat_seconds_} / 2.0 - )); + const auto deadline = engine::Deadline::FromDuration(HalfInterval(configured_heartbeat_seconds_)); auto lock = AmqpConnectionLocker{*connection_}.Lock(deadline); connection_->SetOperationDeadline(deadline); connection_->GetNative().heartbeat(); diff --git a/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.hpp b/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.hpp index c9f87c541833..9ad5738b51f2 100644 --- a/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.hpp +++ b/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.hpp @@ -48,7 +48,7 @@ class AmqpConnectionHandler final : public AMQP::ConnectionHandler { clients::dns::Resolver& resolver, const EndpointInfo& endpoint, const AuthSettings& auth_settings, - size_t heartbeat_interval_seconds, + std::size_t heartbeat_interval_seconds, bool secure, statistics::ConnectionStatistics& stats, engine::Deadline deadline @@ -56,9 +56,9 @@ class AmqpConnectionHandler final : public AMQP::ConnectionHandler { ~AmqpConnectionHandler() override; void onProperties(AMQP::Connection* connection, const AMQP::Table& server, AMQP::Table& client) override; - uint16_t onNegotiate(AMQP::Connection* connection, uint16_t interval) override; + std::uint16_t onNegotiate(AMQP::Connection* connection, std::uint16_t interval) override; - void onData(AMQP::Connection* connection, const char* buffer, size_t size) override; + void onData(AMQP::Connection* connection, const char* buffer, std::size_t size) override; void onError(AMQP::Connection* connection, const char* message) override; @@ -74,8 +74,8 @@ class AmqpConnectionHandler final : public AMQP::ConnectionHandler { void SetOperationDeadline(engine::Deadline deadline); - void AccountRead(size_t size); - void AccountWrite(size_t size); + void AccountRead(std::size_t size); + void AccountWrite(std::size_t size); statistics::ConnectionStatistics& GetStatistics(); diff --git a/rabbitmq/src/urabbitmq/impl/header_value.cpp b/rabbitmq/src/urabbitmq/impl/header_value.cpp index 440627809466..87d16b3075ac 100644 --- a/rabbitmq/src/urabbitmq/impl/header_value.cpp +++ b/rabbitmq/src/urabbitmq/impl/header_value.cpp @@ -1,10 +1,10 @@ #include "header_value.hpp" #include -#include #include #include #include +#include #include @@ -18,32 +18,56 @@ namespace urabbitmq::impl { namespace { -formats::json::Value ToJsonValue(const AMQP::Array& array); -formats::json::Value ToJsonValue(const AMQP::Table& table); -std::unique_ptr ToAmqpField(const HeaderValue& value); - -[[noreturn]] void ThrowUnsupportedAmqpField(char type_id) { - throw std::runtime_error{fmt::format("Unsupported AMQP header field type '{}'", type_id)}; +template +HeaderValue MakeHeaderValue(T&& value) { + return HeaderValue::Builder{std::forward(value)}.ExtractValue(); } -[[noreturn]] void ThrowUnsupportedHeaderValue(const HeaderValue& value) { - throw std::runtime_error{fmt::format("Unsupported RabbitMQ header value at '{}'", value.GetPath())}; -} +AMQP::Array ToAmqpArray(const HeaderValue& value); +AMQP::Table ToAmqpTable(const HeaderValue& value); -formats::json::Value ToJsonValue(const AMQP::Array& array) { - formats::json::ValueBuilder builder{formats::common::Type::kArray}; - const auto count = array.count(); - for (std::uint32_t index = 0; index < count; ++index) { - builder.PushBack(formats::json::ValueBuilder{FieldToHeaderValue(array[static_cast(index)])}); +template +decltype(auto) WithAmqpField(const HeaderValue& value, Func&& func) { + if (value.IsNull()) { + return std::forward(func)(AMQP::VoidField{}); + } + if (value.IsBool()) { + return std::forward(func)(AMQP::BooleanSet{value.As()}); + } + if (value.IsInt()) { + return std::forward(func)(AMQP::Long{value.As()}); + } + if (value.IsInt64()) { + return std::forward(func)(AMQP::LongLong{value.As()}); + } + if (value.IsUInt()) { + return std::forward(func)(AMQP::ULong{value.As()}); + } + if (value.IsUInt64()) { + return std::forward(func)(AMQP::ULongLong{value.As()}); + } + if (value.IsDouble()) { + return std::forward(func)(AMQP::Double{value.As()}); + } + if (value.IsString()) { + return std::forward(func)(AMQP::LongString{value.As()}); + } + if (value.IsArray()) { + auto array = ToAmqpArray(value); + return std::forward(func)(array); + } + if (value.IsObject()) { + auto table = ToAmqpTable(value); + return std::forward(func)(table); } - return builder.ExtractValue(); + throw std::runtime_error{fmt::format("Unsupported RabbitMQ header value at '{}'", value.GetPath())}; } -formats::json::Value ToJsonValue(const AMQP::Table& table) { - formats::json::ValueBuilder builder{formats::common::Type::kObject}; - for (const auto& key : table.keys()) { - builder.EmplaceNocheck(key, formats::json::ValueBuilder{FieldToHeaderValue(table[key])}); +HeaderValue ToHeaderValueFromArray(const AMQP::Array& array) { + HeaderValue::Builder builder{formats::common::Type::kArray}; + for (std::uint32_t index = 0; index < array.count(); ++index) { + builder.PushBack(impl::ToHeaderValue(array[index])); } return builder.ExtractValue(); @@ -52,7 +76,7 @@ formats::json::Value ToJsonValue(const AMQP::Table& table) { AMQP::Array ToAmqpArray(const HeaderValue& value) { AMQP::Array array; for (const auto& item : value) { - array.push_back(*ToAmqpField(item)); + WithAmqpField(item, [&array](const AMQP::Field& field) { array.push_back(field); }); } return array; @@ -61,87 +85,61 @@ AMQP::Array ToAmqpArray(const HeaderValue& value) { AMQP::Table ToAmqpTable(const HeaderValue& value) { AMQP::Table table; for (const auto& [key, item] : formats::common::Items(value)) { - table.set(key, *ToAmqpField(item)); + WithAmqpField(item, [&table, &key](const AMQP::Field& field) { table.set(key, field); }); } return table; } -std::unique_ptr ToAmqpField(const HeaderValue& value) { - if (value.IsNull()) { - return std::make_unique(); - } - if (value.IsBool()) { - return std::make_unique(value.As()); - } - if (value.IsInt() || value.IsInt64()) { - return std::make_unique(value.As()); - } - if (value.IsUInt() || value.IsUInt64()) { - return std::make_unique(value.As()); - } - if (value.IsDouble()) { - return std::make_unique(value.As()); - } - if (value.IsString()) { - return std::make_unique(value.As()); - } - if (value.IsArray()) { - return std::make_unique(ToAmqpArray(value)); - } - if (value.IsObject()) { - return std::make_unique(ToAmqpTable(value)); +HeaderValue ToHeaderValueFromTable(const AMQP::Table& table) { + HeaderValue::Builder builder{formats::common::Type::kObject}; + for (const auto& key : table.keys()) { + builder.EmplaceNocheck(key, impl::ToHeaderValue(table[key])); } - ThrowUnsupportedHeaderValue(value); + return builder.ExtractValue(); } } // namespace -HeaderValue FieldToHeaderValue(const AMQP::Field& field) { +HeaderValue ToHeaderValue(const AMQP::Field& field) { switch (field.typeID()) { case 'S': case 's': - return formats::json::ValueBuilder{static_cast(field)}.ExtractValue(); + return MakeHeaderValue(static_cast(field)); case 't': - return formats::json::ValueBuilder{static_cast(field).value() != 0}.ExtractValue(); + return MakeHeaderValue(static_cast(field).value() != 0); case 'B': - return formats::json::ValueBuilder{static_cast(static_cast(field))} - .ExtractValue(); + return MakeHeaderValue(static_cast(field)); case 'b': - return formats::json::ValueBuilder{static_cast(static_cast(field))} - .ExtractValue(); + return MakeHeaderValue(static_cast(field)); case 'u': - return formats::json::ValueBuilder{static_cast(static_cast(field))} - .ExtractValue(); + return MakeHeaderValue(static_cast(field)); case 'U': - return formats::json::ValueBuilder{static_cast(static_cast(field))} - .ExtractValue(); + return MakeHeaderValue(static_cast(field)); case 'i': - return formats::json::ValueBuilder{static_cast(static_cast(field))} - .ExtractValue(); + return MakeHeaderValue(static_cast(field)); case 'I': - return formats::json::ValueBuilder{static_cast(static_cast(field))} - .ExtractValue(); + return MakeHeaderValue(static_cast(field)); case 'l': case 'T': - return formats::json::ValueBuilder{static_cast(field)}.ExtractValue(); + return MakeHeaderValue(static_cast(field)); case 'L': - return formats::json::ValueBuilder{static_cast(field)}.ExtractValue(); + return MakeHeaderValue(static_cast(field)); case 'f': - return formats::json::ValueBuilder{static_cast(static_cast(field))}.ExtractValue(); + return MakeHeaderValue(static_cast(static_cast(field))); case 'd': case 'D': - return formats::json::ValueBuilder{static_cast(field)}.ExtractValue(); + return MakeHeaderValue(static_cast(field)); case 'A': - return ToJsonValue(static_cast(field)); + return ToHeaderValueFromArray(static_cast(field)); case 'F': - return ToJsonValue(static_cast(field)); + return ToHeaderValueFromTable(static_cast(field)); case 'V': - return formats::json::ValueBuilder{}.ExtractValue(); + return HeaderValue::Builder{}.ExtractValue(); } - ThrowUnsupportedAmqpField(field.typeID()); + throw std::runtime_error{fmt::format("Unsupported AMQP header field type '{}'", field.typeID())}; } std::unordered_map TableToHeaders(const AMQP::Table& table) { @@ -151,7 +149,7 @@ std::unordered_map TableToHeaders(const AMQP::Table& t headers.reserve(keys.size()); for (const auto& key : keys) { - headers.emplace(key, FieldToHeaderValue(table[key])); + headers.emplace(key, ToHeaderValue(table[key])); } return headers; @@ -159,7 +157,7 @@ std::unordered_map TableToHeaders(const AMQP::Table& t void AddHeadersToTable(AMQP::Table& table, const std::unordered_map& headers) { for (const auto& [key, value] : headers) { - table.set(key, *ToAmqpField(value)); + WithAmqpField(value, [&table, &key](const AMQP::Field& field) { table.set(key, field); }); } } diff --git a/rabbitmq/src/urabbitmq/impl/header_value.hpp b/rabbitmq/src/urabbitmq/impl/header_value.hpp index abea1f2f04f6..9c0d0705f0a2 100644 --- a/rabbitmq/src/urabbitmq/impl/header_value.hpp +++ b/rabbitmq/src/urabbitmq/impl/header_value.hpp @@ -11,7 +11,7 @@ USERVER_NAMESPACE_BEGIN namespace urabbitmq::impl { -HeaderValue FieldToHeaderValue(const AMQP::Field& field); +HeaderValue ToHeaderValue(const AMQP::Field& field); std::unordered_map TableToHeaders(const AMQP::Table& table);