Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 84 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,78 @@ fn is_dtls12_psk_only(config: &Config) -> bool {
.is_some_and(|first| first.is_psk() && suites.all(|s| s.is_psk()))
}

/// Peek at a buffered DTLS 1.2 ClientHello to decide whether the auto-sense
/// server fallback should construct a PSK-mode Server12.
///
/// Walks the client's offered cipher suites in order and returns `true` iff
/// the first one allowed by `config` is a PSK suite. This mirrors the suite
/// selection inside `Server12` itself, so the chosen auth mode matches the
/// suite that `Server12` will pick once it reprocesses the ClientHello.
///
/// Returns `false` if `packet` is not a ClientHello or if parsing fails —
/// a fragmented ClientHello (fragment_offset > 0) is skipped and the next
/// buffered packet is tried by the caller.
fn client_hello_wants_psk(packet: &[u8], config: &Config) -> bool {
use dtls12::message::Dtls12CipherSuite;

// DTLS record header: content_type(1) + version(2) + epoch(2) + seq(6) + length(2) = 13
if packet.len() < 13 || packet[0] != 0x16 {
return false;
}
let record_len = u16::from_be_bytes([packet[11], packet[12]]) as usize;
let Some(record_body) = packet.get(13..13 + record_len) else {
return false;
};

// Handshake header: msg_type(1) + length(3) + message_seq(2) +
// fragment_offset(3) + fragment_length(3) = 12
if record_body.len() < 12 || record_body[0] != 0x01 {
return false;
}

let frag_off =
((record_body[6] as u32) << 16) | ((record_body[7] as u32) << 8) | record_body[8] as u32;
if frag_off != 0 {
return false;
}

let frag_len = ((record_body[9] as usize) << 16)
| ((record_body[10] as usize) << 8)
| record_body[11] as usize;
let Some(body) = record_body.get(12..12 + frag_len) else {
return false;
};

// ClientHello body: client_version(2) + random(32) + session_id(var) +
// cookie(var) + cipher_suites(var) + ...
let mut pos = 2 + 32;
let Some(&sid_len) = body.get(pos) else {
return false;
};
pos += 1 + sid_len as usize;
let Some(&cookie_len) = body.get(pos) else {
return false;
};
pos += 1 + cookie_len as usize;
if pos + 2 > body.len() {
return false;
}
let suites_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize;
pos += 2;
if pos + suites_len > body.len() || suites_len % 2 != 0 {
return false;
}

let allowed: Vec<_> = config.dtls12_cipher_suites().map(|cs| cs.suite()).collect();
for chunk in body[pos..pos + suites_len].chunks_exact(2) {
let suite = Dtls12CipherSuite::from_u16(u16::from_be_bytes([chunk[0], chunk[1]]));
if allowed.contains(&suite) {
return suite.is_psk();
}
}
false
}

impl Dtls {
/// Create a new DTLS 1.2 instance in the server role.
///
Expand Down Expand Up @@ -548,7 +620,18 @@ impl Dtls {

let (config, cert, now, buffered) = server.into_parts();

let mut server12 = Server12::new(config, cert, now);
// A Server12 instance is either cert-auth or PSK-auth — the auth
// mode must be chosen before construction. Peek at the buffered
// ClientHello to see which cipher suite the server would pick,
// so PSK clients survive the fallback.
let use_psk =
config.psk().is_some() && buffered.iter().any(|p| client_hello_wants_psk(p, &config));

let mut server12 = if use_psk {
Server12::new_psk(config, now)
} else {
Server12::new(config, cert, now)
};
server12.handle_timeout(now)?;

self.inner = Some(Inner::Server12(server12));
Expand Down
53 changes: 52 additions & 1 deletion tests/auto/server_fallback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
use std::sync::Arc;
use std::time::{Duration, Instant};

use dimpl::{Dtls, Error, Output, ProtocolVersion};
use dimpl::{Config, Dtls, Error, Output, ProtocolVersion, PskResolver};

use crate::common::*;

Expand Down Expand Up @@ -631,6 +631,57 @@ fn auto_server_fragmented_ch_no_cookie() {
assert_eq!(sv, Some(ProtocolVersion::DTLS1_3));
}

// ============================================================================
// Auto server + DTLS 1.2 PSK client → fallback picks PSK-mode Server12
// ============================================================================

/// Regression for https://github.com/algesten/dimpl/issues/100 — a
/// `Dtls::new_auto` server configured with `with_psk_server` must accept a
/// DTLS 1.2 PSK client. Before the fix the fallback always constructed a
/// certificate-auth Server12 and failed with "No mutually acceptable cipher
/// suite".
#[test]
#[cfg(feature = "rcgen")]
fn auto_server_psk_fallback_with_dtls12_psk_client() {
use dimpl::certificate::generate_self_signed_certificate;

let _ = env_logger::try_init();

struct FixedPsk;
impl PskResolver for FixedPsk {
fn resolve(&self, _identity: &[u8]) -> Option<Vec<u8>> {
Some(b"0123456789abcdef".to_vec())
}
}

let server_cert = generate_self_signed_certificate().unwrap();

let client_config = Arc::new(
Config::builder()
.with_psk_client(b"test-device".to_vec(), Arc::new(FixedPsk))
.build()
.expect("build PSK client config"),
);
let server_config = Arc::new(
Config::builder()
.with_psk_server(Some(b"hint".to_vec()), Arc::new(FixedPsk))
.build()
.expect("build PSK server config"),
);

let mut client = Dtls::new_12_psk(client_config, Instant::now());
client.set_active(true);

let mut server = Dtls::new_auto(server_config, server_cert, Instant::now());

let (cc, sc, cv, sv) = run_handshake(&mut client, &mut server);

assert!(cc, "PSK client should connect after auto-server fallback");
assert!(sc, "Auto server should connect to DTLS 1.2 PSK client");
assert_eq!(cv, Some(ProtocolVersion::DTLS1_2));
assert_eq!(sv, Some(ProtocolVersion::DTLS1_2));
}

/// Fragmented DTLS 1.3 ClientHello → keying material matches between client and auto server.
#[test]
#[cfg(feature = "rcgen")]
Expand Down
Loading