Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ references into your provided buffer:
- `PeerCert(&[u8])`: peer leaf certificate (DER) — validate in your app
- `KeyingMaterial(KeyingMaterial, SrtpProfile)`: DTLS‑SRTP export
- `ApplicationData(&[u8])`: plaintext received from peer
- `CloseNotify`: peer sent a `close_notify` alert (graceful shutdown)

## Example (Sans‑IO loop)

Expand Down Expand Up @@ -108,6 +109,10 @@ fn example_event_loop(mut dtls: Dtls) -> Result<(), dimpl::Error> {
Output::ApplicationData(_data) => {
// Deliver plaintext to application
}
Output::CloseNotify => {
// Peer initiated graceful shutdown
break;
}
_ => {}
}
}
Expand Down
61 changes: 59 additions & 2 deletions src/auto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,17 @@ impl HybridClientHello {
// legacy_cookie: empty (DTLS 1.3 requires zero-length)
ch_body.push(0);

// cipher_suites: 1.3 suites first, then 1.2 suites (filtered by config)
// cipher_suites: 1.3 suites first, then non-PSK DTLS 1.2 suites.
// The DTLS 1.2 fallback (`Client12::new_from_hybrid`) is
// certificate-auth only and cannot complete a PSK suite.
let mut suites: ArrayVec<u16, 16> = ArrayVec::new();
for cs in config.dtls13_cipher_suites() {
suites.push(cs.suite().as_u16());
}
for cs in config.dtls12_cipher_suites() {
for cs in config
.dtls12_cipher_suites()
.filter(|cs| !cs.suite().is_psk())
{
suites.push(cs.suite().as_u16());
}
ch_body.extend_from_slice(&((suites.len() * 2) as u16).to_be_bytes());
Expand Down Expand Up @@ -453,6 +458,35 @@ fn server_hello_version_inner(packet: &[u8]) -> Option<DetectedVersion> {
#[cfg(test)]
mod tests {
use super::*;
use crate::PskResolver;
use crate::dtls12::message::Dtls12CipherSuite;

fn offered_cipher_suites(hybrid: &HybridClientHello) -> Vec<u16> {
let body = &hybrid.handshake_fragment[12..];
let mut offset = 2 + 32; // legacy_version + random

let session_id_len = body[offset] as usize;
offset += 1 + session_id_len;

let cookie_len = body[offset] as usize;
offset += 1 + cookie_len;

let suites_len = u16::from_be_bytes([body[offset], body[offset + 1]]) as usize;
offset += 2;

body[offset..offset + suites_len]
.chunks_exact(2)
.map(|chunk| u16::from_be_bytes([chunk[0], chunk[1]]))
.collect()
}

struct DummyResolver;

impl PskResolver for DummyResolver {
fn resolve(&self, _identity: &[u8]) -> Option<Vec<u8>> {
Some(b"0123456789abcdef".to_vec())
}
}

#[test]
fn hello_verify_request_is_dtls12() {
Expand Down Expand Up @@ -609,4 +643,27 @@ mod tests {
];
assert_eq!(server_hello_version(&pkt), DetectedVersion::Unknown);
}

#[test]
fn hybrid_client_hello_excludes_psk_dtls12_suites() {
let config = Arc::new(
Config::builder()
.with_psk_client(b"identity".to_vec(), Arc::new(DummyResolver))
.build()
.expect("config with PSK should build"),
);

assert!(
config.dtls12_cipher_suites().any(|cs| cs.suite().is_psk()),
"precondition: PSK-enabled config should expose a PSK DTLS 1.2 suite"
);

let hybrid = HybridClientHello::new(&config).expect("hybrid ClientHello should build");
let offered = offered_cipher_suites(&hybrid);

assert!(
!offered.contains(&Dtls12CipherSuite::PSK_AES128_CCM_8.as_u16()),
"auto client must not advertise PSK DTLS 1.2 suites it cannot use after fallback"
);
}
}
2 changes: 1 addition & 1 deletion src/crypto/rust_crypto/kx_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl EcdhKeyExchange {
match group {
NamedGroup::X25519 => {
use rand_core::OsRng;
let secret = x25519_dalek::EphemeralSecret::random_from_rng(&mut OsRng);
let secret = x25519_dalek::EphemeralSecret::random_from_rng(OsRng);
let public_key_obj = x25519_dalek::PublicKey::from(&secret);
buf.clear();
buf.extend_from_slice(public_key_obj.as_bytes());
Expand Down
44 changes: 40 additions & 4 deletions src/dtls12/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,10 @@ impl Client {
}

pub fn poll_output<'a>(&mut self, buf: &'a mut [u8]) -> Output<'a> {
let last_now = self.last_now;

if let Some(event) = self.local_events.pop_front() {
return event.into_output(buf, &self.server_certificates);
}

self.engine.poll_output(buf, last_now)
self.engine.poll_output(buf, self.last_now)
}

/// Explicitly start the handshake process by sending a ClientHello
Expand All @@ -214,6 +211,10 @@ impl Client {
/// This should only be called when the client is in the Running state,
/// after the handshake is complete.
pub fn send_application_data(&mut self, data: &[u8]) -> Result<(), Error> {
if self.state == State::Closed {
return Err(Error::ConnectionClosed);
}

if self.state != State::AwaitApplicationData {
self.queued_data.push(data.to_buf());
return Ok(());
Expand All @@ -229,6 +230,25 @@ impl Client {
Ok(())
}

/// Initiate graceful shutdown by sending a `close_notify` alert.
pub fn close(&mut self) -> Result<(), Error> {
if self.state == State::Closed {
return Ok(());
}
if self.state != State::AwaitApplicationData {
self.engine.abort();
self.state = State::Closed;
return Ok(());
}
self.engine
.create_record(ContentType::Alert, 1, false, |body| {
body.push(1); // level: warning
body.push(0); // description: close_notify
})?;
self.state = State::Closed;
Ok(())
}

fn make_progress(&mut self) -> Result<(), Error> {
loop {
let prev_state = self.state;
Expand Down Expand Up @@ -263,6 +283,7 @@ enum State {
AwaitNewSessionTicket,
AwaitFinished,
AwaitApplicationData,
Closed,
}

impl State {
Expand All @@ -284,6 +305,7 @@ impl State {
State::AwaitNewSessionTicket => "AwaitNewSessionTicket",
State::AwaitFinished => "AwaitFinished",
State::AwaitApplicationData => "AwaitApplicationData",
State::Closed => "Closed",
}
}

Expand All @@ -305,6 +327,7 @@ impl State {
State::AwaitNewSessionTicket => self.await_new_session_ticket(client),
State::AwaitFinished => self.await_finished(client),
State::AwaitApplicationData => self.await_application_data(client),
State::Closed => Ok(self),
}
}

Expand Down Expand Up @@ -1148,6 +1171,19 @@ impl State {
}

fn await_application_data(self, client: &mut Client) -> Result<Self, Error> {
if client.engine.close_notify_received() {
// RFC 5246 §7.2.1: respond with a reciprocal close_notify and
// close down immediately, discarding any pending writes.
client.engine.discard_pending_writes();
client
.engine
.create_record(ContentType::Alert, 1, false, |body| {
body.push(1); // level: warning
body.push(0); // description: close_notify
})?;
return Ok(State::Closed);
}

if !client.queued_data.is_empty() {
debug!(
"Sending queued application data: {}",
Expand Down
111 changes: 103 additions & 8 deletions src/dtls12/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use super::queue::{QueueRx, QueueTx};
use crate::buffer::{Buf, BufferPool, TmpBuf};
use crate::crypto::{Aad, Iv, Nonce};
use crate::dtls12::context::{AuthMode, CryptoContext};
use crate::dtls12::incoming::{Incoming, Record, RecordDecrypt};
use crate::dtls12::incoming::{Incoming, Record, RecordHandler};
use crate::dtls12::message::{Body, HashAlgorithm, Header, MessageType, ProtocolVersion, Sequence};
use crate::dtls12::message::{ContentType, DTLSRecord, Dtls12CipherSuite, Handshake};
use crate::timer::ExponentialBackoff;
Expand Down Expand Up @@ -88,6 +88,12 @@ pub struct Engine {

/// Whether we are ready to release application data from poll_output.
release_app_data: bool,

/// Whether a close_notify alert has been received from the peer.
close_notify_received: bool,

/// Whether [`Output::CloseNotify`] has already been emitted.
close_notify_reported: bool,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
Expand Down Expand Up @@ -136,6 +142,8 @@ impl Engine {
flight_timeout: Timeout::Unarmed,
connect_timeout: Timeout::Unarmed,
release_app_data: false,
close_notify_received: false,
close_notify_reported: false,
}
}

Expand Down Expand Up @@ -200,13 +208,9 @@ impl Engine {
Ok(())
}

/// Insert the Incoming using the logic:
///
/// 1. If it is a handshake, sort by the message_seq
/// 2. If it is not a handshake, sort by sequence_number
///
/// Insert a parsed datagram into the receive queue.
fn insert_incoming(&mut self, incoming: Incoming) -> Result<(), Error> {
// Capacity guard
// Capacity guard before iterating records.
if self.queue_rx.len() >= self.config.max_queue_rx() {
warn!(
"Receive queue full (max {}): {:?}",
Expand Down Expand Up @@ -366,6 +370,11 @@ impl Engine {
return Output::Packet(p);
}

if self.close_notify_received && !self.close_notify_reported {
self.close_notify_reported = true;
return Output::CloseNotify;
}

let next_timeout = self.poll_timeout(now);

Output::Timeout(next_timeout)
Expand Down Expand Up @@ -895,6 +904,27 @@ impl Engine {
self.release_app_data = true;
}

/// Whether a close_notify alert has been received from the peer.
pub fn close_notify_received(&self) -> bool {
self.close_notify_received
}

/// Discard all pending outgoing data.
///
/// RFC 5246 §7.2.1: on receiving close_notify, discard any pending writes.
pub fn discard_pending_writes(&mut self) {
self.queue_tx.clear();
}

/// Abort the connection: flush all queued output, retransmission state, and
/// disable timers so that no further packets are emitted.
pub fn abort(&mut self) {
self.queue_tx.clear();
self.flight_saved_records.clear();
self.flight_timeout = Timeout::Disabled;
self.connect_timeout = Timeout::Disabled;
}

/// Pop a buffer from the buffer pool for temporary use
pub(crate) fn pop_buffer(&mut self) -> Buf {
self.buffers_free.pop()
Expand Down Expand Up @@ -1094,7 +1124,72 @@ impl Engine {
}
}

impl RecordDecrypt for Engine {
impl RecordHandler for Engine {
fn classify_record(&mut self, record: Record) -> Result<Option<Record>, Error> {
if record.record().content_type == ContentType::Alert {
let epoch = record.record().sequence.epoch;
if epoch == 0 {
if self.peer_encryption_enabled {
// Post-handshake: epoch 0 alerts are unauthenticated, discard.
self.push_buffer(record.into_buffer());
return Ok(None);
}

let fatal_description = {
let fragment = record.record().fragment(record.buffer());
(fragment.len() >= 2 && fragment[0] == 2).then(|| fragment[1])
};
self.push_buffer(record.into_buffer());

if let Some(description) = fatal_description {
return Err(Error::SecurityError(format!(
"Received fatal alert: level=2, description={}",
description
)));
}

return Ok(None);
}

if !self.peer_encryption_enabled {
// Epoch >= 1 before peer encryption is enabled must stay queued
// for re-parsing after enable_peer_encryption().
return Ok(Some(record));
}

let alert = {
let fragment = record.record().fragment(record.buffer());
(fragment.len() >= 2).then(|| (fragment[0], fragment[1]))
};
self.push_buffer(record.into_buffer());

if let Some((level, description)) = alert {
if description == 0 {
self.close_notify_received = true;
return Ok(None);
}

if level == 2 {
return Err(Error::SecurityError(format!(
"Received fatal alert: level={}, description={}",
level, description
)));
}
}

return Ok(None);
}

if self.close_notify_received
&& record.record().content_type == ContentType::ApplicationData
{
self.push_buffer(record.into_buffer());
return Ok(None);
}

Ok(Some(record))
}

fn is_peer_encryption_enabled(&self) -> bool {
self.peer_encryption_enabled
}
Expand Down
Loading
Loading