diff --git a/compiler/cpp/src/thrift/generate/t_rs_generator.cc b/compiler/cpp/src/thrift/generate/t_rs_generator.cc index f97a808f845..cdd6e7494d4 100644 --- a/compiler/cpp/src/thrift/generate/t_rs_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_rs_generator.cc @@ -1387,6 +1387,7 @@ void t_rs_generator::render_struct_sync_write(t_struct* tstruct, << '\n'; indent_up(); + f_gen_ << indent() << "let _depth_guard = thrift::protocol::DepthGuard::new()?;" << '\n'; // write struct header to output protocol // note: use the *original* struct name here f_gen_ << indent() @@ -1420,6 +1421,7 @@ void t_rs_generator::render_union_sync_write(const string& union_name, t_struct* << '\n'; indent_up(); + f_gen_ << indent() << "let _depth_guard = thrift::protocol::DepthGuard::new()?;" << '\n'; // write struct header to output protocol // note: use the *original* struct name here f_gen_ << indent() @@ -1657,6 +1659,7 @@ void t_rs_generator::render_struct_sync_read(const string& struct_name, indent_up(); + f_gen_ << indent() << "let _depth_guard = thrift::protocol::DepthGuard::new()?;" << '\n'; f_gen_ << indent() << "i_prot.read_struct_begin()?;" << '\n'; // create temporary variables: one for each field in the struct @@ -1805,6 +1808,7 @@ void t_rs_generator::render_union_sync_read(const string& union_name, t_struct* << union_name << "> {" << '\n'; indent_up(); + f_gen_ << indent() << "let _depth_guard = thrift::protocol::DepthGuard::new()?;" << '\n'; // create temporary variables to hold the // completed union as well as a count of fields read f_gen_ << indent() << "let mut ret: Option<" << union_name << "> = None;" << '\n'; diff --git a/lib/rs/src/protocol/mod.rs b/lib/rs/src/protocol/mod.rs index 0d952af34a0..ee7ba773a01 100644 --- a/lib/rs/src/protocol/mod.rs +++ b/lib/rs/src/protocol/mod.rs @@ -116,6 +116,50 @@ pub trait TSerializable: Sized { // recursion. const MAXIMUM_SKIP_DEPTH: i8 = 64; +// Per-thread recursion depth counter used by generated struct read/write code. +use std::cell::Cell; +thread_local! { + static RECURSION_DEPTH: Cell = Cell::new(0); +} +const DEFAULT_RECURSION_DEPTH: u32 = 64; + +/// RAII guard that tracks struct-serialization recursion depth per thread. +/// +/// Constructed by generated code at the start of every `read_from_in_protocol` +/// and `write_to_out_protocol` implementation. The guard increments a +/// thread-local counter on creation (returning an error when the limit is +/// reached) and decrements it when dropped, so the counter is always restored +/// even if the body returns an error. +#[derive(Debug)] +pub struct DepthGuard(()); + +impl DepthGuard { + /// Try to enter one more level of recursion. + /// + /// Returns `Err(ProtocolError { kind: DepthLimit, .. })` when the limit is + /// exceeded and the counter is left unchanged. + pub fn new() -> crate::Result { + RECURSION_DEPTH.with(|d| { + let current = d.get(); + if current >= DEFAULT_RECURSION_DEPTH { + Err(crate::Error::Protocol(ProtocolError::new( + ProtocolErrorKind::DepthLimit, + "Maximum recursion depth exceeded", + ))) + } else { + d.set(current + 1); + Ok(DepthGuard(())) + } + }) + } +} + +impl Drop for DepthGuard { + fn drop(&mut self) { + RECURSION_DEPTH.with(|d| d.set(d.get() - 1)); + } +} + /// Converts a stream of bytes into Thrift identifiers, primitives, /// containers, or structs. /// @@ -1133,4 +1177,32 @@ mod tests { let data = build_struct_with_unknown_binary_field(&[]); assert_eq!(read_struct_skipping_unknown(&data).unwrap(), 42); } + + // Test recursion depth limiting (THRIFT-6057) + + #[test] + fn depth_guard_allows_up_to_limit() { + let _guards: Vec<_> = (0..64).map(|_| DepthGuard::new().unwrap()).collect(); + } + + #[test] + fn depth_guard_rejects_at_limit() { + let _guards: Vec<_> = (0..64).map(|_| DepthGuard::new().unwrap()).collect(); + match DepthGuard::new() { + Err(crate::Error::Protocol(pe)) => { + assert_eq!(pe.kind, crate::ProtocolErrorKind::DepthLimit); + assert!(!pe.message.is_empty()); + } + other => panic!("expected DepthLimit error, got {:?}", other), + } + } + + #[test] + fn depth_guard_restores_depth_on_drop() { + { + let _guards: Vec<_> = (0..64).map(|_| DepthGuard::new().unwrap()).collect(); + assert!(DepthGuard::new().is_err()); + } + assert!(DepthGuard::new().is_ok()); + } }