diff --git a/src/lib.rs b/src/lib.rs index fc446caa..627de0ab 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -294,6 +294,78 @@ fn is_dtls12_psk_only(config: &Config) -> bool { .is_some_and(|first| first.is_psk() && suites.all(|s| s.is_psk())) } +/// Peek at a buffered DTLS 1.2 ClientHello to decide whether the auto-sense +/// server fallback should construct a PSK-mode Server12. +/// +/// Walks the client's offered cipher suites in order and returns `true` iff +/// the first one allowed by `config` is a PSK suite. This mirrors the suite +/// selection inside `Server12` itself, so the chosen auth mode matches the +/// suite that `Server12` will pick once it reprocesses the ClientHello. +/// +/// Returns `false` if `packet` is not a ClientHello or if parsing fails — +/// a fragmented ClientHello (fragment_offset > 0) is skipped and the next +/// buffered packet is tried by the caller. +fn client_hello_wants_psk(packet: &[u8], config: &Config) -> bool { + use dtls12::message::Dtls12CipherSuite; + + // DTLS record header: content_type(1) + version(2) + epoch(2) + seq(6) + length(2) = 13 + if packet.len() < 13 || packet[0] != 0x16 { + return false; + } + let record_len = u16::from_be_bytes([packet[11], packet[12]]) as usize; + let Some(record_body) = packet.get(13..13 + record_len) else { + return false; + }; + + // Handshake header: msg_type(1) + length(3) + message_seq(2) + + // fragment_offset(3) + fragment_length(3) = 12 + if record_body.len() < 12 || record_body[0] != 0x01 { + return false; + } + + let frag_off = + ((record_body[6] as u32) << 16) | ((record_body[7] as u32) << 8) | record_body[8] as u32; + if frag_off != 0 { + return false; + } + + let frag_len = ((record_body[9] as usize) << 16) + | ((record_body[10] as usize) << 8) + | record_body[11] as usize; + let Some(body) = record_body.get(12..12 + frag_len) else { + return false; + }; + + // ClientHello body: client_version(2) + random(32) + session_id(var) + + // cookie(var) + cipher_suites(var) + ... + let mut pos = 2 + 32; + let Some(&sid_len) = body.get(pos) else { + return false; + }; + pos += 1 + sid_len as usize; + let Some(&cookie_len) = body.get(pos) else { + return false; + }; + pos += 1 + cookie_len as usize; + if pos + 2 > body.len() { + return false; + } + let suites_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize; + pos += 2; + if pos + suites_len > body.len() || suites_len % 2 != 0 { + return false; + } + + let allowed: Vec<_> = config.dtls12_cipher_suites().map(|cs| cs.suite()).collect(); + for chunk in body[pos..pos + suites_len].chunks_exact(2) { + let suite = Dtls12CipherSuite::from_u16(u16::from_be_bytes([chunk[0], chunk[1]])); + if allowed.contains(&suite) { + return suite.is_psk(); + } + } + false +} + impl Dtls { /// Create a new DTLS 1.2 instance in the server role. /// @@ -548,7 +620,18 @@ impl Dtls { let (config, cert, now, buffered) = server.into_parts(); - let mut server12 = Server12::new(config, cert, now); + // A Server12 instance is either cert-auth or PSK-auth — the auth + // mode must be chosen before construction. Peek at the buffered + // ClientHello to see which cipher suite the server would pick, + // so PSK clients survive the fallback. + let use_psk = + config.psk().is_some() && buffered.iter().any(|p| client_hello_wants_psk(p, &config)); + + let mut server12 = if use_psk { + Server12::new_psk(config, now) + } else { + Server12::new(config, cert, now) + }; server12.handle_timeout(now)?; self.inner = Some(Inner::Server12(server12)); diff --git a/tests/auto/server_fallback.rs b/tests/auto/server_fallback.rs index d4e9f978..b694b241 100644 --- a/tests/auto/server_fallback.rs +++ b/tests/auto/server_fallback.rs @@ -7,7 +7,7 @@ use std::sync::Arc; use std::time::{Duration, Instant}; -use dimpl::{Dtls, Error, Output, ProtocolVersion}; +use dimpl::{Config, Dtls, Error, Output, ProtocolVersion, PskResolver}; use crate::common::*; @@ -631,6 +631,57 @@ fn auto_server_fragmented_ch_no_cookie() { assert_eq!(sv, Some(ProtocolVersion::DTLS1_3)); } +// ============================================================================ +// Auto server + DTLS 1.2 PSK client → fallback picks PSK-mode Server12 +// ============================================================================ + +/// Regression for https://github.com/algesten/dimpl/issues/100 — a +/// `Dtls::new_auto` server configured with `with_psk_server` must accept a +/// DTLS 1.2 PSK client. Before the fix the fallback always constructed a +/// certificate-auth Server12 and failed with "No mutually acceptable cipher +/// suite". +#[test] +#[cfg(feature = "rcgen")] +fn auto_server_psk_fallback_with_dtls12_psk_client() { + use dimpl::certificate::generate_self_signed_certificate; + + let _ = env_logger::try_init(); + + struct FixedPsk; + impl PskResolver for FixedPsk { + fn resolve(&self, _identity: &[u8]) -> Option> { + Some(b"0123456789abcdef".to_vec()) + } + } + + let server_cert = generate_self_signed_certificate().unwrap(); + + let client_config = Arc::new( + Config::builder() + .with_psk_client(b"test-device".to_vec(), Arc::new(FixedPsk)) + .build() + .expect("build PSK client config"), + ); + let server_config = Arc::new( + Config::builder() + .with_psk_server(Some(b"hint".to_vec()), Arc::new(FixedPsk)) + .build() + .expect("build PSK server config"), + ); + + let mut client = Dtls::new_12_psk(client_config, Instant::now()); + client.set_active(true); + + let mut server = Dtls::new_auto(server_config, server_cert, Instant::now()); + + let (cc, sc, cv, sv) = run_handshake(&mut client, &mut server); + + assert!(cc, "PSK client should connect after auto-server fallback"); + assert!(sc, "Auto server should connect to DTLS 1.2 PSK client"); + assert_eq!(cv, Some(ProtocolVersion::DTLS1_2)); + assert_eq!(sv, Some(ProtocolVersion::DTLS1_2)); +} + /// Fragmented DTLS 1.3 ClientHello → keying material matches between client and auto server. #[test] #[cfg(feature = "rcgen")]