Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions compiler/cpp/src/thrift/generate/t_rs_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1387,6 +1387,7 @@ void t_rs_generator::render_struct_sync_write(t_struct* tstruct,
<< '\n';
indent_up();

f_gen_ << indent() << "let _depth_guard = thrift::protocol::DepthGuard::new()?;" << '\n';
// write struct header to output protocol
// note: use the *original* struct name here
f_gen_ << indent()
Expand Down Expand Up @@ -1420,6 +1421,7 @@ void t_rs_generator::render_union_sync_write(const string& union_name, t_struct*
<< '\n';
indent_up();

f_gen_ << indent() << "let _depth_guard = thrift::protocol::DepthGuard::new()?;" << '\n';
// write struct header to output protocol
// note: use the *original* struct name here
f_gen_ << indent()
Expand Down Expand Up @@ -1657,6 +1659,7 @@ void t_rs_generator::render_struct_sync_read(const string& struct_name,

indent_up();

f_gen_ << indent() << "let _depth_guard = thrift::protocol::DepthGuard::new()?;" << '\n';
f_gen_ << indent() << "i_prot.read_struct_begin()?;" << '\n';

// create temporary variables: one for each field in the struct
Expand Down Expand Up @@ -1805,6 +1808,7 @@ void t_rs_generator::render_union_sync_read(const string& union_name, t_struct*
<< union_name << "> {" << '\n';
indent_up();

f_gen_ << indent() << "let _depth_guard = thrift::protocol::DepthGuard::new()?;" << '\n';
// create temporary variables to hold the
// completed union as well as a count of fields read
f_gen_ << indent() << "let mut ret: Option<" << union_name << "> = None;" << '\n';
Expand Down
72 changes: 72 additions & 0 deletions lib/rs/src/protocol/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,50 @@ pub trait TSerializable: Sized {
// recursion.
const MAXIMUM_SKIP_DEPTH: i8 = 64;

// Per-thread recursion depth counter used by generated struct read/write code.
use std::cell::Cell;
thread_local! {
static RECURSION_DEPTH: Cell<u32> = Cell::new(0);
}
const DEFAULT_RECURSION_DEPTH: u32 = 64;

/// RAII guard that tracks struct-serialization recursion depth per thread.
///
/// Constructed by generated code at the start of every `read_from_in_protocol`
/// and `write_to_out_protocol` implementation. The guard increments a
/// thread-local counter on creation (returning an error when the limit is
/// reached) and decrements it when dropped, so the counter is always restored
/// even if the body returns an error.
#[derive(Debug)]
pub struct DepthGuard(());

impl DepthGuard {
/// Try to enter one more level of recursion.
///
/// Returns `Err(ProtocolError { kind: DepthLimit, .. })` when the limit is
/// exceeded and the counter is left unchanged.
pub fn new() -> crate::Result<DepthGuard> {
RECURSION_DEPTH.with(|d| {
let current = d.get();
if current >= DEFAULT_RECURSION_DEPTH {
Err(crate::Error::Protocol(ProtocolError::new(
ProtocolErrorKind::DepthLimit,
"Maximum recursion depth exceeded",
)))
} else {
d.set(current + 1);
Ok(DepthGuard(()))
}
})
}
}

impl Drop for DepthGuard {
fn drop(&mut self) {
RECURSION_DEPTH.with(|d| d.set(d.get() - 1));
}
}

/// Converts a stream of bytes into Thrift identifiers, primitives,
/// containers, or structs.
///
Expand Down Expand Up @@ -1133,4 +1177,32 @@ mod tests {
let data = build_struct_with_unknown_binary_field(&[]);
assert_eq!(read_struct_skipping_unknown(&data).unwrap(), 42);
}

// Test recursion depth limiting (THRIFT-6057)

#[test]
fn depth_guard_allows_up_to_limit() {
let _guards: Vec<_> = (0..64).map(|_| DepthGuard::new().unwrap()).collect();
}

#[test]
fn depth_guard_rejects_at_limit() {
let _guards: Vec<_> = (0..64).map(|_| DepthGuard::new().unwrap()).collect();
match DepthGuard::new() {
Err(crate::Error::Protocol(pe)) => {
assert_eq!(pe.kind, crate::ProtocolErrorKind::DepthLimit);
assert!(!pe.message.is_empty());
}
other => panic!("expected DepthLimit error, got {:?}", other),
}
}

#[test]
fn depth_guard_restores_depth_on_drop() {
{
let _guards: Vec<_> = (0..64).map(|_| DepthGuard::new().unwrap()).collect();
assert!(DepthGuard::new().is_err());
}
assert!(DepthGuard::new().is_ok());
}
}
Loading