diff --git a/Cargo.lock b/Cargo.lock index aa2294e..ac194de 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -416,6 +416,7 @@ dependencies = [ "hyper-util", "indicatif", "libc", + "liblzma", "nix", "openssl", "rcgen", @@ -429,7 +430,6 @@ dependencies = [ "tokio", "tokio-rustls", "wiremock", - "xz2", ] [[package]] @@ -963,6 +963,27 @@ dependencies = [ "windows-link", ] +[[package]] +name = "liblzma" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6033b77c21d1f56deeae8014eb9fbe7bdf1765185a6c508b5ca82eeaed7f899" +dependencies = [ + "liblzma-sys", + "num_cpus", +] + +[[package]] +name = "liblzma-sys" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a60851d15cd8c5346eca4ab8babff585be2ae4bc8097c067291d3ffe2add3b6" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "libredox" version = "0.1.12" @@ -1007,17 +1028,6 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" -[[package]] -name = "lzma-sys" -version = "0.1.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fda04ab3764e6cde78b9974eec4f779acaba7c4e84b36eca3cf77c581b85d27" -dependencies = [ - "cc", - "libc", - "pkg-config", -] - [[package]] name = "memchr" version = "2.7.6" @@ -2345,15 +2355,6 @@ dependencies = [ "rustix", ] -[[package]] -name = "xz2" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "388c44dc09d76f1536602ead6d325eb532f5c122f17782bd57fb47baeeb767e2" -dependencies = [ - "lzma-sys", -] - [[package]] name = "yasna" version = "0.5.2" diff --git a/Cargo.toml b/Cargo.toml index 361a7f9..9a5773f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ nix = { version = "0.29", features = ["ioctl"] } # Enable vendored OpenSSL for cross-compilation to musl targets # This ensures OpenSSL builds from source with musl compatibility openssl = { version = "0.10", features = ["vendored"] } -xz2 = "0.1" +liblzma = { version = "0.4", features = ["parallel"] } [dev-dependencies] http-body = "1.0.1" diff --git a/src/fls/decompress.rs b/src/fls/decompress.rs index 6084c35..0415c0a 100644 --- a/src/fls/decompress.rs +++ b/src/fls/decompress.rs @@ -1,3 +1,8 @@ +use crate::fls::byte_channel::ByteBoundedReceiver; +use crate::fls::compression::Compression; +use crate::fls::stream_utils::ChannelReader; +use bytes::Bytes; +use std::io::Read; use tokio::io::AsyncReadExt; use tokio::process::{Child, Command}; use tokio::sync::mpsc; @@ -75,6 +80,78 @@ fn spawn_decompressor( Ok((process, cmd)) } +pub(crate) fn get_compression_from_url(url: &str) -> Compression { + let path = url.split('?').next().unwrap_or(url); + let path = path.split('#').next().unwrap_or(path); + let extension = path.rsplit('.').next().unwrap_or("").to_lowercase(); + match extension.as_str() { + "gz" => Compression::Gzip, + "xz" => Compression::Xz, + "zst" | "zstd" => Compression::Zstd, + _ => Compression::None, + } +} + +type DecompressorResult = ( + mpsc::Receiver>, + std::thread::JoinHandle>, +); + +pub(crate) fn start_inprocess_decompressor( + buffer_rx: ByteBoundedReceiver, + compression: Compression, + consumed_progress_tx: mpsc::UnboundedSender, +) -> Result> { + let (decompressed_tx, decompressed_rx) = mpsc::channel::>(8); + + let handle = std::thread::Builder::new() + .name("decompressor".to_string()) + .spawn(move || { + let channel_reader = + ChannelReader::new_byte_bounded(buffer_rx).with_progress(consumed_progress_tx); + + let mut decoder: Box = match compression { + Compression::Xz => { + let num_threads = std::thread::available_parallelism() + .map(|n| n.get() as u32) + .unwrap_or(2); + eprintln!("XZ decompression: using {} threads", num_threads); + let stream = liblzma::stream::MtStreamBuilder::new() + .threads(num_threads) + .memlimit_threading(u64::MAX) + .memlimit_stop(u64::MAX) + .decoder() + .map_err(|e| format!("Failed to create MT XZ decoder: {}", e))?; + Box::new(liblzma::read::XzDecoder::new_stream(channel_reader, stream)) + } + Compression::Gzip => Box::new(flate2::read::GzDecoder::new(channel_reader)), + Compression::None => Box::new(channel_reader), + Compression::Zstd => { + return Err("Zstd in-process decompression is not supported".to_string()); + } + }; + + let mut buf = vec![0u8; 8 * 1024 * 1024]; + loop { + let n = decoder + .read(&mut buf) + .map_err(|e| format!("Decompression error: {}", e))?; + if n == 0 { + break; + } + if decompressed_tx.blocking_send(buf[..n].to_vec()).is_err() { + return Err("Writer task closed, stopping decompression".to_string()); + } + } + Ok(()) + }) + .map_err(|e| -> Box { + format!("Failed to spawn decompressor thread: {}", e).into() + })?; + + Ok((decompressed_rx, handle)) +} + pub(crate) async fn spawn_stderr_reader( mut stderr: tokio::process::ChildStderr, error_tx: mpsc::UnboundedSender, diff --git a/src/fls/from_url.rs b/src/fls/from_url.rs index fc03f85..f6c726e 100644 --- a/src/fls/from_url.rs +++ b/src/fls/from_url.rs @@ -1,12 +1,12 @@ +use futures_util::StreamExt; use std::io; use std::time::Duration; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::sync::mpsc; use tokio::task::JoinHandle; use crate::fls::block_writer::AsyncBlockWriter; use crate::fls::byte_channel::byte_bounded_channel; -use crate::fls::decompress::{spawn_stderr_reader, start_decompressor_process}; +use crate::fls::decompress::{get_compression_from_url, start_inprocess_decompressor}; use crate::fls::download_error::DownloadError; use crate::fls::error_handling::process_error_messages; use crate::fls::format_detector::{DetectionResult, FileFormat, FormatDetector}; @@ -36,19 +36,6 @@ async fn get_writer_error(handle: JoinHandle>) -> Box>, -) -> Box { - match handle.await { - Ok(Ok(())) => "Decompressor stdin closed unexpectedly".into(), - Ok(Err(e)) => e.into(), - Err(e) => format!("Decompressor writer task panicked: {}", e).into(), - } -} - use crate::fls::download_error::handle_download_retry; /// Execute a sequence of write commands on the block writer @@ -179,12 +166,11 @@ pub async fn flash_from_url( let http_options: HttpClientOptions = (&options).into(); let client = setup_http_client(&http_options).await?; - let (mut decompressor, decompressor_name) = start_decompressor_process(url).await?; - - // Extract stdio handles - let mut decompressor_stdin = decompressor.stdin.take().unwrap(); - let decompressor_stdout = decompressor.stdout.take().unwrap(); - let decompressor_stderr = decompressor.stderr.take().unwrap(); + let compression = get_compression_from_url(url); + let is_compressed = compression != crate::fls::compression::Compression::None; + if is_compressed { + eprintln!("Using decompressor: {} (in-process)", compression); + } // Create channels let (decompressed_progress_tx, mut decompressed_progress_rx) = mpsc::unbounded_channel::(); @@ -205,24 +191,39 @@ pub async fn flash_from_url( options.common.write_buffer_size_mb, )?; - // Spawn background task to read from decompressor and write to block device + // Create byte-bounded download buffer + let buffer_size_mb = options.common.buffer_size_mb; + let max_buffer_bytes = buffer_size_mb * 1024 * 1024; + + println!( + "Using download buffer: {} MB (byte-bounded)", + buffer_size_mb + ); + + let (buffer_tx, buffer_rx) = byte_bounded_channel::(max_buffer_bytes, 4096); + + // Channel for tracking bytes consumed from buffer by decompressor + let (decompressor_written_progress_tx, mut decompressor_written_progress_rx) = + mpsc::unbounded_channel::(); + + // Start in-process decompressor thread + let (mut decompressed_rx, decompressor_handle) = + start_inprocess_decompressor(buffer_rx, compression, decompressor_written_progress_tx)?; + + // Spawn background task to read decompressed data and write to block device let error_tx_clone = error_tx.clone(); let debug = options.common.debug; let writer_handle = { let writer = block_writer; tokio::spawn(async move { - let mut stdout = decompressor_stdout; - let mut buffer = vec![0u8; 8 * 1024 * 1024]; // 8MB buffer - - // Auto-detect sparse image format from initial data let mut detector = FormatDetector::new(); let mut parser: Option = None; let mut format_determined = false; loop { - let n = match stdout.read(&mut buffer).await { - Ok(0) => { - // EOF - check if we have incomplete format detection + let data = match decompressed_rx.recv().await { + Some(data) => data, + None => { if !format_determined { if let Some(buffered_data) = detector.finalize_at_eof() { if debug { @@ -236,20 +237,15 @@ pub async fn flash_from_url( } break; } - Ok(n) => n, - Err(e) => { - let _ = error_tx_clone - .send(format!("Error reading from decompressor stdout: {}", e)); - return Err(e); - } }; + let n = data.len(); if decompressed_progress_tx.send(n as u64).is_err() { break; } if !format_determined { - match detector.process(&buffer[..n]) { + match detector.process(&data) { DetectionResult::NeedMoreData => { if debug { eprintln!( @@ -263,7 +259,7 @@ pub async fn flash_from_url( consumed_bytes, consumed_from_input, } => { - let remaining = &buffer[consumed_from_input..n]; + let remaining = &data[consumed_from_input..n]; parser = handle_detected_format( format, consumed_bytes, @@ -281,81 +277,34 @@ pub async fn flash_from_url( } } } else if let Some(ref mut p) = parser { - process_sparse_data(p, &buffer[..n], &writer, &error_tx_clone, debug).await?; + process_sparse_data(p, &data, &writer, &error_tx_clone, debug).await?; } else { - write_regular_data(buffer[..n].to_vec(), &writer, &error_tx_clone).await?; + write_regular_data(data, &writer, &error_tx_clone).await?; } } - // Close writer and get final bytes written writer.close().await }) }; - tokio::spawn(spawn_stderr_reader( - decompressor_stderr, - error_tx.clone(), - decompressor_name, - )); - // Spawn message processors let error_processor = tokio::spawn(process_error_messages(error_rx)); // Main download loop with retry logic let mut progress = ProgressTracker::new(options.common.newline_progress, options.common.show_memory); - // Set whether we're actually decompressing (not using cat for uncompressed files) - progress.set_is_compressed(decompressor_name != "cat"); + progress.set_is_compressed(is_compressed); let update_interval = Duration::from_secs_f64(options.common.progress_interval_secs); let mut bytes_sent_to_decompressor: u64 = 0; let mut retry_count = 0; let debug = options.common.debug; - use futures_util::StreamExt; - - // Create byte-bounded download buffer (shared across all retry attempts) - let buffer_size_mb = options.common.buffer_size_mb; - let max_buffer_bytes = buffer_size_mb * 1024 * 1024; - - println!( - "Using download buffer: {} MB (byte-bounded)", - buffer_size_mb - ); - - // Create persistent byte-bounded channel for download buffering (lives across retries) - // max_items=4096 prevents unbounded item queuing; byte budget is the real bound - let (buffer_tx, mut buffer_rx) = byte_bounded_channel::(max_buffer_bytes, 4096); - - // Channels for tracking bytes actually written to decompressor - let (decompressor_written_progress_tx, mut decompressor_written_progress_rx) = - mpsc::unbounded_channel::(); - - // Spawn persistent task to write buffered chunks to decompressor - let decompressor_writer_handle = tokio::spawn(async move { - while let Some(chunk) = buffer_rx.recv().await { - let chunk_len = chunk.len() as u64; - if let Err(e) = decompressor_stdin.write_all(&chunk).await { - return Err(format!("Error writing to decompressor stdin: {}", e)); - } - // Notify that bytes were written to decompressor - let _ = decompressor_written_progress_tx.send(chunk_len); - } - // Close decompressor stdin when channel is closed - Ok::<(), String>(()) - }); - loop { - // Check if writer or decompressor has failed before attempting download/retry if writer_handle.is_finished() { eprintln!(); eprintln!("Writer task has terminated, stopping download"); return Err(get_writer_error(writer_handle).await); } - if decompressor_writer_handle.is_finished() { - eprintln!(); - eprintln!("Decompressor writer task has terminated, stopping download"); - return Err(get_decompressor_error(decompressor_writer_handle).await); - } // Resume from the HTTP download position, not the decompressor write position // The buffer may contain data that's been downloaded but not yet written to decompressor @@ -414,21 +363,11 @@ pub async fn flash_from_url( // Send to buffer - detect if it's blocking let send_start = std::time::Instant::now(); if buffer_tx.send(chunk).await.is_err() { - // Check if writer or decompressor has failed if writer_handle.is_finished() { eprintln!(); eprintln!("Writer task has terminated unexpectedly"); return Err(get_writer_error(writer_handle).await); } - if decompressor_writer_handle.is_finished() { - eprintln!(); - eprintln!( - "Decompressor writer task has terminated unexpectedly" - ); - return Err( - get_decompressor_error(decompressor_writer_handle).await - ); - } connection_error = Some(DownloadError::Other("Buffer channel closed".to_string())); connection_broken = true; @@ -500,17 +439,11 @@ pub async fn flash_from_url( // If connection broke, retry if connection_broken { - // First check if writer or decompressor has failed - if so, don't retry if writer_handle.is_finished() { eprintln!(); eprintln!("Connection interrupted and writer task has terminated"); return Err(get_writer_error(writer_handle).await); } - if decompressor_writer_handle.is_finished() { - eprintln!(); - eprintln!("Connection interrupted and decompressor writer task has terminated"); - return Err(get_decompressor_error(decompressor_writer_handle).await); - } if let Some(e) = connection_error { eprintln!("\nConnection interrupted: {}", e.format_error()); @@ -554,9 +487,8 @@ pub async fn flash_from_url( // Close buffer channel to signal end of download drop(buffer_tx); - // Poll for decompressor writer completion while showing progress + // Wait for decompressor thread to finish while showing progress loop { - // Update progress from all channels let mut updated = false; while let Ok(written_len) = decompressor_written_progress_rx.try_recv() { @@ -578,98 +510,29 @@ pub async fn flash_from_url( let _ = progress.update_progress(None, update_interval, false); } - // Check if decompressor writer task is done - if decompressor_writer_handle.is_finished() { + if decompressor_handle.is_finished() { break; } - // Small sleep to avoid busy waiting tokio::time::sleep(Duration::from_millis(100)).await; } - // Get the result from the decompressor writer task - if let Err(e) = decompressor_writer_handle - .await - .map_err(|e| e.to_string()) - .and_then(|r| r) - { - eprintln!(); - return Err(e.into()); - } - - // Update any remaining progress - while let Ok(byte_count) = decompressed_progress_rx.try_recv() { - progress.bytes_decompressed += byte_count; - } - - // Check if decompressor has already finished - let decompressor_already_done = match decompressor.try_wait() { - Ok(Some(status)) => { - if !status.success() { - eprintln!(); - return Err(format!( - "{} process failed with status: {:?}", - decompressor_name, - status.code() - ) - .into()); - } - true - } - Ok(None) => false, - Err(e) => { + match decompressor_handle.join() { + Ok(Ok(())) => {} + Ok(Err(e)) => { eprintln!(); return Err(e.into()); } - }; - - // Only wait if decompressor is not already done - if !decompressor_already_done { - // Poll for decompressor completion while showing progress - loop { - // Update progress from channels - let mut updated = false; - - while let Ok(byte_count) = decompressed_progress_rx.try_recv() { - progress.bytes_decompressed += byte_count; - updated = true; - } - - while let Ok(written_bytes) = written_progress_rx.try_recv() { - progress.bytes_written = written_bytes; - updated = true; - } - - if updated { - let _ = progress.update_progress(None, update_interval, false); - } - - // Check if decompressor is done (non-blocking check) - match decompressor.try_wait() { - Ok(Some(status)) => { - if !status.success() { - eprintln!(); - return Err(format!( - "{} process failed with status: {:?}", - decompressor_name, - status.code() - ) - .into()); - } - break; - } - Ok(None) => { - // Still running, sleep briefly - tokio::time::sleep(Duration::from_millis(100)).await; - } - Err(e) => { - eprintln!(); - return Err(e.into()); - } - } + Err(_) => { + eprintln!(); + return Err("Decompressor thread panicked".into()); } } + while let Ok(byte_count) = decompressed_progress_rx.try_recv() { + progress.bytes_decompressed += byte_count; + } + // Capture the decompression rate and duration at completion let elapsed = progress.start_time.elapsed(); progress.decompress_duration = Some(elapsed); diff --git a/src/fls/magic_bytes.rs b/src/fls/magic_bytes.rs index dc2d411..fb20894 100644 --- a/src/fls/magic_bytes.rs +++ b/src/fls/magic_bytes.rs @@ -143,8 +143,8 @@ fn decompress_gzip_sample(data: &[u8]) -> Result, String> { /// Decompress a sample of XZ data to analyze content type fn decompress_xz_sample(data: &[u8]) -> Result, String> { + use liblzma::read::XzDecoder; use std::io::Read; - use xz2::read::XzDecoder; let mut decoder = XzDecoder::new(data); let mut buffer = vec![0u8; 8192]; // Decompress up to 8KB diff --git a/src/fls/oci/from_oci.rs b/src/fls/oci/from_oci.rs index 4ab101d..571bbe3 100644 --- a/src/fls/oci/from_oci.rs +++ b/src/fls/oci/from_oci.rs @@ -10,10 +10,10 @@ use std::time::Duration; use bytes::Bytes; use flate2::read::GzDecoder; use futures_util::StreamExt; +use liblzma::read::XzDecoder; use reqwest::StatusCode; use tokio::io::AsyncWriteExt; use tokio::sync::mpsc; -use xz2::read::XzDecoder; use crate::fls::byte_channel::{byte_bounded_channel, ByteBoundedReceiver, ByteBoundedSender}; use crate::fls::decompress::start_decompressor_for_compression; diff --git a/src/fls/stream_utils.rs b/src/fls/stream_utils.rs index e6807ab..0193cc5 100644 --- a/src/fls/stream_utils.rs +++ b/src/fls/stream_utils.rs @@ -1,36 +1,34 @@ use crate::fls::byte_channel::ByteBoundedReceiver; -/// Stream utilities for async/sync bridging and download handling -/// -/// Provides reusable components for streaming data between async HTTP -/// downloads and sync processing (like tar extraction or decompression). use bytes::Bytes; use std::io::Read; +use tokio::sync::mpsc; -/// Reader that pulls bytes from a byte-bounded channel -/// -/// This bridges async HTTP streaming with synchronous readers -/// like tar::Archive or flate2::GzDecoder. pub struct ChannelReader { rx: ByteBoundedReceiver, current: Option, offset: usize, + progress_tx: Option>, } impl ChannelReader { - /// Create a new ChannelReader from a byte-bounded receiver pub fn new_byte_bounded(rx: ByteBoundedReceiver) -> Self { Self { rx, current: None, offset: 0, + progress_tx: None, } } + + pub fn with_progress(mut self, tx: mpsc::UnboundedSender) -> Self { + self.progress_tx = Some(tx); + self + } } impl Read for ChannelReader { fn read(&mut self, buf: &mut [u8]) -> std::io::Result { loop { - // If we have current data, use it if let Some(ref data) = self.current { let remaining = &data[self.offset..]; if !remaining.is_empty() { @@ -41,14 +39,15 @@ impl Read for ChannelReader { } } - // Need more data - blocking receive match self.rx.blocking_recv() { Some(data) => { + if let Some(ref tx) = self.progress_tx { + let _ = tx.send(data.len() as u64); + } self.current = Some(data); self.offset = 0; } None => { - // Channel closed - EOF return Ok(0); } } diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 5d62adb..b8813f1 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1,8 +1,8 @@ // Shared test utilities use flate2::write::GzEncoder; use flate2::Compression; +use liblzma::write::XzEncoder; use std::io::Write; -use xz2::write::XzEncoder; /// Generate deterministic test data of a given size #[allow(dead_code)]