diff --git a/libwebauthn/src/proto/ctap2/protocol.rs b/libwebauthn/src/proto/ctap2/protocol.rs index 3376adb..3630a3e 100644 --- a/libwebauthn/src/proto/ctap2/protocol.rs +++ b/libwebauthn/src/proto/ctap2/protocol.rs @@ -1,10 +1,12 @@ use std::time::Duration; use async_trait::async_trait; +use tokio::time::timeout as tokio_timeout; use tracing::{debug, instrument, trace, warn}; -use crate::proto::ctap2::cbor::{self, CborRequest}; +use crate::proto::ctap2::cbor::{self, CborRequest, CborResponse}; use crate::proto::ctap2::{Ctap2BioEnrollmentResponse, Ctap2CommandCode}; +use crate::transport::error::TransportError; use crate::transport::Channel; use crate::unwrap_field; use crate::webauthn::error::{CtapError, Error, PlatformError}; @@ -19,6 +21,21 @@ use super::{ const TIMEOUT_GET_INFO: Duration = Duration::from_millis(250); +/// CBOR send + recv with a wall-clock timeout over the pair. Mirrors +/// `send_apdu_request_wait_uv` in the CTAP1 module. +async fn cbor_send_recv( + channel: &mut C, + request: &CborRequest, + timeout: Duration, +) -> Result { + tokio_timeout(timeout, async { + channel.cbor_send(request, timeout).await?; + channel.cbor_recv(timeout).await + }) + .await + .map_err(|_| Error::Transport(TransportError::Timeout))? +} + macro_rules! parse_cbor { ($type:ty, $data:expr) => {{ match cbor::from_slice::<$type>($data) { @@ -83,8 +100,7 @@ where #[instrument(skip_all)] async fn ctap2_get_info(&mut self) -> Result { let cbor_request = CborRequest::new(Ctap2CommandCode::AuthenticatorGetInfo); - self.cbor_send(&cbor_request, TIMEOUT_GET_INFO).await?; - let cbor_response = self.cbor_recv(TIMEOUT_GET_INFO).await?; + let cbor_response = cbor_send_recv(self, &cbor_request, TIMEOUT_GET_INFO).await?; match cbor_response.status_code { CtapError::Ok => (), error => return Err(Error::Ctap(error)), @@ -103,8 +119,7 @@ where timeout: Duration, ) -> Result { trace!(?request); - self.cbor_send(&request.try_into()?, timeout).await?; - let cbor_response = self.cbor_recv(timeout).await?; + let cbor_response = cbor_send_recv(self, &request.try_into()?, timeout).await?; match cbor_response.status_code { CtapError::Ok => (), error => return Err(Error::Ctap(error)), @@ -124,8 +139,7 @@ where timeout: Duration, ) -> Result { trace!(?request); - self.cbor_send(&request.try_into()?, timeout).await?; - let cbor_response = self.cbor_recv(timeout).await?; + let cbor_response = cbor_send_recv(self, &request.try_into()?, timeout).await?; match cbor_response.status_code { CtapError::Ok => (), error => return Err(Error::Ctap(error)), @@ -145,8 +159,7 @@ where ) -> Result { debug!("CTAP2 GetNextAssertion request"); let cbor_request = CborRequest::new(Ctap2CommandCode::AuthenticatorGetNextAssertion); - self.cbor_send(&cbor_request, timeout).await?; - let cbor_response = self.cbor_recv(timeout).await?; + let cbor_response = cbor_send_recv(self, &cbor_request, timeout).await?; let data = unwrap_field!(cbor_response.data); let ctap_response = parse_cbor!(Ctap2GetAssertionResponse, &data); debug!("CTAP2 GetNextAssertion successful"); @@ -159,8 +172,7 @@ where debug!("CTAP2 Authenticator Selection request"); let cbor_request = CborRequest::new(Ctap2CommandCode::AuthenticatorSelection); - self.cbor_send(&cbor_request, timeout).await?; - let cbor_response = self.cbor_recv(timeout).await?; + let cbor_response = cbor_send_recv(self, &cbor_request, timeout).await?; match cbor_response.status_code { CtapError::Ok => { return Ok(()); @@ -179,8 +191,7 @@ where timeout: Duration, ) -> Result { trace!(?request); - self.cbor_send(&request.try_into()?, timeout).await?; - let cbor_response = self.cbor_recv(timeout).await?; + let cbor_response = cbor_send_recv(self, &request.try_into()?, timeout).await?; match cbor_response.status_code { CtapError::Ok => (), error => return Err(Error::Ctap(error)), @@ -205,8 +216,7 @@ where timeout: Duration, ) -> Result<(), Error> { trace!(?request); - self.cbor_send(&request.try_into()?, timeout).await?; - let cbor_response = self.cbor_recv(timeout).await?; + let cbor_response = cbor_send_recv(self, &request.try_into()?, timeout).await?; match cbor_response.status_code { CtapError::Ok => { return Ok(()); @@ -228,8 +238,7 @@ where timeout: Duration, ) -> Result { trace!(?request); - self.cbor_send(&request.try_into()?, timeout).await?; - let cbor_response = self.cbor_recv(timeout).await?; + let cbor_response = cbor_send_recv(self, &request.try_into()?, timeout).await?; match cbor_response.status_code { CtapError::Ok => (), error => return Err(Error::Ctap(error)), @@ -254,8 +263,7 @@ where timeout: Duration, ) -> Result { trace!(?request); - self.cbor_send(&request.try_into()?, timeout).await?; - let cbor_response = self.cbor_recv(timeout).await?; + let cbor_response = cbor_send_recv(self, &request.try_into()?, timeout).await?; match cbor_response.status_code { CtapError::Ok => (), error => return Err(Error::Ctap(error)), diff --git a/libwebauthn/src/transport/hid/channel.rs b/libwebauthn/src/transport/hid/channel.rs index 36e2c74..9b3ba0d 100644 --- a/libwebauthn/src/transport/hid/channel.rs +++ b/libwebauthn/src/transport/hid/channel.rs @@ -2,16 +2,15 @@ use std::convert::TryFrom; use std::fmt::{Debug, Display, Formatter}; use std::io::{Cursor as IOCursor, Seek, SeekFrom}; use std::ops::DerefMut; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; -use std::time::Duration; +use std::time::{Duration, Instant}; use async_trait::async_trait; use byteorder::{BigEndian, ReadBytesExt}; use hidapi::HidDevice as HidApiDevice; use rand::{thread_rng, Rng}; -use tokio::sync::broadcast; -use tokio::sync::mpsc::error::TryRecvError; -use tokio::sync::mpsc::{self, Receiver, Sender}; +use tokio::sync::{broadcast, Notify}; use tokio::time::sleep; use tracing::{debug, info, instrument, trace, warn, Level}; @@ -44,25 +43,58 @@ const INIT_TIMEOUT: Duration = Duration::from_millis(200); const PACKET_SIZE: usize = 64; const REPORT_ID: u8 = 0x00; +// Per-iteration cap on hidapi::read_timeout. `read_timeout` returns as soon +// as the device delivers a report, so this does NOT add latency to normal +// responses; it only bounds how quickly the loop wakes up to re-check the +// wall-clock deadline and the cancel flag. 100ms is a small fraction of any +// realistic CTAP timeout, gives ~10 wakeups/sec per active channel (cheap +// even on battery), and is short enough that user-perceived cancel latency +// stays well under the round-trip a click already costs. +const HID_READ_POLL_INTERVAL: Duration = Duration::from_millis(100); + // Some devices fail when sending a WINK command followed immediately // by a CBOR command, so we want to ensure we wait some time after winking. const WINK_MIN_WAIT: Duration = Duration::from_secs(2); -pub type CancelHidOperation = (); enum OpenHidDevice { - HidApiDevice(Arc)>>), + HidApiDevice(Arc>), #[cfg(test)] VirtualDevice(Arc>), } +/// Shared cancel state. The atomic flag is checked by the blocking hidapi +/// reader between poll iterations; the notify wakes the async caller so a +/// cancel observed from another task is seen without waiting out the poll +/// interval. +#[derive(Debug, Default)] +struct CancelState { + flag: AtomicBool, + notify: Notify, +} + +impl CancelState { + fn signal(&self) { + self.flag.store(true, Ordering::SeqCst); + self.notify.notify_waiters(); + } + + fn is_cancelled(&self) -> bool { + self.flag.load(Ordering::SeqCst) + } + + fn reset(&self) { + self.flag.store(false, Ordering::SeqCst); + } +} + #[derive(Debug, Clone)] pub struct HidChannelHandle { - tx: Sender, + cancel: Arc, } impl HidChannelHandle { pub async fn cancel_ongoing_operation(&self) { - let _ = self.tx.send(()).await; + self.cancel.signal(); } } @@ -74,6 +106,7 @@ pub struct HidChannel<'d> { auth_token_data: Option, ux_update_sender: broadcast::Sender, handle: HidChannelHandle, + cancel: Arc, #[cfg(test)] pin_protocol_override: Option, } @@ -81,8 +114,10 @@ pub struct HidChannel<'d> { impl<'d> HidChannel<'d> { pub async fn new(device: &'d HidDevice) -> Result, Error> { let (ux_update_sender, _) = broadcast::channel(16); - let (handle_tx, handle_rx) = mpsc::channel(1); - let handle = HidChannelHandle { tx: handle_tx }; + let cancel = Arc::new(CancelState::default()); + let handle = HidChannelHandle { + cancel: cancel.clone(), + }; let mut channel = Self { status: ChannelStatus::Ready, @@ -90,7 +125,7 @@ impl<'d> HidChannel<'d> { open_device: match device.backend { HidBackendDevice::HidApiDevice(_) => { let hidapi_device = Self::hid_open(device)?; - OpenHidDevice::HidApiDevice(Arc::new(Mutex::new((hidapi_device, handle_rx)))) + OpenHidDevice::HidApiDevice(Arc::new(Mutex::new(hidapi_device))) } #[cfg(test)] HidBackendDevice::VirtualDevice => { @@ -101,6 +136,7 @@ impl<'d> HidChannel<'d> { auth_token_data: None, ux_update_sender, handle, + cancel, #[cfg(test)] pin_protocol_override: None, }; @@ -293,15 +329,14 @@ impl<'d> HidChannel<'d> { warn!("Poisoned lock on HID API device"); return Err(Error::Transport(TransportError::ConnectionLost)); }; - let (device, cancel_rx) = guard.deref_mut(); - let response = Self::hid_send_hidapi(device, cancel_rx, msg); + let device = guard.deref_mut(); + let response = Self::hid_send_hidapi(device, &self.cancel, msg); if matches!(response, Err(Error::Platform(PlatformError::Cancelled))) { - // Using hid_send_hidapi directly, instead of hid_cancel, to avoid recursion - let _ = Self::hid_send_hidapi( - device, - cancel_rx, - &HidMessage::new(self.init.cid, HidCommand::Cancel, &[]), - ); + // CTAPHID_CANCEL must still reach the device even though + // the cancel flag is set; bypass the flag check via + // write_packets (also avoids recursing into hid_cancel). + let cancel_msg = HidMessage::new(self.init.cid, HidCommand::Cancel, &[]); + let _ = Self::write_packets(device, &cancel_msg); } response } @@ -318,14 +353,14 @@ impl<'d> HidChannel<'d> { fn hid_send_hidapi( device: &hidapi::HidDevice, - cancel_rx: &mut Receiver, + cancel: &CancelState, msg: &HidMessage, ) -> Result<(), Error> { let packets = msg .packets(PACKET_SIZE) .or(Err(Error::Transport(TransportError::InvalidFraming)))?; for (i, packet) in packets.iter().enumerate() { - if !matches!(cancel_rx.try_recv(), Err(TryRecvError::Empty)) { + if cancel.is_cancelled() { return Err(Error::Platform(PlatformError::Cancelled)); } @@ -341,30 +376,74 @@ impl<'d> HidChannel<'d> { Ok(()) } + /// Send a message without consulting the cancel flag. Used when emitting + /// CTAPHID_CANCEL after a cancellation has already been observed. + fn write_packets(device: &hidapi::HidDevice, msg: &HidMessage) -> Result<(), Error> { + let packets = msg + .packets(PACKET_SIZE) + .or(Err(Error::Transport(TransportError::InvalidFraming)))?; + for packet in &packets { + let mut report: Vec = vec![REPORT_ID]; + report.extend(packet); + report.extend(vec![0; PACKET_SIZE - packet.len()]); + device + .write(&report) + .or(Err(Error::Transport(TransportError::ConnectionLost)))?; + } + Ok(()) + } + #[instrument(skip_all)] pub async fn hid_recv(&self, timeout: Duration) -> Result { + // Reset the cancel flag so a prior cancellation does not short-circuit + // a fresh receive. The drop guard signals the flag if this future is + // dropped before completing, so the blocking reader self-terminates + // within one HID_READ_POLL_INTERVAL. + self.cancel.reset(); + let mut drop_guard = CancelOnDrop::new(&self.cancel); + + let result = self.hid_recv_inner(timeout).await; + drop_guard.disarm(); + result + } + + async fn hid_recv_inner(&self, timeout: Duration) -> Result { loop { let response = match &self.open_device { OpenHidDevice::HidApiDevice(hidapi_device) => { let device = Arc::clone(hidapi_device); + let cancel = Arc::clone(&self.cancel); // The HID device will block when waiting for a user to // interact with the device, so mark the task as blocking to // allow other tasks to complete. // Note that we're just using spawn_blocking() on hid_recv(), not on hid_send(), // since implementing this on hid_send and would cause unnecessary copies/locking. - tokio::task::spawn_blocking(move || { + let read = tokio::task::spawn_blocking(move || { let Ok(mut guard) = device.lock() else { warn!("Poisoned lock on HID API device"); return Err(Error::Transport(TransportError::ConnectionLost)); }; - let (device, cancel_rx) = guard.deref_mut(); - Self::hid_recv_hidapi(device, cancel_rx, timeout) - }) - .await - .map_err(|e| { - warn!(?e, "HID read task failed"); - Error::Transport(TransportError::ConnectionLost) - })? + let device = guard.deref_mut(); + Self::hid_recv_hidapi(device, &cancel, timeout) + }); + tokio::pin!(read); + // Race the blocking read against cancel notifications. + // spawn_blocking cannot be aborted, so the flag store + // here is observed by the reader on its next poll + // (bounded by HID_READ_POLL_INTERVAL). + loop { + tokio::select! { + res = &mut read => { + break res.map_err(|e| { + warn!(?e, "HID read task failed"); + Error::Transport(TransportError::ConnectionLost) + })?; + } + _ = self.cancel.notify.notified() => { + self.cancel.flag.store(true, Ordering::SeqCst); + } + } + } } #[cfg(test)] OpenHidDevice::VirtualDevice(virt_device) => { @@ -384,7 +463,14 @@ impl<'d> HidChannel<'d> { debug!("Ignoring HID keep-alive"); continue; } - Err(Error::Platform(PlatformError::Cancelled)) => { + Err(Error::Platform(PlatformError::Cancelled)) + | Err(Error::Transport(TransportError::Timeout)) => { + // CTAP 2.2 §11.2.9.1.5: send CTAPHID_CANCEL when the + // platform gives up (caller cancelled or wall-clock + // budget exhausted). The blocking reader has released + // the device mutex by now; reset the flag so the send + // itself is not short-circuited. + self.cancel.reset(); let _ = self.hid_cancel().await; break response; } @@ -395,20 +481,41 @@ impl<'d> HidChannel<'d> { fn hid_recv_hidapi( device: &hidapi::HidDevice, - cancel_rx: &mut Receiver, + cancel: &CancelState, timeout: Duration, ) -> Result { let mut parser = HidMessageParser::new(); + let deadline = Instant::now().checked_add(timeout); loop { - if !matches!(cancel_rx.try_recv(), Err(TryRecvError::Empty)) { + if cancel.is_cancelled() { return Err(Error::Platform(PlatformError::Cancelled)); } + // Cap each read at HID_READ_POLL_INTERVAL so we re-check the + // cancel flag and remaining budget; a stalled device cannot + // hang the caller past `timeout`. + let remaining = match deadline { + Some(d) => d.saturating_duration_since(Instant::now()), + None => timeout, + }; + if remaining.is_zero() { + warn!("HID receive timed out before any data was read"); + return Err(Error::Transport(TransportError::Timeout)); + } + let read_for = remaining.min(HID_READ_POLL_INTERVAL); + let mut report = [0; PACKET_SIZE]; - device - .read_timeout(&mut report, timeout.as_millis() as i32) + let bytes_read = device + .read_timeout(&mut report, read_for.as_millis() as i32) .or(Err(Error::Transport(TransportError::ConnectionLost)))?; - debug!({ len = report.len() }, "Received HID report"); + if bytes_read == 0 { + // hidapi signals per-iteration timeout as Ok(0); retry + // against the remaining budget rather than passing the + // zero-initialised buffer to the parser. + trace!("hidapi read_timeout returned 0 bytes, continuing"); + continue; + } + debug!({ len = bytes_read }, "Received HID report"); trace!(?report); if let HidMessageParserState::Done = parser .update(&report) @@ -427,6 +534,34 @@ impl<'d> HidChannel<'d> { } } +/// Signals the cancel flag if its scope exits via panic or future-drop. +/// Call `disarm()` on the normal-return path. +struct CancelOnDrop<'a> { + cancel: &'a CancelState, + armed: bool, +} + +impl<'a> CancelOnDrop<'a> { + fn new(cancel: &'a CancelState) -> Self { + Self { + cancel, + armed: true, + } + } + + fn disarm(&mut self) { + self.armed = false; + } +} + +impl Drop for CancelOnDrop<'_> { + fn drop(&mut self) { + if self.armed { + self.cancel.signal(); + } + } +} + impl Drop for HidChannel<'_> { #[instrument(level = Level::DEBUG, skip_all, fields(dev = %self.device))] fn drop(&mut self) { @@ -435,11 +570,28 @@ impl Drop for HidChannel<'_> { return; } - if let Err(err) = futures::executor::block_on(self.hid_cancel()) { - warn!( - ?err, - "Failed to send hid_cancel on the channel being dropped" - ) + // Lock-free: signal any in-flight blocking read to abort on its + // next poll iteration. Then best-effort emit CTAPHID_CANCEL via + // try_lock; if the device mutex is contended (reader still active) + // we skip — the reader is about to release it and the device's + // own transaction timeout will reclaim the channel. + self.cancel.signal(); + + match &self.open_device { + OpenHidDevice::HidApiDevice(hidapi_device) => match hidapi_device.try_lock() { + Ok(mut guard) => { + let device = guard.deref_mut(); + let cancel_msg = HidMessage::new(self.init.cid, HidCommand::Cancel, &[]); + if let Err(err) = Self::write_packets(device, &cancel_msg) { + debug!(?err, "Best-effort CTAPHID_CANCEL on channel drop failed"); + } + } + Err(_) => { + debug!("Device mutex contended on drop, skipping CTAPHID_CANCEL packet"); + } + }, + #[cfg(test)] + OpenHidDevice::VirtualDevice(_) => {} } } } diff --git a/libwebauthn/src/transport/hid/framing.rs b/libwebauthn/src/transport/hid/framing.rs index 9785100..edd1c08 100644 --- a/libwebauthn/src/transport/hid/framing.rs +++ b/libwebauthn/src/transport/hid/framing.rs @@ -3,12 +3,13 @@ use std::io::{Cursor as IOCursor, Error as IOError, ErrorKind as IOErrorKind}; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; use num_enum::{IntoPrimitive, TryFromPrimitive}; -use tracing::{debug, error}; +use tracing::error; const BROADCAST_CID: u32 = 0xFFFFFFFF; const PACKET_INITIAL_HEADER_SIZE: usize = 7; const PACKET_INITIAL_CMD_MASK: u8 = 0x80; const PACKET_CONT_HEADER_SIZE: usize = 5; +const PACKET_CONT_SEQ_MAX: u8 = 0x7F; #[derive(Debug, IntoPrimitive, TryFromPrimitive, Copy, Clone, PartialEq)] #[repr(u8)] @@ -121,17 +122,69 @@ impl HidMessageParser { if (self.packets.is_empty() && packet.len() < PACKET_INITIAL_HEADER_SIZE) || packet.len() < PACKET_CONT_HEADER_SIZE + 1 { - error!("Packet length in invalid"); + error!("Packet length is invalid"); return Err(IOError::new( IOErrorKind::InvalidInput, "Packet length is invalid", )); } + + // CID 0x00000000 is reserved (CTAP 2.2 §11.2.4); reject all-zero frames. if packet.iter().all(|&b| b == 0) { - debug!("Received unexpected packet of all zeroes, ignoring"); // ?! + error!("Received all-zero packet, rejecting"); + return Err(IOError::new( + IOErrorKind::InvalidData, + "All-zero packet is not a valid CTAPHID frame", + )); + } + + if self.packets.is_empty() { + // First packet must be an initialization packet: high bit of + // byte 4 set (CTAP 2.2 §11.2.4). + if packet[4] & PACKET_INITIAL_CMD_MASK == 0 { + error!("First packet is not an initialization packet"); + return Err(IOError::new( + IOErrorKind::InvalidData, + "First packet must be an initialization packet", + )); + } } else { - self.packets.push(Vec::from(packet)); + // Continuation packets: same CID as the initial packet, SEQ has + // high bit cleared, SEQ starts at 0 and increments monotonically. + let initial = &self.packets[0]; + if packet[..4] != initial[..4] { + error!("Continuation packet CID does not match initial packet"); + return Err(IOError::new( + IOErrorKind::InvalidData, + "Continuation packet CID mismatch", + )); + } + let seq = packet[4]; + if seq & PACKET_INITIAL_CMD_MASK != 0 { + error!(seq, "Unexpected init packet during continuation"); + return Err(IOError::new( + IOErrorKind::InvalidData, + "Unexpected initialization packet during continuation", + )); + } + let expected_seq = (self.packets.len() - 1) as u8; + if expected_seq > PACKET_CONT_SEQ_MAX { + error!(seq, "Continuation count exceeds maximum SEQ"); + return Err(IOError::new( + IOErrorKind::InvalidData, + "Too many continuation packets", + )); + } + if seq != expected_seq { + error!(seq, expected_seq, "Out-of-order continuation SEQ"); + return Err(IOError::new( + IOErrorKind::InvalidData, + "Out-of-order continuation SEQ", + )); + } } + + self.packets.push(Vec::from(packet)); if self.more_packets_needed() { Ok(HidMessageParserState::MorePacketsExpected) } else { @@ -292,4 +345,91 @@ mod tests { assert_eq!(msg.cmd, HidCommand::Msg); assert_eq!(msg.payload, vec![0x0A, 0x0B, 0x0C, 0x0D, 0x0E]); } + + #[test] + fn parse_continuation_with_wrong_cid_is_rejected() { + let mut parser = HidMessageParser::new(); + assert_eq!( + parser + .update(&[0xC0, 0xC1, 0xC2, 0xC3, 0x83, 0x00, 0x05, 0x0A]) + .unwrap(), + HidMessageParserState::MorePacketsExpected + ); + // Continuation from a different channel. + let err = parser + .update(&[0xD0, 0xD1, 0xD2, 0xD3, 0x00, 0x0B, 0x0C]) + .unwrap_err(); + assert_eq!(err.kind(), IOErrorKind::InvalidData); + } + + #[test] + fn parse_continuation_with_non_zero_first_seq_is_rejected() { + let mut parser = HidMessageParser::new(); + assert_eq!( + parser + .update(&[0xC0, 0xC1, 0xC2, 0xC3, 0x83, 0x00, 0x05, 0x0A]) + .unwrap(), + HidMessageParserState::MorePacketsExpected + ); + // First continuation must have SEQ=0. + let err = parser + .update(&[0xC0, 0xC1, 0xC2, 0xC3, 0x01, 0x0B, 0x0C]) + .unwrap_err(); + assert_eq!(err.kind(), IOErrorKind::InvalidData); + } + + #[test] + fn parse_continuation_with_non_monotonic_seq_is_rejected() { + let mut parser = HidMessageParser::new(); + assert_eq!( + parser + .update(&[0xC0, 0xC1, 0xC2, 0xC3, 0x83, 0x00, 0x07, 0x0A]) + .unwrap(), + HidMessageParserState::MorePacketsExpected + ); + assert_eq!( + parser + .update(&[0xC0, 0xC1, 0xC2, 0xC3, 0x00, 0x0B, 0x0C]) + .unwrap(), + HidMessageParserState::MorePacketsExpected + ); + // Skipping SEQ=1 and jumping to SEQ=2 is not allowed. + let err = parser + .update(&[0xC0, 0xC1, 0xC2, 0xC3, 0x02, 0x0D, 0x0E]) + .unwrap_err(); + assert_eq!(err.kind(), IOErrorKind::InvalidData); + } + + #[test] + fn parse_init_packet_after_init_is_rejected() { + let mut parser = HidMessageParser::new(); + assert_eq!( + parser + .update(&[0xC0, 0xC1, 0xC2, 0xC3, 0x83, 0x00, 0x05, 0x0A]) + .unwrap(), + HidMessageParserState::MorePacketsExpected + ); + // Another init packet (high bit set on byte 4) for a new transaction. + let err = parser + .update(&[0xC0, 0xC1, 0xC2, 0xC3, 0x83, 0x00, 0x05, 0x0B]) + .unwrap_err(); + assert_eq!(err.kind(), IOErrorKind::InvalidData); + } + + #[test] + fn parse_all_zero_packet_is_rejected() { + let mut parser = HidMessageParser::new(); + let err = parser.update(&[0u8; 64]).unwrap_err(); + assert_eq!(err.kind(), IOErrorKind::InvalidData); + } + + #[test] + fn parse_first_packet_must_be_init_packet() { + // High bit of byte 4 cleared means continuation; not allowed first. + let mut parser = HidMessageParser::new(); + let err = parser + .update(&[0xC0, 0xC1, 0xC2, 0xC3, 0x00, 0x00, 0x05, 0x0A]) + .unwrap_err(); + assert_eq!(err.kind(), IOErrorKind::InvalidData); + } }