Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions compiler/cpp/src/thrift/generate/t_lua_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,9 @@ void t_lua_generator::generate_lua_struct_reader(ostream& out, t_struct* tstruct
indent(out) << '\n' << '\n' << "function " << tstruct->get_name() << ":read(iprot)" << '\n';
indent_up();

indent(out) << "iprot:incrementRecursionDepth()" << '\n';
indent(out) << "local ok, err = pcall(function()" << '\n';
indent_up();
indent(out) << "iprot:readStructBegin()" << '\n';

// while: Read in fields
Expand Down Expand Up @@ -460,6 +463,10 @@ void t_lua_generator::generate_lua_struct_reader(ostream& out, t_struct* tstruct
indent_down();
indent(out) << "end" << '\n';
indent(out) << "iprot:readStructEnd()" << '\n';
indent_down();
indent(out) << "end)" << '\n';
indent(out) << "iprot:decrementRecursionDepth()" << '\n';
indent(out) << "if not ok then error(err, 0) end" << '\n';

// end function
indent_down();
Expand All @@ -478,6 +485,9 @@ void t_lua_generator::generate_lua_struct_writer(ostream& out, t_struct* tstruct
indent(out) << '\n' << '\n' << "function " << tstruct->get_name() << ":write(oprot)" << '\n';
indent_up();

indent(out) << "oprot:incrementRecursionDepth()" << '\n';
indent(out) << "local ok, err = pcall(function()" << '\n';
indent_up();
indent(out) << "oprot:writeStructBegin('" << tstruct->get_name() << "')" << '\n';
for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
// To check element of self whether nil or not.
Expand All @@ -497,6 +507,10 @@ void t_lua_generator::generate_lua_struct_writer(ostream& out, t_struct* tstruct
}
indent(out) << "oprot:writeFieldStop()" << '\n';
indent(out) << "oprot:writeStructEnd()" << '\n';
indent_down();
indent(out) << "end)" << '\n';
indent(out) << "oprot:decrementRecursionDepth()" << '\n';
indent(out) << "if not ok then error(err, 0) end" << '\n';

// end function
indent_down();
Expand Down
18 changes: 18 additions & 0 deletions lib/lua/TProtocol.lua
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ function TProtocolException:__errorCodeToString()
end
end

DEFAULT_RECURSION_DEPTH = 64

TProtocolBase = __TObject:new{
__type = 'TProtocolBase',
trans
Expand All @@ -63,9 +65,25 @@ function TProtocolBase:new(obj)
error('You must provide ' .. ttype(self) .. ' with a trans')
end

obj.recursionDepth = 0
return __TObject.new(self, obj)
end

function TProtocolBase:incrementRecursionDepth()
self.recursionDepth = self.recursionDepth + 1
if self.recursionDepth > DEFAULT_RECURSION_DEPTH then
self.recursionDepth = self.recursionDepth - 1
terror(TProtocolException:new{
message = 'Maximum recursion depth exceeded',
errorCode = TProtocolException.DEPTH_LIMIT
})
end
end

function TProtocolBase:decrementRecursionDepth()
self.recursionDepth = self.recursionDepth - 1
end

function TProtocolBase:writeMessageBegin(name, ttype, seqid) end
function TProtocolBase:writeMessageEnd() end
function TProtocolBase:writeStructBegin(name) end
Expand Down
91 changes: 91 additions & 0 deletions lib/lua/test/test_recursion_depth.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one
-- or more contributor license agreements. See the NOTICE file
-- distributed with this work for additional information
-- regarding copyright ownership. The ASF licenses this file
-- to you under the Apache License, Version 2.0 (the
-- "License"); you may not use this file except in compliance
-- with the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing,
-- software distributed under the License is distributed on an
-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-- KIND, either express or implied. See the License for the
-- specific language governing permissions and limitations
-- under the License.
--

-- Stub C extensions that are not needed for this test.
package.preload['libluabitwise'] = function()
return {
bor = function(a, b) return a | b end,
band = function(a, b) return a & b end,
buor = function(a, b) return a | b end,
buand = function(a, b) return a & b end,
ushiftl = function(a, n) return a << n end,
ushiftr = function(a, n) return a >> n end,
}
end
package.preload['liblualongnumber'] = function() return {} end
package.preload['libluabpack'] = function() return {} end

-- A minimal transport stub for TProtocolBase:new.
local DummyTransport = { readAll = function() return '' end }

require('Thrift')
require('TProtocol')

local passed = 0
local failed = 0

local function ok(cond, name)
if cond then
print('ok - ' .. name)
passed = passed + 1
else
print('not ok - ' .. name)
failed = failed + 1
end
end

-- Test 1: new protocol starts with recursionDepth == 0
local p = TProtocolBase:new{ trans = DummyTransport }
ok(p.recursionDepth == 0, 'new protocol starts with recursionDepth 0')

-- Test 2: incrementRecursionDepth allows up to DEFAULT_RECURSION_DEPTH
local p2 = TProtocolBase:new{ trans = DummyTransport }
local ok2 = pcall(function()
for i = 1, DEFAULT_RECURSION_DEPTH do
p2:incrementRecursionDepth()
end
end)
ok(ok2, 'allows exactly DEFAULT_RECURSION_DEPTH increments without error')
ok(p2.recursionDepth == DEFAULT_RECURSION_DEPTH,
'recursionDepth equals DEFAULT_RECURSION_DEPTH after max increments')

-- Test 3: one more increment throws (terror converts to string via __tostring)
local p3 = TProtocolBase:new{ trans = DummyTransport }
for i = 1, DEFAULT_RECURSION_DEPTH do p3:incrementRecursionDepth() end
local ok3b, err3 = xpcall(function() p3:incrementRecursionDepth() end, function(e) return e end)
ok(not ok3b, 'throws on depth limit exceeded')
ok(type(err3) == 'string' and err3:find('TProtocolException') ~= nil,
'error message contains TProtocolException')
ok(type(err3) == 'string' and
(err3:find('[Dd]epth') ~= nil or err3:find('[Ll]imit') ~= nil or err3:find('[Ee]xceeded') ~= nil),
'error message mentions depth limit')

-- Test 4: depth is not incremented on throw (stays at limit)
ok(p3.recursionDepth == DEFAULT_RECURSION_DEPTH,
'recursionDepth stays at limit after failed increment')

-- Test 5: decrementRecursionDepth restores capacity
p3:decrementRecursionDepth()
ok(p3.recursionDepth == DEFAULT_RECURSION_DEPTH - 1,
'decrementRecursionDepth reduces depth by 1')
local ok5 = pcall(function() p3:incrementRecursionDepth() end)
ok(ok5, 'increment succeeds again after decrement')

print(string.format('\n%d passed, %d failed', passed, failed))
if failed > 0 then os.exit(1) end
Loading