diff --git a/protocol/BUILD b/protocol/BUILD index c3e8761..fa19fff 100644 --- a/protocol/BUILD +++ b/protocol/BUILD @@ -264,6 +264,16 @@ cc_library( hdrs = ["progress.h"], ) +cc_test( + name = "progress_test", + srcs = ["progress_test.cc"], + deps = [ + ":progress", + "@googletest//:gtest", + "@googletest//:gtest_main", + ], +) + cc_library( name = "spi_proxy", srcs = ["spi_proxy.c"], diff --git a/protocol/progress.c b/protocol/progress.c index ba8da02..230f50a 100644 --- a/protocol/progress.c +++ b/protocol/progress.c @@ -14,11 +14,15 @@ #include "progress.h" +#include +#include #include #include #include -static struct timespec ts_now() { +// This function is defined as weak to allow unit tests to override it +// with a mock implementation for deterministic time control. +__attribute__((weak)) struct timespec libhoth_progress_get_time(void) { struct timespec result; int rv = clock_gettime(CLOCK_MONOTONIC, &result); if (rv != 0) { @@ -29,7 +33,7 @@ static struct timespec ts_now() { } static struct timespec ts_subtract(struct timespec a, struct timespec b) { - if (a.tv_nsec > b.tv_nsec) { + if (a.tv_nsec >= b.tv_nsec) { return (struct timespec){ .tv_sec = (a.tv_sec - b.tv_sec), .tv_nsec = (a.tv_nsec - b.tv_nsec), @@ -45,38 +49,57 @@ static uint64_t ts_milliseconds(struct timespec ts) { return ((uint64_t)ts.tv_sec) * 1000 + ts.tv_nsec / 1000000; } -struct stderr_progress { - struct libhoth_progress progress; - struct timespec start_time; - const char* action_title; -}; +static void libhoth_progress_stderr_func(void* param, const uint64_t current, + const uint64_t total) { + struct libhoth_progress_stderr* const self = + (struct libhoth_progress_stderr*)param; -static void libhoth_progress_stderr_func(void* param, uint64_t numerator, - uint64_t denominator) { - struct stderr_progress* self = (struct stderr_progress*)param; - if (isatty(STDERR_FILENO)) { - uint64_t duration_ms = - ts_milliseconds(ts_subtract(ts_now(), self->start_time)); - if (duration_ms == 0) { - // avoid divide-by-zero - duration_ms = 1; - } - fprintf( - stderr, - "%s: % 3.0f%% - %lldkB / %lldkB %lld kB/sec; %.1f s remaining %s", - self->action_title, ((double)numerator / (double)denominator) * 100.0, - (long long)(numerator / 1000), (long long)(denominator / 1000), - (long long)(numerator / duration_ms), - (double)(denominator - numerator) * (double)duration_ms * 0.001 / - (double)numerator, - numerator == denominator ? "\n" : "\r"); + if (!self->is_tty) { + return; } + + // Calculate 1% of the total size as the minimum increment for reporting. + const uint64_t one_percent_threshold = (total < 100) ? 1 : (total / 100); + + const bool is_start = (current == 0); + const bool is_end = (current == total); + const bool has_sufficient_progress = + (current >= self->last_reported_val + one_percent_threshold); + + if (!is_start && !is_end && !has_sufficient_progress) { + return; + } + + self->last_reported_val = current; + + struct timespec now = libhoth_progress_get_time(); + + uint64_t duration_ms = ts_milliseconds(ts_subtract(now, self->start_time)); + if (duration_ms == 0) { + // avoid divide-by-zero + duration_ms = 1; + } + + const double progress_pct = total > 0 ? (100.0 * current) / total : 100.0; + const double speed_kib_s = (current / 1024.0) / (duration_ms / 1000.0); + double remaining_s = 0; + if (speed_kib_s > 0) { + remaining_s = ((total - current) / 1024.0) / speed_kib_s; + } + + fprintf(stderr, + "%s: %3.0f%% - %" PRIu64 "KiB / %" PRIu64 + "KiB %.0f KiB/sec; %.0f s remaining%s", + self->action_title, progress_pct, current / 1024, total / 1024, + speed_kib_s, remaining_s, is_end ? "\033[K\n" : "\033[K\r"); } void libhoth_progress_stderr_init(struct libhoth_progress_stderr* progress, const char* action_title) { progress->progress.param = progress; progress->progress.func = libhoth_progress_stderr_func; - progress->start_time = ts_now(); + progress->start_time = libhoth_progress_get_time(); progress->action_title = action_title; + progress->last_reported_val = 0; + progress->is_tty = isatty(STDERR_FILENO); } diff --git a/protocol/progress.h b/protocol/progress.h index d69f909..019422f 100644 --- a/protocol/progress.h +++ b/protocol/progress.h @@ -19,11 +19,12 @@ extern "C" { #endif +#include #include #include struct libhoth_progress { - void (*func)(void*, uint64_t numerator, uint64_t denominator); + void (*func)(void*, uint64_t current, uint64_t total); void* param; }; @@ -31,6 +32,8 @@ struct libhoth_progress_stderr { struct libhoth_progress progress; struct timespec start_time; const char* action_title; + uint64_t last_reported_val; + bool is_tty; }; void libhoth_progress_stderr_init(struct libhoth_progress_stderr* progress, diff --git a/protocol/progress_test.cc b/protocol/progress_test.cc new file mode 100644 index 0000000..9631ba1 --- /dev/null +++ b/protocol/progress_test.cc @@ -0,0 +1,226 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "progress.h" + +#include +#include + +#include + +namespace { + +using ::testing::HasSubstr; + +static struct timespec mock_time; + +extern "C" struct timespec libhoth_progress_get_time(void) { return mock_time; } + +class ProgressTest : public ::testing::Test { + protected: + void SetUp() override { + mock_time.tv_sec = 0; + mock_time.tv_nsec = 0; + libhoth_progress_stderr_init(&progress_, "Test Action"); + // Force TTY to true so we can test the output even if running in + // non-interactive mode. + progress_.is_tty = true; + } + + struct libhoth_progress_stderr progress_; +}; + +TEST_F(ProgressTest, Init) { + EXPECT_EQ(progress_.last_reported_val, 0); + EXPECT_EQ(progress_.action_title, "Test Action"); + EXPECT_TRUE(progress_.is_tty); + EXPECT_NE(progress_.progress.func, nullptr); + EXPECT_EQ(progress_.progress.param, &progress_); +} + +TEST_F(ProgressTest, StartAlwaysPrints) { + testing::internal::CaptureStderr(); + progress_.progress.func(progress_.progress.param, 0, 100); + std::string output = testing::internal::GetCapturedStderr(); + + EXPECT_THAT(output, HasSubstr("Test Action: 0%")); + EXPECT_THAT(output, HasSubstr("0KiB / 0KiB")); +} + +TEST_F(ProgressTest, EndAlwaysPrints) { + testing::internal::CaptureStderr(); + progress_.progress.func(progress_.progress.param, 100, 100); + std::string output = testing::internal::GetCapturedStderr(); + + EXPECT_THAT(output, HasSubstr("Test Action: 100%")); + EXPECT_THAT(output, HasSubstr("0KiB / 0KiB")); + // Should end with newline at 100% + EXPECT_THAT(output, HasSubstr("\n")); +} + +TEST_F(ProgressTest, UpdatesThrottled) { + // First call (0%) + testing::internal::CaptureStderr(); + progress_.progress.func(progress_.progress.param, 0, 1000); + std::string output = testing::internal::GetCapturedStderr(); + EXPECT_THAT(output, HasSubstr("0%")); + + // Small update (< 1%) -> Should NOT print + testing::internal::CaptureStderr(); + progress_.progress.func(progress_.progress.param, 5, 1000); // 0.5% + output = testing::internal::GetCapturedStderr(); + EXPECT_TRUE(output.empty()); + + // Threshold update (== 1%) -> Should print + testing::internal::CaptureStderr(); + progress_.progress.func(progress_.progress.param, 10, 1000); // 1.0% + output = testing::internal::GetCapturedStderr(); + EXPECT_THAT(output, HasSubstr("1%")); + + // Large update (>= 1%) -> Should print + testing::internal::CaptureStderr(); + progress_.progress.func(progress_.progress.param, 20, 1000); // 2.0% + output = testing::internal::GetCapturedStderr(); + EXPECT_THAT(output, HasSubstr("2%")); +} + +TEST_F(ProgressTest, ThroughputCalculation) { + // 1 second passed + mock_time.tv_sec = 1; + mock_time.tv_nsec = 0; + + testing::internal::CaptureStderr(); + // 50 KiB transferred in 1 second + progress_.progress.func(progress_.progress.param, 50 * 1024, 100 * 1024); + std::string output = testing::internal::GetCapturedStderr(); + + EXPECT_THAT(output, HasSubstr("50KiB")); + EXPECT_THAT(output, HasSubstr("50 KiB/sec")); + // 50 KiB remaining at 50 KiB/s -> 1 second remaining + EXPECT_THAT(output, HasSubstr(" 1 s remaining")); +} + +TEST_F(ProgressTest, ThroughputCalculationSlow) { + // 10 seconds passed + mock_time.tv_sec = 10; + mock_time.tv_nsec = 0; + + testing::internal::CaptureStderr(); + // 10 KiB transferred in 10 seconds = 1 KiB/s + progress_.progress.func(progress_.progress.param, 10 * 1024, 110 * 1024); + std::string output = testing::internal::GetCapturedStderr(); + + EXPECT_THAT(output, HasSubstr(" 1 KiB/sec")); + // 100 KiB remaining at 1 KiB/s -> 100 seconds remaining + EXPECT_THAT(output, HasSubstr(" 100 s remaining")); +} + +TEST_F(ProgressTest, ThroughputCalculationFast) { + // 0.5 seconds passed + mock_time.tv_sec = 0; + mock_time.tv_nsec = 500000000; + + testing::internal::CaptureStderr(); + // 1024 KiB transferred in 0.5 seconds = 2048 KiB/s + progress_.progress.func(progress_.progress.param, 1024 * 1024, 2048 * 1024); + std::string output = testing::internal::GetCapturedStderr(); + + EXPECT_THAT(output, HasSubstr(" 2048 KiB/sec")); + // 1024 KiB remaining at 2048 KiB/s -> 0.5 seconds remaining + EXPECT_THAT(output, HasSubstr(" 0 s remaining") /* 0.5 rounded down */); +} + +TEST_F(ProgressTest, ZeroTotal) { + testing::internal::CaptureStderr(); + progress_.progress.func(progress_.progress.param, 0, 0); + std::string output = testing::internal::GetCapturedStderr(); + + // 0/0 is treated as 100% in the code + EXPECT_THAT(output, HasSubstr("100%")); + EXPECT_THAT(output, HasSubstr("\n")); +} + +TEST_F(ProgressTest, ZeroDuration) { + // Explicitly set time to match start time (0,0) to force 0 duration + mock_time.tv_sec = 0; + mock_time.tv_nsec = 0; + progress_.start_time = mock_time; + + testing::internal::CaptureStderr(); + progress_.progress.func(progress_.progress.param, 50, 100); + std::string output = testing::internal::GetCapturedStderr(); + + // Should not crash and should use 1ms fallback + EXPECT_THAT(output, HasSubstr("50%")); + EXPECT_THAT(output, HasSubstr("KiB/sec")); +} + +TEST_F(ProgressTest, SmallTotalThreshold) { + testing::internal::CaptureStderr(); + // total < 100, threshold should be 1 byte + progress_.progress.func(progress_.progress.param, 1, 50); + std::string output = testing::internal::GetCapturedStderr(); + + EXPECT_THAT(output, HasSubstr("2%")); + + // Calling again with same value should NOT print (threshold check) + testing::internal::CaptureStderr(); + progress_.progress.func(progress_.progress.param, 1, 50); + output = testing::internal::GetCapturedStderr(); + + EXPECT_TRUE(output.empty()); +} + +TEST_F(ProgressTest, VeryLargeTransfer) { + // 8 GiB total + uint64_t total = 8ULL * 1024 * 1024 * 1024; + uint64_t current = 4ULL * 1024 * 1024 * 1024; + + mock_time.tv_sec = 10; // 10 seconds for 4GiB = 400MiB/s approx + + testing::internal::CaptureStderr(); + progress_.progress.func(progress_.progress.param, current, total); + std::string output = testing::internal::GetCapturedStderr(); + + EXPECT_THAT(output, HasSubstr("50%")); + EXPECT_THAT(output, HasSubstr("4194304KiB / 8388608KiB")); + // 4GiB / 10s = 4096MiB / 10s = 409.6 MiB/s = 419430.4 KiB/s approx in double + EXPECT_THAT(output, HasSubstr("419430 KiB/sec")); +} + +TEST_F(ProgressTest, ZeroSpeed) { + // Tests behavior when time advances but no bytes have been transferred yet. + // This verifies that we don't divide by zero when calculating remaining time + // and that we print sensible "0" values. + mock_time.tv_sec = 10; + + testing::internal::CaptureStderr(); + progress_.progress.func(progress_.progress.param, 0, 100); + std::string output = testing::internal::GetCapturedStderr(); + + EXPECT_THAT(output, HasSubstr(" 0 KiB/sec")); + EXPECT_THAT(output, HasSubstr(" 0 s remaining")); +} + +TEST_F(ProgressTest, NoTtyNoOutput) { + progress_.is_tty = false; + + testing::internal::CaptureStderr(); + progress_.progress.func(progress_.progress.param, 0, 100); + std::string output = testing::internal::GetCapturedStderr(); + + EXPECT_TRUE(output.empty()); +} + +} // namespace