Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions protocol/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,16 @@ cc_library(
hdrs = ["progress.h"],
)

cc_test(
name = "progress_test",
srcs = ["progress_test.cc"],
deps = [
":progress",
"@googletest//:gtest",
"@googletest//:gtest_main",
],
)

cc_library(
name = "spi_proxy",
srcs = ["spi_proxy.c"],
Expand Down
77 changes: 50 additions & 27 deletions protocol/progress.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@

#include "progress.h"

#include <inttypes.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>

static struct timespec ts_now() {
// This function is defined as weak to allow unit tests to override it
// with a mock implementation for deterministic time control.
__attribute__((weak)) struct timespec libhoth_progress_get_time(void) {
struct timespec result;
int rv = clock_gettime(CLOCK_MONOTONIC, &result);
if (rv != 0) {
Expand All @@ -29,7 +33,7 @@ static struct timespec ts_now() {
}

static struct timespec ts_subtract(struct timespec a, struct timespec b) {
if (a.tv_nsec > b.tv_nsec) {
if (a.tv_nsec >= b.tv_nsec) {
return (struct timespec){
.tv_sec = (a.tv_sec - b.tv_sec),
.tv_nsec = (a.tv_nsec - b.tv_nsec),
Expand All @@ -45,38 +49,57 @@ static uint64_t ts_milliseconds(struct timespec ts) {
return ((uint64_t)ts.tv_sec) * 1000 + ts.tv_nsec / 1000000;
}

struct stderr_progress {
struct libhoth_progress progress;
struct timespec start_time;
const char* action_title;
};
static void libhoth_progress_stderr_func(void* param, const uint64_t current,
const uint64_t total) {
struct libhoth_progress_stderr* const self =
(struct libhoth_progress_stderr*)param;

static void libhoth_progress_stderr_func(void* param, uint64_t numerator,
uint64_t denominator) {
struct stderr_progress* self = (struct stderr_progress*)param;
if (isatty(STDERR_FILENO)) {
uint64_t duration_ms =
ts_milliseconds(ts_subtract(ts_now(), self->start_time));
if (duration_ms == 0) {
// avoid divide-by-zero
duration_ms = 1;
}
fprintf(
stderr,
"%s: % 3.0f%% - %lldkB / %lldkB %lld kB/sec; %.1f s remaining %s",
self->action_title, ((double)numerator / (double)denominator) * 100.0,
(long long)(numerator / 1000), (long long)(denominator / 1000),
(long long)(numerator / duration_ms),
(double)(denominator - numerator) * (double)duration_ms * 0.001 /
(double)numerator,
numerator == denominator ? "\n" : "\r");
if (!self->is_tty) {
return;
}

// Calculate 1% of the total size as the minimum increment for reporting.
const uint64_t one_percent_threshold = (total < 100) ? 1 : (total / 100);

const bool is_start = (current == 0);
const bool is_end = (current == total);
const bool has_sufficient_progress =
(current >= self->last_reported_val + one_percent_threshold);

if (!is_start && !is_end && !has_sufficient_progress) {
return;
}

self->last_reported_val = current;

struct timespec now = libhoth_progress_get_time();

uint64_t duration_ms = ts_milliseconds(ts_subtract(now, self->start_time));
if (duration_ms == 0) {
// avoid divide-by-zero
duration_ms = 1;
}

const double progress_pct = total > 0 ? (100.0 * current) / total : 100.0;
const double speed_kib_s = (current / 1024.0) / (duration_ms / 1000.0);
double remaining_s = 0;
if (speed_kib_s > 0) {
remaining_s = ((total - current) / 1024.0) / speed_kib_s;
}

fprintf(stderr,
"%s: %3.0f%% - %" PRIu64 "KiB / %" PRIu64
"KiB %.0f KiB/sec; %.0f s remaining%s",
self->action_title, progress_pct, current / 1024, total / 1024,
speed_kib_s, remaining_s, is_end ? "\033[K\n" : "\033[K\r");
}

void libhoth_progress_stderr_init(struct libhoth_progress_stderr* progress,
const char* action_title) {
progress->progress.param = progress;
progress->progress.func = libhoth_progress_stderr_func;
progress->start_time = ts_now();
progress->start_time = libhoth_progress_get_time();
progress->action_title = action_title;
progress->last_reported_val = 0;
progress->is_tty = isatty(STDERR_FILENO);
}
5 changes: 4 additions & 1 deletion protocol/progress.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,21 @@
extern "C" {
#endif

#include <stdbool.h>
#include <stdint.h>
#include <time.h>

struct libhoth_progress {
void (*func)(void*, uint64_t numerator, uint64_t denominator);
void (*func)(void*, uint64_t current, uint64_t total);
void* param;
};

struct libhoth_progress_stderr {
struct libhoth_progress progress;
struct timespec start_time;
const char* action_title;
uint64_t last_reported_val;
bool is_tty;
};

void libhoth_progress_stderr_init(struct libhoth_progress_stderr* progress,
Expand Down
226 changes: 226 additions & 0 deletions protocol/progress_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "progress.h"

#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <string>

namespace {

using ::testing::HasSubstr;

static struct timespec mock_time;

extern "C" struct timespec libhoth_progress_get_time(void) { return mock_time; }

class ProgressTest : public ::testing::Test {
protected:
void SetUp() override {
mock_time.tv_sec = 0;
mock_time.tv_nsec = 0;
libhoth_progress_stderr_init(&progress_, "Test Action");
// Force TTY to true so we can test the output even if running in
// non-interactive mode.
progress_.is_tty = true;
}

struct libhoth_progress_stderr progress_;
};

TEST_F(ProgressTest, Init) {
EXPECT_EQ(progress_.last_reported_val, 0);
EXPECT_EQ(progress_.action_title, "Test Action");
EXPECT_TRUE(progress_.is_tty);
EXPECT_NE(progress_.progress.func, nullptr);
EXPECT_EQ(progress_.progress.param, &progress_);
}

TEST_F(ProgressTest, StartAlwaysPrints) {
testing::internal::CaptureStderr();
progress_.progress.func(progress_.progress.param, 0, 100);
std::string output = testing::internal::GetCapturedStderr();

EXPECT_THAT(output, HasSubstr("Test Action: 0%"));
EXPECT_THAT(output, HasSubstr("0KiB / 0KiB"));
}

TEST_F(ProgressTest, EndAlwaysPrints) {
testing::internal::CaptureStderr();
progress_.progress.func(progress_.progress.param, 100, 100);
std::string output = testing::internal::GetCapturedStderr();

EXPECT_THAT(output, HasSubstr("Test Action: 100%"));
EXPECT_THAT(output, HasSubstr("0KiB / 0KiB"));
// Should end with newline at 100%
EXPECT_THAT(output, HasSubstr("\n"));
}

TEST_F(ProgressTest, UpdatesThrottled) {
// First call (0%)
testing::internal::CaptureStderr();
progress_.progress.func(progress_.progress.param, 0, 1000);
std::string output = testing::internal::GetCapturedStderr();
EXPECT_THAT(output, HasSubstr("0%"));

// Small update (< 1%) -> Should NOT print
testing::internal::CaptureStderr();
progress_.progress.func(progress_.progress.param, 5, 1000); // 0.5%
output = testing::internal::GetCapturedStderr();
EXPECT_TRUE(output.empty());

// Threshold update (== 1%) -> Should print
testing::internal::CaptureStderr();
progress_.progress.func(progress_.progress.param, 10, 1000); // 1.0%
output = testing::internal::GetCapturedStderr();
EXPECT_THAT(output, HasSubstr("1%"));

// Large update (>= 1%) -> Should print
testing::internal::CaptureStderr();
progress_.progress.func(progress_.progress.param, 20, 1000); // 2.0%
output = testing::internal::GetCapturedStderr();
EXPECT_THAT(output, HasSubstr("2%"));
}

TEST_F(ProgressTest, ThroughputCalculation) {
// 1 second passed
mock_time.tv_sec = 1;
mock_time.tv_nsec = 0;

testing::internal::CaptureStderr();
// 50 KiB transferred in 1 second
progress_.progress.func(progress_.progress.param, 50 * 1024, 100 * 1024);
std::string output = testing::internal::GetCapturedStderr();

EXPECT_THAT(output, HasSubstr("50KiB"));
EXPECT_THAT(output, HasSubstr("50 KiB/sec"));
// 50 KiB remaining at 50 KiB/s -> 1 second remaining
EXPECT_THAT(output, HasSubstr(" 1 s remaining"));
}

TEST_F(ProgressTest, ThroughputCalculationSlow) {
// 10 seconds passed
mock_time.tv_sec = 10;
mock_time.tv_nsec = 0;

testing::internal::CaptureStderr();
// 10 KiB transferred in 10 seconds = 1 KiB/s
progress_.progress.func(progress_.progress.param, 10 * 1024, 110 * 1024);
std::string output = testing::internal::GetCapturedStderr();

EXPECT_THAT(output, HasSubstr(" 1 KiB/sec"));
// 100 KiB remaining at 1 KiB/s -> 100 seconds remaining
EXPECT_THAT(output, HasSubstr(" 100 s remaining"));
}

TEST_F(ProgressTest, ThroughputCalculationFast) {
// 0.5 seconds passed
mock_time.tv_sec = 0;
mock_time.tv_nsec = 500000000;

testing::internal::CaptureStderr();
// 1024 KiB transferred in 0.5 seconds = 2048 KiB/s
progress_.progress.func(progress_.progress.param, 1024 * 1024, 2048 * 1024);
std::string output = testing::internal::GetCapturedStderr();

EXPECT_THAT(output, HasSubstr(" 2048 KiB/sec"));
// 1024 KiB remaining at 2048 KiB/s -> 0.5 seconds remaining
EXPECT_THAT(output, HasSubstr(" 0 s remaining") /* 0.5 rounded down */);
}

TEST_F(ProgressTest, ZeroTotal) {
testing::internal::CaptureStderr();
progress_.progress.func(progress_.progress.param, 0, 0);
std::string output = testing::internal::GetCapturedStderr();

// 0/0 is treated as 100% in the code
EXPECT_THAT(output, HasSubstr("100%"));
EXPECT_THAT(output, HasSubstr("\n"));
}

TEST_F(ProgressTest, ZeroDuration) {
// Explicitly set time to match start time (0,0) to force 0 duration
mock_time.tv_sec = 0;
mock_time.tv_nsec = 0;
progress_.start_time = mock_time;

testing::internal::CaptureStderr();
progress_.progress.func(progress_.progress.param, 50, 100);
std::string output = testing::internal::GetCapturedStderr();

// Should not crash and should use 1ms fallback
EXPECT_THAT(output, HasSubstr("50%"));
EXPECT_THAT(output, HasSubstr("KiB/sec"));
}

TEST_F(ProgressTest, SmallTotalThreshold) {
testing::internal::CaptureStderr();
// total < 100, threshold should be 1 byte
progress_.progress.func(progress_.progress.param, 1, 50);
std::string output = testing::internal::GetCapturedStderr();

EXPECT_THAT(output, HasSubstr("2%"));

// Calling again with same value should NOT print (threshold check)
testing::internal::CaptureStderr();
progress_.progress.func(progress_.progress.param, 1, 50);
output = testing::internal::GetCapturedStderr();

EXPECT_TRUE(output.empty());
}

TEST_F(ProgressTest, VeryLargeTransfer) {
// 8 GiB total
uint64_t total = 8ULL * 1024 * 1024 * 1024;
uint64_t current = 4ULL * 1024 * 1024 * 1024;

mock_time.tv_sec = 10; // 10 seconds for 4GiB = 400MiB/s approx

testing::internal::CaptureStderr();
progress_.progress.func(progress_.progress.param, current, total);
std::string output = testing::internal::GetCapturedStderr();

EXPECT_THAT(output, HasSubstr("50%"));
EXPECT_THAT(output, HasSubstr("4194304KiB / 8388608KiB"));
// 4GiB / 10s = 4096MiB / 10s = 409.6 MiB/s = 419430.4 KiB/s approx in double
EXPECT_THAT(output, HasSubstr("419430 KiB/sec"));
}

TEST_F(ProgressTest, ZeroSpeed) {
// Tests behavior when time advances but no bytes have been transferred yet.
// This verifies that we don't divide by zero when calculating remaining time
// and that we print sensible "0" values.
mock_time.tv_sec = 10;

testing::internal::CaptureStderr();
progress_.progress.func(progress_.progress.param, 0, 100);
std::string output = testing::internal::GetCapturedStderr();

EXPECT_THAT(output, HasSubstr(" 0 KiB/sec"));
EXPECT_THAT(output, HasSubstr(" 0 s remaining"));
}

TEST_F(ProgressTest, NoTtyNoOutput) {
progress_.is_tty = false;

testing::internal::CaptureStderr();
progress_.progress.func(progress_.progress.param, 0, 100);
std::string output = testing::internal::GetCapturedStderr();

EXPECT_TRUE(output.empty());
}

} // namespace
Loading