From bca5d0de80d19186ca635ce561e0ff4acf88afa4 Mon Sep 17 00:00:00 2001 From: Jens Geyer Date: Thu, 28 May 2026 01:10:31 +0200 Subject: [PATCH] THRIFT-6049: Harden Lua protocol recursion depth Client: lua Co-Authored-By: Claude Sonnet 4.6 --- .../src/thrift/generate/t_lua_generator.cc | 14 +++ lib/lua/TProtocol.lua | 18 ++++ lib/lua/test/test_recursion_depth.lua | 91 +++++++++++++++++++ 3 files changed, 123 insertions(+) create mode 100644 lib/lua/test/test_recursion_depth.lua diff --git a/compiler/cpp/src/thrift/generate/t_lua_generator.cc b/compiler/cpp/src/thrift/generate/t_lua_generator.cc index f7f8f054e08..54c1fd7097f 100644 --- a/compiler/cpp/src/thrift/generate/t_lua_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_lua_generator.cc @@ -421,6 +421,9 @@ void t_lua_generator::generate_lua_struct_reader(ostream& out, t_struct* tstruct indent(out) << '\n' << '\n' << "function " << tstruct->get_name() << ":read(iprot)" << '\n'; indent_up(); + indent(out) << "iprot:incrementRecursionDepth()" << '\n'; + indent(out) << "local ok, err = pcall(function()" << '\n'; + indent_up(); indent(out) << "iprot:readStructBegin()" << '\n'; // while: Read in fields @@ -460,6 +463,10 @@ void t_lua_generator::generate_lua_struct_reader(ostream& out, t_struct* tstruct indent_down(); indent(out) << "end" << '\n'; indent(out) << "iprot:readStructEnd()" << '\n'; + indent_down(); + indent(out) << "end)" << '\n'; + indent(out) << "iprot:decrementRecursionDepth()" << '\n'; + indent(out) << "if not ok then error(err, 0) end" << '\n'; // end function indent_down(); @@ -478,6 +485,9 @@ void t_lua_generator::generate_lua_struct_writer(ostream& out, t_struct* tstruct indent(out) << '\n' << '\n' << "function " << tstruct->get_name() << ":write(oprot)" << '\n'; indent_up(); + indent(out) << "oprot:incrementRecursionDepth()" << '\n'; + indent(out) << "local ok, err = pcall(function()" << '\n'; + indent_up(); indent(out) << "oprot:writeStructBegin('" << tstruct->get_name() << "')" << '\n'; for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { // To check element of self whether nil or not. @@ -497,6 +507,10 @@ void t_lua_generator::generate_lua_struct_writer(ostream& out, t_struct* tstruct } indent(out) << "oprot:writeFieldStop()" << '\n'; indent(out) << "oprot:writeStructEnd()" << '\n'; + indent_down(); + indent(out) << "end)" << '\n'; + indent(out) << "oprot:decrementRecursionDepth()" << '\n'; + indent(out) << "if not ok then error(err, 0) end" << '\n'; // end function indent_down(); diff --git a/lib/lua/TProtocol.lua b/lib/lua/TProtocol.lua index f7a993f0b50..8a86e756970 100644 --- a/lib/lua/TProtocol.lua +++ b/lib/lua/TProtocol.lua @@ -48,6 +48,8 @@ function TProtocolException:__errorCodeToString() end end +DEFAULT_RECURSION_DEPTH = 64 + TProtocolBase = __TObject:new{ __type = 'TProtocolBase', trans @@ -63,9 +65,25 @@ function TProtocolBase:new(obj) error('You must provide ' .. ttype(self) .. ' with a trans') end + obj.recursionDepth = 0 return __TObject.new(self, obj) end +function TProtocolBase:incrementRecursionDepth() + self.recursionDepth = self.recursionDepth + 1 + if self.recursionDepth > DEFAULT_RECURSION_DEPTH then + self.recursionDepth = self.recursionDepth - 1 + terror(TProtocolException:new{ + message = 'Maximum recursion depth exceeded', + errorCode = TProtocolException.DEPTH_LIMIT + }) + end +end + +function TProtocolBase:decrementRecursionDepth() + self.recursionDepth = self.recursionDepth - 1 +end + function TProtocolBase:writeMessageBegin(name, ttype, seqid) end function TProtocolBase:writeMessageEnd() end function TProtocolBase:writeStructBegin(name) end diff --git a/lib/lua/test/test_recursion_depth.lua b/lib/lua/test/test_recursion_depth.lua new file mode 100644 index 00000000000..ddd6be8bf3f --- /dev/null +++ b/lib/lua/test/test_recursion_depth.lua @@ -0,0 +1,91 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +-- Stub C extensions that are not needed for this test. +package.preload['libluabitwise'] = function() + return { + bor = function(a, b) return a | b end, + band = function(a, b) return a & b end, + buor = function(a, b) return a | b end, + buand = function(a, b) return a & b end, + ushiftl = function(a, n) return a << n end, + ushiftr = function(a, n) return a >> n end, + } +end +package.preload['liblualongnumber'] = function() return {} end +package.preload['libluabpack'] = function() return {} end + +-- A minimal transport stub for TProtocolBase:new. +local DummyTransport = { readAll = function() return '' end } + +require('Thrift') +require('TProtocol') + +local passed = 0 +local failed = 0 + +local function ok(cond, name) + if cond then + print('ok - ' .. name) + passed = passed + 1 + else + print('not ok - ' .. name) + failed = failed + 1 + end +end + +-- Test 1: new protocol starts with recursionDepth == 0 +local p = TProtocolBase:new{ trans = DummyTransport } +ok(p.recursionDepth == 0, 'new protocol starts with recursionDepth 0') + +-- Test 2: incrementRecursionDepth allows up to DEFAULT_RECURSION_DEPTH +local p2 = TProtocolBase:new{ trans = DummyTransport } +local ok2 = pcall(function() + for i = 1, DEFAULT_RECURSION_DEPTH do + p2:incrementRecursionDepth() + end +end) +ok(ok2, 'allows exactly DEFAULT_RECURSION_DEPTH increments without error') +ok(p2.recursionDepth == DEFAULT_RECURSION_DEPTH, + 'recursionDepth equals DEFAULT_RECURSION_DEPTH after max increments') + +-- Test 3: one more increment throws (terror converts to string via __tostring) +local p3 = TProtocolBase:new{ trans = DummyTransport } +for i = 1, DEFAULT_RECURSION_DEPTH do p3:incrementRecursionDepth() end +local ok3b, err3 = xpcall(function() p3:incrementRecursionDepth() end, function(e) return e end) +ok(not ok3b, 'throws on depth limit exceeded') +ok(type(err3) == 'string' and err3:find('TProtocolException') ~= nil, + 'error message contains TProtocolException') +ok(type(err3) == 'string' and + (err3:find('[Dd]epth') ~= nil or err3:find('[Ll]imit') ~= nil or err3:find('[Ee]xceeded') ~= nil), + 'error message mentions depth limit') + +-- Test 4: depth is not incremented on throw (stays at limit) +ok(p3.recursionDepth == DEFAULT_RECURSION_DEPTH, + 'recursionDepth stays at limit after failed increment') + +-- Test 5: decrementRecursionDepth restores capacity +p3:decrementRecursionDepth() +ok(p3.recursionDepth == DEFAULT_RECURSION_DEPTH - 1, + 'decrementRecursionDepth reduces depth by 1') +local ok5 = pcall(function() p3:incrementRecursionDepth() end) +ok(ok5, 'increment succeeds again after decrement') + +print(string.format('\n%d passed, %d failed', passed, failed)) +if failed > 0 then os.exit(1) end