diff --git a/libwebauthn/src/proto/ctap2/protocol.rs b/libwebauthn/src/proto/ctap2/protocol.rs index 3376adb..3630a3e 100644 --- a/libwebauthn/src/proto/ctap2/protocol.rs +++ b/libwebauthn/src/proto/ctap2/protocol.rs @@ -1,10 +1,12 @@ use std::time::Duration; use async_trait::async_trait; +use tokio::time::timeout as tokio_timeout; use tracing::{debug, instrument, trace, warn}; -use crate::proto::ctap2::cbor::{self, CborRequest}; +use crate::proto::ctap2::cbor::{self, CborRequest, CborResponse}; use crate::proto::ctap2::{Ctap2BioEnrollmentResponse, Ctap2CommandCode}; +use crate::transport::error::TransportError; use crate::transport::Channel; use crate::unwrap_field; use crate::webauthn::error::{CtapError, Error, PlatformError}; @@ -19,6 +21,21 @@ use super::{ const TIMEOUT_GET_INFO: Duration = Duration::from_millis(250); +/// CBOR send + recv with a wall-clock timeout over the pair. Mirrors +/// `send_apdu_request_wait_uv` in the CTAP1 module. +async fn cbor_send_recv( + channel: &mut C, + request: &CborRequest, + timeout: Duration, +) -> Result { + tokio_timeout(timeout, async { + channel.cbor_send(request, timeout).await?; + channel.cbor_recv(timeout).await + }) + .await + .map_err(|_| Error::Transport(TransportError::Timeout))? +} + macro_rules! parse_cbor { ($type:ty, $data:expr) => {{ match cbor::from_slice::<$type>($data) { @@ -83,8 +100,7 @@ where #[instrument(skip_all)] async fn ctap2_get_info(&mut self) -> Result { let cbor_request = CborRequest::new(Ctap2CommandCode::AuthenticatorGetInfo); - self.cbor_send(&cbor_request, TIMEOUT_GET_INFO).await?; - let cbor_response = self.cbor_recv(TIMEOUT_GET_INFO).await?; + let cbor_response = cbor_send_recv(self, &cbor_request, TIMEOUT_GET_INFO).await?; match cbor_response.status_code { CtapError::Ok => (), error => return Err(Error::Ctap(error)), @@ -103,8 +119,7 @@ where timeout: Duration, ) -> Result { trace!(?request); - self.cbor_send(&request.try_into()?, timeout).await?; - let cbor_response = self.cbor_recv(timeout).await?; + let cbor_response = cbor_send_recv(self, &request.try_into()?, timeout).await?; match cbor_response.status_code { CtapError::Ok => (), error => return Err(Error::Ctap(error)), @@ -124,8 +139,7 @@ where timeout: Duration, ) -> Result { trace!(?request); - self.cbor_send(&request.try_into()?, timeout).await?; - let cbor_response = self.cbor_recv(timeout).await?; + let cbor_response = cbor_send_recv(self, &request.try_into()?, timeout).await?; match cbor_response.status_code { CtapError::Ok => (), error => return Err(Error::Ctap(error)), @@ -145,8 +159,7 @@ where ) -> Result { debug!("CTAP2 GetNextAssertion request"); let cbor_request = CborRequest::new(Ctap2CommandCode::AuthenticatorGetNextAssertion); - self.cbor_send(&cbor_request, timeout).await?; - let cbor_response = self.cbor_recv(timeout).await?; + let cbor_response = cbor_send_recv(self, &cbor_request, timeout).await?; let data = unwrap_field!(cbor_response.data); let ctap_response = parse_cbor!(Ctap2GetAssertionResponse, &data); debug!("CTAP2 GetNextAssertion successful"); @@ -159,8 +172,7 @@ where debug!("CTAP2 Authenticator Selection request"); let cbor_request = CborRequest::new(Ctap2CommandCode::AuthenticatorSelection); - self.cbor_send(&cbor_request, timeout).await?; - let cbor_response = self.cbor_recv(timeout).await?; + let cbor_response = cbor_send_recv(self, &cbor_request, timeout).await?; match cbor_response.status_code { CtapError::Ok => { return Ok(()); @@ -179,8 +191,7 @@ where timeout: Duration, ) -> Result { trace!(?request); - self.cbor_send(&request.try_into()?, timeout).await?; - let cbor_response = self.cbor_recv(timeout).await?; + let cbor_response = cbor_send_recv(self, &request.try_into()?, timeout).await?; match cbor_response.status_code { CtapError::Ok => (), error => return Err(Error::Ctap(error)), @@ -205,8 +216,7 @@ where timeout: Duration, ) -> Result<(), Error> { trace!(?request); - self.cbor_send(&request.try_into()?, timeout).await?; - let cbor_response = self.cbor_recv(timeout).await?; + let cbor_response = cbor_send_recv(self, &request.try_into()?, timeout).await?; match cbor_response.status_code { CtapError::Ok => { return Ok(()); @@ -228,8 +238,7 @@ where timeout: Duration, ) -> Result { trace!(?request); - self.cbor_send(&request.try_into()?, timeout).await?; - let cbor_response = self.cbor_recv(timeout).await?; + let cbor_response = cbor_send_recv(self, &request.try_into()?, timeout).await?; match cbor_response.status_code { CtapError::Ok => (), error => return Err(Error::Ctap(error)), @@ -254,8 +263,7 @@ where timeout: Duration, ) -> Result { trace!(?request); - self.cbor_send(&request.try_into()?, timeout).await?; - let cbor_response = self.cbor_recv(timeout).await?; + let cbor_response = cbor_send_recv(self, &request.try_into()?, timeout).await?; match cbor_response.status_code { CtapError::Ok => (), error => return Err(Error::Ctap(error)), diff --git a/libwebauthn/src/transport/hid/channel.rs b/libwebauthn/src/transport/hid/channel.rs index e56f7aa..240f4bf 100644 --- a/libwebauthn/src/transport/hid/channel.rs +++ b/libwebauthn/src/transport/hid/channel.rs @@ -3,7 +3,7 @@ use std::fmt::{Debug, Display, Formatter}; use std::io::{Cursor as IOCursor, Seek, SeekFrom}; use std::ops::DerefMut; use std::sync::{Arc, Mutex}; -use std::time::Duration; +use std::time::{Duration, Instant}; use async_trait::async_trait; use byteorder::{BigEndian, ReadBytesExt}; @@ -44,6 +44,13 @@ const INIT_TIMEOUT: Duration = Duration::from_millis(200); const PACKET_SIZE: usize = 64; const REPORT_ID: u8 = 0x00; +// Per-iteration cap on hidapi::read_timeout. `read_timeout` returns as soon +// as the device delivers a report, so this does NOT add latency to normal +// responses; it only bounds how quickly the loop wakes up to re-check the +// wall-clock deadline and the cancel signal. 100ms is a small fraction of +// any realistic CTAP timeout and gives ~10 wakeups/sec per active channel. +const HID_READ_POLL_INTERVAL: Duration = Duration::from_millis(100); + // Some devices fail when sending a WINK command followed immediately // by a CBOR command, so we want to ensure we wait some time after winking. const WINK_MIN_WAIT: Duration = Duration::from_secs(2); @@ -383,7 +390,11 @@ impl<'d> HidChannel<'d> { debug!("Ignoring HID keep-alive"); continue; } - Err(Error::Platform(PlatformError::Cancelled)) => { + Err(Error::Platform(PlatformError::Cancelled)) + | Err(Error::Transport(TransportError::Timeout)) => { + // CTAP 2.2 §11.2.9.1.5: send CTAPHID_CANCEL when the + // platform gives up (caller cancelled or wall-clock + // budget exhausted). let _ = self.hid_cancel().await; break response; } @@ -398,16 +409,37 @@ impl<'d> HidChannel<'d> { timeout: Duration, ) -> Result { let mut parser = HidMessageParser::new(); + let deadline = Instant::now().checked_add(timeout); loop { if !matches!(cancel_rx.try_recv(), Err(TryRecvError::Empty)) { return Err(Error::Platform(PlatformError::Cancelled)); } + // Cap each read at HID_READ_POLL_INTERVAL so we re-check the + // cancel channel and remaining budget; a stalled device cannot + // hang the caller past `timeout`. + let remaining = match deadline { + Some(d) => d.saturating_duration_since(Instant::now()), + None => timeout, + }; + if remaining.is_zero() { + warn!("HID receive timed out before any data was read"); + return Err(Error::Transport(TransportError::Timeout)); + } + let read_for = remaining.min(HID_READ_POLL_INTERVAL); + let mut report = [0; PACKET_SIZE]; - device - .read_timeout(&mut report, timeout.as_millis() as i32) + let bytes_read = device + .read_timeout(&mut report, read_for.as_millis() as i32) .or(Err(Error::Transport(TransportError::ConnectionLost)))?; - debug!({ len = report.len() }, "Received HID report"); + if bytes_read == 0 { + // hidapi signals per-iteration timeout as Ok(0); retry + // against the remaining budget rather than passing the + // zero-initialised buffer to the parser. + trace!("hidapi read_timeout returned 0 bytes, continuing"); + continue; + } + debug!({ len = bytes_read }, "Received HID report"); trace!(?report); if let HidMessageParserState::Done = parser .update(&report) diff --git a/libwebauthn/src/transport/hid/framing.rs b/libwebauthn/src/transport/hid/framing.rs index 9785100..edd1c08 100644 --- a/libwebauthn/src/transport/hid/framing.rs +++ b/libwebauthn/src/transport/hid/framing.rs @@ -3,12 +3,13 @@ use std::io::{Cursor as IOCursor, Error as IOError, ErrorKind as IOErrorKind}; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; use num_enum::{IntoPrimitive, TryFromPrimitive}; -use tracing::{debug, error}; +use tracing::error; const BROADCAST_CID: u32 = 0xFFFFFFFF; const PACKET_INITIAL_HEADER_SIZE: usize = 7; const PACKET_INITIAL_CMD_MASK: u8 = 0x80; const PACKET_CONT_HEADER_SIZE: usize = 5; +const PACKET_CONT_SEQ_MAX: u8 = 0x7F; #[derive(Debug, IntoPrimitive, TryFromPrimitive, Copy, Clone, PartialEq)] #[repr(u8)] @@ -121,17 +122,69 @@ impl HidMessageParser { if (self.packets.is_empty() && packet.len() < PACKET_INITIAL_HEADER_SIZE) || packet.len() < PACKET_CONT_HEADER_SIZE + 1 { - error!("Packet length in invalid"); + error!("Packet length is invalid"); return Err(IOError::new( IOErrorKind::InvalidInput, "Packet length is invalid", )); } + + // CID 0x00000000 is reserved (CTAP 2.2 §11.2.4); reject all-zero frames. if packet.iter().all(|&b| b == 0) { - debug!("Received unexpected packet of all zeroes, ignoring"); // ?! + error!("Received all-zero packet, rejecting"); + return Err(IOError::new( + IOErrorKind::InvalidData, + "All-zero packet is not a valid CTAPHID frame", + )); + } + + if self.packets.is_empty() { + // First packet must be an initialization packet: high bit of + // byte 4 set (CTAP 2.2 §11.2.4). + if packet[4] & PACKET_INITIAL_CMD_MASK == 0 { + error!("First packet is not an initialization packet"); + return Err(IOError::new( + IOErrorKind::InvalidData, + "First packet must be an initialization packet", + )); + } } else { - self.packets.push(Vec::from(packet)); + // Continuation packets: same CID as the initial packet, SEQ has + // high bit cleared, SEQ starts at 0 and increments monotonically. + let initial = &self.packets[0]; + if packet[..4] != initial[..4] { + error!("Continuation packet CID does not match initial packet"); + return Err(IOError::new( + IOErrorKind::InvalidData, + "Continuation packet CID mismatch", + )); + } + let seq = packet[4]; + if seq & PACKET_INITIAL_CMD_MASK != 0 { + error!(seq, "Unexpected init packet during continuation"); + return Err(IOError::new( + IOErrorKind::InvalidData, + "Unexpected initialization packet during continuation", + )); + } + let expected_seq = (self.packets.len() - 1) as u8; + if expected_seq > PACKET_CONT_SEQ_MAX { + error!(seq, "Continuation count exceeds maximum SEQ"); + return Err(IOError::new( + IOErrorKind::InvalidData, + "Too many continuation packets", + )); + } + if seq != expected_seq { + error!(seq, expected_seq, "Out-of-order continuation SEQ"); + return Err(IOError::new( + IOErrorKind::InvalidData, + "Out-of-order continuation SEQ", + )); + } } + + self.packets.push(Vec::from(packet)); if self.more_packets_needed() { Ok(HidMessageParserState::MorePacketsExpected) } else { @@ -292,4 +345,91 @@ mod tests { assert_eq!(msg.cmd, HidCommand::Msg); assert_eq!(msg.payload, vec![0x0A, 0x0B, 0x0C, 0x0D, 0x0E]); } + + #[test] + fn parse_continuation_with_wrong_cid_is_rejected() { + let mut parser = HidMessageParser::new(); + assert_eq!( + parser + .update(&[0xC0, 0xC1, 0xC2, 0xC3, 0x83, 0x00, 0x05, 0x0A]) + .unwrap(), + HidMessageParserState::MorePacketsExpected + ); + // Continuation from a different channel. + let err = parser + .update(&[0xD0, 0xD1, 0xD2, 0xD3, 0x00, 0x0B, 0x0C]) + .unwrap_err(); + assert_eq!(err.kind(), IOErrorKind::InvalidData); + } + + #[test] + fn parse_continuation_with_non_zero_first_seq_is_rejected() { + let mut parser = HidMessageParser::new(); + assert_eq!( + parser + .update(&[0xC0, 0xC1, 0xC2, 0xC3, 0x83, 0x00, 0x05, 0x0A]) + .unwrap(), + HidMessageParserState::MorePacketsExpected + ); + // First continuation must have SEQ=0. + let err = parser + .update(&[0xC0, 0xC1, 0xC2, 0xC3, 0x01, 0x0B, 0x0C]) + .unwrap_err(); + assert_eq!(err.kind(), IOErrorKind::InvalidData); + } + + #[test] + fn parse_continuation_with_non_monotonic_seq_is_rejected() { + let mut parser = HidMessageParser::new(); + assert_eq!( + parser + .update(&[0xC0, 0xC1, 0xC2, 0xC3, 0x83, 0x00, 0x07, 0x0A]) + .unwrap(), + HidMessageParserState::MorePacketsExpected + ); + assert_eq!( + parser + .update(&[0xC0, 0xC1, 0xC2, 0xC3, 0x00, 0x0B, 0x0C]) + .unwrap(), + HidMessageParserState::MorePacketsExpected + ); + // Skipping SEQ=1 and jumping to SEQ=2 is not allowed. + let err = parser + .update(&[0xC0, 0xC1, 0xC2, 0xC3, 0x02, 0x0D, 0x0E]) + .unwrap_err(); + assert_eq!(err.kind(), IOErrorKind::InvalidData); + } + + #[test] + fn parse_init_packet_after_init_is_rejected() { + let mut parser = HidMessageParser::new(); + assert_eq!( + parser + .update(&[0xC0, 0xC1, 0xC2, 0xC3, 0x83, 0x00, 0x05, 0x0A]) + .unwrap(), + HidMessageParserState::MorePacketsExpected + ); + // Another init packet (high bit set on byte 4) for a new transaction. + let err = parser + .update(&[0xC0, 0xC1, 0xC2, 0xC3, 0x83, 0x00, 0x05, 0x0B]) + .unwrap_err(); + assert_eq!(err.kind(), IOErrorKind::InvalidData); + } + + #[test] + fn parse_all_zero_packet_is_rejected() { + let mut parser = HidMessageParser::new(); + let err = parser.update(&[0u8; 64]).unwrap_err(); + assert_eq!(err.kind(), IOErrorKind::InvalidData); + } + + #[test] + fn parse_first_packet_must_be_init_packet() { + // High bit of byte 4 cleared means continuation; not allowed first. + let mut parser = HidMessageParser::new(); + let err = parser + .update(&[0xC0, 0xC1, 0xC2, 0xC3, 0x00, 0x00, 0x05, 0x0A]) + .unwrap_err(); + assert_eq!(err.kind(), IOErrorKind::InvalidData); + } }