From 2db5638a2d1b9f05c59b835eb5ceeb4e7cedd9be Mon Sep 17 00:00:00 2001 From: Jens Geyer Date: Thu, 28 May 2026 01:45:58 +0200 Subject: [PATCH] THRIFT-6052: Harden Smalltalk protocol against deeply nested messages Client: st - Add recursionDepth instance variable to TProtocol - Add incrementRecursionDepth (limit 64) and decrementRecursionDepth methods; raises TProtocolError on excess - Add depthLimit class method to TProtocolError (returns 6) - Generator wraps struct read/write blocks with ensure: to guarantee decrementRecursionDepth always runs even when an exception is raised - Also fixes a pre-existing bug: struct_reader used oprot instead of iprot for readStructEnd Co-Authored-By: Claude Sonnet 4.6 --- .../cpp/src/thrift/generate/t_st_generator.cc | 12 +-- lib/st/test/TProtocolRecursionDepthTest.st | 73 +++++++++++++++++++ lib/st/thrift.st | 17 ++++- 3 files changed, 96 insertions(+), 6 deletions(-) create mode 100644 lib/st/test/TProtocolRecursionDepthTest.st diff --git a/compiler/cpp/src/thrift/generate/t_st_generator.cc b/compiler/cpp/src/thrift/generate/t_st_generator.cc index c1ad35577e8..6b4ff6f34ef 100644 --- a/compiler/cpp/src/thrift/generate/t_st_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_st_generator.cc @@ -706,9 +706,10 @@ string t_st_generator::struct_writer(t_struct* tstruct, string sname) { const vector& fields = tstruct->get_sorted_members(); vector::const_iterator fld_iter; - out << "[oprot writeStructBegin: " - << "(TStruct new name: '" + tstruct->get_name() + "')." << '\n'; + out << "[oprot incrementRecursionDepth." << '\n'; indent_up(); + out << indent() << "[oprot writeStructBegin: " + << "(TStruct new name: '" + tstruct->get_name() + "')." << '\n'; for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { bool optional = (*fld_iter)->get_req() == t_field::T_OPTIONAL; @@ -735,7 +736,7 @@ string t_st_generator::struct_writer(t_struct* tstruct, string sname) { out << "." << '\n'; } - out << indent() << "oprot writeFieldStop; writeStructEnd] value"; + out << indent() << "oprot writeFieldStop; writeStructEnd] ensure: [oprot decrementRecursionDepth]] value"; indent_down(); return out.str(); @@ -759,8 +760,9 @@ string t_st_generator::struct_reader(t_struct* tstruct, string clsName = "") { // This is nasty, but without it we'll break things by prefixing TResult. string name = ((capitalize(clsName) == "TResult") ? capitalize(clsName) : prefix(clsName)); out << indent() << val << " := " << name << " new." << '\n'; + out << indent() << "iprot incrementRecursionDepth." << '\n'; - out << indent() << "iprot readStructBegin." << '\n' << indent() << "[" << desc + out << indent() << "[iprot readStructBegin." << '\n' << indent() << "[" << desc << " := iprot readFieldBegin." << '\n' << indent() << desc << " type = TType stop] whileFalse: [|" << found << "|" << '\n'; indent_up(); @@ -779,7 +781,7 @@ string t_st_generator::struct_reader(t_struct* tstruct, string clsName = "") { out << indent() << found << " ifNil: [iprot skip: " << desc << " type]]." << '\n'; indent_down(); - out << indent() << "oprot readStructEnd." << '\n' << indent() << val << "] value"; + out << indent() << "iprot readStructEnd] ensure: [iprot decrementRecursionDepth]." << '\n' << indent() << val << "] value"; indent_down(); return out.str(); diff --git a/lib/st/test/TProtocolRecursionDepthTest.st b/lib/st/test/TProtocolRecursionDepthTest.st new file mode 100644 index 00000000000..fb76aae5bdf --- /dev/null +++ b/lib/st/test/TProtocolRecursionDepthTest.st @@ -0,0 +1,73 @@ +'Test suite for TProtocol recursion depth limiting. + + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +'! + +TestCase subclass: #TProtocolRecursionDepthTest + instanceVariableNames: 'proto' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Tests'! + +!TProtocolRecursionDepthTest methodsFor: 'setup'! +setUp + "TProtocol is abstract; use TBinaryProtocol with nil transport for depth-only tests" + proto := TBinaryProtocol new! ! + +!TProtocolRecursionDepthTest methodsFor: 'tests'! +test01InitialDepthAllows64Increments + "64 increments must all succeed; depth starts at nil/0" + 1 to: 64 do: [:i | proto incrementRecursionDepth]. + self assert: true! ! + +!TProtocolRecursionDepthTest methodsFor: 'tests'! +test02The65thIncrementRaisesError + "The 65th increment must raise TProtocolError" + 1 to: 64 do: [:i | proto incrementRecursionDepth]. + self + should: [proto incrementRecursionDepth] + raise: TProtocolError! ! + +!TProtocolRecursionDepthTest methodsFor: 'tests'! +test03DepthRestoredAfterRejectedIncrement + "After a rejected increment, decrement 64 times then increment must succeed" + 1 to: 64 do: [:i | proto incrementRecursionDepth]. + [proto incrementRecursionDepth] on: TProtocolError do: [:e | "ignored"]. + 1 to: 64 do: [:i | proto decrementRecursionDepth]. + self shouldnt: [proto incrementRecursionDepth] raise: TProtocolError! ! + +!TProtocolRecursionDepthTest methodsFor: 'tests'! +test04IncrementDecrementBalance + "Paired increment/decrement must leave depth unchanged" + 1 to: 10 do: [:i | proto incrementRecursionDepth]. + 1 to: 10 do: [:i | proto decrementRecursionDepth]. + "Should still be able to do 64 more increments" + self shouldnt: [1 to: 64 do: [:i | proto incrementRecursionDepth]] raise: TProtocolError! ! + +!TProtocolRecursionDepthTest methodsFor: 'tests'! +test05EnsureDecrementsOnException + "ensure: cleanup must fire even when body raises an exception" + | raised | + raised := false. + proto incrementRecursionDepth. + [ + [Error signal: 'simulated error'] ensure: [proto decrementRecursionDepth] + ] on: Error do: [:e | raised := true]. + self assert: raised. + "Depth is back to 0; 64 increments must succeed" + self shouldnt: [1 to: 64 do: [:i | proto incrementRecursionDepth]] raise: TProtocolError! ! diff --git a/lib/st/thrift.st b/lib/st/thrift.st index b2dbc9768e4..451d56dfa00 100644 --- a/lib/st/thrift.st +++ b/lib/st/thrift.st @@ -69,6 +69,10 @@ sizeLimit unknown ^ 0! ! +!TProtocolError class methodsFor: 'as yet unclassified'! +depthLimit + ^ 6! ! + TError subclass: #TTransportError instanceVariableNames: '' classVariableNames: '' @@ -191,7 +195,7 @@ type: anInteger type := anInteger! ! Object subclass: #TProtocol - instanceVariableNames: 'transport' + instanceVariableNames: 'transport recursionDepth' classVariableNames: '' poolDictionaries: '' category: 'Thrift-Protocol'! @@ -507,6 +511,17 @@ transport transport: aTransport transport := aTransport! ! +!TProtocol methodsFor: 'as yet unclassified'! +incrementRecursionDepth + recursionDepth := (recursionDepth ifNil: [0]) + 1. + recursionDepth > 64 ifTrue: [ + recursionDepth := recursionDepth - 1. + TProtocolError signal: 'Maximum recursion depth exceeded']! ! + +!TProtocol methodsFor: 'as yet unclassified'! +decrementRecursionDepth + recursionDepth := (recursionDepth ifNil: [0]) - 1! ! + !TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:37'! writeBool: aBool! !