diff --git a/lib/rb/lib/thrift/protocol/base_protocol.rb b/lib/rb/lib/thrift/protocol/base_protocol.rb index 2a13b692ec..f4e88ffa61 100644 --- a/lib/rb/lib/thrift/protocol/base_protocol.rb +++ b/lib/rb/lib/thrift/protocol/base_protocol.rb @@ -42,10 +42,25 @@ def initialize(type = UNKNOWN, message = nil) class BaseProtocol + DEFAULT_RECURSION_DEPTH = 64 + attr_reader :trans def initialize(trans) @trans = trans + @recursion_depth = 0 + end + + def increment_recursion_depth + @recursion_depth += 1 + if @recursion_depth > DEFAULT_RECURSION_DEPTH + @recursion_depth -= 1 + raise ProtocolException.new(ProtocolException::DEPTH_LIMIT, 'Maximum recursion depth exceeded') + end + end + + def decrement_recursion_depth + @recursion_depth -= 1 end def native? diff --git a/lib/rb/lib/thrift/struct.rb b/lib/rb/lib/thrift/struct.rb index 691cf125f2..efee3f8df6 100644 --- a/lib/rb/lib/thrift/struct.rb +++ b/lib/rb/lib/thrift/struct.rb @@ -82,36 +82,46 @@ def inspect(skip_optional_nulls = true) end def read(iprot) - iprot.read_struct_begin - loop do - fname, ftype, fid = iprot.read_field_begin - break if (ftype == Types::STOP) - handle_message(iprot, fid, ftype) - iprot.read_field_end - end - iprot.read_struct_end - validate + iprot.increment_recursion_depth + begin + iprot.read_struct_begin + loop do + fname, ftype, fid = iprot.read_field_begin + break if (ftype == Types::STOP) + handle_message(iprot, fid, ftype) + iprot.read_field_end + end + iprot.read_struct_end + validate + ensure + iprot.decrement_recursion_depth + end end def write(oprot) validate - oprot.write_struct_begin(self.class.name) - each_field do |fid, field_info| - name = field_info[:name] - type = field_info[:type] - value = instance_variable_get("@#{name}") - unless value.nil? - if is_container? type - oprot.write_field_begin(name, type, fid) - write_container(oprot, value, field_info) - oprot.write_field_end - else - oprot.write_field(field_info, fid, value) + oprot.increment_recursion_depth + begin + oprot.write_struct_begin(self.class.name) + each_field do |fid, field_info| + name = field_info[:name] + type = field_info[:type] + value = instance_variable_get("@#{name}") + unless value.nil? + if is_container? type + oprot.write_field_begin(name, type, fid) + write_container(oprot, value, field_info) + oprot.write_field_end + else + oprot.write_field(field_info, fid, value) + end end end + oprot.write_field_stop + oprot.write_struct_end + ensure + oprot.decrement_recursion_depth end - oprot.write_field_stop - oprot.write_struct_end end def ==(other)