Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion compiler/cpp/src/thrift/generate/t_ocaml_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,9 @@ void t_ocaml_generator::generate_ocaml_struct_reader(ostream& out, t_struct* tst
indent_up();
indent(out) << "let " << str << " = new " << sname << " in" << '\n';
indent_up();
indent(out) << "iprot#increment_recursion_depth;" << '\n';
indent(out) << "(Fun.protect ~finally:(fun () -> iprot#decrement_recursion_depth) (fun () ->" << '\n';
indent_up();
indent(out) << "ignore(iprot#readStructBegin);" << '\n';

// Loop over reading in fields
Expand Down Expand Up @@ -803,7 +806,9 @@ void t_ocaml_generator::generate_ocaml_struct_reader(ostream& out, t_struct* tst
indent_down();
indent(out) << "with Break -> ());" << '\n';

indent(out) << "iprot#readStructEnd;" << '\n';
indent(out) << "iprot#readStructEnd" << '\n';
indent_down();
indent(out) << "));" << '\n';

indent(out) << str << '\n' << '\n';
indent_down();
Expand All @@ -819,6 +824,9 @@ void t_ocaml_generator::generate_ocaml_struct_writer(ostream& out, t_struct* tst

indent(out) << "method write (oprot : Protocol.t) =" << '\n';
indent_up();
indent(out) << "oprot#increment_recursion_depth;" << '\n';
indent(out) << "Fun.protect ~finally:(fun () -> oprot#decrement_recursion_depth) (fun () ->" << '\n';
indent_up();
indent(out) << "oprot#writeStructBegin \"" << name << "\";" << '\n';

for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
Expand Down Expand Up @@ -878,6 +886,8 @@ void t_ocaml_generator::generate_ocaml_struct_writer(ostream& out, t_struct* tst

// Write the struct map
out << indent() << "oprot#writeFieldStop;" << '\n' << indent() << "oprot#writeStructEnd" << '\n';
indent_down();
indent(out) << ")" << '\n';

indent_down();
}
Expand Down
33 changes: 21 additions & 12 deletions lib/ocaml/src/Thrift.ml
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,30 @@ struct
| 4 -> ONEWAY
| _ -> raise Thrift_error

type exn_type =
| UNKNOWN
| INVALID_DATA
| NEGATIVE_SIZE
| SIZE_LIMIT
| BAD_VERSION
| NOT_IMPLEMENTED
| DEPTH_LIMIT

exception E of exn_type * string;;

class virtual t (trans: Transport.t) =
object (self)
val mutable trans_ = trans
val mutable recursion_depth_ = 0
method getTransport = trans_
method increment_recursion_depth =
recursion_depth_ <- recursion_depth_ + 1;
if recursion_depth_ > 64 then begin
recursion_depth_ <- recursion_depth_ - 1;
raise (E (DEPTH_LIMIT, "Maximum recursion depth exceeded"))
end
method decrement_recursion_depth =
recursion_depth_ <- recursion_depth_ - 1
(* writing methods *)
method virtual writeMessageBegin : string * message_type * int -> unit
method virtual writeMessageEnd : unit
Expand Down Expand Up @@ -246,25 +266,14 @@ struct
self#readListEnd)
| T_UTF8 -> ()
| T_UTF16 -> ()
| _ -> raise (Protocol.E (Protocol.INVALID_DATA, "Invalid data"))
| _ -> raise (E (INVALID_DATA, "Invalid data"))
end

class virtual factory =
object
method virtual getProtocol : Transport.t -> t
end

type exn_type =
| UNKNOWN
| INVALID_DATA
| NEGATIVE_SIZE
| SIZE_LIMIT
| BAD_VERSION
| NOT_IMPLEMENTED
| DEPTH_LIMIT

exception E of exn_type * string;;

end;;


Expand Down
146 changes: 146 additions & 0 deletions lib/ocaml/test/test_recursion_depth.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
(*
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
*)

open Thrift

(* Minimal no-op transport for testing *)
class null_transport =
object
inherit Transport.t
method isOpen = true
method opn = ()
method close = ()
method read _buf _off _len = 0
method write _buf _off _len = ()
method flush = ()
end

(* Minimal concrete protocol subclass for testing *)
class test_protocol trans =
object (self)
inherit Protocol.t trans
method writeMessageBegin _ = ()
method writeMessageEnd = ()
method writeStructBegin _ = ()
method writeStructEnd = ()
method writeFieldBegin _ = ()
method writeFieldEnd = ()
method writeFieldStop = ()
method writeMapBegin _ = ()
method writeMapEnd = ()
method writeListBegin _ = ()
method writeListEnd = ()
method writeSetBegin _ = ()
method writeSetEnd = ()
method writeBool _ = ()
method writeByte _ = ()
method writeI16 _ = ()
method writeI32 _ = ()
method writeI64 _ = ()
method writeDouble _ = ()
method writeString _ = ()
method writeBinary _ = ()
method readMessageBegin = ("", Protocol.CALL, 0)
method readMessageEnd = ()
method readStructBegin = ""
method readStructEnd = ()
method readFieldBegin = ("", Protocol.T_STOP, 0)
method readFieldEnd = ()
method readMapBegin = (Protocol.T_STOP, Protocol.T_STOP, 0)
method readMapEnd = ()
method readListBegin = (Protocol.T_STOP, 0)
method readListEnd = ()
method readSetBegin = (Protocol.T_STOP, 0)
method readSetEnd = ()
method readBool = false
method readByte = 0
method readI16 = 0
method readI32 = 0l
method readI64 = 0L
method readDouble = 0.0
method readString = ""
method readBinary = ""
end

let passed = ref 0
let failed = ref 0

let check label cond =
if cond then begin
Printf.printf "PASS: %s\n%!" label;
incr passed
end else begin
Printf.printf "FAIL: %s\n%!" label;
incr failed
end

let () =
let trans = new null_transport in
let proto = new test_protocol trans in

(* Test 1: initial depth is 0 (64 increments succeed) *)
(try
for _ = 1 to 64 do
proto#increment_recursion_depth
done;
check "64 increments succeed" true
with _ ->
check "64 increments succeed" false);

(* Test 2: 65th increment raises DEPTH_LIMIT *)
(try
proto#increment_recursion_depth;
check "65th increment raises DEPTH_LIMIT" false
with Protocol.E (Protocol.DEPTH_LIMIT, _) ->
check "65th increment raises DEPTH_LIMIT" true);

(* Test 3: depth stays at 64 after rejected increment *)
(* decrement 64 times should bring us back to 0 *)
(try
for _ = 1 to 64 do
proto#decrement_recursion_depth
done;
(* now one more increment should succeed (depth was 64, now 0) *)
proto#increment_recursion_depth;
proto#decrement_recursion_depth;
check "depth correctly maintained at limit" true
with _ ->
check "depth correctly maintained at limit" false);

(* Test 4: decrement restores; fresh increments work *)
(try
proto#increment_recursion_depth;
proto#increment_recursion_depth;
proto#decrement_recursion_depth;
proto#decrement_recursion_depth;
check "increment/decrement balance" true
with _ ->
check "increment/decrement balance" false);

(* Test 5: DEPTH_LIMIT exception carries correct message *)
(let proto2 = new test_protocol trans in
for _ = 1 to 64 do proto2#increment_recursion_depth done;
try
proto2#increment_recursion_depth;
check "DEPTH_LIMIT message non-empty" false
with Protocol.E (Protocol.DEPTH_LIMIT, msg) ->
check "DEPTH_LIMIT message non-empty" (String.length msg > 0));

Printf.printf "\nResults: %d passed, %d failed\n" !passed !failed;
if !failed > 0 then exit 1
Loading