diff --git a/apps/benchmark/src/benchmark.cpp b/apps/benchmark/src/benchmark.cpp index b85c2e40..9b98ac7e 100644 --- a/apps/benchmark/src/benchmark.cpp +++ b/apps/benchmark/src/benchmark.cpp @@ -10,7 +10,7 @@ #include "rtbot/Output.h" #include "rtbot/Program.h" #include "rtbot/bindings.h" -#include "rtbot/std/MathSyncBinaryOp.h" +#include "rtbot/std/ArithmeticSync.h" #include "rtbot/std/MovingAverage.h" #include "rtbot/std/PeakDetector.h" #include "tools.h" @@ -157,7 +157,7 @@ class PPGPipelineBenchmark { p.input = std::make_shared("i1", std::vector{PortType::NUMBER}); p.ma_short = std::make_shared("ma1", short_window_); p.ma_long = std::make_shared("ma2", long_window_); - p.minus = std::make_shared("diff"); + p.minus = std::make_shared("diff", 2); p.peak = std::make_shared("peak", 2 * short_window_ + 1); p.join = std::make_shared("join", std::vector{PortType::NUMBER, PortType::NUMBER}); p.output = std::make_shared("o1", std::vector{PortType::NUMBER}); @@ -252,8 +252,8 @@ class BollingerBandsPureBenchmark { p.ma = std::make_shared("ma", 14); p.sd = std::make_shared("sd", 14); p.scale = std::make_shared("scale", 2.0); - p.upper = std::make_shared("upper"); - p.lower = std::make_shared("lower"); + p.upper = std::make_shared("upper", 2); + p.lower = std::make_shared("lower", 2); p.output = std::make_shared("output", std::vector{PortType::NUMBER, PortType::NUMBER, PortType::NUMBER}); diff --git a/libs/core/include/rtbot/Buffer.h b/libs/core/include/rtbot/Buffer.h index dc62420c..a28da934 100644 --- a/libs/core/include/rtbot/Buffer.h +++ b/libs/core/include/rtbot/Buffer.h @@ -141,11 +141,13 @@ class Buffer : public Operator { Bytes msg_bytes(it, it + msg_size); // Deserialize message and cast to derived type - buffer_.push_back( - std::unique_ptr>( - dynamic_cast*>(BaseMessage::deserialize(msg_bytes).release()) - ) - ); + auto base_msg = BaseMessage::deserialize(msg_bytes); + auto* typed_msg = dynamic_cast*>(base_msg.get()); + if (!typed_msg) { + throw std::runtime_error("Failed to cast message during Buffer restore"); + } + base_msg.release(); // Safe: cast validated above + buffer_.push_back(std::unique_ptr>(typed_msg)); it += msg_size; } @@ -199,7 +201,13 @@ class Buffer : public Operator { } // Add new message to buffer - buffer_.push_back(std::unique_ptr>(dynamic_cast*>(input_queue.front()->clone().release()))); + auto cloned = input_queue.front()->clone(); + auto* typed_clone = dynamic_cast*>(cloned.get()); + if (!typed_clone) { + throw std::runtime_error("Failed to cast cloned message in Buffer"); + } + cloned.release(); // Safe: cast validated above + buffer_.push_back(std::unique_ptr>(typed_clone)); // Update statistics with added and removed values update_statistics(msg->data.value, removed_value); diff --git a/libs/core/include/rtbot/Message.h b/libs/core/include/rtbot/Message.h index cdfdacb9..ee280ece 100644 --- a/libs/core/include/rtbot/Message.h +++ b/libs/core/include/rtbot/Message.h @@ -358,7 +358,8 @@ inline std::unique_ptr> BaseMessage::deserialize_as(const Bytes& byte if (!typed_msg) { throw std::runtime_error("Failed to cast message to requested type"); } - return std::unique_ptr>(static_cast*>(base_msg.release())); + base_msg.release(); // Safe: cast validated above + return std::unique_ptr>(typed_msg); } // Helper functions diff --git a/libs/core/include/rtbot/Operator.h b/libs/core/include/rtbot/Operator.h index 5c5de0cd..548e25ff 100644 --- a/libs/core/include/rtbot/Operator.h +++ b/libs/core/include/rtbot/Operator.h @@ -30,6 +30,18 @@ struct PortInfo { MessageQueue queue; std::type_index type; timestamp_t last_timestamp{std::numeric_limits::min()}; + + // Constructor + PortInfo(MessageQueue q, std::type_index t) + : queue(std::move(q)), type(t) {} + + // Delete copy operations (MessageQueue contains unique_ptr) + PortInfo(const PortInfo&) = delete; + PortInfo& operator=(const PortInfo&) = delete; + + // Explicitly default move operations with noexcept for vector reallocation + PortInfo(PortInfo&&) noexcept = default; + PortInfo& operator=(PortInfo&&) noexcept = default; }; enum class PortKind { DATA, CONTROL }; @@ -510,7 +522,12 @@ class Operator { } // Send the messages to the connected operators - for (auto& conn : connections_) { + for (auto& conn : connections_) { + auto child = conn.child.lock(); // Lock weak_ptr to get shared_ptr + if (!child) { + continue; // Child has been destroyed + } + auto& output_queue = output_ports_[conn.output_port].queue; if (output_queue.empty()) { continue; @@ -519,17 +536,17 @@ class Operator { for (size_t i = 0; i < output_queue.size(); i++) { auto msg_copy = output_queue[i]->clone(); #ifdef RTBOT_INSTRUMENTATION - RTBOT_RECORD_MESSAGE_SENT(id_, type_name(), std::to_string(i), conn.child->id(), conn.child->type_name(), + RTBOT_RECORD_MESSAGE_SENT(id_, type_name(), std::to_string(i), child->id(), child->type_name(), std::to_string(conn.child_input_port), conn.child_port_kind == PortKind::DATA ? "" : "[c]", output_queue[i]->clone()); #endif // Route message based on connection port kind if (conn.child_port_kind == PortKind::DATA) { - conn.child->receive_data(std::move(msg_copy), conn.child_input_port); + child->receive_data(std::move(msg_copy), conn.child_input_port); } else { - conn.child->receive_control(std::move(msg_copy), conn.child_input_port); + child->receive_control(std::move(msg_copy), conn.child_input_port); } - propagated_outputs.insert(conn.output_port); + propagated_outputs.insert(conn.output_port); } } @@ -550,16 +567,18 @@ class Operator { // Then execute connected operators - for (auto& conn : connections_) { - if (conn.child != nullptr && propagated_outputs.find(conn.output_port) != propagated_outputs.end()) - conn.child->execute(debug); + for (auto& conn : connections_) { + auto child = conn.child.lock(); + if (child && propagated_outputs.find(conn.output_port) != propagated_outputs.end()) { + child->execute(debug); + } } } struct Connection { - std::shared_ptr child; + std::weak_ptr child; // Use weak_ptr to avoid circular references size_t output_port; - size_t child_input_port; + size_t child_input_port; PortKind child_port_kind{PortKind::DATA}; }; diff --git a/libs/finance/include/rtbot/finance/RelativeStrengthIndex.h b/libs/finance/include/rtbot/finance/RelativeStrengthIndex.h index 8f2b8575..b2de45d4 100644 --- a/libs/finance/include/rtbot/finance/RelativeStrengthIndex.h +++ b/libs/finance/include/rtbot/finance/RelativeStrengthIndex.h @@ -3,92 +3,116 @@ #include #include +#include #include +#include "rtbot/Buffer.h" +#include "rtbot/Message.h" #include "rtbot/Operator.h" +#include "rtbot/PortType.h" namespace rtbot { -template -struct RelativeStrengthIndex : public Operator { - RelativeStrengthIndex() = default; +// RSI needs to track sum for computing averages +struct RSIFeatures { + static constexpr bool TRACK_SUM = true; + static constexpr bool TRACK_VARIANCE = false; +}; - RelativeStrengthIndex(string const& id, size_t n) : Operator(id), initialized(false) { - this->addDataInput("i1", n + 1); - this->addOutput("o1"); +class RelativeStrengthIndex : public Buffer { + public: + RelativeStrengthIndex(std::string id, size_t n) + : Buffer(std::move(id), n + 1), + initialized_(false), + average_gain_(0.0), + average_loss_(0.0), + prev_average_gain_(0.0), + prev_average_loss_(0.0) {} + + std::string type_name() const override { return "RelativeStrengthIndex"; } + + void reset() override { + Buffer::reset(); + initialized_ = false; + average_gain_ = 0.0; + average_loss_ = 0.0; + prev_average_gain_ = 0.0; + prev_average_loss_ = 0.0; } - string typeName() const override { return "RelativeStrengthIndex"; } - - OperatorMessage processData() override { - string inputPort; - auto in = this->getDataInputs(); - if (in.size() == 1) - inputPort = in.at(0); - else - throw runtime_error(typeName() + " : more than 1 input port found"); - Message out; - size_t n = this->getDataInputSize(inputPort); - V diff, rs, rsi, gain, loss; - - if (!initialized) { - averageGain = 0; - averageLoss = 0; + protected: + std::vector>> process_message(const Message* msg) override { + // Only compute RSI when buffer is full + if (!buffer_full()) { + return {}; + } + + size_t n = buffer_size(); + double diff, rs, rsi, gain, loss; + + if (!initialized_) { + average_gain_ = 0.0; + average_loss_ = 0.0; + + // Calculate initial average gain/loss from buffer for (size_t i = 1; i < n; i++) { - diff = this->getDataInputMessage(inputPort, i).value - this->getDataInputMessage(inputPort, i - 1).value; - if (diff > 0) - averageGain = averageGain + diff; - else if (diff < 0) - averageLoss = averageLoss - diff; + diff = buffer()[i]->data.value - buffer()[i - 1]->data.value; + if (diff > 0) { + average_gain_ += diff; + } else if (diff < 0) { + average_loss_ -= diff; // Make positive + } } - averageGain = averageGain / (n - 1); - averageLoss = averageLoss / (n - 1); + average_gain_ /= (n - 1); + average_loss_ /= (n - 1); - initialized = true; + initialized_ = true; } else { - diff = this->getDataInputMessage(inputPort, n - 1).value - this->getDataInputMessage(inputPort, n - 2).value; + // Use smoothed average for subsequent calculations + diff = buffer()[n - 1]->data.value - buffer()[n - 2]->data.value; if (diff > 0) { gain = diff; - loss = 0; + loss = 0.0; } else if (diff < 0) { loss = -diff; - gain = 0; + gain = 0.0; } else { - loss = 0; - gain = 0; + loss = 0.0; + gain = 0.0; } - averageGain = (prevAverageGain * (n - 2) + gain) / (n - 1); - averageLoss = (prevAverageLoss * (n - 2) + loss) / (n - 1); + average_gain_ = (prev_average_gain_ * (n - 2) + gain) / (n - 1); + average_loss_ = (prev_average_loss_ * (n - 2) + loss) / (n - 1); } - prevAverageGain = averageGain; - prevAverageLoss = averageLoss; + prev_average_gain_ = average_gain_; + prev_average_loss_ = average_loss_; - if (averageLoss > 0) { - rs = averageGain / averageLoss; - - rsi = 100.0 - (100.0 / (1 + rs)); - } else + // Calculate RSI + if (average_loss_ > 0) { + rs = average_gain_ / average_loss_; + rsi = 100.0 - (100.0 / (1.0 + rs)); + } else { rsi = 100.0; - out.value = rsi; - - out.time = this->getDataInputLastMessage(inputPort).time; + } - OperatorMessage outputMsgs; - PortMessage v; - v.push_back(out); - outputMsgs.emplace("o1", v); - return outputMsgs; + // Create output message + std::vector>> result; + result.push_back(create_message(msg->time, NumberData{rsi})); + return result; } private: - V averageGain; - V averageLoss; - V prevAverageGain; - V prevAverageLoss; - bool initialized = false; + bool initialized_; + double average_gain_; + double average_loss_; + double prev_average_gain_; + double prev_average_loss_; }; +inline std::shared_ptr make_relative_strength_index(std::string id, size_t n) { + return std::make_shared(std::move(id), n); +} + } // namespace rtbot #endif // RELATIVESTRENGTHINDEX_H diff --git a/libs/finance/test/test_relative_strength_index.cpp b/libs/finance/test/test_relative_strength_index.cpp index 468398d8..ee72e084 100644 --- a/libs/finance/test/test_relative_strength_index.cpp +++ b/libs/finance/test/test_relative_strength_index.cpp @@ -1,35 +1,146 @@ #define CATCH_CONFIG_MAIN -#include - #include +#include +#include #include "rtbot/finance/RelativeStrengthIndex.h" using namespace rtbot; -using namespace std; - -TEST_CASE("Relative Strength Index") { - auto rsi = RelativeStrengthIndex("rsi", 14); - vector values = {54.8, 56.8, 57.85, 59.85, 60.57, 61.1, 62.17, 60.6, 62.35, 62.15, 62.35, 61.45, 62.8, - 61.37, 62.5, 62.57, 60.8, 59.37, 60.35, 62.35, 62.17, 62.55, 64.55, 64.37, 65.3, 64.42, - 62.9, 61.6, 62.05, 60.05, 59.7, 60.9, 60.25, 58.27, 58.7, 57.72, 58.1, 58.2}; - - vector rsis = {0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 74.21384, 74.33552, - 65.87129, 59.93370, 62.43288, 66.96205, 66.18862, 67.05377, 71.22679, 70.36299, - 72.23644, 67.86486, 60.99822, 55.79821, 57.15964, 49.81579, 48.63810, 52.76154, - 50.40119, 43.95111, 45.57992, 42.54534, 44.09946, 44.52472}; - - SECTION("emits right values") { - for (int i = 0; i < values.size(); i++) { - rsi.receiveData(Message(i + 1, values.at(i))); - ProgramMessage emitted = rsi.executeData(); - - if (i < 14) { - REQUIRE(emitted.empty()); - } else { - REQUIRE(abs(emitted.find("rsi")->second.find("o1")->second.at(0).value - rsis.at(i)) <= 0.00001); + +SCENARIO("RelativeStrengthIndex operator computes correct RSI values", "[rsi]") { + GIVEN("A RelativeStrengthIndex operator with period 14") { + auto rsi = RelativeStrengthIndex("rsi", 14); + + // Test data based on typical RSI calculation examples + std::vector values = { + 54.8, 56.8, 57.85, 59.85, 60.57, 61.1, 62.17, 60.6, 62.35, 62.15, + 62.35, 61.45, 62.8, 61.37, 62.5, 62.57, 60.8, 59.37, 60.35, 62.35, + 62.17, 62.55, 64.55, 64.37, 65.3, 64.42, 62.9, 61.6, 62.05, 60.05, + 59.7, 60.9, 60.25, 58.27, 58.7, 57.72, 58.1, 58.2}; + + // Expected RSI values (0 means no output expected) + std::vector expected_rsis = { + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 74.21384, 74.33552, + 65.87129, 59.93370, 62.43288, 66.96205, 66.18862, 67.05377, 71.22679, 70.36299, + 72.23644, 67.86486, 60.99822, 55.79821, 57.15964, 49.81579, 48.63810, 52.76154, + 50.40119, 43.95111, 45.57992, 42.54534, 44.09946, 44.52472}; + + WHEN("Processing the input sequence") { + for (size_t i = 0; i < values.size(); i++) { + rsi.receive_data(create_message(i + 1, NumberData{values[i]}), 0); + rsi.execute(); + + const auto& output = rsi.get_output_queue(0); + + if (i < 14) { + THEN("No output is produced for first 14 values at index " + std::to_string(i)) { + REQUIRE(output.empty()); + } + } else { + THEN("Correct RSI is produced at index " + std::to_string(i)) { + REQUIRE(output.size() == 1); + + const auto* msg = dynamic_cast*>(output.front().get()); + REQUIRE(msg != nullptr); + REQUIRE(std::abs(msg->data.value - expected_rsis[i]) <= 0.00001); + } + } + + rsi.clear_all_output_ports(); + } + } + } +} + +SCENARIO("RelativeStrengthIndex operator handles edge cases", "[rsi]") { + SECTION("Small period") { + auto rsi = RelativeStrengthIndex("rsi", 2); + + // Send 3 values (period + 1 needed for first output) + rsi.receive_data(create_message(1, NumberData{10.0}), 0); + rsi.execute(); + REQUIRE(rsi.get_output_queue(0).empty()); + + rsi.receive_data(create_message(2, NumberData{12.0}), 0); + rsi.execute(); + REQUIRE(rsi.get_output_queue(0).empty()); + + rsi.receive_data(create_message(3, NumberData{14.0}), 0); + rsi.execute(); + + // Should have output now + const auto& output = rsi.get_output_queue(0); + REQUIRE(output.size() == 1); + + const auto* msg = dynamic_cast*>(output.front().get()); + REQUIRE(msg != nullptr); + // All gains, no losses -> RSI should be 100 + REQUIRE(msg->data.value == 100.0); + } + + SECTION("All losses") { + auto rsi = RelativeStrengthIndex("rsi", 2); + + // Send decreasing values + rsi.receive_data(create_message(1, NumberData{20.0}), 0); + rsi.execute(); + rsi.receive_data(create_message(2, NumberData{15.0}), 0); + rsi.execute(); + rsi.receive_data(create_message(3, NumberData{10.0}), 0); + rsi.execute(); + + const auto& output = rsi.get_output_queue(0); + REQUIRE(output.size() == 1); + + const auto* msg = dynamic_cast*>(output.front().get()); + REQUIRE(msg != nullptr); + // All losses, no gains -> RSI should be 0 + REQUIRE(msg->data.value == 0.0); + } + + SECTION("No change in values") { + auto rsi = RelativeStrengthIndex("rsi", 2); + + // Send same value multiple times + rsi.receive_data(create_message(1, NumberData{50.0}), 0); + rsi.execute(); + rsi.receive_data(create_message(2, NumberData{50.0}), 0); + rsi.execute(); + rsi.receive_data(create_message(3, NumberData{50.0}), 0); + rsi.execute(); + + const auto& output = rsi.get_output_queue(0); + REQUIRE(output.size() == 1); + + const auto* msg = dynamic_cast*>(output.front().get()); + REQUIRE(msg != nullptr); + // No gains, no losses, average_loss is 0 -> RSI should be 100 + REQUIRE(msg->data.value == 100.0); + } +} + +SCENARIO("RelativeStrengthIndex reset functionality", "[rsi]") { + GIVEN("An RSI operator with some state") { + auto rsi = RelativeStrengthIndex("rsi", 3); + + // Add some data + rsi.receive_data(create_message(1, NumberData{10.0}), 0); + rsi.execute(); + rsi.receive_data(create_message(2, NumberData{12.0}), 0); + rsi.execute(); + + WHEN("Reset is called") { + rsi.reset(); + + THEN("Operator behaves as if freshly constructed") { + // Add same data again + rsi.receive_data(create_message(1, NumberData{10.0}), 0); + rsi.execute(); + + // Should have no output yet (buffer not full) + REQUIRE(rsi.get_output_queue(0).empty()); } } } -} \ No newline at end of file +} diff --git a/libs/std/include/rtbot/std/ResamplerConstant.h b/libs/std/include/rtbot/std/ResamplerConstant.h index b1b25db3..3aae99f7 100644 --- a/libs/std/include/rtbot/std/ResamplerConstant.h +++ b/libs/std/include/rtbot/std/ResamplerConstant.h @@ -1,6 +1,8 @@ #ifndef RESAMPLER_CONSTANT_H #define RESAMPLER_CONSTANT_H +#include + #include "rtbot/Message.h" #include "rtbot/Operator.h" #include "rtbot/PortType.h"