diff --git a/lib/src/helpers.rs b/lib/src/helpers.rs index 9f89b51f..94340d09 100644 --- a/lib/src/helpers.rs +++ b/lib/src/helpers.rs @@ -61,20 +61,20 @@ pub fn has_ff(v: u64) -> bool { } #[inline(always)] -pub const fn devli(s: u8, value: u16) -> i16 { - let shifted = 1 << s; +pub const fn devli(s: u32, value: u32) -> i16 { + let shifted = (1 << s) - 1; - if value & (shifted >> 1) != 0 { + if (value << 1) > shifted { value as i16 } else { - value.wrapping_add(2).wrapping_add(!shifted) as i16 + value.wrapping_sub(shifted) as i16 } } /// check to make sure the behavior hasn't changed even with the optimization #[test] fn devli_test() { - for s in 0u8..15 { + for s in 0u32..15 { for value in 0..(1 << s) { assert_eq!( devli(s, value), diff --git a/lib/src/jpeg/bit_reader.rs b/lib/src/jpeg/bit_reader.rs index 26be2004..68198747 100644 --- a/lib/src/jpeg/bit_reader.rs +++ b/lib/src/jpeg/bit_reader.rs @@ -8,7 +8,7 @@ use std::io::BufRead; use super::jpeg_code; use crate::helpers::has_ff; -use crate::lepton_error::{ExitCode, err_exit_code}; +use crate::lepton_error::{ExitCode, Result, err_exit_code}; use crate::{LeptonError, StreamPosition}; // Implemenation of bit reader on top of JPEG data stream as read by a reader @@ -59,25 +59,21 @@ impl BitReader { impl BitReader { #[inline(always)] - pub fn read(&mut self, bits_to_read: u32) -> std::io::Result { - if bits_to_read == 0 { - return Ok(0); - } - + pub fn read(&mut self, bits_to_read: u32) -> Result { if self.bits_left < bits_to_read { - self.fill_register(bits_to_read)?; + self.fill_register_slow(bits_to_read)?; } let retval = - (self.bits >> (self.bits_left - bits_to_read) & ((1 << bits_to_read) - 1)) as u16; + ((self.bits >> (self.bits_left - bits_to_read)) as u32) & ((1 << bits_to_read) - 1); self.bits_left -= bits_to_read; return Ok(retval); } #[inline(always)] - pub fn peek(&self) -> (u8, u32) { + pub fn peek(&self) -> (u32, u32) { ( - ((self.bits.wrapping_shl(64 - self.bits_left)) >> 56) as u8, + ((self.bits.wrapping_shl(64 - self.bits_left)) >> 56) as u32, self.bits_left, ) } @@ -87,44 +83,95 @@ impl BitReader { self.bits_left -= bits; } + /// Fills the register as much as possible with the current buffer in + /// a non-destructive manner. This requires the BufRead implementation to + /// have enough space available. #[inline(always)] - pub fn fill_register(&mut self, bits_to_read: u32) -> Result<(), std::io::Error> { + pub fn optimistic_fill(&mut self) { + if self.bits_left < 8 { + self.optimistic_fill_slow(); + } + } + + /// Assuming that there are less than 8 bits in our buffer, fill the rest of the buffer + /// with as many bits as possible, leaving the possibility of unwinding via the undo_read_ahead + /// function in certain corner cases. + #[inline(never)] + #[cold] + fn optimistic_fill_slow(&mut self) { + // for correctness, we need to have less than 8 bits left, otherwise we can't + // consume the read_ahead_bytes and successfully undo the read_ahead if we have to. + debug_assert!(self.bits_left < 8); + // first consume the read_ahead bytes that we have now consumed // (otherwise we wouldn't have been called) self.inner.consume(self.read_ahead_bytes as usize); - let fb = self.inner.fill_buf()?; + if let Ok(fb) = self.inner.fill_buf() { + // if we have 8 bytes and there is no 0xff in them, then we can just read the bits directly as big endian + if fb.len() < 8 { + self.read_ahead_bytes = 0; + return; + } - // if we have 8 bytes and there is no 0xff in them, then we can just read the bits directly as big endian - let mut v; - if fb.len() < 8 || { - v = u64::from_le_bytes(fb[..8].try_into().unwrap()); - has_ff(v) - } { - self.read_ahead_bytes = 0; - return self.fill_register_slow(bits_to_read); - } + let mut v = u64::from_le_bytes(fb[..8].try_into().unwrap()); + if has_ff(v) { + // this is the expensive path, but rarer where there are 0xff bytes in the buffer + let mut bytes_left = 8; + self.read_ahead_bytes = 0; - v = v.to_be(); + while bytes_left >= 2 { + if v & 0xff == 0xff { + if v & 0xff00 != 0 { + // reset marker or end of scan, just exit the loop and let fill_register handle it + break; + } - // only fill 63 bits not 64 to avoid having to special case - // of self.bits << 64 which is a nop - let bytes_to_read = (63 - self.bits_left) / 8; + self.bits = (self.bits << 8) | 0xff; + self.bits_left += 8; + self.read_ahead_bytes += 2; + + v >>= 16; + bytes_left -= 2; + } else { + self.bits = (self.bits << 8) | (v & 0xff); + self.bits_left += 8; + self.read_ahead_bytes += 1; - self.bits = self.bits << (bytes_to_read * 8) | v >> (64 - bytes_to_read * 8); - self.bits_left += bytes_to_read * 8; - self.read_ahead_bytes = (self.bits_left - bits_to_read) / 8; + v >>= 8; + bytes_left -= 1; + } + } + } else { + // no 0xff bytes, just read the bits all at a time and reverse endian to get them in the right order. + v = v.to_be(); - self.inner - .consume((bytes_to_read - self.read_ahead_bytes) as usize); + // only fill 63 bits not 64 to avoid having to special case + // of self.bits << 64 which is a nop + let bytes_to_read = (63 - self.bits_left) / 8; - return Ok(()); + self.bits = self.bits << (bytes_to_read * 8) | v >> (64 - bytes_to_read * 8); + self.bits_left += bytes_to_read * 8; + self.read_ahead_bytes = self.bits_left / 8; + } + } } + /// Fills the register up to the number of bits requested, with the assumption that these + /// will be immediately consumed. + /// + /// This function ends up being called very infrequently since almost all of the time the optimistic_fill ensures + /// that there are enough bits to work with. Effectively this function ends up only being called in corner cases + /// where we are near the end of a BufRead block, at the end of the file or about to hit a reset marker. #[cold] - fn fill_register_slow(&mut self, bits_to_read: u32) -> Result<(), std::io::Error> { - loop { + #[inline(never)] + fn fill_register_slow(&mut self, bits_to_read: u32) -> Result<()> { + self.inner.consume(self.read_ahead_bytes as usize); + self.read_ahead_bytes = 0; + + while self.bits_left < bits_to_read { let fb = self.inner.fill_buf()?; + if let &[b, ..] = fb { self.inner.consume(1); @@ -176,10 +223,6 @@ impl BitReader { // continue since we still might need to read more 0 bits } - - if self.bits_left >= bits_to_read { - break; - } } Ok(()) } @@ -190,10 +233,7 @@ impl BitReader { /// used to verify whether this image is using 1s or 0s as fill bits. /// Returns whether the fill bit was 1 or so or unknown (None) - pub fn read_and_verify_fill_bits( - &mut self, - pad_bit: &mut Option, - ) -> Result<(), LeptonError> { + pub fn read_and_verify_fill_bits(&mut self, pad_bit: &mut Option) -> Result<()> { self.undo_read_ahead(); // if there are bits left, we need to see whether they @@ -222,7 +262,7 @@ impl BitReader { } Some(x) => { // if we already saw a padding, then it should match - let expected = u16::from(x) & all_one; + let expected = u32::from(x) & all_one; if actual != expected { return err_exit_code( ExitCode::InvalidPadding, @@ -239,7 +279,7 @@ impl BitReader { return Ok(()); } - pub fn verify_reset_code(&mut self) -> Result<(), LeptonError> { + pub fn verify_reset_code(&mut self) -> Result<()> { // we reached the end of a MCU, so we need to find a reset code and the huffman codes start get padded out, but the spec // doesn't specify whether the padding should be 1s or 0s, so we ensure that at least the file is consistant so that we // can recode it again just by remembering the pad bit. @@ -285,9 +325,15 @@ impl BitReader { /// the only bits that are left are part of the current byte. pub fn undo_read_ahead(&mut self) { while self.bits_left >= 8 && self.read_ahead_bytes > 0 { + // if it was an 0xff then rewind 2 bytes since this was an escape code + if self.bits & 0xff == 0xff { + self.read_ahead_bytes -= 2; + } else { + self.read_ahead_bytes -= 1; + } + self.bits_left -= 8; self.bits >>= 8; - self.read_ahead_bytes -= 1; } if self.read_ahead_bytes > 0 { diff --git a/lib/src/jpeg/bit_writer.rs b/lib/src/jpeg/bit_writer.rs index a1ec5eff..cff9ad01 100644 --- a/lib/src/jpeg/bit_writer.rs +++ b/lib/src/jpeg/bit_writer.rs @@ -192,7 +192,7 @@ mod tests { let mut r = BitReader::new(Cursor::new(&buf)); for i in 1..2048 { - assert_eq!(i, r.read(u32_bit_length(i as u32) as u32).unwrap()); + assert_eq!(i, r.read(u32_bit_length(i) as u32).unwrap()); } let mut pad = Some(0xff); @@ -231,7 +231,7 @@ mod tests { if rng.gen_range(0..100) == 0 { test_data.push(Action::Pad(0xff)); } else { - test_data.push(Action::Write(v as u16, bits as u8)); + test_data.push(Action::Write(v, bits)); } } test_data.push(Action::Pad(0xff)); @@ -264,16 +264,16 @@ mod tests { let (peekcode, peekbits) = r.peek(); let num_valid_bits = peekbits.min(8).min(u32::from(numbits)); - let mask = (0xff00 >> num_valid_bits) as u8; + let mask = 0xff00u32 >> num_valid_bits; assert_eq!( - expected_peek_byte & mask, + expected_peek_byte as u32 & mask, peekcode & mask, "peek unexpected result" ); assert_eq!( - code, + u32::from(code), r.read(numbits as u32).unwrap(), "read unexpected result" ); diff --git a/lib/src/jpeg/jpeg_read.rs b/lib/src/jpeg/jpeg_read.rs index 2a646f3a..946d0dae 100644 --- a/lib/src/jpeg/jpeg_read.rs +++ b/lib/src/jpeg/jpeg_read.rs @@ -718,12 +718,13 @@ pub(crate) fn decode_block_seq( } /// Reads and decodes next Huffman code from BitReader using the provided tree -#[inline(always)] +#[cold] +#[inline(never)] fn next_huff_code(bit_reader: &mut BitReader, ctree: &HuffTree) -> Result { let mut node: u16 = 0; while node < 256 { - node = ctree.node[usize::from(node)][usize::from(bit_reader.read(1)?)]; + node = ctree.node[usize::from(node)][bit_reader.read(1)? as usize]; } if node == 0xffff { @@ -733,6 +734,7 @@ fn next_huff_code(bit_reader: &mut BitReader, ctree: &HuffTree) - } } +#[inline(always)] fn read_dc(bit_reader: &mut BitReader, tree: &HuffTree) -> Result { let (z, coef) = read_coef(bit_reader, tree)?.unwrap_or((0, 0)); if z != 0 { @@ -750,45 +752,37 @@ fn read_coef( bit_reader: &mut BitReader, tree: &HuffTree, ) -> Result> { - // if the code we found is smaller or equal to the number of bits left, take the shortcut - let hc; - - loop { - // peek ahead to see if we can decode the symbol immediately - // given what has already been read into the bitreader - let (peek_value, peek_len) = bit_reader.peek(); - - // use lookup table to figure out the first code in this byte and how long it is - let (code, code_len) = tree.peek_code[peek_value as usize]; - - if u32::from(code_len) <= peek_len { - // found code directly, so advance by the number of bits immediately - hc = code; - bit_reader.advance(u32::from(code_len)); - break; - } else if peek_len < 8 { - // peek code works with up to 8 bits at a time. If we had less - // than this, then we need to read more bits into the bitreader - bit_reader.fill_register(8)?; - } else { - // take slow path since we have a code that is bigger than 8 bits (but pretty rare) - hc = next_huff_code(bit_reader, tree)?; - break; - } - } + let hc = read_code(bit_reader, tree)?; // analyse code if hc != 0 { let z = usize::from(lbits(hc, 4)); - let literal_bits = rbits(hc, 4); + let literal_bits = u32::from(rbits(hc, 4)); - let value = bit_reader.read(u32::from(literal_bits))?; + let value = bit_reader.read(literal_bits)?; Ok(Some((z, devli(literal_bits, value)))) } else { Ok(None) } } +#[inline(always)] +fn read_code(bit_reader: &mut BitReader, tree: &HuffTree) -> Result { + bit_reader.optimistic_fill(); + let (peek_value, peek_len) = bit_reader.peek(); + let (code, code_len) = tree.peek_code[peek_value as usize]; + + if u32::from(code_len) <= peek_len { + // found code directly, so advance by the number of bits immediately + bit_reader.advance(u32::from(code_len)); + + Ok(code) + } else { + // take slow path since we have a code that is bigger than 8 bits (but pretty rare) + next_huff_code(bit_reader, tree) + } +} + /// progressive AC decoding (first pass) fn decode_ac_prg_fs( bit_reader: &mut BitReader, @@ -804,7 +798,7 @@ fn decode_ac_prg_fs( let mut bpos = from; while bpos <= to { // decode next - let hc = next_huff_code(bit_reader, actree)?; + let hc = read_code(bit_reader, actree)?; let l = lbits(hc, 4); let r = rbits(hc, 4); @@ -813,8 +807,8 @@ fn decode_ac_prg_fs( if (l == 15) || (r > 0) { // decode run/level combination let mut z = l; - let s = r; - let n = bit_reader.read(u32::from(s))?; + let s = u32::from(r); + let n = bit_reader.read(s)?; if (z + bpos) > to { return err_exit_code(ExitCode::UnsupportedJpeg, "run is too long"); } @@ -831,7 +825,7 @@ fn decode_ac_prg_fs( // decode eobrun let s = l; let n = bit_reader.read(u32::from(s))?; - state.eobrun = decode_eobrun_bits(s, n); + state.eobrun = decode_eobrun_bits(s, n) as u16; state.eobrun -= 1; // decrement eobrun ( for this one ) @@ -860,7 +854,7 @@ fn decode_ac_prg_sa( // decode AC succesive approximation bits while bpos <= to { // decode next - let hc = next_huff_code(bit_reader, actree)?; + let hc = read_code(bit_reader, actree)?; let l = lbits(hc, 4); let r = rbits(hc, 4); @@ -909,7 +903,7 @@ fn decode_ac_prg_sa( eob = bpos; let s = l; let n = bit_reader.read(u32::from(s))?; - state.eobrun = decode_eobrun_bits(s, n); + state.eobrun = decode_eobrun_bits(s, n) as u16; // since we hit EOB, the rest can be done with the zero block decoder decode_eobrun_sa(bit_reader, block, state, bpos, to)?; @@ -946,7 +940,7 @@ fn decode_eobrun_sa( /// decoding for decoding eobrun lengths. The encoding chops off the most significant /// bit since it is always 1, so we need to add it back. -fn decode_eobrun_bits(s: u8, n: u16) -> u16 { +fn decode_eobrun_bits(s: u8, n: u32) -> u32 { n + (1 << s) }