diff --git a/compiler/cpp/src/thrift/generate/t_javame_generator.cc b/compiler/cpp/src/thrift/generate/t_javame_generator.cc index e2c3a395cdc..4c08be9733f 100644 --- a/compiler/cpp/src/thrift/generate/t_javame_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_javame_generator.cc @@ -1304,7 +1304,11 @@ void t_javame_generator::generate_java_struct_reader(ostream& out, t_struct* tst vector::const_iterator f_iter; // Declare stack tmp variables and read struct header - out << indent() << "TField field;" << '\n' << indent() << "iprot.readStructBegin();" << '\n'; + out << indent() << "TField field;" << '\n'; + indent(out) << "iprot.incrementRecursionDepth();" << '\n'; + indent(out) << "try {" << '\n'; + indent_up(); + out << indent() << "iprot.readStructBegin();" << '\n'; // Loop over reading in fields indent(out) << "while (true)" << '\n'; @@ -1358,6 +1362,13 @@ void t_javame_generator::generate_java_struct_reader(ostream& out, t_struct* tst // performs various checks (e.g. check that all required fields are set) indent(out) << "validate();" << '\n'; + indent_down(); + indent(out) << "} finally {" << '\n'; + indent_up(); + indent(out) << "iprot.decrementRecursionDepth();" << '\n'; + indent_down(); + indent(out) << "}" << '\n'; + indent_down(); out << indent() << "}" << '\n' << '\n'; } @@ -1400,6 +1411,9 @@ void t_javame_generator::generate_java_struct_writer(ostream& out, t_struct* tst // performs various checks (e.g. check that all required fields are set) indent(out) << "validate();" << '\n' << '\n'; + indent(out) << "oprot.incrementRecursionDepth();" << '\n'; + indent(out) << "try {" << '\n'; + indent_up(); indent(out) << "oprot.writeStructBegin(STRUCT_DESC);" << '\n'; for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { @@ -1436,6 +1450,13 @@ void t_javame_generator::generate_java_struct_writer(ostream& out, t_struct* tst out << indent() << "oprot.writeFieldStop();" << '\n' << indent() << "oprot.writeStructEnd();" << '\n'; + indent_down(); + indent(out) << "} finally {" << '\n'; + indent_up(); + indent(out) << "oprot.decrementRecursionDepth();" << '\n'; + indent_down(); + indent(out) << "}" << '\n'; + indent_down(); out << indent() << "}" << '\n' << '\n'; } @@ -1456,6 +1477,9 @@ void t_javame_generator::generate_java_struct_result_writer(ostream& out, t_stru const vector& fields = tstruct->get_sorted_members(); vector::const_iterator f_iter; + indent(out) << "oprot.incrementRecursionDepth();" << '\n'; + indent(out) << "try {" << '\n'; + indent_up(); indent(out) << "oprot.writeStructBegin(STRUCT_DESC);" << '\n'; bool first = true; @@ -1487,6 +1511,13 @@ void t_javame_generator::generate_java_struct_result_writer(ostream& out, t_stru out << '\n' << indent() << "oprot.writeFieldStop();" << '\n' << indent() << "oprot.writeStructEnd();" << '\n'; + indent_down(); + indent(out) << "} finally {" << '\n'; + indent_up(); + indent(out) << "oprot.decrementRecursionDepth();" << '\n'; + indent_down(); + indent(out) << "}" << '\n'; + indent_down(); out << indent() << "}" << '\n' << '\n'; } diff --git a/lib/javame/src/org/apache/thrift/protocol/TProtocol.java b/lib/javame/src/org/apache/thrift/protocol/TProtocol.java index 710e6d4c1cf..48cd57b9784 100644 --- a/lib/javame/src/org/apache/thrift/protocol/TProtocol.java +++ b/lib/javame/src/org/apache/thrift/protocol/TProtocol.java @@ -38,6 +38,11 @@ private TProtocol() {} */ protected TTransport trans_; + /** Current recursion depth during struct serialization */ + private int recursionDepth_ = 0; + + private static final int DEFAULT_RECURSION_DEPTH = 64; + /** * Constructor */ @@ -52,6 +57,24 @@ public TTransport getTransport() { return trans_; } + /** + * Increment recursion depth, checking against the limit. + * + * @throws TProtocolException with DEPTH_LIMIT if the limit is exceeded + */ + public void incrementRecursionDepth() throws TProtocolException { + if (recursionDepth_ >= DEFAULT_RECURSION_DEPTH) { + throw new TProtocolException( + TProtocolException.DEPTH_LIMIT, "Maximum recursion depth exceeded"); + } + ++recursionDepth_; + } + + /** Decrement recursion depth. Must be called in a finally block. */ + public void decrementRecursionDepth() { + --recursionDepth_; + } + /** * Writing methods. */ diff --git a/lib/javame/src/org/apache/thrift/protocol/TProtocolException.java b/lib/javame/src/org/apache/thrift/protocol/TProtocolException.java index 248815beccd..870f1b93923 100644 --- a/lib/javame/src/org/apache/thrift/protocol/TProtocolException.java +++ b/lib/javame/src/org/apache/thrift/protocol/TProtocolException.java @@ -35,6 +35,7 @@ public class TProtocolException extends TException { public static final int SIZE_LIMIT = 3; public static final int BAD_VERSION = 4; public static final int NOT_IMPLEMENTED = 5; + public static final int DEPTH_LIMIT = 6; protected int type_ = UNKNOWN; diff --git a/lib/javame/test/org/apache/thrift/protocol/TestRecursionDepth.java b/lib/javame/test/org/apache/thrift/protocol/TestRecursionDepth.java new file mode 100644 index 00000000000..403d397e4ec --- /dev/null +++ b/lib/javame/test/org/apache/thrift/protocol/TestRecursionDepth.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +import org.apache.thrift.transport.TMemoryBuffer; + +/** + * Standalone test for TProtocol recursion depth limiting in the JavaME library. + * + * Run via: + * javac -sourcepath src:test -d /tmp/javame_classes \ + * test/org/apache/thrift/protocol/TestRecursionDepth.java \ + * $(find src -name "*.java") + * java -cp /tmp/javame_classes org.apache.thrift.protocol.TestRecursionDepth + */ +public class TestRecursionDepth { + + private static TProtocol makeProtocol() throws Exception { + TMemoryBuffer buf = new TMemoryBuffer(256); + return new TBinaryProtocol(buf); + } + + private static void assertEqual(Object expected, Object actual, String msg) { + if (!expected.equals(actual)) { + throw new AssertionError(msg + ": expected " + expected + " but got " + actual); + } + } + + private static void assertTrue(boolean condition, String msg) { + if (!condition) { + throw new AssertionError(msg); + } + } + + /** 64 increments must succeed. */ + static void testUnderLimitSucceeds() throws Exception { + TProtocol proto = makeProtocol(); + for (int i = 0; i < 64; i++) { + proto.incrementRecursionDepth(); + } + for (int i = 0; i < 64; i++) { + proto.decrementRecursionDepth(); + } + System.out.println("PASS testUnderLimitSucceeds"); + } + + /** The 65th increment must throw DEPTH_LIMIT. */ + static void testDepthLimitThrows() throws Exception { + TProtocol proto = makeProtocol(); + for (int i = 0; i < 64; i++) { + proto.incrementRecursionDepth(); + } + try { + proto.incrementRecursionDepth(); + throw new AssertionError("Expected TProtocolException with DEPTH_LIMIT"); + } catch (TProtocolException e) { + assertEqual(TProtocolException.DEPTH_LIMIT, e.getType(), + "exception type must be DEPTH_LIMIT"); + } + for (int i = 0; i < 64; i++) { + proto.decrementRecursionDepth(); + } + System.out.println("PASS testDepthLimitThrows"); + } + + /** Exception message must be non-empty. */ + static void testDepthLimitMessage() throws Exception { + TProtocol proto = makeProtocol(); + for (int i = 0; i < 64; i++) { + proto.incrementRecursionDepth(); + } + try { + proto.incrementRecursionDepth(); + throw new AssertionError("Expected TProtocolException"); + } catch (TProtocolException e) { + assertTrue(e.getMessage() != null && !e.getMessage().isEmpty(), + "exception message must be non-empty"); + } + for (int i = 0; i < 64; i++) { + proto.decrementRecursionDepth(); + } + System.out.println("PASS testDepthLimitMessage"); + } + + /** After a failed increment the counter must be unchanged. */ + static void testCounterRestoredAfterLimit() throws Exception { + TProtocol proto = makeProtocol(); + for (int i = 0; i < 63; i++) { + proto.incrementRecursionDepth(); + } + proto.incrementRecursionDepth(); // depth = 64 + try { + proto.incrementRecursionDepth(); // must throw, depth stays 64 + } catch (TProtocolException ignored) {} + // depth is still 64, one more increment should still throw + try { + proto.incrementRecursionDepth(); + throw new AssertionError("Expected second DEPTH_LIMIT throw"); + } catch (TProtocolException e) { + assertEqual(TProtocolException.DEPTH_LIMIT, e.getType(), "must still be DEPTH_LIMIT"); + } + for (int i = 0; i < 64; i++) { + proto.decrementRecursionDepth(); + } + System.out.println("PASS testCounterRestoredAfterLimit"); + } + + /** Increment/decrement pairs leave the counter at zero. */ + static void testIncrementDecrementBalance() throws Exception { + TProtocol proto = makeProtocol(); + for (int i = 0; i < 32; i++) { + proto.incrementRecursionDepth(); + } + for (int i = 0; i < 32; i++) { + proto.decrementRecursionDepth(); + } + // After balance, 64 more increments must succeed + for (int i = 0; i < 64; i++) { + proto.incrementRecursionDepth(); + } + for (int i = 0; i < 64; i++) { + proto.decrementRecursionDepth(); + } + System.out.println("PASS testIncrementDecrementBalance"); + } + + public static void main(String[] args) throws Exception { + testUnderLimitSucceeds(); + testDepthLimitThrows(); + testDepthLimitMessage(); + testCounterRestoredAfterLimit(); + testIncrementDecrementBalance(); + System.out.println("All tests passed."); + } +}