diff --git a/Cargo.lock b/Cargo.lock index 3054b0c8..f48028ca 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -274,6 +274,18 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "async-compression" +version = "0.4.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0f9ee0f6e02ffd7ad5816e9464499fba7b3effd01123b515c41d1697c43dad1" +dependencies = [ + "compression-codecs", + "compression-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "async-executor" version = "1.13.3" @@ -945,6 +957,23 @@ dependencies = [ "memchr", ] +[[package]] +name = "compression-codecs" +version = "0.4.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb7b51a7d9c967fc26773061ba86150f19c50c0d65c887cb1fbe295fd16619b7" +dependencies = [ + "compression-core", + "flate2", + "memchr", +] + +[[package]] +name = "compression-core" +version = "0.4.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75984efb6ed102a0d42db99afb6c1948f0380d1d91808d5529916e6c08b49d8d" + [[package]] name = "concurrent-queue" version = "2.5.0" @@ -1733,9 +1762,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.1.1" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ced92e76e966ca2fd84c8f7aa01a4aea65b0eb6648d72f7c8f3e2764a67fece" +checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" dependencies = [ "crc32fast", "miniz_oxide", @@ -1812,9 +1841,9 @@ dependencies = [ [[package]] name = "futures" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d" dependencies = [ "futures-channel", "futures-core", @@ -1827,9 +1856,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" dependencies = [ "futures-core", "futures-sink", @@ -1837,15 +1866,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" [[package]] name = "futures-executor" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d" dependencies = [ "futures-core", "futures-task", @@ -1854,9 +1883,9 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" [[package]] name = "futures-lite" @@ -1873,9 +1902,9 @@ dependencies = [ [[package]] name = "futures-macro" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" dependencies = [ "proc-macro2", "quote", @@ -1884,21 +1913,21 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" +checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" [[package]] name = "futures-task" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" [[package]] name = "futures-util" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" dependencies = [ "futures-channel", "futures-core", @@ -1908,7 +1937,6 @@ dependencies = [ "futures-task", "memchr", "pin-project-lite", - "pin-utils", "slab", ] @@ -4132,12 +4160,6 @@ version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" -[[package]] -name = "pin-utils" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" - [[package]] name = "piper" version = "0.2.4" @@ -4674,6 +4696,7 @@ checksum = "20675572f6f24e9e76ef639bc5552774ed45f1c30e2951e1e99c59888861c539" name = "rayhunter" version = "0.10.2" dependencies = [ + "async-compression", "bytes", "chrono", "crc", diff --git a/check/src/main.rs b/check/src/main.rs index 87105341..cf7c02a2 100644 --- a/check/src/main.rs +++ b/check/src/main.rs @@ -113,12 +113,8 @@ async fn analyze_pcap(pcap_path: &str, show_skipped: bool) { async fn analyze_qmdl(qmdl_path: &str, show_skipped: bool) { let mut harness = Harness::new_with_config(&AnalyzerConfig::default()); let qmdl_file = &mut File::open(&qmdl_path).await.expect("failed to open file"); - let file_size = qmdl_file - .metadata() - .await - .expect("failed to get QMDL file metadata") - .len(); - let mut qmdl_reader = QmdlReader::new(qmdl_file, Some(file_size as usize)); + let compressed = qmdl_path.ends_with(".gz"); + let qmdl_reader = QmdlReader::new(qmdl_file, compressed, None); let mut qmdl_stream = pin!( qmdl_reader .as_stream() @@ -141,8 +137,9 @@ async fn pcapify(qmdl_path: &PathBuf) { let qmdl_file = &mut File::open(&qmdl_path) .await .expect("failed to open qmdl file"); + let compressed = qmdl_path.ends_with(".gz"); let qmdl_file_size = qmdl_file.metadata().await.unwrap().len(); - let mut qmdl_reader = QmdlReader::new(qmdl_file, Some(qmdl_file_size as usize)); + let mut qmdl_reader = QmdlReader::new(qmdl_file, compressed, Some(qmdl_file_size as usize)); let mut pcap_path = qmdl_path.clone(); pcap_path.set_extension("pcapng"); let pcap_file = &mut File::create(&pcap_path) @@ -197,9 +194,7 @@ async fn main() { let name_str = name.to_str().unwrap(); let path = entry.path(); let path_str = path.to_str().unwrap(); - // instead of relying on the QMDL extension, can we check if a file is - // QMDL by inspecting the contents? - if name_str.ends_with(".qmdl") { + if name_str.ends_with(".qmdl") || name_str.ends_with(".qmdl.gz") { info!("**** Beginning analysis of {name_str}"); analyze_qmdl(path_str, args.show_skipped).await; if args.pcapify { diff --git a/daemon/Cargo.toml b/daemon/Cargo.toml index dd487545..ab95ceac 100644 --- a/daemon/Cargo.toml +++ b/daemon/Cargo.toml @@ -33,7 +33,7 @@ futures-macro = "0.3.30" include_dir = "0.7.3" chrono = { version = "0.4.31", features = ["serde"] } tokio-stream = { version = "0.1.14", default-features = false, features = ["io-util"] } -futures = { version = "0.3.30", default-features = false } +futures = { version = "0.3.32", default-features = false, features = ["std"] } serde_json = "1.0.114" image = { version = "0.25.1", default-features = false, features = ["png", "gif"] } tempfile = "3.10.2" diff --git a/daemon/src/analysis.rs b/daemon/src/analysis.rs index 48c29b94..e7c9a78e 100644 --- a/daemon/src/analysis.rs +++ b/daemon/src/analysis.rs @@ -10,7 +10,6 @@ use futures::TryStreamExt; use log::{error, info}; use rayhunter::analysis::analyzer::{AnalyzerConfig, EventType, Harness}; use rayhunter::diag::{DataType, MessagesContainer}; -use rayhunter::qmdl::QmdlReader; use serde::Serialize; use tokio::fs::File; use tokio::io::{AsyncWriteExt, BufWriter}; @@ -135,7 +134,7 @@ async fn perform_analysis( analyzer_config: &AnalyzerConfig, ) -> Result<(), String> { info!("Opening QMDL and analysis file for {name}..."); - let (analysis_file, qmdl_file) = { + let (analysis_file, qmdl_reader) = { let mut qmdl_store = qmdl_store_lock.write().await; let (entry_index, _) = qmdl_store .entry_for_name(name) @@ -144,23 +143,17 @@ async fn perform_analysis( .clear_and_open_entry_analysis(entry_index) .await .map_err(|e| format!("{e:?}"))?; - let qmdl_file = qmdl_store + let qmdl_reader = qmdl_store .open_entry_qmdl(entry_index) .await .map_err(|e| format!("{e:?}"))?; - (analysis_file, qmdl_file) + (analysis_file, qmdl_reader) }; let mut analysis_writer = AnalysisWriter::new(analysis_file, analyzer_config) .await .map_err(|e| format!("{e:?}"))?; - let file_size = qmdl_file - .metadata() - .await - .expect("failed to get QMDL file metadata") - .len(); - let mut qmdl_reader = QmdlReader::new(qmdl_file, Some(file_size as usize)); let mut qmdl_stream = pin::pin!( qmdl_reader .as_stream() diff --git a/daemon/src/diag.rs b/daemon/src/diag.rs index 3453c9c0..e5da9dbc 100644 --- a/daemon/src/diag.rs +++ b/daemon/src/diag.rs @@ -63,7 +63,7 @@ pub struct DiagTask { enum DiagState { Recording { - qmdl_writer: QmdlWriter, + qmdl_writer: Box>, analysis_writer: Box, }, Stopped, @@ -143,7 +143,7 @@ impl DiagTask { DiskSpaceCheck::Failed => {} } - let (qmdl_file, analysis_file) = match qmdl_store.new_entry().await { + let (qmdl_gz_file, analysis_file) = match qmdl_store.new_entry().await { Ok(files) => files, Err(e) => { let msg = format!("failed creating QMDL file entry: {e}"); @@ -152,7 +152,7 @@ impl DiagTask { } }; self.stop_current_recording().await; - let qmdl_writer = QmdlWriter::new(qmdl_file); + let qmdl_writer = Box::new(QmdlWriter::new(qmdl_gz_file)); let analysis_writer = match AnalysisWriter::new(analysis_file, &self.analyzer_config).await { Ok(writer) => Box::new(writer), @@ -237,13 +237,23 @@ impl DiagTask { let mut state = DiagState::Stopped; std::mem::swap(&mut self.state, &mut state); if let DiagState::Recording { - analysis_writer, .. + qmdl_writer, + analysis_writer, + .. } = state { - analysis_writer - .close() - .await - .expect("failed to close analysis writer"); + match (qmdl_writer.close().await, analysis_writer.close().await) { + (Ok(()), Ok(())) => {} + (qmdl_result, analysis_result) => { + if let Err(err) = qmdl_result { + error!("failed to close QmdlWriter: {:?}", err); + } + if let Err(err) = analysis_result { + error!("failed to close AnalysisWriter: {:?}", err); + } + panic!(); + } + } } } @@ -315,13 +325,13 @@ impl DiagTask { } debug!( "total QMDL bytes written: {}, updating manifest...", - qmdl_writer.total_written + qmdl_writer.total_uncompressed_bytes ); let index = qmdl_store .current_entry .expect("DiagDevice had qmdl_writer, but QmdlStore didn't have current entry???"); if let Err(e) = qmdl_store - .update_entry_qmdl_size(index, qmdl_writer.total_written) + .update_entry_qmdl_size(index, qmdl_writer.total_uncompressed_bytes) .await { let reason = format!("failed to update manifest (disk full?): {e}"); diff --git a/daemon/src/pcap.rs b/daemon/src/pcap.rs index fce37d64..93b29ccc 100644 --- a/daemon/src/pcap.rs +++ b/daemon/src/pcap.rs @@ -45,23 +45,20 @@ pub async fn get_pcap( StatusCode::NOT_FOUND, format!("couldn't find manifest entry with name {qmdl_name}"), ))?; - if entry.qmdl_size_bytes == 0 { + if entry.uncompressed_qmdl_size_bytes == 0 { return Err(( StatusCode::SERVICE_UNAVAILABLE, "QMDL file is empty, try again in a bit!".to_string(), )); } - let qmdl_size_bytes = entry.qmdl_size_bytes; - let qmdl_file = qmdl_store + let qmdl_reader = qmdl_store .open_entry_qmdl(entry_index) .await .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("{e:?}")))?; - // the QMDL reader should stop at the last successfully written data chunk - // (entry.size_bytes) let (reader, writer) = duplex(1024); tokio::spawn(async move { - if let Err(e) = generate_pcap_data(writer, qmdl_file, qmdl_size_bytes).await { + if let Err(e) = generate_pcap_data(writer, qmdl_reader).await { error!("failed to generate PCAP: {e:?}"); } }); @@ -71,11 +68,7 @@ pub async fn get_pcap( Ok((headers, body).into_response()) } -pub async fn generate_pcap_data( - writer: W, - qmdl_file: R, - qmdl_size_bytes: usize, -) -> Result<(), Error> +pub async fn generate_pcap_data(writer: W, mut reader: QmdlReader) -> Result<(), Error> where W: AsyncWrite + Unpin + Send, R: AsyncRead + Unpin, @@ -83,7 +76,6 @@ where let mut pcap_writer = GsmtapPcapWriter::new(writer).await?; pcap_writer.write_iface_header().await?; - let mut reader = QmdlReader::new(qmdl_file, Some(qmdl_size_bytes)); while let Some(container) = reader.get_next_messages_container().await? { if container.data_type != DataType::UserSpace { continue; diff --git a/daemon/src/qmdl_store.rs b/daemon/src/qmdl_store.rs index acc48263..ac521826 100644 --- a/daemon/src/qmdl_store.rs +++ b/daemon/src/qmdl_store.rs @@ -4,6 +4,7 @@ use std::path::{Path, PathBuf}; use chrono::{DateTime, Local}; use log::{info, warn}; +use rayhunter::qmdl::QmdlReader; use rayhunter::util::RuntimeMetadata; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -57,8 +58,10 @@ pub struct ManifestEntry { /// The system time when the last message was recorded to the file #[cfg_attr(feature = "apidocs", schema(value_type = String))] pub last_message_time: Option>, - /// The size of the QMDL file in bytes - pub qmdl_size_bytes: usize, + /// The size of the uncompressed QMDL data in bytes. Previously this was + /// called `qmdl_size_bytes`, so alias it for backwards compatibility. + #[serde(alias = "qmdl_size_bytes")] + pub uncompressed_qmdl_size_bytes: usize, /// The rayhunter daemon version which generated the file pub rayhunter_version: Option, /// The OS which created the file @@ -67,6 +70,8 @@ pub struct ManifestEntry { pub arch: Option, #[serde(default)] pub stop_reason: Option, + #[serde(default)] + pub compressed: bool, } impl ManifestEntry { @@ -77,17 +82,22 @@ impl ManifestEntry { name: format!("{}", now.timestamp()), start_time: now, last_message_time: None, - qmdl_size_bytes: 0, + uncompressed_qmdl_size_bytes: 0, rayhunter_version: Some(metadata.rayhunter_version), system_os: Some(metadata.system_os), arch: Some(metadata.arch), stop_reason: None, + compressed: true, } } pub fn get_qmdl_filepath>(&self, path: P) -> PathBuf { let mut filepath = path.as_ref().join(&self.name); - filepath.set_extension("qmdl"); + if self.compressed { + filepath.set_extension("qmdl.gz"); + } else { + filepath.set_extension("qmdl"); + } filepath } @@ -153,8 +163,9 @@ impl RecordingStore { } // Does a best-effort attempt to recover the manifest from a directory of - // QMDL files. We expect these files to be named like ".qmdl", - // and skip any files which don't match that pattern. + // QMDL files. We expect these files to be named like ".qmdl" + // or ".qmdl.gz", and skip any files which don't match that + // pattern. pub async fn recover

(path: P) -> Result where P: AsRef, @@ -174,11 +185,14 @@ impl RecordingStore { continue; }; - if !filename.ends_with(".qmdl") { + let (stem, compressed) = if filename.ends_with(".qmdl") { + (filename.trim_end_matches(".qmdl"), false) + } else if filename.ends_with(".qmdl.gz") { + (filename.trim_end_matches(".qmdl.gz"), true) + } else { continue; - } + }; - let stem = filename.trim_end_matches(".qmdl"); let Ok(start_timestamp) = stem.parse::() else { warn!("QMDL file has invalid name {os_filename:?}, skipping"); continue; @@ -205,9 +219,10 @@ impl RecordingStore { info!("successfully recovered QMDL entry {os_filename:?}!"); manifest_entries.push(ManifestEntry { name: stem.to_string(), + compressed, start_time: start_time.into(), last_message_time: Some(last_message_time.into()), - qmdl_size_bytes: metadata.size() as usize, + uncompressed_qmdl_size_bytes: metadata.size() as usize, rayhunter_version: None, system_os: None, arch: None, @@ -265,11 +280,19 @@ impl RecordingStore { } // Returns the corresponding QMDL file for a given entry - pub async fn open_entry_qmdl(&self, entry_index: usize) -> Result { + pub async fn open_entry_qmdl( + &self, + entry_index: usize, + ) -> Result, RecordingStoreError> { let entry = &self.manifest.entries[entry_index]; - File::open(entry.get_qmdl_filepath(&self.path)) + let file = File::open(entry.get_qmdl_filepath(&self.path)) .await - .map_err(RecordingStoreError::ReadFileError) + .map_err(RecordingStoreError::ReadFileError)?; + Ok(QmdlReader::new( + file, + entry.compressed, + Some(entry.uncompressed_qmdl_size_bytes), + )) } // Returns the corresponding QMDL file for a given entry @@ -314,7 +337,7 @@ impl RecordingStore { entry_index: usize, size_bytes: usize, ) -> Result<(), RecordingStoreError> { - self.manifest.entries[entry_index].qmdl_size_bytes = size_bytes; + self.manifest.entries[entry_index].uncompressed_qmdl_size_bytes = size_bytes; self.manifest.entries[entry_index].last_message_time = Some(rayhunter::clock::get_adjusted_now()); self.write_manifest().await @@ -490,7 +513,10 @@ mod tests { .entry_for_name(&store.manifest.entries[entry_index].name) .unwrap(); assert!(entry.last_message_time.is_some()); - assert_eq!(store.manifest.entries[entry_index].qmdl_size_bytes, 1000); + assert_eq!( + store.manifest.entries[entry_index].uncompressed_qmdl_size_bytes, + 1000 + ); assert_eq!( RecordingStore::read_manifest(dir.path()).await.unwrap(), store.manifest diff --git a/daemon/src/server.rs b/daemon/src/server.rs index 82696db2..41768fdd 100644 --- a/daemon/src/server.rs +++ b/daemon/src/server.rs @@ -14,7 +14,8 @@ use log::{error, warn}; use serde::{Deserialize, Serialize}; use std::sync::Arc; use tokio::fs::write; -use tokio::io::{AsyncReadExt, copy, duplex}; +use tokio::io::copy; +use tokio::io::duplex; use tokio::sync::RwLock; use tokio::sync::mpsc::Sender; use tokio_util::compat::FuturesAsyncWriteCompatExt; @@ -64,7 +65,7 @@ pub async fn get_qmdl( StatusCode::NOT_FOUND, format!("couldn't find qmdl file with name {qmdl_idx}"), ))?; - let qmdl_file = qmdl_store + let qmdl_reader = qmdl_store .open_entry_qmdl(entry_index) .await .map_err(|err| { @@ -73,14 +74,15 @@ pub async fn get_qmdl( format!("error opening QMDL file: {err}"), ) })?; - let limited_qmdl_file = qmdl_file.take(entry.qmdl_size_bytes as u64); - let qmdl_stream = ReaderStream::new(limited_qmdl_file); let headers = [ (CONTENT_TYPE, "application/octet-stream"), - (CONTENT_LENGTH, &entry.qmdl_size_bytes.to_string()), + ( + CONTENT_LENGTH, + &entry.uncompressed_qmdl_size_bytes.to_string(), + ), ]; - let body = Body::from_stream(qmdl_stream); + let body = Body::from_stream(qmdl_reader.as_stream()); Ok((headers, body).into_response()) } @@ -308,21 +310,21 @@ pub async fn get_zip( Path(entry_name): Path, ) -> Result { let qmdl_idx = entry_name.trim_end_matches(".zip").to_owned(); - let (entry_index, qmdl_size_bytes) = { + let (entry_index, compressed) = { let qmdl_store = state.qmdl_store_lock.read().await; let (entry_index, entry) = qmdl_store.entry_for_name(&qmdl_idx).ok_or(( StatusCode::NOT_FOUND, format!("couldn't find entry with name {qmdl_idx}"), ))?; - if entry.qmdl_size_bytes == 0 { + if entry.uncompressed_qmdl_size_bytes == 0 { return Err(( StatusCode::SERVICE_UNAVAILABLE, "QMDL file is empty, try again in a bit!".to_string(), )); } - (entry_index, entry.qmdl_size_bytes) + (entry_index, entry.compressed) }; let qmdl_store_lock = state.qmdl_store_lock.clone(); @@ -335,22 +337,18 @@ pub async fn get_zip( // Add QMDL file { - let entry = - ZipEntryBuilder::new(format!("{qmdl_idx}.qmdl").into(), Compression::Stored); + let extension = if compressed { "qmdl.gz" } else { "qmdl" }; + let entry = ZipEntryBuilder::new( + format!("{qmdl_idx}.{extension}").into(), + Compression::Stored, + ); // FuturesAsyncWriteCompatExt::compat_write because async-zip's entrystream does // not impl tokio's AsyncWrite, but only future's AsyncWrite. This can be removed // once https://github.com/Majored/rs-async-zip/pull/160 is released. let mut entry_writer = zip.write_entry_stream(entry).await?.compat_write(); - - let mut qmdl_file = { - let qmdl_store = qmdl_store_lock.read().await; - qmdl_store - .open_entry_qmdl(entry_index) - .await? - .take(qmdl_size_bytes as u64) - }; - - copy(&mut qmdl_file, &mut entry_writer).await?; + let qmdl_store = qmdl_store_lock.read().await; + let mut qmdl_reader = qmdl_store.open_entry_qmdl(entry_index).await?; + copy(&mut qmdl_reader, &mut entry_writer).await?; entry_writer.into_inner().close().await?; } @@ -360,17 +358,10 @@ pub async fn get_zip( ZipEntryBuilder::new(format!("{qmdl_idx}.pcapng").into(), Compression::Stored); let mut entry_writer = zip.write_entry_stream(entry).await?.compat_write(); - let qmdl_file_for_pcap = { - let qmdl_store = qmdl_store_lock.read().await; - qmdl_store - .open_entry_qmdl(entry_index) - .await? - .take(qmdl_size_bytes as u64) - }; - - if let Err(e) = - generate_pcap_data(&mut entry_writer, qmdl_file_for_pcap, qmdl_size_bytes).await - { + let qmdl_store = qmdl_store_lock.read().await; + let qmdl_reader = qmdl_store.open_entry_qmdl(entry_index).await?; + + if let Err(e) = generate_pcap_data(&mut entry_writer, qmdl_reader).await { // if we fail to generate the PCAP file, we should still continue and give the // user the QMDL. error!("Failed to generate PCAP: {e:?}"); @@ -530,7 +521,10 @@ mod tests { assert_eq!( filenames, - vec![format!("{entry_name}.qmdl"), format!("{entry_name}.pcapng"),] + vec![ + format!("{entry_name}.qmdl.gz"), + format!("{entry_name}.pcapng"), + ] ); } } diff --git a/lib/Cargo.toml b/lib/Cargo.toml index a945336f..f5cacfe8 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -30,5 +30,6 @@ serde = { version = "1.0.197", features = ["derive"] } serde_json = "1.0" num_enum = "0.7.4" utoipa = { version = "5.4.0", optional = true } +async-compression = { version = "0.4.41", features = ["tokio", "gzip"] } [dev-dependencies] diff --git a/lib/src/diag.rs b/lib/src/diag.rs index fcdb33b6..2cd2e0b9 100644 --- a/lib/src/diag.rs +++ b/lib/src/diag.rs @@ -1,5 +1,6 @@ //! Diag protocol serialization/deserialization +use bytes::Bytes; use chrono::{DateTime, FixedOffset}; use crc::{Algorithm, Crc}; use deku::prelude::*; @@ -113,6 +114,12 @@ impl MessagesContainer { } } +impl From for Bytes { + fn from(value: MessagesContainer) -> Self { + value.to_bytes().unwrap().into() + } +} + #[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] pub struct HdlcEncapsulatedMessage { pub len: u32, diff --git a/lib/src/qmdl.rs b/lib/src/qmdl.rs index 9a2f0c6c..78c41fef 100644 --- a/lib/src/qmdl.rs +++ b/lib/src/qmdl.rs @@ -3,8 +3,14 @@ //! QmdlReader and QmdlWriter can read and write MessagesContainers to and from //! QMDL files. +use std::io::{Cursor, ErrorKind}; +use std::pin::Pin; +use std::task::Poll; + use crate::diag::{DataType, HdlcEncapsulatedMessage, MESSAGE_TERMINATOR, MessagesContainer}; +use async_compression::tokio::bufread::GzipDecoder; +use async_compression::tokio::write::GzipEncoder; use futures::TryStream; use log::error; use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; @@ -13,8 +19,8 @@ pub struct QmdlWriter where T: AsyncWrite + Unpin, { - writer: T, - pub total_written: usize, + writer: GzipEncoder, + pub total_uncompressed_bytes: usize, } impl QmdlWriter @@ -22,50 +28,160 @@ where T: AsyncWrite + Unpin, { pub fn new(writer: T) -> Self { - QmdlWriter::new_with_existing_size(writer, 0) - } - - pub fn new_with_existing_size(writer: T, existing_size: usize) -> Self { + let gzip_writer = GzipEncoder::new(writer); QmdlWriter { - writer, - total_written: existing_size, + writer: gzip_writer, + total_uncompressed_bytes: 0, } } pub async fn write_container(&mut self, container: &MessagesContainer) -> std::io::Result<()> { for msg in &container.messages { - self.writer.write_all(&msg.data).await?; - self.total_written += msg.data.len(); + // for a gzipped file, we can't use `msg.data.len()` to + // determine the number of bytes written, so we have to + // manually do a `write_all()` type loop + let mut buf = Cursor::new(&msg.data); + loop { + let bytes_written = self.writer.write_buf(&mut buf).await?; + self.writer.flush().await?; + if bytes_written == 0 { + break; + } + self.total_uncompressed_bytes += bytes_written; + } } Ok(()) } + + pub async fn close(mut self) -> std::io::Result<()> { + self.writer.shutdown().await?; + Ok(()) + } +} + +#[derive(Debug)] +enum QmdlReaderSource { + Compressed { + reader: GzipDecoder>, + eof: bool, + }, + Uncompressed { + reader: T, + }, +} + +#[derive(Debug)] +struct QmdlAsyncReader { + source: QmdlReaderSource, + uncompressed_bytes_read: usize, + max_uncompressed_bytes: Option, +} + +impl QmdlAsyncReader +where + T: AsyncRead, +{ + pub fn new(reader: T, compressed: bool, max_uncompressed_bytes: Option) -> Self { + let source = if compressed { + QmdlReaderSource::Compressed { + reader: GzipDecoder::new(BufReader::new(reader)), + eof: false, + } + } else { + QmdlReaderSource::Uncompressed { reader } + }; + Self { + source, + uncompressed_bytes_read: 0, + max_uncompressed_bytes, + } + } } +impl AsyncRead for QmdlAsyncReader +where + T: AsyncRead + Unpin, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + // if we've already read beyond the byte limit, return without reading + // into the buffer, essentially signalling EOF + if let Some(max_bytes) = self.max_uncompressed_bytes + && self.uncompressed_bytes_read >= max_bytes + { + if self.uncompressed_bytes_read > max_bytes { + error!( + "warning: {} bytes read, but max_bytes was {}", + self.uncompressed_bytes_read, max_bytes + ); + } + return Poll::Ready(Ok(())); + } + + let before = buf.filled().len(); + let this = self.get_mut(); + let res = match &mut this.source { + QmdlReaderSource::Compressed { reader, eof } => { + // if we already determined we've reached the Gzip EOF, don't read more + if *eof { + return Poll::Ready(Ok(())); + } + + match Pin::new(reader).poll_read(cx, buf) { + // if we hit an unexpected EOF in a Gzip file, it shouldn't + // be considered fatal, just a truncated file. mark that + // we're done and return the result as usual + Poll::Ready(Err(err)) if err.kind() == ErrorKind::UnexpectedEof => { + *eof = true; + Poll::Ready(Ok(())) + } + res => res, + } + } + QmdlReaderSource::Uncompressed { reader } => Pin::new(reader).poll_read(cx, buf), + }; + + // if we read more bytes than is allowed, cap the buffer by + // our max bytes + let after = buf.filled().len(); + let read = after - before; + if let Some(max_bytes) = this.max_uncompressed_bytes + && this.uncompressed_bytes_read + read > max_bytes + { + let overread = this.uncompressed_bytes_read + read - max_bytes; + buf.set_filled(after - overread); + } + res + } +} + +#[derive(Debug)] pub struct QmdlReader where T: AsyncRead, { - reader: BufReader, - bytes_read: usize, - max_bytes: Option, + buf_reader: BufReader>, } impl QmdlReader where T: AsyncRead + Unpin, { - pub fn new(reader: T, max_bytes: Option) -> Self { + pub fn new(reader: T, compressed: bool, max_uncompressed_bytes: Option) -> Self { QmdlReader { - reader: BufReader::new(reader), - bytes_read: 0, - max_bytes, + buf_reader: BufReader::new(QmdlAsyncReader::new( + reader, + compressed, + max_uncompressed_bytes, + )), } } - pub fn as_stream( - &mut self, - ) -> impl TryStream + '_ { - futures::stream::try_unfold(self, |reader| async { + pub fn as_stream(self) -> impl TryStream { + futures::stream::try_unfold(self, |mut reader| async { let maybe_container = reader.get_next_messages_container().await?; match maybe_container { Some(container) => Ok(Some((container, reader))), @@ -77,22 +193,16 @@ where pub async fn get_next_messages_container( &mut self, ) -> Result, std::io::Error> { - if let Some(max_bytes) = self.max_bytes - && self.bytes_read >= max_bytes + let mut buf = Vec::new(); + if self + .buf_reader + .read_until(MESSAGE_TERMINATOR, &mut buf) + .await? + == 0 { - if self.bytes_read > max_bytes { - error!( - "warning: {} bytes read, but max_bytes was {}", - self.bytes_read, max_bytes - ); - } return Ok(None); } - let mut buf = Vec::new(); - let bytes_read = self.reader.read_until(MESSAGE_TERMINATOR, &mut buf).await?; - self.bytes_read += bytes_read; - // Since QMDL is just a flat list of messages, we can't actually // reproduce the container structure they came from in the original // read. So we'll just pretend that all containers had exactly one @@ -102,13 +212,26 @@ where data_type: DataType::UserSpace, num_messages: 1, messages: vec![HdlcEncapsulatedMessage { - len: bytes_read as u32, + len: buf.len() as u32, data: buf, }], })) } } +impl AsyncRead for QmdlReader +where + T: AsyncRead + Unpin, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.get_mut().buf_reader).poll_read(cx, buf) + } +} + #[cfg(test)] mod test { use std::io::Cursor; @@ -160,7 +283,7 @@ mod test { #[tokio::test] async fn test_unbounded_qmdl_reader() { let mut buf = Cursor::new(get_test_message_bytes()); - let mut reader = QmdlReader::new(&mut buf, None); + let mut reader = QmdlReader::new(&mut buf, false, None); let expected_messages = get_test_messages(); for message in expected_messages { let expected_container = MessagesContainer { @@ -183,7 +306,7 @@ mod test { let mut expected_messages = get_test_messages(); let limit = expected_messages[0].len + expected_messages[1].len; - let mut reader = QmdlReader::new(&mut buf, Some(limit as usize)); + let mut reader = QmdlReader::new(&mut buf, false, Some(limit as usize)); for message in expected_messages.drain(0..2) { let expected_container = MessagesContainer { data_type: DataType::UserSpace, @@ -201,29 +324,22 @@ mod test { )); } - #[tokio::test] - async fn test_qmdl_writer() { - let mut buf = Vec::new(); - let mut writer = QmdlWriter::new(&mut buf); - let expected_containers = get_test_containers(); - for container in &expected_containers { - writer.write_container(container).await.unwrap(); - } - assert_eq!(writer.total_written, buf.len()); - assert_eq!(buf, get_test_message_bytes()); - } - - #[tokio::test] - async fn test_writing_and_reading() { + /// Writes the test containers to a QmdlWriter, optionally finishing the + /// gzip stream with a footer. Then, attempts to decompress the buffer with + /// a QmdlWriter, asserting that the containers match what's expected. + async fn run_compressed_reading_and_writing_tests(do_close: bool) { + let containers = get_test_containers(); let mut buf = Vec::new(); - let mut writer = QmdlWriter::new(&mut buf); - let expected_containers = get_test_containers(); - for container in &expected_containers { - writer.write_container(container).await.unwrap(); + { + let mut writer = QmdlWriter::new(&mut buf); + for container in &containers { + writer.write_container(&container).await.unwrap(); + } + if do_close { + writer.close().await.unwrap(); + } } - - let limit = Some(buf.len()); - let mut reader = QmdlReader::new(Cursor::new(&mut buf), limit); + let mut reader = QmdlReader::new(Cursor::new(buf), true, None); let expected_messages = get_test_messages(); for message in expected_messages { let expected_container = MessagesContainer { @@ -241,4 +357,10 @@ mod test { Ok(None) )); } + + #[tokio::test] + async fn test_compressed_reading_and_writing() { + run_compressed_reading_and_writing_tests(true).await; + run_compressed_reading_and_writing_tests(false).await; + } }