diff --git a/rabbitmq/functional_tests/basic_chaos/rabbitmq_service.cpp b/rabbitmq/functional_tests/basic_chaos/rabbitmq_service.cpp index 32d00633ba41..b1faaae3473d 100644 --- a/rabbitmq/functional_tests/basic_chaos/rabbitmq_service.cpp +++ b/rabbitmq/functional_tests/basic_chaos/rabbitmq_service.cpp @@ -7,6 +7,8 @@ #include #include #include +#include +#include #include #include #include @@ -26,8 +28,7 @@ class ChaosProducer final : public components::LoggableComponentBase { ChaosProducer(const components::ComponentConfig& config, const components::ComponentContext& context) : components::LoggableComponentBase{config, context}, - rabbit_client_{context.FindComponent("chaos-rabbit").GetClient()} - { + rabbit_client_{context.FindComponent("chaos-rabbit").GetClient()} { const auto setup_deadline = engine::Deadline::FromDuration(kDefaultOperationTimeout); auto admin_channel = rabbit_client_->GetAdminChannel(setup_deadline); @@ -77,9 +78,7 @@ class ChaosConsumer final : public components::ComponentBase { static constexpr std::string_view kName{"chaos-consumer"}; ChaosConsumer(const components::ComponentConfig& config, const components::ComponentContext& context) - : components::ComponentBase{config, context}, - consumer_{config, context, messages_} - { + : components::ComponentBase{config, context}, consumer_{config, context, messages_} { Start(); } @@ -119,8 +118,7 @@ class ChaosConsumer final : public components::ComponentBase { ) : urabbitmq:: ConsumerBase{context.FindComponent(config["rabbit_name"].As()).GetClient(), ParseSettings(config)}, - messages_{messages} - {} + messages_{messages} {} protected: void Process(urabbitmq::ConsumedMessage msg) override { @@ -150,8 +148,7 @@ class ChaosHandler final : public server::handlers::HttpHandlerBase { ChaosHandler(const components::ComponentConfig& config, const components::ComponentContext& context) : server::handlers::HttpHandlerBase{config, context}, producer_{context.FindComponent()}, - consumer_{context.FindComponent()} - {} + consumer_{context.FindComponent()} {} std::string HandleRequestThrow(const server::http::HttpRequest& request, server::request::RequestContext&) const override { @@ -178,6 +175,13 @@ class ChaosHandler final : public server::handlers::HttpHandlerBase { throw server::handlers::ClientError{server::handlers::ExternalBody{"No 'message' query argument"}}; } urabbitmq::Envelope envelope{message, urabbitmq::MessageType::kTransient, {}, {}, {}}; + if (!request.RequestBody().empty()) { + const auto request_json = formats::json::FromString(request.RequestBody()); + if (request_json.HasMember("headers")) { + envelope + .headers = request_json["headers"].As>(); + } + } const auto& correlation_id = request.GetArg("correlation_id"); if (!correlation_id.empty()) { envelope.correlation_id = correlation_id; @@ -219,9 +223,9 @@ class ChaosHandler final : public server::handlers::HttpHandlerBase { } std::string HandleGet() const { - formats::json::ValueBuilder messages_builder; + urabbitmq::HeaderValue::Builder messages_builder; for (const auto& item : consumer_.GetMessages()) { - formats::json::ValueBuilder item_builder; + urabbitmq::HeaderValue::Builder item_builder; item_builder["message"] = item.message; if (item.correlation_id.has_value()) { item_builder["correlation_id"] = item.correlation_id; @@ -229,6 +233,7 @@ class ChaosHandler final : public server::handlers::HttpHandlerBase { if (item.reply_to.has_value()) { item_builder["reply_to"] = item.reply_to; } + item_builder["headers"] = item.headers; messages_builder.PushBack(std::move(item_builder)); } return formats::json::ToString(messages_builder.ExtractValue()); diff --git a/rabbitmq/functional_tests/basic_chaos/static_config.yaml b/rabbitmq/functional_tests/basic_chaos/static_config.yaml index e509f78aa359..3da63680e54d 100644 --- a/rabbitmq/functional_tests/basic_chaos/static_config.yaml +++ b/rabbitmq/functional_tests/basic_chaos/static_config.yaml @@ -17,6 +17,7 @@ components_manager: min_pool_size: 1 max_pool_size: 1 max_in_flight_requests: 5 + heartbeat_interval_seconds: 1 use_secure_connection: false secdist: {} # Component that stores configuration of hosts and passwords diff --git a/rabbitmq/functional_tests/basic_chaos/tests/test_rabbitmq.py b/rabbitmq/functional_tests/basic_chaos/tests/test_rabbitmq.py index 9718e0d35e39..1db29a2eb3cb 100644 --- a/rabbitmq/functional_tests/basic_chaos/tests/test_rabbitmq.py +++ b/rabbitmq/functional_tests/basic_chaos/tests/test_rabbitmq.py @@ -37,6 +37,10 @@ async def _clear_messages(service_client): assert response.status_code == 200 +def _strip_headers(messages): + return [{key: value for key, value in message.items() if key != 'headers'} for message in messages] + + async def _publish_and_consume(testpoint, client): @testpoint('message_consumed') def message_consumed(data): @@ -51,7 +55,7 @@ def message_consumed(data): response = await client.get('/v1/messages') assert response.status_code == 200 - assert response.json() == MESSAGES + assert _strip_headers(response.json()) == MESSAGES async def test_rabbitmq_happy(testpoint, service_client, gate): @@ -60,6 +64,55 @@ async def test_rabbitmq_happy(testpoint, service_client, gate): await _publish_and_consume(testpoint, service_client) +async def test_rabbitmq_headers(testpoint, service_client, gate): + await _clear_messages(service_client) + + @testpoint('message_consumed') + def message_consumed(data): + pass + + expected_headers = { + 'x-bool': True, + 'x-int': -10, + 'x-uint': 10, + 'x-double': 2.5, + 'x-array': [-7, 'array-value', {'enabled': False, 'nullable': None}], + 'x-object': { + 'count': 42, + 'name': 'nested-object', + 'array': [-7, 'array-value', {'enabled': False, 'nullable': None}], + }, + 'x-null': None, + } + + response = await service_client.post( + '/v1/messages?message=headers&reliable=1&reply_to=reply&correlation_id=corr-id', + json={'headers': expected_headers}, + ) + assert response.status_code == 200 + + await message_consumed.wait_call() + + response = await service_client.get('/v1/messages') + assert response.status_code == 200 + messages = response.json() + assert len(messages) == 1 + + consumed = messages[0] + assert consumed['message'] == 'headers' + assert consumed['reply_to'] == 'reply' + assert consumed['correlation_id'] == 'corr-id' + assert consumed['headers']['x-bool'] is True + assert consumed['headers']['x-int'] == -10 + assert consumed['headers']['x-uint'] == 10 + assert consumed['headers']['x-double'] == 2.5 + assert consumed['headers']['x-array'] == expected_headers['x-array'] + assert consumed['headers']['x-object'] == expected_headers['x-object'] + assert consumed['headers']['x-null'] is None + assert consumed['headers']['u-trace-id'] + assert consumed['headers']['u-parent-span-id'] + + @pytest.mark.skip(reason='std::terminate is called, fix in TAXICOMMON-6995') async def test_consumer_reconnects(testpoint, service_client, gate): await _clear_messages(service_client) @@ -85,4 +138,35 @@ def message_consumed(data): response = await service_client.get('/v1/messages') assert response.status_code == 200 - assert response.json() == MESSAGES + assert _strip_headers(response.json()) == MESSAGES + + +async def test_rabbitmq_heartbeat_reconnects(testpoint, service_client, gate): + await _clear_messages(service_client) + + @testpoint('message_consumed') + def message_consumed(data): + pass + + response = await service_client.post('/v1/messages?message=before-heartbeat') + assert response.status_code == 200 + await message_consumed.wait_call() + + await gate.to_server_noop() + await gate.to_client_noop() + await asyncio.sleep(2.5) + await gate.to_server_pass() + await gate.to_client_pass() + await gate.sockets_close() + + await gate.wait_for_connections(timeout=10.0) + await asyncio.sleep(1.0) + + response = await service_client.post('/v1/messages?message=after-heartbeat') + assert response.status_code == 200 + + await message_consumed.wait_call() + + response = await service_client.get('/v1/messages') + assert response.status_code == 200 + assert any(message['message'] == 'after-heartbeat' for message in response.json()) diff --git a/rabbitmq/include/userver/urabbitmq/client_settings.hpp b/rabbitmq/include/userver/urabbitmq/client_settings.hpp index 5a0122bb10c1..5ebee8b3d538 100644 --- a/rabbitmq/include/userver/urabbitmq/client_settings.hpp +++ b/rabbitmq/include/userver/urabbitmq/client_settings.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -74,6 +75,10 @@ struct PoolSettings final { /// (tcp error/protocol error/write timeout) leads to a errors burst: /// all outstanding request will fails at once size_t max_in_flight_requests = 5; + + /// Requested AMQP heartbeat interval in seconds. + /// Set to 0 to disable heartbeats. + size_t heartbeat_interval_seconds = 60; }; class TestsHelper; diff --git a/rabbitmq/include/userver/urabbitmq/typedefs.hpp b/rabbitmq/include/userver/urabbitmq/typedefs.hpp index 4714339b802d..0339868d5a42 100644 --- a/rabbitmq/include/userver/urabbitmq/typedefs.hpp +++ b/rabbitmq/include/userver/urabbitmq/typedefs.hpp @@ -4,7 +4,12 @@ /// @brief Convenient typedefs for RabbitMQ entities. #include +#include +#include +#include +#include +#include #include USERVER_NAMESPACE_BEGIN @@ -63,6 +68,10 @@ enum class MessageType { kTransient, }; +/// JSON-like representation of an AMQP header value. +/// This is not JSON, but a convenient tree representation for AMQP field values. +using HeaderValue = formats::json::Value; + /// @brief Structure holding an AMQP message body along with some of its /// metadata fields. This struct is used to pass messages to the end user, /// hiding the actual AMQP message object implementation. @@ -75,6 +84,7 @@ struct ConsumedMessage { Metadata metadata; std::optional reply_to{}; std::optional correlation_id{}; + std::unordered_map headers{}; }; /// @brief Structure holding an AMQP message body along with some of its @@ -86,6 +96,7 @@ struct Envelope { std::optional reply_to{}; std::optional correlation_id{}; std::optional expiration{}; + std::optional> headers{}; }; } // namespace urabbitmq diff --git a/rabbitmq/src/tests/header_value_rmqtest.cpp b/rabbitmq/src/tests/header_value_rmqtest.cpp new file mode 100644 index 000000000000..bddb480443d2 --- /dev/null +++ b/rabbitmq/src/tests/header_value_rmqtest.cpp @@ -0,0 +1,125 @@ +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include + +USERVER_NAMESPACE_BEGIN + +namespace { + +template +urabbitmq::HeaderValue MakeHeaderValue(T&& value) { + return urabbitmq::HeaderValue::Builder{std::forward(value)}.ExtractValue(); +} + +urabbitmq::HeaderValue MakeNestedArrayValue() { + urabbitmq::HeaderValue::Builder builder{formats::common::Type::kArray}; + builder.PushBack(std::int64_t{-7}); + builder.PushBack("array-value"); + + urabbitmq::HeaderValue::Builder nested_object{formats::common::Type::kObject}; + nested_object["enabled"] = false; + nested_object["nullable"] = urabbitmq::HeaderValue::Builder{}; + builder.PushBack(std::move(nested_object)); + + return builder.ExtractValue(); +} + +urabbitmq::HeaderValue MakeNestedObjectValue() { + urabbitmq::HeaderValue::Builder builder{formats::common::Type::kObject}; + builder["count"] = std::uint64_t{42}; + builder["name"] = "nested-object"; + builder["array"] = urabbitmq::HeaderValue::Builder{MakeNestedArrayValue()}; + + return builder.ExtractValue(); +} + +void ExpectHeadersEqual( + const std::unordered_map& actual, + const std::unordered_map& expected +) { + ASSERT_EQ(actual.size(), expected.size()); + for (const auto& [key, expected_value] : expected) { + const auto it = actual.find(key); + ASSERT_NE(it, actual.end()) << "Missing key: " << key; + EXPECT_EQ(it->second, expected_value) << "Unexpected value for key: " << key; + } +} + +} // namespace + +UTEST(HeaderValue, ConvertsNestedAmqpTypes) { + AMQP::Table headers; + headers.set("string", "value"); + headers.set("bool", true); + headers.set("signed", static_cast(-10)); + headers.set("unsigned", static_cast(10)); + headers.set("double", AMQP::Double{1.5}); + headers.set("null", nullptr); + + AMQP::Array nested_array; + nested_array.push_back(AMQP::LongLong{-7}); + nested_array.push_back(AMQP::LongString{"array-value"}); + AMQP::Table nested_array_object; + nested_array_object.set("enabled", false); + nested_array_object.set("nullable", nullptr); + nested_array.push_back(nested_array_object); + headers.set("array", nested_array); + + AMQP::Table nested_object; + nested_object.set("count", static_cast(42)); + nested_object.set("name", "nested-object"); + nested_object.set("array", nested_array); + headers.set("object", nested_object); + + const std::unordered_map expected{ + {"string", MakeHeaderValue("value")}, + {"bool", MakeHeaderValue(true)}, + {"signed", MakeHeaderValue(-10)}, + {"unsigned", MakeHeaderValue(10u)}, + {"double", MakeHeaderValue(1.5)}, + {"null", urabbitmq::HeaderValue::Builder{}.ExtractValue()}, + {"array", MakeNestedArrayValue()}, + {"object", MakeNestedObjectValue()}, + }; + + const auto actual = urabbitmq::impl::TableToHeaders(headers); + ExpectHeadersEqual(actual, expected); + EXPECT_TRUE(actual.at("signed").IsInt()); + EXPECT_TRUE(actual.at("unsigned").IsUInt()); +} + +UTEST(HeaderValue, RoundTripsHeaders) { + const std::unordered_map expected{ + {"string", MakeHeaderValue("value")}, + {"bool", MakeHeaderValue(false)}, + {"signed", MakeHeaderValue(-123456789)}, + {"signed64", MakeHeaderValue(std::int64_t{-1234567890123})}, + {"unsigned", MakeHeaderValue(123456789u)}, + {"unsigned64", MakeHeaderValue(std::uint64_t{1234567890123})}, + {"double", MakeHeaderValue(3.25)}, + {"null", urabbitmq::HeaderValue::Builder{}.ExtractValue()}, + {"array", MakeNestedArrayValue()}, + {"object", MakeNestedObjectValue()}, + }; + + AMQP::Table table; + urabbitmq::impl::AddHeadersToTable(table, expected); + + const auto actual = urabbitmq::impl::TableToHeaders(table); + ExpectHeadersEqual(actual, expected); + EXPECT_TRUE(actual.at("signed").IsInt()); + EXPECT_TRUE(actual.at("signed64").IsInt64()); + EXPECT_TRUE(actual.at("unsigned").IsUInt()); + EXPECT_TRUE(actual.at("unsigned64").IsUInt64()); +} + +USERVER_NAMESPACE_END diff --git a/rabbitmq/src/tests/publish_consume_rmqtest.cpp b/rabbitmq/src/tests/publish_consume_rmqtest.cpp index 976274cf255c..e40399878898 100644 --- a/rabbitmq/src/tests/publish_consume_rmqtest.cpp +++ b/rabbitmq/src/tests/publish_consume_rmqtest.cpp @@ -1,26 +1,63 @@ #include "utils_rmqtest.hpp" +#include #include +#include +#include #include #include #include #include #include +#include +#include #include USERVER_NAMESPACE_BEGIN namespace { +template +urabbitmq::HeaderValue MakeHeaderValue(T&& value) { + return urabbitmq::HeaderValue::Builder{std::forward(value)}.ExtractValue(); +} + +urabbitmq::HeaderValue MakeNestedArrayValue() { + urabbitmq::HeaderValue::Builder builder{formats::common::Type::kArray}; + builder.PushBack(std::int64_t{-7}); + builder.PushBack("array-value"); + + urabbitmq::HeaderValue::Builder nested_object{formats::common::Type::kObject}; + nested_object["enabled"] = false; + nested_object["nullable"] = urabbitmq::HeaderValue::Builder{}; + builder.PushBack(std::move(nested_object)); + + return builder.ExtractValue(); +} + +urabbitmq::HeaderValue MakeNestedObjectValue() { + urabbitmq::HeaderValue::Builder builder{formats::common::Type::kObject}; + builder["count"] = std::uint64_t{42}; + builder["name"] = "nested-object"; + builder["array"] = urabbitmq::HeaderValue::Builder{MakeNestedArrayValue()}; + + return builder.ExtractValue(); +} + class Consumer final : public urabbitmq::ConsumerBase { public: using urabbitmq::ConsumerBase::ConsumerBase; ~Consumer() override { Stop(); } - void Process(std::string message) override { + void Process(urabbitmq::ConsumedMessage message) override { + const auto plain_message = message.message; { auto locked = messages_.Lock(); + locked->emplace_back(plain_message); + } + { + auto locked = messages_with_metadata_.Lock(); locked->emplace_back(std::move(message)); } @@ -31,22 +68,25 @@ class Consumer final : public urabbitmq::ConsumerBase { void ExpectConsume(size_t count) { expected_consumed_ = count; } - std::vector Wait() { + void Wait() { if (expected_consumed_ != 0) { [[maybe_unused]] auto res = event_.WaitForEventFor(utest::kMaxTestWaitTime); } - - return Get(); } - std::vector Get() { + std::vector GetMessages() { auto locked = messages_.Lock(); return *locked; } + std::vector GetMessagesWithMetadata() { + auto locked = messages_with_metadata_.Lock(); + return *locked; + } + private: - std::optional single_expected_message_; concurrent::Variable> messages_; + concurrent::Variable> messages_with_metadata_; std::atomic expected_consumed_{0}; std::atomic consumed_{0}; engine::SingleConsumerEvent event_; @@ -108,7 +148,8 @@ UTEST(Consumer, ConsumeWorks) { consumer.ExpectConsume(1); consumer.Start(); - auto consumed = consumer.Wait(); + consumer.Wait(); + auto consumed = consumer.GetMessages(); ASSERT_EQ(consumed.size(), 1); EXPECT_EQ(consumed[0], envelope.message); @@ -167,12 +208,13 @@ UTEST(Consumer, ThrowsReturnsToQueue) { good_consumer.Start(); engine::InterruptibleSleepFor(std::chrono::milliseconds{200}); - auto consumed = good_consumer.Get(); + auto consumed = good_consumer.GetMessages(); EXPECT_LT(consumed.size(), messages_count); throwing_consumer.Throw(); throwing_consumer.Stop(); - EXPECT_EQ(good_consumer.Wait().size(), messages_count); + good_consumer.Wait(); + EXPECT_EQ(good_consumer.GetMessages().size(), messages_count); } UTEST(Consumer, MultipleConcurrentWork) { @@ -193,8 +235,8 @@ UTEST(Consumer, MultipleConcurrentWork) { second_consumer.Start(); engine::InterruptibleSleepFor(std::chrono::milliseconds{200}); - EXPECT_GT(first_consumer.Get().size(), 0); - EXPECT_GT(second_consumer.Get().size(), 0); + EXPECT_GT(first_consumer.GetMessages().size(), 0); + EXPECT_GT(second_consumer.GetMessages().size(), 0); } UTEST(Consumer, ForDifferentQueuesWork) { @@ -221,10 +263,113 @@ UTEST(Consumer, ForDifferentQueuesWork) { second_consumer.ExpectConsume(messages_count); second_consumer.Start(); - EXPECT_EQ(first_consumer.Wait().size(), messages_count); - EXPECT_EQ(second_consumer.Wait().size(), messages_count); + first_consumer.Wait(); + second_consumer.Wait(); + EXPECT_EQ(first_consumer.GetMessages().size(), messages_count); + EXPECT_EQ(second_consumer.GetMessages().size(), messages_count); client->GetAdminChannel(client.GetDeadline()).RemoveQueue(second_queue, client.GetDeadline()); } +UTEST(Consumer, ConsumeMetadataAndHeadersWork) { + ClientWrapper client{}; + client.SetupRmqEntities(); + const urabbitmq::ConsumerSettings settings{client.GetQueue(), 10}; + + struct Case { + std::string name; + std::optional reply_to; + std::optional correlation_id; + std::unordered_map headers; + }; + + const std::vector cases{ + {"no-user-headers", std::nullopt, std::nullopt, {}}, + { + "scalar-user-headers", + "reply-queue", + "corr-id", + { + {"x-custom-header", MakeHeaderValue("custom-value")}, + {"x-bool", MakeHeaderValue(true)}, + {"x-int", MakeHeaderValue(-10)}, + {"x-int64", MakeHeaderValue(std::int64_t{-10})}, + {"x-uint", MakeHeaderValue(10u)}, + {"x-uint64", MakeHeaderValue(std::uint64_t{10})}, + {"x-double", MakeHeaderValue(2.5)}, + {"x-null", urabbitmq::HeaderValue::Builder{}.ExtractValue()}, + }, + }, + { + "nested-user-headers", + "reply-nested", + "corr-nested", + { + {"x-array", MakeNestedArrayValue()}, + {"x-object", MakeNestedObjectValue()}, + }, + }, + { + "trace-headers-override", + "reply-override", + "corr-override", + { + {"u-trace-id", MakeHeaderValue("trace-from-user")}, + {"u-parent-span-id", MakeHeaderValue("parent-from-user")}, + {"x-another", MakeHeaderValue("value")}, + }, + }, + }; + + for (const auto& case_data : cases) { + urabbitmq::Envelope envelope{ + "payload-" + case_data.name, + urabbitmq::MessageType::kTransient, + }; + envelope.reply_to = case_data.reply_to; + envelope.correlation_id = case_data.correlation_id; + envelope.headers = case_data.headers; + client->PublishReliable(client.GetExchange(), client.GetRoutingKey(), envelope, client.GetDeadline()); + } + + Consumer consumer{client.Get(), settings}; + consumer.ExpectConsume(cases.size()); + consumer.Start(); + consumer.Wait(); + auto consumed = consumer.GetMessagesWithMetadata(); + + ASSERT_EQ(consumed.size(), cases.size()); + std::unordered_map consumed_by_payload; + consumed_by_payload.reserve(consumed.size()); + for (const auto& msg : consumed) { + consumed_by_payload.emplace(msg.message, &msg); + } + + for (const auto& case_data : cases) { + const auto payload = "payload-" + case_data.name; + const auto it = consumed_by_payload.find(payload); + ASSERT_NE(it, consumed_by_payload.end()) << "Missing consumed payload: " << payload; + + const auto& msg = *it->second; + EXPECT_EQ(msg.message, payload); + EXPECT_EQ(msg.metadata.exchange, client.GetExchange().GetUnderlying()); + EXPECT_EQ(msg.metadata.routingKey, client.GetRoutingKey()); + EXPECT_EQ(msg.reply_to, case_data.reply_to); + EXPECT_EQ(msg.correlation_id, case_data.correlation_id); + + for (const auto& [header_key, header_value] : case_data.headers) { + ASSERT_EQ(msg.headers.count(header_key), 1) << "Missing header '" << header_key << "' in " << payload; + EXPECT_EQ(msg.headers.at(header_key), header_value) + << "Unexpected value for header '" << header_key << "' in " << payload; + } + + ASSERT_EQ(msg.headers.count("u-trace-id"), 1) << "Missing u-trace-id in " << payload; + ASSERT_EQ(msg.headers.count("u-parent-span-id"), 1) << "Missing u-parent-span-id in " << payload; + ASSERT_TRUE(msg.headers.at("u-trace-id").IsString()); + ASSERT_TRUE(msg.headers.at("u-parent-span-id").IsString()); + EXPECT_FALSE(msg.headers.at("u-trace-id").As().empty()); + EXPECT_FALSE(msg.headers.at("u-parent-span-id").As().empty()); + } +} + USERVER_NAMESPACE_END diff --git a/rabbitmq/src/urabbitmq/client_settings.cpp b/rabbitmq/src/urabbitmq/client_settings.cpp index 85d25ba5adf0..ede927fb60e4 100644 --- a/rabbitmq/src/urabbitmq/client_settings.cpp +++ b/rabbitmq/src/urabbitmq/client_settings.cpp @@ -1,5 +1,7 @@ #include +#include +#include #include #include #include @@ -100,9 +102,16 @@ PoolSettings Parse(const yaml_config::YamlConfig& config, formats::parse::To(result.min_pool_size); result.max_pool_size = config["max_pool_size"].As(result.max_pool_size); result.max_in_flight_requests = config["max_in_flight_requests"].As(result.max_in_flight_requests); + result + .heartbeat_interval_seconds = config["heartbeat_interval_seconds"].As(result.heartbeat_interval_seconds + ); UINVARIANT(result.min_pool_size <= result.max_pool_size, "max_pool_size is less than min_pool_size"); UINVARIANT(result.max_pool_size > 0, "max_pool_size is set to zero"); + UINVARIANT( + result.heartbeat_interval_seconds <= std::numeric_limits::max(), + "heartbeat_interval_seconds is too large" + ); return result; } @@ -112,8 +121,7 @@ ClientSettings::ClientSettings() = default; ClientSettings::ClientSettings(const components::ComponentConfig& config, const RabbitEndpoints& rabbit_endpoints) : pool_settings{config.As()}, endpoints{rabbit_endpoints}, - use_secure_connection{config["use_secure_connection"].As(true)} -{} + use_secure_connection{config["use_secure_connection"].As(true)} {} RabbitEndpointsMulti::RabbitEndpointsMulti(const formats::json::Value& doc) { const auto rabbitmq_settings = doc["rabbitmq_settings"]; diff --git a/rabbitmq/src/urabbitmq/component.yaml b/rabbitmq/src/urabbitmq/component.yaml index cdc35b89df77..998116d3d320 100644 --- a/rabbitmq/src/urabbitmq/component.yaml +++ b/rabbitmq/src/urabbitmq/component.yaml @@ -23,6 +23,11 @@ properties: description: | per-connection limit for requests awaiting response from the broker default: 5 + heartbeat_interval_seconds: + type: integer + description: | + requested AMQP heartbeat interval in seconds; 0 disables heartbeats + default: 60 use_secure_connection: type: boolean description: whether to use TLS for connections diff --git a/rabbitmq/src/urabbitmq/connection.cpp b/rabbitmq/src/urabbitmq/connection.cpp index f54e752b5bf4..4d8fcac026ed 100644 --- a/rabbitmq/src/urabbitmq/connection.cpp +++ b/rabbitmq/src/urabbitmq/connection.cpp @@ -9,11 +9,20 @@ Connection::Connection( const EndpointInfo& endpoint, const AuthSettings& auth_settings, size_t max_in_flight_requests, + size_t heartbeat_interval_seconds, bool secure, statistics::ConnectionStatistics& stats, engine::Deadline deadline ) - : handler_{resolver, endpoint, auth_settings, secure, stats, deadline}, + : handler_{ + resolver, + endpoint, + auth_settings, + heartbeat_interval_seconds, + secure, + stats, + deadline, + }, connection_{handler_, max_in_flight_requests, deadline}, channel_{connection_}, reliable_channel_{connection_} diff --git a/rabbitmq/src/urabbitmq/connection.hpp b/rabbitmq/src/urabbitmq/connection.hpp index 372596ffadb9..a53796b948a6 100644 --- a/rabbitmq/src/urabbitmq/connection.hpp +++ b/rabbitmq/src/urabbitmq/connection.hpp @@ -29,6 +29,7 @@ class Connection final { const EndpointInfo& endpoint, const AuthSettings& auth_settings, size_t max_in_flight_requests, + size_t heartbeat_interval_seconds, bool secure, statistics::ConnectionStatistics& stats, engine::Deadline deadline diff --git a/rabbitmq/src/urabbitmq/connection_pool.cpp b/rabbitmq/src/urabbitmq/connection_pool.cpp index e1a9baf3a245..febb66583cc4 100644 --- a/rabbitmq/src/urabbitmq/connection_pool.cpp +++ b/rabbitmq/src/urabbitmq/connection_pool.cpp @@ -92,6 +92,7 @@ ConnectionPool::ConnectionUniquePtr ConnectionPool::DoCreateConnection(engine::D endpoint_info_, auth_settings_, pool_settings_.max_in_flight_requests, + pool_settings_.heartbeat_interval_seconds, use_secure_connection_, stats_, deadline diff --git a/rabbitmq/src/urabbitmq/consumer_base_impl.cpp b/rabbitmq/src/urabbitmq/consumer_base_impl.cpp index 3bb22d1ea41b..dc40e20ae969 100644 --- a/rabbitmq/src/urabbitmq/consumer_base_impl.cpp +++ b/rabbitmq/src/urabbitmq/consumer_base_impl.cpp @@ -12,6 +12,7 @@ #include #include #include +#include USERVER_NAMESPACE_BEGIN @@ -28,8 +29,7 @@ ConsumerBaseImpl::ConsumerBaseImpl(ConnectionPtr&& connection, const ConsumerSet queue_name_{settings.queue.GetUnderlying()}, prefetch_count_{settings.prefetch_count}, connection_ptr_{std::move(connection)}, - channel_{connection_ptr_->GetChannel()} -{ + channel_{connection_ptr_->GetChannel()} { // We take ownership of the connection, because if it remains pooled // things get messy with lifetimes and callbacks connection_ptr_.Adopt(); @@ -93,9 +93,10 @@ void ConsumerBaseImpl::Stop() { bool ConsumerBaseImpl::IsBroken() const { return broken_ || !connection_ptr_.IsUsable(); } void ConsumerBaseImpl::OnMessage(const AMQP::Message& message, uint64_t delivery_tag) { + const auto& headers = message.headers(); std::string span_name{fmt::format("consume_{}_{}", queue_name_, consumer_tag_.value_or("ctag:unknown"))}; - std::string trace_id = message.headers().get("u-trace-id"); - std::string parent_span_id = message.headers().get("u-parent-span-id"); + std::string trace_id = headers.get("u-trace-id"); + std::string parent_span_id = headers.get("u-parent-span-id"); ConsumedMessage consumed; consumed.message = std::string(message.body(), message.bodySize()); consumed.metadata.exchange = message.exchange(); @@ -107,6 +108,8 @@ void ConsumerBaseImpl::OnMessage(const AMQP::Message& message, uint64_t delivery consumed.correlation_id = message.correlationID(); } + consumed.headers = impl::TableToHeaders(headers); + bts_.Detach(engine::AsyncNoSpan( dispatcher_, [this, diff --git a/rabbitmq/src/urabbitmq/impl/amqp_channel.cpp b/rabbitmq/src/urabbitmq/impl/amqp_channel.cpp index c374dc41bd5d..69ce974ed119 100644 --- a/rabbitmq/src/urabbitmq/impl/amqp_channel.cpp +++ b/rabbitmq/src/urabbitmq/impl/amqp_channel.cpp @@ -7,6 +7,7 @@ #include #include +#include #include USERVER_NAMESPACE_BEGIN @@ -94,11 +95,18 @@ AMQP::Table CreateHeaders() { return headers; } +AMQP::Table CreateHeadersForPublish(const Envelope& envelope) { + auto headers = CreateHeaders(); + if (envelope.headers.has_value()) { + AddHeadersToTable(headers, *envelope.headers); + } + + return headers; +} + } // namespace -AmqpChannel::AmqpChannel(AmqpConnection& conn) - : conn_{conn} -{} +AmqpChannel::AmqpChannel(AmqpConnection& conn) : conn_{conn} {} AmqpChannel::~AmqpChannel() = default; @@ -196,7 +204,7 @@ void AmqpChannel::Publish( ) { AMQP::Envelope native_envelope{envelope.message.data(), envelope.message.size()}; native_envelope.setPersistent(envelope.type == MessageType::kPersistent); - native_envelope.setHeaders(CreateHeaders()); + native_envelope.setHeaders(CreateHeadersForPublish(envelope)); if (envelope.reply_to.has_value()) { native_envelope.setReplyTo(envelope.reply_to.value().c_str()); } @@ -262,9 +270,7 @@ void AmqpChannel::CancelConsumer(const std::optional& consumer_tag) void AmqpChannel::AccountMessageConsumed() { conn_.GetStatistics().AccountMessageConsumed(); } -AmqpReliableChannel::AmqpReliableChannel(AmqpConnection& conn) - : conn_{conn} -{} +AmqpReliableChannel::AmqpReliableChannel(AmqpConnection& conn) : conn_{conn} {} AmqpReliableChannel::~AmqpReliableChannel() = default; @@ -285,7 +291,7 @@ ResponseAwaiter AmqpReliableChannel::Publish( if (envelope.expiration.has_value()) { native_envelope.setExpiration(std::to_string(envelope.expiration.value().count())); } - native_envelope.setHeaders(CreateHeaders()); + native_envelope.setHeaders(CreateHeadersForPublish(envelope)); auto awaiter = conn_.GetAwaiter(deadline); diff --git a/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp b/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp index e0857e5510da..81975639bd33 100644 --- a/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp +++ b/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp @@ -1,5 +1,8 @@ #include "amqp_connection_handler.hpp" +#include +#include + #include #include @@ -11,6 +14,7 @@ #include #include +#include #include USERVER_NAMESPACE_BEGIN @@ -85,12 +89,17 @@ AMQP::Address ToAmqpAddress(const EndpointInfo& endpoint, const AuthSettings& se return {endpoint.host, endpoint.port, AMQP::Login{settings.login, settings.password}, settings.vhost, secure}; } +std::chrono::milliseconds HalfInterval(std::uint16_t interval_seconds) { + return std::chrono::duration_cast(std::chrono::seconds{interval_seconds} / 2.0); +} + } // namespace AmqpConnectionHandler::AmqpConnectionHandler( clients::dns::Resolver& resolver, const EndpointInfo& endpoint, const AuthSettings& auth_settings, + std::size_t heartbeat_interval_seconds, bool secure, statistics::ConnectionStatistics& stats, engine::Deadline deadline @@ -98,10 +107,14 @@ AmqpConnectionHandler::AmqpConnectionHandler( : address_{ToAmqpAddress(endpoint, auth_settings, secure)}, socket_{CreateSocketPtr(resolver, address_, auth_settings, deadline)}, reader_{*this, *socket_}, - stats_{stats} -{} + configured_heartbeat_seconds_{static_cast< + std::uint16_t>(std::min(heartbeat_interval_seconds, std::numeric_limits::max()))}, + stats_{stats} {} -AmqpConnectionHandler::~AmqpConnectionHandler() { reader_.Stop(); } +AmqpConnectionHandler::~AmqpConnectionHandler() { + heartbeat_task_.Stop(); + reader_.Stop(); +} void AmqpConnectionHandler::onProperties(AMQP::Connection*, const AMQP::Table&, AMQP::Table& client) { client["product"] = "uServer AMQP library"; @@ -109,7 +122,14 @@ void AmqpConnectionHandler::onProperties(AMQP::Connection*, const AMQP::Table&, client["information"] = "https://userver.tech/dd/de2/rabbitmq_driver.html"; } -void AmqpConnectionHandler::onData(AMQP::Connection* connection, const char* buffer, size_t size) { +std::uint16_t AmqpConnectionHandler::onNegotiate(AMQP::Connection*, std::uint16_t interval) { + const auto negotiated = std::min(interval, configured_heartbeat_seconds_); + negotiated_heartbeat_seconds_.store(negotiated, std::memory_order_relaxed); + LOG_INFO() << "RabbitMQ heartbeat negotiated at " << negotiated << "s"; + return negotiated; +} + +void AmqpConnectionHandler::onData(AMQP::Connection* connection, const char* buffer, std::size_t size) { if (IsBroken()) { // No further actions can be done return; @@ -160,28 +180,43 @@ void AmqpConnectionHandler::onReady(AMQP::Connection*) { } void AmqpConnectionHandler::OnConnectionCreated(AmqpConnection* connection, engine::Deadline deadline) { + connection_ = connection; reader_.Start(connection); if (!connection_ready_event_.WaitForEventUntil(deadline)) { reader_.Stop(); + connection_ = nullptr; throw ConnectionSetupTimeout{"Failed to setup a connection within specified deadline"}; } if (error_.has_value()) { reader_.Stop(); + connection_ = nullptr; throw ConnectionSetupError{"Failed to setup a connection: " + *error_}; } + + const auto heartbeat_seconds = negotiated_heartbeat_seconds_.load(std::memory_order_relaxed); + if (heartbeat_seconds > 0) { + heartbeat_task_ + .Start("amqp_heartbeat", {HalfInterval(heartbeat_seconds), utils::PeriodicTask::Flags::kNow}, [this] { + SendHeartbeat(); + }); + } } -void AmqpConnectionHandler::OnConnectionDestruction() { reader_.Stop(); } +void AmqpConnectionHandler::OnConnectionDestruction() { + heartbeat_task_.Stop(); + connection_ = nullptr; + reader_.Stop(); +} void AmqpConnectionHandler::Invalidate() { broken_ = true; } bool AmqpConnectionHandler::IsBroken() const { return broken_.load(); } -void AmqpConnectionHandler::AccountRead(size_t size) { stats_.AccountRead(size); } +void AmqpConnectionHandler::AccountRead(std::size_t size) { stats_.AccountRead(size); } -void AmqpConnectionHandler::AccountWrite(size_t size) { stats_.AccountWrite(size); } +void AmqpConnectionHandler::AccountWrite(std::size_t size) { stats_.AccountWrite(size); } void AmqpConnectionHandler::SetOperationDeadline(engine::Deadline deadline) { operation_deadline_ = deadline; } @@ -189,6 +224,26 @@ statistics::ConnectionStatistics& AmqpConnectionHandler::GetStatistics() { retur const AMQP::Address& AmqpConnectionHandler::GetAddress() const { return address_; } +void AmqpConnectionHandler::SendHeartbeat() { + if (IsBroken() || connection_ == nullptr) { + return; + } + + try { + const auto deadline = engine::Deadline::FromDuration(HalfInterval(configured_heartbeat_seconds_)); + auto lock = AmqpConnectionLocker{*connection_}.Lock(deadline); + connection_->SetOperationDeadline(deadline); + connection_->GetNative().heartbeat(); + } catch (const std::exception& ex) { + LOG_WARNING() << "Failed to send AMQP heartbeat: " << ex.what(); + Invalidate(); + if (connection_ != nullptr) { + auto lock = AmqpConnectionLocker{*connection_}.Lock({}); + connection_->GetNative().fail("Underlying connection broke."); + } + } +} + } // namespace urabbitmq::impl USERVER_NAMESPACE_END diff --git a/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.hpp b/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.hpp index 31b7a8a88806..9ad5738b51f2 100644 --- a/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.hpp +++ b/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.hpp @@ -1,5 +1,7 @@ #pragma once +#include +#include #include #include #include @@ -7,6 +9,7 @@ #include #include +#include #include @@ -45,6 +48,7 @@ class AmqpConnectionHandler final : public AMQP::ConnectionHandler { clients::dns::Resolver& resolver, const EndpointInfo& endpoint, const AuthSettings& auth_settings, + std::size_t heartbeat_interval_seconds, bool secure, statistics::ConnectionStatistics& stats, engine::Deadline deadline @@ -52,8 +56,9 @@ class AmqpConnectionHandler final : public AMQP::ConnectionHandler { ~AmqpConnectionHandler() override; void onProperties(AMQP::Connection* connection, const AMQP::Table& server, AMQP::Table& client) override; + std::uint16_t onNegotiate(AMQP::Connection* connection, std::uint16_t interval) override; - void onData(AMQP::Connection* connection, const char* buffer, size_t size) override; + void onData(AMQP::Connection* connection, const char* buffer, std::size_t size) override; void onError(AMQP::Connection* connection, const char* message) override; @@ -69,17 +74,23 @@ class AmqpConnectionHandler final : public AMQP::ConnectionHandler { void SetOperationDeadline(engine::Deadline deadline); - void AccountRead(size_t size); - void AccountWrite(size_t size); + void AccountRead(std::size_t size); + void AccountWrite(std::size_t size); statistics::ConnectionStatistics& GetStatistics(); const AMQP::Address& GetAddress() const; private: + void SendHeartbeat(); + AMQP::Address address_; std::unique_ptr socket_; io::SocketReader reader_; + utils::PeriodicTask heartbeat_task_; + AmqpConnection* connection_{nullptr}; + std::atomic negotiated_heartbeat_seconds_{0}; + std::uint16_t configured_heartbeat_seconds_{0}; engine::SingleConsumerEvent connection_ready_event_; std::atomic broken_{false}; diff --git a/rabbitmq/src/urabbitmq/impl/header_value.cpp b/rabbitmq/src/urabbitmq/impl/header_value.cpp new file mode 100644 index 000000000000..87d16b3075ac --- /dev/null +++ b/rabbitmq/src/urabbitmq/impl/header_value.cpp @@ -0,0 +1,166 @@ +#include "header_value.hpp" + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +USERVER_NAMESPACE_BEGIN + +namespace urabbitmq::impl { + +namespace { + +template +HeaderValue MakeHeaderValue(T&& value) { + return HeaderValue::Builder{std::forward(value)}.ExtractValue(); +} + +AMQP::Array ToAmqpArray(const HeaderValue& value); +AMQP::Table ToAmqpTable(const HeaderValue& value); + +template +decltype(auto) WithAmqpField(const HeaderValue& value, Func&& func) { + if (value.IsNull()) { + return std::forward(func)(AMQP::VoidField{}); + } + if (value.IsBool()) { + return std::forward(func)(AMQP::BooleanSet{value.As()}); + } + if (value.IsInt()) { + return std::forward(func)(AMQP::Long{value.As()}); + } + if (value.IsInt64()) { + return std::forward(func)(AMQP::LongLong{value.As()}); + } + if (value.IsUInt()) { + return std::forward(func)(AMQP::ULong{value.As()}); + } + if (value.IsUInt64()) { + return std::forward(func)(AMQP::ULongLong{value.As()}); + } + if (value.IsDouble()) { + return std::forward(func)(AMQP::Double{value.As()}); + } + if (value.IsString()) { + return std::forward(func)(AMQP::LongString{value.As()}); + } + if (value.IsArray()) { + auto array = ToAmqpArray(value); + return std::forward(func)(array); + } + if (value.IsObject()) { + auto table = ToAmqpTable(value); + return std::forward(func)(table); + } + + throw std::runtime_error{fmt::format("Unsupported RabbitMQ header value at '{}'", value.GetPath())}; +} + +HeaderValue ToHeaderValueFromArray(const AMQP::Array& array) { + HeaderValue::Builder builder{formats::common::Type::kArray}; + for (std::uint32_t index = 0; index < array.count(); ++index) { + builder.PushBack(impl::ToHeaderValue(array[index])); + } + + return builder.ExtractValue(); +} + +AMQP::Array ToAmqpArray(const HeaderValue& value) { + AMQP::Array array; + for (const auto& item : value) { + WithAmqpField(item, [&array](const AMQP::Field& field) { array.push_back(field); }); + } + + return array; +} + +AMQP::Table ToAmqpTable(const HeaderValue& value) { + AMQP::Table table; + for (const auto& [key, item] : formats::common::Items(value)) { + WithAmqpField(item, [&table, &key](const AMQP::Field& field) { table.set(key, field); }); + } + + return table; +} + +HeaderValue ToHeaderValueFromTable(const AMQP::Table& table) { + HeaderValue::Builder builder{formats::common::Type::kObject}; + for (const auto& key : table.keys()) { + builder.EmplaceNocheck(key, impl::ToHeaderValue(table[key])); + } + + return builder.ExtractValue(); +} + +} // namespace + +HeaderValue ToHeaderValue(const AMQP::Field& field) { + switch (field.typeID()) { + case 'S': + case 's': + return MakeHeaderValue(static_cast(field)); + case 't': + return MakeHeaderValue(static_cast(field).value() != 0); + case 'B': + return MakeHeaderValue(static_cast(field)); + case 'b': + return MakeHeaderValue(static_cast(field)); + case 'u': + return MakeHeaderValue(static_cast(field)); + case 'U': + return MakeHeaderValue(static_cast(field)); + case 'i': + return MakeHeaderValue(static_cast(field)); + case 'I': + return MakeHeaderValue(static_cast(field)); + case 'l': + case 'T': + return MakeHeaderValue(static_cast(field)); + case 'L': + return MakeHeaderValue(static_cast(field)); + case 'f': + return MakeHeaderValue(static_cast(static_cast(field))); + case 'd': + case 'D': + return MakeHeaderValue(static_cast(field)); + case 'A': + return ToHeaderValueFromArray(static_cast(field)); + case 'F': + return ToHeaderValueFromTable(static_cast(field)); + case 'V': + return HeaderValue::Builder{}.ExtractValue(); + } + + throw std::runtime_error{fmt::format("Unsupported AMQP header field type '{}'", field.typeID())}; +} + +std::unordered_map TableToHeaders(const AMQP::Table& table) { + const auto keys = table.keys(); + + std::unordered_map headers; + headers.reserve(keys.size()); + + for (const auto& key : keys) { + headers.emplace(key, ToHeaderValue(table[key])); + } + + return headers; +} + +void AddHeadersToTable(AMQP::Table& table, const std::unordered_map& headers) { + for (const auto& [key, value] : headers) { + WithAmqpField(value, [&table, &key](const AMQP::Field& field) { table.set(key, field); }); + } +} + +} // namespace urabbitmq::impl + +USERVER_NAMESPACE_END diff --git a/rabbitmq/src/urabbitmq/impl/header_value.hpp b/rabbitmq/src/urabbitmq/impl/header_value.hpp new file mode 100644 index 000000000000..9c0d0705f0a2 --- /dev/null +++ b/rabbitmq/src/urabbitmq/impl/header_value.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include +#include + +#include + +#include + +USERVER_NAMESPACE_BEGIN + +namespace urabbitmq::impl { + +HeaderValue ToHeaderValue(const AMQP::Field& field); + +std::unordered_map TableToHeaders(const AMQP::Table& table); + +void AddHeadersToTable(AMQP::Table& table, const std::unordered_map& headers); + +} // namespace urabbitmq::impl + +USERVER_NAMESPACE_END