diff --git a/src/wire.rs b/src/wire.rs index 8e839fe..9a26234 100644 --- a/src/wire.rs +++ b/src/wire.rs @@ -1,4 +1,6 @@ -use anyhow::Result; +use std::io::Read; + +use anyhow::{Context, Result, bail}; use prost::Message; use crate::config::Config; @@ -7,6 +9,12 @@ use crate::proto::patch::Patch; /// Zstd frame magic number (little-endian). const ZSTD_MAGIC: [u8; 4] = [0x28, 0xB5, 0x2F, 0xFD]; +/// Upper bound on the decompressed size of a patch. A zstd frame can claim a +/// tiny compressed size while expanding to gigabytes (a "decompression bomb"). +/// Patches decoded here may arrive from an untrusted peer, so refuse to +/// allocate more than this; the ceiling is far above any realistic patch. +const MAX_DECOMPRESSED_PATCH_SIZE: u64 = 1 << 30; // 1 GiB + /// Encode a Patch to protobuf, optionally compressing with zstd. pub fn encode_patch(config: &Config, patch: &Patch) -> Result> { let mut buf = Vec::new(); @@ -40,7 +48,7 @@ pub fn encode_patch(config: &Config, patch: &Patch) -> Result> { /// first. Otherwise, it is treated as raw protobuf. pub fn decode_patch(data: &[u8]) -> Result { let bytes = if data.starts_with(&ZSTD_MAGIC) { - zstd::decode_all(data)? + decompress_bounded(data, MAX_DECOMPRESSED_PATCH_SIZE)? } else { data.to_vec() }; @@ -48,6 +56,24 @@ pub fn decode_patch(data: &[u8]) -> Result { Ok(patch) } +/// Decompress a zstd frame, refusing to produce more than `max` bytes of +/// output so a malicious frame cannot exhaust memory. +fn decompress_bounded(data: &[u8], max: u64) -> Result> { + let decoder = + zstd::stream::read::Decoder::new(data).context("failed to initialize zstd decoder")?; + let mut bytes = Vec::new(); + // Read one byte past the limit so output that exactly fills `max` is still + // accepted while anything larger is detected and rejected. + decoder + .take(max + 1) + .read_to_end(&mut bytes) + .context("failed to decompress patch")?; + if bytes.len() as u64 > max { + bail!("decompressed patch exceeds the maximum allowed size of {max} bytes"); + } + Ok(bytes) +} + #[cfg(test)] mod tests { use super::*; @@ -79,4 +105,25 @@ mod tests { assert!(patch.deltas.is_empty()); assert!(patch.states.is_empty()); } + + #[test] + fn test_decompress_bounded_rejects_oversized_output() { + // A small frame that expands past the cap must be rejected rather than + // allocated in full. + let original = vec![0u8; 1_000_000]; + let compressed = zstd::encode_all(original.as_slice(), 0).unwrap(); + assert!(compressed.len() < 1_000_000, "expected high compression"); + + let err = decompress_bounded(&compressed, 1024).err().unwrap(); + let msg = format!("{:#}", err); + assert!(msg.contains("maximum allowed size"), "got: {msg}"); + } + + #[test] + fn test_decompress_bounded_accepts_output_within_limit() { + let original = vec![7u8; 1000]; + let compressed = zstd::encode_all(original.as_slice(), 0).unwrap(); + let out = decompress_bounded(&compressed, 1_000_000).unwrap(); + assert_eq!(out, original); + } }