From efefe5d947aa773188f04f8ac9bcce5c19fdb150 Mon Sep 17 00:00:00 2001 From: Jens Geyer Date: Thu, 28 May 2026 01:38:33 +0200 Subject: [PATCH] THRIFT-6051: Harden OCaml protocol against deeply nested messages Client: ocaml - Add recursion_depth_ field and increment/decrement_recursion_depth methods to Protocol.t in Thrift.ml - Move type exn_type and exception E before class virtual t to fix a pre-existing forward-reference bug (Protocol.INVALID_DATA in skip) - Generator wraps struct read/write with Fun.protect to ensure decrement_recursion_depth always runs even on exception - Limit is 64 levels; raises Protocol.E(DEPTH_LIMIT,...) on excess Co-Authored-By: Claude Sonnet 4.6 --- .../src/thrift/generate/t_ocaml_generator.cc | 12 +- lib/ocaml/src/Thrift.ml | 33 ++-- lib/ocaml/test/test_recursion_depth.ml | 146 ++++++++++++++++++ 3 files changed, 178 insertions(+), 13 deletions(-) create mode 100644 lib/ocaml/test/test_recursion_depth.ml diff --git a/compiler/cpp/src/thrift/generate/t_ocaml_generator.cc b/compiler/cpp/src/thrift/generate/t_ocaml_generator.cc index 34dda402a1a..6c628477f09 100644 --- a/compiler/cpp/src/thrift/generate/t_ocaml_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_ocaml_generator.cc @@ -761,6 +761,9 @@ void t_ocaml_generator::generate_ocaml_struct_reader(ostream& out, t_struct* tst indent_up(); indent(out) << "let " << str << " = new " << sname << " in" << '\n'; indent_up(); + indent(out) << "iprot#increment_recursion_depth;" << '\n'; + indent(out) << "(Fun.protect ~finally:(fun () -> iprot#decrement_recursion_depth) (fun () ->" << '\n'; + indent_up(); indent(out) << "ignore(iprot#readStructBegin);" << '\n'; // Loop over reading in fields @@ -803,7 +806,9 @@ void t_ocaml_generator::generate_ocaml_struct_reader(ostream& out, t_struct* tst indent_down(); indent(out) << "with Break -> ());" << '\n'; - indent(out) << "iprot#readStructEnd;" << '\n'; + indent(out) << "iprot#readStructEnd" << '\n'; + indent_down(); + indent(out) << "));" << '\n'; indent(out) << str << '\n' << '\n'; indent_down(); @@ -819,6 +824,9 @@ void t_ocaml_generator::generate_ocaml_struct_writer(ostream& out, t_struct* tst indent(out) << "method write (oprot : Protocol.t) =" << '\n'; indent_up(); + indent(out) << "oprot#increment_recursion_depth;" << '\n'; + indent(out) << "Fun.protect ~finally:(fun () -> oprot#decrement_recursion_depth) (fun () ->" << '\n'; + indent_up(); indent(out) << "oprot#writeStructBegin \"" << name << "\";" << '\n'; for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { @@ -878,6 +886,8 @@ void t_ocaml_generator::generate_ocaml_struct_writer(ostream& out, t_struct* tst // Write the struct map out << indent() << "oprot#writeFieldStop;" << '\n' << indent() << "oprot#writeStructEnd" << '\n'; + indent_down(); + indent(out) << ")" << '\n'; indent_down(); } diff --git a/lib/ocaml/src/Thrift.ml b/lib/ocaml/src/Thrift.ml index 063459ba0c1..10fe4564bd0 100644 --- a/lib/ocaml/src/Thrift.ml +++ b/lib/ocaml/src/Thrift.ml @@ -156,10 +156,30 @@ struct | 4 -> ONEWAY | _ -> raise Thrift_error + type exn_type = + | UNKNOWN + | INVALID_DATA + | NEGATIVE_SIZE + | SIZE_LIMIT + | BAD_VERSION + | NOT_IMPLEMENTED + | DEPTH_LIMIT + + exception E of exn_type * string;; + class virtual t (trans: Transport.t) = object (self) val mutable trans_ = trans + val mutable recursion_depth_ = 0 method getTransport = trans_ + method increment_recursion_depth = + recursion_depth_ <- recursion_depth_ + 1; + if recursion_depth_ > 64 then begin + recursion_depth_ <- recursion_depth_ - 1; + raise (E (DEPTH_LIMIT, "Maximum recursion depth exceeded")) + end + method decrement_recursion_depth = + recursion_depth_ <- recursion_depth_ - 1 (* writing methods *) method virtual writeMessageBegin : string * message_type * int -> unit method virtual writeMessageEnd : unit @@ -246,7 +266,7 @@ struct self#readListEnd) | T_UTF8 -> () | T_UTF16 -> () - | _ -> raise (Protocol.E (Protocol.INVALID_DATA, "Invalid data")) + | _ -> raise (E (INVALID_DATA, "Invalid data")) end class virtual factory = @@ -254,17 +274,6 @@ struct method virtual getProtocol : Transport.t -> t end - type exn_type = - | UNKNOWN - | INVALID_DATA - | NEGATIVE_SIZE - | SIZE_LIMIT - | BAD_VERSION - | NOT_IMPLEMENTED - | DEPTH_LIMIT - - exception E of exn_type * string;; - end;; diff --git a/lib/ocaml/test/test_recursion_depth.ml b/lib/ocaml/test/test_recursion_depth.ml new file mode 100644 index 00000000000..ace87939bd2 --- /dev/null +++ b/lib/ocaml/test/test_recursion_depth.ml @@ -0,0 +1,146 @@ +(* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*) + +open Thrift + +(* Minimal no-op transport for testing *) +class null_transport = +object + inherit Transport.t + method isOpen = true + method opn = () + method close = () + method read _buf _off _len = 0 + method write _buf _off _len = () + method flush = () +end + +(* Minimal concrete protocol subclass for testing *) +class test_protocol trans = +object (self) + inherit Protocol.t trans + method writeMessageBegin _ = () + method writeMessageEnd = () + method writeStructBegin _ = () + method writeStructEnd = () + method writeFieldBegin _ = () + method writeFieldEnd = () + method writeFieldStop = () + method writeMapBegin _ = () + method writeMapEnd = () + method writeListBegin _ = () + method writeListEnd = () + method writeSetBegin _ = () + method writeSetEnd = () + method writeBool _ = () + method writeByte _ = () + method writeI16 _ = () + method writeI32 _ = () + method writeI64 _ = () + method writeDouble _ = () + method writeString _ = () + method writeBinary _ = () + method readMessageBegin = ("", Protocol.CALL, 0) + method readMessageEnd = () + method readStructBegin = "" + method readStructEnd = () + method readFieldBegin = ("", Protocol.T_STOP, 0) + method readFieldEnd = () + method readMapBegin = (Protocol.T_STOP, Protocol.T_STOP, 0) + method readMapEnd = () + method readListBegin = (Protocol.T_STOP, 0) + method readListEnd = () + method readSetBegin = (Protocol.T_STOP, 0) + method readSetEnd = () + method readBool = false + method readByte = 0 + method readI16 = 0 + method readI32 = 0l + method readI64 = 0L + method readDouble = 0.0 + method readString = "" + method readBinary = "" +end + +let passed = ref 0 +let failed = ref 0 + +let check label cond = + if cond then begin + Printf.printf "PASS: %s\n%!" label; + incr passed + end else begin + Printf.printf "FAIL: %s\n%!" label; + incr failed + end + +let () = + let trans = new null_transport in + let proto = new test_protocol trans in + + (* Test 1: initial depth is 0 (64 increments succeed) *) + (try + for _ = 1 to 64 do + proto#increment_recursion_depth + done; + check "64 increments succeed" true + with _ -> + check "64 increments succeed" false); + + (* Test 2: 65th increment raises DEPTH_LIMIT *) + (try + proto#increment_recursion_depth; + check "65th increment raises DEPTH_LIMIT" false + with Protocol.E (Protocol.DEPTH_LIMIT, _) -> + check "65th increment raises DEPTH_LIMIT" true); + + (* Test 3: depth stays at 64 after rejected increment *) + (* decrement 64 times should bring us back to 0 *) + (try + for _ = 1 to 64 do + proto#decrement_recursion_depth + done; + (* now one more increment should succeed (depth was 64, now 0) *) + proto#increment_recursion_depth; + proto#decrement_recursion_depth; + check "depth correctly maintained at limit" true + with _ -> + check "depth correctly maintained at limit" false); + + (* Test 4: decrement restores; fresh increments work *) + (try + proto#increment_recursion_depth; + proto#increment_recursion_depth; + proto#decrement_recursion_depth; + proto#decrement_recursion_depth; + check "increment/decrement balance" true + with _ -> + check "increment/decrement balance" false); + + (* Test 5: DEPTH_LIMIT exception carries correct message *) + (let proto2 = new test_protocol trans in + for _ = 1 to 64 do proto2#increment_recursion_depth done; + try + proto2#increment_recursion_depth; + check "DEPTH_LIMIT message non-empty" false + with Protocol.E (Protocol.DEPTH_LIMIT, msg) -> + check "DEPTH_LIMIT message non-empty" (String.length msg > 0)); + + Printf.printf "\nResults: %d passed, %d failed\n" !passed !failed; + if !failed > 0 then exit 1