diff --git a/lib/d/src/thrift/codegen/base.d b/lib/d/src/thrift/codegen/base.d index db549928c25..4973803214c 100644 --- a/lib/d/src/thrift/codegen/base.d +++ b/lib/d/src/thrift/codegen/base.d @@ -50,6 +50,10 @@ import thrift.internal.codegen; import thrift.protocol.base; import thrift.util.hashset; +// Thread-local recursion depth counter used by readStruct/writeStruct. +private uint currentRecursionDepth_; +private enum uint DEFAULT_MAX_RECURSION_DEPTH = 64; + /* * Thrift struct/service meta data, which is used to store information from * the interface definition files not representable in plain D, i.e. field @@ -594,6 +598,12 @@ template TIsSetFlags(T, alias fieldMetaData) { void readStruct(T, Protocol, alias fieldMetaData = cast(TFieldMeta[])null, bool pointerStruct = false)(auto ref T s, Protocol p) if (isTProtocol!Protocol) { + if (++currentRecursionDepth_ > DEFAULT_MAX_RECURSION_DEPTH) { + --currentRecursionDepth_; + throw new TProtocolException("Maximum recursion depth exceeded", + TProtocolException.Type.DEPTH_LIMIT); + } + scope(exit) --currentRecursionDepth_; mixin({ string code; @@ -813,6 +823,13 @@ void writeStruct(T, Protocol, alias fieldMetaData = cast(TFieldMeta[])null, return code; }()); + if (++currentRecursionDepth_ > DEFAULT_MAX_RECURSION_DEPTH) { + --currentRecursionDepth_; + throw new TProtocolException("Maximum recursion depth exceeded", + TProtocolException.Type.DEPTH_LIMIT); + } + scope(exit) --currentRecursionDepth_; + p.writeStructBegin(TStruct(T.stringof)); mixin({ string writeValueCode(ValueType)(string v, size_t level = 0) { @@ -982,6 +999,39 @@ unittest { assert(a.toHash() == b.toHash()); } +// Recursion depth limiting in readStruct and writeStruct. +unittest { + import std.exception : collectException; + + static struct Flat { + int x; + mixin TStructHelpers!([TFieldMeta("x", 1, TReq.OPTIONAL)]); + } + + // Save and restore the module-level depth counter so this test is isolated. + uint savedDepth = currentRecursionDepth_; + scope(exit) currentRecursionDepth_ = savedDepth; + + // At depth 64 the next call must raise DEPTH_LIMIT before touching the + // protocol, so passing null is safe here. + currentRecursionDepth_ = DEFAULT_MAX_RECURSION_DEPTH; + + auto readEx = collectException!TProtocolException( + readStruct!(Flat, TProtocol)(*(cast(Flat*)null), cast(TProtocol)null)); + assert(readEx !is null, "readStruct at depth limit must throw"); + assert(readEx.type == TProtocolException.Type.DEPTH_LIMIT); + // Increment was rolled back, so depth stays at the limit. + assert(currentRecursionDepth_ == DEFAULT_MAX_RECURSION_DEPTH); + + currentRecursionDepth_ = DEFAULT_MAX_RECURSION_DEPTH; + Flat f; + auto writeEx = collectException!TProtocolException( + writeStruct!(Flat, TProtocol)(f, cast(TProtocol)null)); + assert(writeEx !is null, "writeStruct at depth limit must throw"); + assert(writeEx.type == TProtocolException.Type.DEPTH_LIMIT); + assert(currentRecursionDepth_ == DEFAULT_MAX_RECURSION_DEPTH); +} + private { /* * Returns a D code string containing the matching TType value for a passed