Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion compiler/cpp/src/thrift/generate/t_javame_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,11 @@ void t_javame_generator::generate_java_struct_reader(ostream& out, t_struct* tst
vector<t_field*>::const_iterator f_iter;

// Declare stack tmp variables and read struct header
out << indent() << "TField field;" << '\n' << indent() << "iprot.readStructBegin();" << '\n';
out << indent() << "TField field;" << '\n';
indent(out) << "iprot.incrementRecursionDepth();" << '\n';
indent(out) << "try {" << '\n';
indent_up();
out << indent() << "iprot.readStructBegin();" << '\n';

// Loop over reading in fields
indent(out) << "while (true)" << '\n';
Expand Down Expand Up @@ -1358,6 +1362,13 @@ void t_javame_generator::generate_java_struct_reader(ostream& out, t_struct* tst
// performs various checks (e.g. check that all required fields are set)
indent(out) << "validate();" << '\n';

indent_down();
indent(out) << "} finally {" << '\n';
indent_up();
indent(out) << "iprot.decrementRecursionDepth();" << '\n';
indent_down();
indent(out) << "}" << '\n';

indent_down();
out << indent() << "}" << '\n' << '\n';
}
Expand Down Expand Up @@ -1400,6 +1411,9 @@ void t_javame_generator::generate_java_struct_writer(ostream& out, t_struct* tst
// performs various checks (e.g. check that all required fields are set)
indent(out) << "validate();" << '\n' << '\n';

indent(out) << "oprot.incrementRecursionDepth();" << '\n';
indent(out) << "try {" << '\n';
indent_up();
indent(out) << "oprot.writeStructBegin(STRUCT_DESC);" << '\n';

for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
Expand Down Expand Up @@ -1436,6 +1450,13 @@ void t_javame_generator::generate_java_struct_writer(ostream& out, t_struct* tst
out << indent() << "oprot.writeFieldStop();" << '\n' << indent() << "oprot.writeStructEnd();"
<< '\n';

indent_down();
indent(out) << "} finally {" << '\n';
indent_up();
indent(out) << "oprot.decrementRecursionDepth();" << '\n';
indent_down();
indent(out) << "}" << '\n';

indent_down();
out << indent() << "}" << '\n' << '\n';
}
Expand All @@ -1456,6 +1477,9 @@ void t_javame_generator::generate_java_struct_result_writer(ostream& out, t_stru
const vector<t_field*>& fields = tstruct->get_sorted_members();
vector<t_field*>::const_iterator f_iter;

indent(out) << "oprot.incrementRecursionDepth();" << '\n';
indent(out) << "try {" << '\n';
indent_up();
indent(out) << "oprot.writeStructBegin(STRUCT_DESC);" << '\n';

bool first = true;
Expand Down Expand Up @@ -1487,6 +1511,13 @@ void t_javame_generator::generate_java_struct_result_writer(ostream& out, t_stru
out << '\n' << indent() << "oprot.writeFieldStop();" << '\n' << indent()
<< "oprot.writeStructEnd();" << '\n';

indent_down();
indent(out) << "} finally {" << '\n';
indent_up();
indent(out) << "oprot.decrementRecursionDepth();" << '\n';
indent_down();
indent(out) << "}" << '\n';

indent_down();
out << indent() << "}" << '\n' << '\n';
}
Expand Down
23 changes: 23 additions & 0 deletions lib/javame/src/org/apache/thrift/protocol/TProtocol.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ private TProtocol() {}
*/
protected TTransport trans_;

/** Current recursion depth during struct serialization */
private int recursionDepth_ = 0;

private static final int DEFAULT_RECURSION_DEPTH = 64;

/**
* Constructor
*/
Expand All @@ -52,6 +57,24 @@ public TTransport getTransport() {
return trans_;
}

/**
* Increment recursion depth, checking against the limit.
*
* @throws TProtocolException with DEPTH_LIMIT if the limit is exceeded
*/
public void incrementRecursionDepth() throws TProtocolException {
if (recursionDepth_ >= DEFAULT_RECURSION_DEPTH) {
throw new TProtocolException(
TProtocolException.DEPTH_LIMIT, "Maximum recursion depth exceeded");
}
++recursionDepth_;
}

/** Decrement recursion depth. Must be called in a finally block. */
public void decrementRecursionDepth() {
--recursionDepth_;
}

/**
* Writing methods.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public class TProtocolException extends TException {
public static final int SIZE_LIMIT = 3;
public static final int BAD_VERSION = 4;
public static final int NOT_IMPLEMENTED = 5;
public static final int DEPTH_LIMIT = 6;

protected int type_ = UNKNOWN;

Expand Down
152 changes: 152 additions & 0 deletions lib/javame/test/org/apache/thrift/protocol/TestRecursionDepth.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.thrift.protocol;

import org.apache.thrift.transport.TMemoryBuffer;

/**
* Standalone test for TProtocol recursion depth limiting in the JavaME library.
*
* Run via:
* javac -sourcepath src:test -d /tmp/javame_classes \
* test/org/apache/thrift/protocol/TestRecursionDepth.java \
* $(find src -name "*.java")
* java -cp /tmp/javame_classes org.apache.thrift.protocol.TestRecursionDepth
*/
public class TestRecursionDepth {

private static TProtocol makeProtocol() throws Exception {
TMemoryBuffer buf = new TMemoryBuffer(256);
return new TBinaryProtocol(buf);
}

private static void assertEqual(Object expected, Object actual, String msg) {
if (!expected.equals(actual)) {
throw new AssertionError(msg + ": expected " + expected + " but got " + actual);
}
}

private static void assertTrue(boolean condition, String msg) {
if (!condition) {
throw new AssertionError(msg);
}
}

/** 64 increments must succeed. */
static void testUnderLimitSucceeds() throws Exception {
TProtocol proto = makeProtocol();
for (int i = 0; i < 64; i++) {
proto.incrementRecursionDepth();
}
for (int i = 0; i < 64; i++) {
proto.decrementRecursionDepth();
}
System.out.println("PASS testUnderLimitSucceeds");
}

/** The 65th increment must throw DEPTH_LIMIT. */
static void testDepthLimitThrows() throws Exception {
TProtocol proto = makeProtocol();
for (int i = 0; i < 64; i++) {
proto.incrementRecursionDepth();
}
try {
proto.incrementRecursionDepth();
throw new AssertionError("Expected TProtocolException with DEPTH_LIMIT");
} catch (TProtocolException e) {
assertEqual(TProtocolException.DEPTH_LIMIT, e.getType(),
"exception type must be DEPTH_LIMIT");
}
for (int i = 0; i < 64; i++) {
proto.decrementRecursionDepth();
}
System.out.println("PASS testDepthLimitThrows");
}

/** Exception message must be non-empty. */
static void testDepthLimitMessage() throws Exception {
TProtocol proto = makeProtocol();
for (int i = 0; i < 64; i++) {
proto.incrementRecursionDepth();
}
try {
proto.incrementRecursionDepth();
throw new AssertionError("Expected TProtocolException");
} catch (TProtocolException e) {
assertTrue(e.getMessage() != null && !e.getMessage().isEmpty(),
"exception message must be non-empty");
}
for (int i = 0; i < 64; i++) {
proto.decrementRecursionDepth();
}
System.out.println("PASS testDepthLimitMessage");
}

/** After a failed increment the counter must be unchanged. */
static void testCounterRestoredAfterLimit() throws Exception {
TProtocol proto = makeProtocol();
for (int i = 0; i < 63; i++) {
proto.incrementRecursionDepth();
}
proto.incrementRecursionDepth(); // depth = 64
try {
proto.incrementRecursionDepth(); // must throw, depth stays 64
} catch (TProtocolException ignored) {}
// depth is still 64, one more increment should still throw
try {
proto.incrementRecursionDepth();
throw new AssertionError("Expected second DEPTH_LIMIT throw");
} catch (TProtocolException e) {
assertEqual(TProtocolException.DEPTH_LIMIT, e.getType(), "must still be DEPTH_LIMIT");
}
for (int i = 0; i < 64; i++) {
proto.decrementRecursionDepth();
}
System.out.println("PASS testCounterRestoredAfterLimit");
}

/** Increment/decrement pairs leave the counter at zero. */
static void testIncrementDecrementBalance() throws Exception {
TProtocol proto = makeProtocol();
for (int i = 0; i < 32; i++) {
proto.incrementRecursionDepth();
}
for (int i = 0; i < 32; i++) {
proto.decrementRecursionDepth();
}
// After balance, 64 more increments must succeed
for (int i = 0; i < 64; i++) {
proto.incrementRecursionDepth();
}
for (int i = 0; i < 64; i++) {
proto.decrementRecursionDepth();
}
System.out.println("PASS testIncrementDecrementBalance");
}

public static void main(String[] args) throws Exception {
testUnderLimitSucceeds();
testDepthLimitThrows();
testDepthLimitMessage();
testCounterRestoredAfterLimit();
testIncrementDecrementBalance();
System.out.println("All tests passed.");
}
}
Loading