diff --git a/build.zig b/build.zig index 66c7bc2..68a1e54 100644 --- a/build.zig +++ b/build.zig @@ -4,7 +4,7 @@ pub fn build(b: *std.Build) void { const target = b.standardTargetOptions(.{}); const optimize = b.standardOptimizeOption(.{}); - const dep_zli = b.dependency("zli", .{ .target = target }); + const dep_clap = b.dependency("clap", .{ .target = target, .optimize = optimize }); const dep_mvzr = b.dependency("mvzr", .{ .target = target, .optimize = optimize }); const build_options = b.addOptions(); @@ -26,7 +26,7 @@ pub fn build(b: *std.Build) void { }), }); - exe.root_module.addImport("zli", dep_zli.module("zli")); + exe.root_module.addImport("clap", dep_clap.module("clap")); exe.root_module.addImport("mvzr", dep_mvzr.module("mvzr")); exe.root_module.addImport("build_options", build_options.createModule()); diff --git a/build.zig.zon b/build.zig.zon index 80ca222..73a22e8 100644 --- a/build.zig.zon +++ b/build.zig.zon @@ -8,14 +8,14 @@ .minimum_zig_version = "0.15.1", .dependencies = .{ - .zli = .{ - .url = "https://github.com/xcaeser/zli/archive/v4.1.1.tar.gz", - .hash = "zli-4.1.1-LeUjpljfAAAak_E3L4NPowuzPs_FUF9-jYyxuTSNSthM", - }, .mvzr = .{ .url = "https://github.com/mnemnion/mvzr/archive/refs/tags/v0.3.7.tar.gz", .hash = "mvzr-0.3.7-ZSOky5FtAQB2VrFQPNbXHQCFJxWTMAYEK7ljYEaMR6jt", }, + .clap = .{ + .url = "https://github.com/Hejsil/zig-clap/archive/refs/tags/0.11.0.tar.gz", + .hash = "clap-0.11.0-oBajB-HnAQDPCKYzwF7rO3qDFwRcD39Q0DALlTSz5H7e", + }, }, .paths = .{ "build.zig", diff --git a/src/cli/args.zig b/src/cli/args.zig new file mode 100644 index 0000000..81133d1 --- /dev/null +++ b/src/cli/args.zig @@ -0,0 +1,76 @@ +const std = @import("std"); +const clap = @import("clap"); +const Allocator = std.mem.Allocator; +const build_options = @import("build_options"); + +// ANSI formatting codes +const BOLD = "\x1b[1m"; +const YELLOW = "\x1b[33m"; +const RESET = "\x1b[0m"; + +pub const Args = struct { + https: bool, + upload: bool, + json: bool, + duration: u32, + help: bool, + + allocator: Allocator, + clap_result: ?clap.Result(clap.Help, ¶ms, parsers) = null, + + pub fn deinit(self: *Args) void { + if (self.clap_result) |*res| { + res.deinit(); + } + } +}; + +const params = clap.parseParamsComptime( + \\-h, --help Display this help and exit. + \\ --https Use HTTPS when connecting to fast.com (default) + \\ --no-https Use HTTP instead of HTTPS + \\-u, --upload Check upload speed as well + \\-j, --json Output results in JSON format + \\-d, --duration Maximum test duration in seconds (default: 30) + \\ +); + +const parsers = .{ + .usize = clap.parsers.int(u32, 10), +}; + +pub fn parse(allocator: Allocator) !Args { + var diag = clap.Diagnostic{}; + const res = clap.parse(clap.Help, ¶ms, parsers, .{ + .diagnostic = &diag, + .allocator = allocator, + }) catch |err| { + var stderr_buffer: [4096]u8 = undefined; + var stderr_writer = std.fs.File.stderr().writer(&stderr_buffer); + const stderr = &stderr_writer.interface; + try diag.report(stderr, err); + return err; + }; + + return .{ + .https = if (res.args.@"no-https" != 0) false else true, + .upload = res.args.upload != 0, + .json = res.args.json != 0, + .duration = res.args.duration orelse 30, + .help = res.args.help != 0, + .allocator = allocator, + .clap_result = res, + }; +} + +pub fn printHelp() !void { + var stderr_buffer: [4096]u8 = undefined; + var stderr_writer = std.fs.File.stderr().writerStreaming(&stderr_buffer); + const stderr = &stderr_writer.interface; + try stderr.print(BOLD ++ "fast-cli" ++ RESET ++ " v{s} - Estimate connection speed using fast.com\n\n", .{build_options.version}); + try stderr.writeAll(YELLOW ++ "USAGE:\n" ++ RESET); + try stderr.writeAll(" fast-cli [OPTIONS]\n\n"); + try stderr.writeAll(YELLOW ++ "OPTIONS:\n" ++ RESET); + try clap.help(stderr, clap.Help, ¶ms, .{ .spacing_between_parameters = 0, .description_on_new_line = false }); + try stderr.flush(); +} diff --git a/src/cli/root.zig b/src/cli/root.zig index 91a5bfb..1f5369c 100644 --- a/src/cli/root.zig +++ b/src/cli/root.zig @@ -1,173 +1,134 @@ const std = @import("std"); -const zli = @import("zli"); -const builtin = @import("builtin"); -const Writer = std.Io.Writer; +const Args = @import("args.zig"); +const Spinner = @import("../lib/spinner/spinner.zig"); const log = std.log.scoped(.cli); const Fast = @import("../lib/fast.zig").Fast; const HTTPSpeedTester = @import("../lib/http_speed_tester_v2.zig").HTTPSpeedTester; - const StabilityCriteria = @import("../lib/http_speed_tester_v2.zig").StabilityCriteria; const SpeedTestResult = @import("../lib/http_speed_tester_v2.zig").SpeedTestResult; -const BandwidthMeter = @import("../lib/bandwidth.zig"); const SpeedMeasurement = @import("../lib/bandwidth.zig").SpeedMeasurement; const progress = @import("../lib/progress.zig"); const HttpLatencyTester = @import("../lib/http_latency_tester.zig").HttpLatencyTester; -const https_flag = zli.Flag{ - .name = "https", - .description = "Use https when connecting to fast.com", - .type = .Bool, - .default_value = .{ .Bool = true }, -}; - -const check_upload_flag = zli.Flag{ - .name = "upload", - .description = "Check upload speed as well", - .shortcut = "u", - .type = .Bool, - .default_value = .{ .Bool = false }, -}; - -const json_output_flag = zli.Flag{ - .name = "json", - .description = "Output results in JSON format", - .shortcut = "j", - .type = .Bool, - .default_value = .{ .Bool = false }, -}; - -const max_duration_flag = zli.Flag{ - .name = "duration", - .description = "Maximum test duration in seconds (uses CoV stability detection by default)", - .shortcut = "d", - .type = .Int, - .default_value = .{ .Int = 30 }, -}; - -pub fn build(writer: *Writer, allocator: std.mem.Allocator) !*zli.Command { - const root = try zli.Command.init(writer, allocator, .{ - .name = "fast-cli", - .description = "Estimate connection speed using fast.com", - .version = null, - }, run); - - try root.addFlag(https_flag); - try root.addFlag(check_upload_flag); - try root.addFlag(json_output_flag); - try root.addFlag(max_duration_flag); - - return root; -} - -fn run(ctx: zli.CommandContext) !void { - const use_https = ctx.flag("https", bool); - const check_upload = ctx.flag("upload", bool); - const json_output = ctx.flag("json", bool); - const max_duration = ctx.flag("duration", i64); +pub fn run(allocator: std.mem.Allocator) !void { + var args = try Args.parse(allocator); + defer args.deinit(); - const spinner = ctx.spinner; + if (args.help) { + try Args.printHelp(); + return; + } log.info("Config: https={}, upload={}, json={}, max_duration={}s", .{ - use_https, check_upload, json_output, max_duration, + args.https, + args.upload, + args.json, + args.duration, }); - var fast = Fast.init(std.heap.smp_allocator, use_https); + var spinner = Spinner.init(allocator, .{}); + defer spinner.deinit(); + + var fast = Fast.init(std.heap.smp_allocator, args.https); defer fast.deinit(); const urls = fast.get_urls(5) catch |err| { - if (!json_output) { + if (!args.json) { try spinner.fail("Failed to get URLs: {}", .{err}); } else { const error_msg = switch (err) { error.ConnectionTimeout => "Failed to contact fast.com servers", else => "Failed to get URLs", }; - try outputJson(ctx.writer, null, null, null, error_msg); + try outputJson(null, null, null, error_msg); } return; }; - log.info("Got {} URLs\n", .{urls.len}); + log.info("Got {} URLs", .{urls.len}); for (urls) |url| { - log.info("URL: {s}\n", .{url}); + log.info("URL: {s}", .{url}); } - // Measure latency first + // Measure latency var latency_tester = HttpLatencyTester.init(std.heap.smp_allocator); defer latency_tester.deinit(); - const latency_ms = if (!json_output) blk: { + const latency_ms = if (!args.json) blk: { try spinner.start("Measuring latency...", .{}); const result = latency_tester.measureLatency(urls) catch |err| { try spinner.fail("Latency test failed: {}", .{err}); break :blk null; }; + spinner.stop(); break :blk result; } else blk: { break :blk latency_tester.measureLatency(urls) catch null; }; - if (!json_output) { + if (!args.json) { log.info("Measuring download speed...", .{}); + try spinner.start("Measuring download speed...", .{}); } // Initialize speed tester var speed_tester = HTTPSpeedTester.init(std.heap.smp_allocator); defer speed_tester.deinit(); - // Use Fast.com-style stability detection by default const criteria = StabilityCriteria{ .ramp_up_duration_seconds = 4, - .max_duration_seconds = @as(u32, @intCast(@max(25, max_duration))), + .max_duration_seconds = @as(u32, @intCast(@max(25, args.duration))), .measurement_interval_ms = 750, .sliding_window_size = 6, .stability_threshold_cov = 0.15, .stable_checks_required = 2, }; - const download_result = if (json_output) blk: { - // JSON mode: clean output only + const download_result = if (args.json) blk: { break :blk speed_tester.measure_download_speed_stability(urls, criteria) catch |err| { try spinner.fail("Download test failed: {}", .{err}); - try outputJson(ctx.writer, null, null, null, "Download test failed"); + try outputJson(null, null, null, "Download test failed"); return; }; } else blk: { - // Interactive mode with spinner updates - const progressCallback = progress.createCallback(spinner, updateSpinnerText); - break :blk speed_tester.measureDownloadSpeedWithStabilityProgress(urls, criteria, progressCallback) catch |err| { + const progressCallback = progress.createCallback(&spinner, updateSpinnerText); + const result = speed_tester.measureDownloadSpeedWithStabilityProgress(urls, criteria, progressCallback) catch |err| { try spinner.fail("Download test failed: {}", .{err}); return; }; + spinner.stop(); + break :blk result; }; var upload_result: ?SpeedTestResult = null; - if (check_upload) { - if (!json_output) { + if (args.upload) { + if (!args.json) { + spinner.stop(); log.info("Measuring upload speed...", .{}); + try spinner.start("Measuring upload speed...", .{}); } - upload_result = if (json_output) blk: { - // JSON mode: clean output only + upload_result = if (args.json) blk: { break :blk speed_tester.measure_upload_speed_stability(urls, criteria) catch |err| { try spinner.fail("Upload test failed: {}", .{err}); - try outputJson(ctx.writer, download_result.speed.value, latency_ms, null, "Upload test failed"); + try outputJson(download_result.speed.value, latency_ms, null, "Upload test failed"); return; }; } else blk: { - // Interactive mode with spinner updates - const uploadProgressCallback = progress.createCallback(spinner, updateUploadSpinnerText); - break :blk speed_tester.measureUploadSpeedWithStabilityProgress(urls, criteria, uploadProgressCallback) catch |err| { + const uploadProgressCallback = progress.createCallback(&spinner, updateUploadSpinnerText); + const result = speed_tester.measureUploadSpeedWithStabilityProgress(urls, criteria, uploadProgressCallback) catch |err| { try spinner.fail("Upload test failed: {}", .{err}); return; }; + spinner.stop(); + break :blk result; }; } // Output results - if (!json_output) { + if (!args.json) { if (latency_ms) |ping| { if (upload_result) |up| { try spinner.succeed("🏓 {d:.0}ms | ⬇️ Download: {d:.1} {s} | ⬆️ Upload: {d:.1} {s}", .{ ping, download_result.speed.value, download_result.speed.unit.toString(), up.speed.value, up.speed.unit.toString() }); @@ -183,21 +144,23 @@ fn run(ctx: zli.CommandContext) !void { } } else { const upload_speed = if (upload_result) |up| up.speed.value else null; - try outputJson(ctx.writer, download_result.speed.value, latency_ms, upload_speed, null); + try outputJson(download_result.speed.value, latency_ms, upload_speed, null); } } -/// Update spinner text with current speed measurement -fn updateSpinnerText(spinner: anytype, measurement: SpeedMeasurement) void { +fn updateSpinnerText(spinner: *Spinner, measurement: SpeedMeasurement) void { spinner.updateMessage("⬇️ {d:.1} {s}", .{ measurement.value, measurement.unit.toString() }) catch {}; } -/// Update spinner text with current upload speed measurement -fn updateUploadSpinnerText(spinner: anytype, measurement: SpeedMeasurement) void { +fn updateUploadSpinnerText(spinner: *Spinner, measurement: SpeedMeasurement) void { spinner.updateMessage("⬆️ {d:.1} {s}", .{ measurement.value, measurement.unit.toString() }) catch {}; } -fn outputJson(writer: *Writer, download_mbps: ?f64, ping_ms: ?f64, upload_mbps: ?f64, error_message: ?[]const u8) !void { +fn outputJson(download_mbps: ?f64, ping_ms: ?f64, upload_mbps: ?f64, error_message: ?[]const u8) !void { + var stdout_buffer: [4096]u8 = undefined; + var stdout_writer = std.fs.File.stdout().writerStreaming(&stdout_buffer); + const stdout = &stdout_writer.interface; + var download_buf: [32]u8 = undefined; var ping_buf: [32]u8 = undefined; var upload_buf: [32]u8 = undefined; @@ -208,5 +171,6 @@ fn outputJson(writer: *Writer, download_mbps: ?f64, ping_ms: ?f64, upload_mbps: const upload_str = if (upload_mbps) |u| try std.fmt.bufPrint(&upload_buf, "{d:.1}", .{u}) else "null"; const error_str = if (error_message) |e| try std.fmt.bufPrint(&error_buf, "\"{s}\"", .{e}) else "null"; - try writer.print("{{\"download_mbps\": {s}, \"ping_ms\": {s}, \"upload_mbps\": {s}, \"error\": {s}}}\n", .{ download_str, ping_str, upload_str, error_str }); + try stdout.print("{{\"download_mbps\": {s}, \"ping_ms\": {s}, \"upload_mbps\": {s}, \"error\": {s}}}\n", .{ download_str, ping_str, upload_str, error_str }); + try stdout.flush(); } diff --git a/src/lib/spinner/spinner.zig b/src/lib/spinner/spinner.zig new file mode 100644 index 0000000..64d04ed --- /dev/null +++ b/src/lib/spinner/spinner.zig @@ -0,0 +1,326 @@ +const std = @import("std"); +const Thread = std.Thread; +const Allocator = std.mem.Allocator; + +const Spinner = @This(); + +const frames = [_][]const u8{ "⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏" }; + +const WriterType = union(enum) { + file: std.fs.File.Writer, + test_writer: std.Io.Writer, +}; + +pub const Options = struct { + refresh_rate_ms: u64 = 80, + writer: ?WriterType = null, +}; + +// ANSI escape codes +const HIDE_CURSOR = "\x1b[?25l"; +const SHOW_CURSOR = "\x1b[?25h"; +const CLEAR_LINE = "\r\x1b[K"; +const GREEN = "\x1b[32m"; +const RED = "\x1b[31m"; +const RESET = "\x1b[0m"; + +allocator: Allocator, +message: []u8 = &.{}, +writer_buffer: [4096]u8, +writer: WriterType, +thread: ?Thread = null, +mutex: Thread.Mutex = .{}, +should_stop: std.atomic.Value(bool) = std.atomic.Value(bool).init(true), +refresh_rate_ms: u64, + +pub fn init(allocator: Allocator, options: Options) Spinner { + var spinner: Spinner = undefined; + spinner.allocator = allocator; + spinner.refresh_rate_ms = options.refresh_rate_ms; + spinner.message = &.{}; + spinner.thread = null; + spinner.mutex = .{}; + spinner.should_stop = std.atomic.Value(bool).init(true); + + if (options.writer) |w| { + spinner.writer_buffer = undefined; + spinner.writer = w; + } else { + spinner.writer = .{ .file = std.fs.File.stderr().writer(&spinner.writer_buffer) }; + } + + return spinner; +} + +pub fn deinit(self: *Spinner) void { + self.stop(); + self.mutex.lock(); + defer self.mutex.unlock(); + if (self.message.len > 0) { + self.allocator.free(self.message); + self.message = &.{}; + } +} + +pub fn start(self: *Spinner, comptime fmt: []const u8, args: anytype) !void { + self.stop(); + + self.mutex.lock(); + defer self.mutex.unlock(); + + if (self.message.len > 0) { + self.allocator.free(self.message); + } + self.message = try std.fmt.allocPrint(self.allocator, fmt, args); + + switch (self.writer) { + .file => |*w| { + w.interface.writeAll(HIDE_CURSOR) catch {}; + w.interface.flush() catch {}; + }, + .test_writer => |*w| { + w.writeAll(HIDE_CURSOR) catch {}; + }, + } + + self.should_stop.store(false, .release); + self.thread = try Thread.spawn(.{}, spinLoop, .{self}); +} + +pub fn stop(self: *Spinner) void { + if (self.should_stop.load(.acquire)) return; + + self.should_stop.store(true, .release); + if (self.thread) |t| { + t.join(); + self.thread = null; + } + + switch (self.writer) { + .file => |*w| { + w.interface.writeAll(CLEAR_LINE ++ SHOW_CURSOR) catch {}; + w.interface.flush() catch {}; + }, + .test_writer => |*w| { + w.writeAll(CLEAR_LINE ++ SHOW_CURSOR) catch {}; + }, + } +} + +pub fn updateMessage(self: *Spinner, comptime fmt: []const u8, args: anytype) !void { + self.mutex.lock(); + defer self.mutex.unlock(); + + if (self.message.len > 0) { + self.allocator.free(self.message); + } + self.message = try std.fmt.allocPrint(self.allocator, fmt, args); +} + +pub fn succeed(self: *Spinner, comptime fmt: []const u8, args: anytype) !void { + self.stop(); + + self.mutex.lock(); + const msg = try std.fmt.allocPrint(self.allocator, fmt, args); + defer self.allocator.free(msg); + self.mutex.unlock(); + + switch (self.writer) { + .file => |*w| { + w.interface.writeAll(SHOW_CURSOR) catch {}; + try w.interface.print(GREEN ++ "✔" ++ RESET ++ " {s}\n", .{msg}); + try w.interface.flush(); + }, + .test_writer => |*w| { + w.writeAll(SHOW_CURSOR) catch {}; + try w.print(GREEN ++ "✔" ++ RESET ++ " {s}\n", .{msg}); + }, + } +} + +pub fn fail(self: *Spinner, comptime fmt: []const u8, args: anytype) !void { + self.stop(); + + self.mutex.lock(); + const msg = try std.fmt.allocPrint(self.allocator, fmt, args); + defer self.allocator.free(msg); + self.mutex.unlock(); + + switch (self.writer) { + .file => |*w| { + w.interface.writeAll(SHOW_CURSOR) catch {}; + try w.interface.print(RED ++ "✖" ++ RESET ++ " {s}\n", .{msg}); + try w.interface.flush(); + }, + .test_writer => |*w| { + w.writeAll(SHOW_CURSOR) catch {}; + try w.print(RED ++ "✖" ++ RESET ++ " {s}\n", .{msg}); + }, + } +} + +fn spinLoop(self: *Spinner) void { + var frame_idx: usize = 0; + + while (!self.should_stop.load(.acquire)) { + self.mutex.lock(); + const msg = self.message; + switch (self.writer) { + .file => |*w| { + w.interface.print(CLEAR_LINE ++ "{s} {s}", .{ frames[frame_idx], msg }) catch {}; + w.interface.flush() catch {}; + }, + .test_writer => |*w| { + w.print(CLEAR_LINE ++ "{s} {s}", .{ frames[frame_idx], msg }) catch {}; + }, + } + self.mutex.unlock(); + + frame_idx = (frame_idx + 1) % frames.len; + Thread.sleep(self.refresh_rate_ms * std.time.ns_per_ms); + } +} + +test "spinner outputs hide cursor on start" { + const testing = std.testing; + + var buffer: [4096]u8 = undefined; + const test_writer = std.Io.Writer.fixed(&buffer); + + var spinner = Spinner.init(testing.allocator, .{ .writer = .{ .test_writer = test_writer } }); + defer spinner.deinit(); + + try spinner.start("Processing", .{}); + Thread.sleep(50 * std.time.ns_per_ms); + spinner.stop(); + + const output = getTestOutput(&spinner); + try testing.expect(std.mem.indexOf(u8, output, HIDE_CURSOR) != null); +} + +test "spinner outputs show cursor on stop" { + const testing = std.testing; + + var buffer: [4096]u8 = undefined; + const test_writer = std.Io.Writer.fixed(&buffer); + + var spinner = Spinner.init(testing.allocator, .{ .writer = .{ .test_writer = test_writer } }); + defer spinner.deinit(); + + try spinner.start("Loading", .{}); + Thread.sleep(50 * std.time.ns_per_ms); + spinner.stop(); + + const output = getTestOutput(&spinner); + try testing.expect(std.mem.indexOf(u8, output, SHOW_CURSOR) != null); +} + +test "spinner outputs message and frames" { + const testing = std.testing; + + var buffer: [4096]u8 = undefined; + const test_writer = std.Io.Writer.fixed(&buffer); + + var spinner = Spinner.init(testing.allocator, .{ .writer = .{ .test_writer = test_writer }, .refresh_rate_ms = 30 }); + defer spinner.deinit(); + + try spinner.start("Loading {s}", .{"data"}); + Thread.sleep(150 * std.time.ns_per_ms); + spinner.stop(); + + const output = getTestOutput(&spinner); + try testing.expect(std.mem.indexOf(u8, output, "Loading data") != null); + try testing.expect(std.mem.indexOf(u8, output, CLEAR_LINE) != null); +} + +test "spinner succeed outputs green checkmark" { + const testing = std.testing; + + var buffer: [4096]u8 = undefined; + const test_writer = std.Io.Writer.fixed(&buffer); + + var spinner = Spinner.init(testing.allocator, .{ .writer = .{ .test_writer = test_writer } }); + defer spinner.deinit(); + + try spinner.start("Working", .{}); + Thread.sleep(50 * std.time.ns_per_ms); + try spinner.succeed("Done", .{}); + + const output = getTestOutput(&spinner); + try testing.expect(std.mem.indexOf(u8, output, "✔") != null); + try testing.expect(std.mem.indexOf(u8, output, GREEN) != null); + try testing.expect(std.mem.indexOf(u8, output, "Done") != null); +} + +test "spinner fail outputs red cross" { + const testing = std.testing; + + var buffer: [4096]u8 = undefined; + const test_writer = std.Io.Writer.fixed(&buffer); + + var spinner = Spinner.init(testing.allocator, .{ .writer = .{ .test_writer = test_writer } }); + defer spinner.deinit(); + + try spinner.start("Working", .{}); + Thread.sleep(50 * std.time.ns_per_ms); + try spinner.fail("Error occurred", .{}); + + const output = getTestOutput(&spinner); + try testing.expect(std.mem.indexOf(u8, output, "✖") != null); + try testing.expect(std.mem.indexOf(u8, output, RED) != null); + try testing.expect(std.mem.indexOf(u8, output, "Error occurred") != null); +} + +test "spinner updateMessage changes displayed text" { + const testing = std.testing; + + var buffer: [4096]u8 = undefined; + const test_writer = std.Io.Writer.fixed(&buffer); + + var spinner = Spinner.init(testing.allocator, .{ .writer = .{ .test_writer = test_writer }, .refresh_rate_ms = 30 }); + defer spinner.deinit(); + + try spinner.start("Step 1", .{}); + Thread.sleep(100 * std.time.ns_per_ms); + try spinner.updateMessage("Step 2", .{}); + Thread.sleep(100 * std.time.ns_per_ms); + spinner.stop(); + + const output = getTestOutput(&spinner); + try testing.expect(std.mem.indexOf(u8, output, "Step 1") != null); + try testing.expect(std.mem.indexOf(u8, output, "Step 2") != null); +} + +test "spinner can stop without starting" { + const testing = std.testing; + var spinner = Spinner.init(testing.allocator, .{}); + defer spinner.deinit(); + + spinner.stop(); + try testing.expect(spinner.should_stop.load(.acquire)); +} + +test "spinner multiple start/stop cycles work" { + const testing = std.testing; + + var buffer: [4096]u8 = undefined; + const test_writer = std.Io.Writer.fixed(&buffer); + + var spinner = Spinner.init(testing.allocator, .{ .writer = .{ .test_writer = test_writer } }); + defer spinner.deinit(); + + for (0..3) |i| { + try spinner.start("Cycle {d}", .{i}); + Thread.sleep(50 * std.time.ns_per_ms); + spinner.stop(); + } + + try testing.expect(spinner.thread == null); +} + +fn getTestOutput(spinner: *Spinner) []const u8 { + return switch (spinner.writer) { + .test_writer => |*w| w.buffer[0..w.end], + else => &.{}, + }; +} diff --git a/src/main.zig b/src/main.zig index da43f91..e0fdabf 100644 --- a/src/main.zig +++ b/src/main.zig @@ -1,9 +1,7 @@ const std = @import("std"); - const cli = @import("cli/root.zig"); pub const std_options: std.Options = .{ - // Set log level based on build mode .log_level = switch (@import("builtin").mode) { .Debug => .debug, .ReleaseSafe, .ReleaseFast, .ReleaseSmall => .warn, @@ -11,13 +9,14 @@ pub const std_options: std.Options = .{ }; pub fn main() !void { - const allocator = std.heap.smp_allocator; + var dbg = std.heap.DebugAllocator(.{}).init; - const file = std.fs.File.stdout(); - var writer = file.writerStreaming(&.{}).interface; + const allocator = switch (@import("builtin").mode) { + .Debug => dbg.allocator(), + .ReleaseFast, .ReleaseSafe, .ReleaseSmall => std.heap.smp_allocator, + }; - const root = try cli.build(&writer, allocator); - defer root.deinit(); + defer if (@import("builtin").mode == .Debug) std.debug.assert(dbg.deinit() == .ok); - try root.execute(.{}); + try cli.run(allocator); }