diff --git a/compiler/cpp/src/thrift/generate/t_swift_generator.cc b/compiler/cpp/src/thrift/generate/t_swift_generator.cc index 56411a15d1b..9a1139db57c 100644 --- a/compiler/cpp/src/thrift/generate/t_swift_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_swift_generator.cc @@ -1044,6 +1044,8 @@ void t_swift_generator::generate_swift_union_reader(ostream& out, t_struct* tstr indent(out) << "public static func read(from proto: TProtocol) throws -> " << tstruct->get_name(); block_open(out); + indent(out) << "try proto.incrementRecursionDepth()" << '\n'; + indent(out) << "defer { proto.decrementRecursionDepth() }" << '\n'; indent(out) << "_ = try proto.readStructBegin()" << '\n'; indent(out) << "var ret: " << tstruct->get_name() << "?"; @@ -1139,6 +1141,8 @@ void t_swift_generator::generate_swift_struct_reader(ostream& out, << tstruct->get_name(); block_open(out); + indent(out) << "try proto.incrementRecursionDepth()" << '\n'; + indent(out) << "defer { proto.decrementRecursionDepth() }" << '\n'; indent(out) << "_ = try proto.readStructBegin()" << '\n'; const vector& fields = tstruct->get_members(); diff --git a/lib/swift/Sources/TBinaryProtocol.swift b/lib/swift/Sources/TBinaryProtocol.swift index 766027e729a..318e2dd8f7b 100644 --- a/lib/swift/Sources/TBinaryProtocol.swift +++ b/lib/swift/Sources/TBinaryProtocol.swift @@ -26,7 +26,8 @@ public struct TBinaryProtocolVersion { public class TBinaryProtocol: TProtocol { public var messageSizeLimit: UInt32 = 0 - + public var recursionDepth: Int = 0 + public var transport: TTransport // class level properties for setting global config (useful for server in lieu of Factory design) diff --git a/lib/swift/Sources/TCompactProtocol.swift b/lib/swift/Sources/TCompactProtocol.swift index 9400faf2070..e2d13765459 100644 --- a/lib/swift/Sources/TCompactProtocol.swift +++ b/lib/swift/Sources/TCompactProtocol.swift @@ -50,6 +50,7 @@ public class TCompactProtocol: TProtocol { static let maxVarintBytes = 10 // ceil(64/7); matches protobuf wire format public var transport: TTransport + public var recursionDepth: Int = 0 var lastField: [UInt8] = [] var lastFieldId: UInt8 = 0 diff --git a/lib/swift/Sources/TJSONProtocol.swift b/lib/swift/Sources/TJSONProtocol.swift index b1e41c74559..a9f960b9f7d 100644 --- a/lib/swift/Sources/TJSONProtocol.swift +++ b/lib/swift/Sources/TJSONProtocol.swift @@ -29,6 +29,7 @@ public class TJSONProtocol: TProtocol { static let Version: Int = 1 public var transport: TTransport + public var recursionDepth: Int = 0 // Temporary buffer used by several methods private var tempBuffer: [UInt8] = [0,0,0,0] diff --git a/lib/swift/Sources/TProtocol.swift b/lib/swift/Sources/TProtocol.swift index e3a7a0ec408..725d22333d1 100644 --- a/lib/swift/Sources/TProtocol.swift +++ b/lib/swift/Sources/TProtocol.swift @@ -43,8 +43,9 @@ public enum TType: Int32 { case uuid = 16 } -public protocol TProtocol { +public protocol TProtocol: AnyObject { var transport: TTransport { get set } + var recursionDepth: Int { get set } init(on transport: TTransport) // Reading Methods @@ -132,6 +133,18 @@ public extension TProtocol { try writeMessageEnd() } + func incrementRecursionDepth() throws { + recursionDepth += 1 + if recursionDepth > 64 { + recursionDepth -= 1 + throw TProtocolError(error: .depthLimit, message: "Maximum recursion depth exceeded") + } + } + + func decrementRecursionDepth() { + recursionDepth -= 1 + } + func skip(type: TType) throws { try skip(type: type, depth: 0) } diff --git a/lib/swift/Sources/TProtocolDecorator.swift b/lib/swift/Sources/TProtocolDecorator.swift index e831f27a8a0..bb77a9ca496 100644 --- a/lib/swift/Sources/TProtocolDecorator.swift +++ b/lib/swift/Sources/TProtocolDecorator.swift @@ -24,6 +24,11 @@ class TProtocolDecorator: TProtocol { private let proto: TProtocol var transport: TTransport + var recursionDepth: Int { + get { proto.recursionDepth } + set { proto.recursionDepth = newValue } + } + init(proto: TProtocol) { self.proto = proto self.transport = proto.transport diff --git a/lib/swift/Sources/TStruct.swift b/lib/swift/Sources/TStruct.swift index d0a1a4bcd21..e603b6837fa 100644 --- a/lib/swift/Sources/TStruct.swift +++ b/lib/swift/Sources/TStruct.swift @@ -33,6 +33,8 @@ public extension TStruct { static var thriftType: TType { return .struct } func write(to proto: TProtocol) throws { + try proto.incrementRecursionDepth() + defer { proto.decrementRecursionDepth() } // Write struct name first try proto.writeStructBegin(name: Self.structName) diff --git a/lib/swift/Sources/TWrappedProtocol.swift b/lib/swift/Sources/TWrappedProtocol.swift index 8e47bd58e88..6a47232c0b1 100644 --- a/lib/swift/Sources/TWrappedProtocol.swift +++ b/lib/swift/Sources/TWrappedProtocol.swift @@ -35,6 +35,11 @@ open class TWrappedProtocol : TProtocol { } } + public var recursionDepth: Int { + get { concreteProtocol.recursionDepth } + set { concreteProtocol.recursionDepth = newValue } + } + public required init(on transport: TTransport) { self.concreteProtocol = Protocol(on: transport) } diff --git a/lib/swift/Tests/LinuxMain.swift b/lib/swift/Tests/LinuxMain.swift index a1c3f92105f..83b921152ce 100644 --- a/lib/swift/Tests/LinuxMain.swift +++ b/lib/swift/Tests/LinuxMain.swift @@ -24,4 +24,5 @@ XCTMain([ testCase(ThriftTests.allTests), testCase(TBinaryProtocolTests.allTests), testCase(TCompactProtocolTests.allTests), + testCase(TRecursionDepthTests.allTests), ]) diff --git a/lib/swift/Tests/ThriftTests/TRecursionDepthTests.swift b/lib/swift/Tests/ThriftTests/TRecursionDepthTests.swift new file mode 100644 index 00000000000..adb39a44939 --- /dev/null +++ b/lib/swift/Tests/ThriftTests/TRecursionDepthTests.swift @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import XCTest +@testable import Thrift + +class TRecursionDepthTests: XCTestCase { + + var transport: TMemoryBufferTransport! + var proto: TBinaryProtocol! + + override func setUp() { + super.setUp() + transport = TMemoryBufferTransport(flushHandler: { $0.reset(readBuffer: $1) }) + proto = TBinaryProtocol(on: transport) + } + + func testInitialDepthIsZero() { + XCTAssertEqual(proto.recursionDepth, 0) + } + + func testIncrementAllowsUpToLimit() { + XCTAssertNoThrow(try { + for _ in 0..<64 { + try self.proto.incrementRecursionDepth() + } + }()) + XCTAssertEqual(proto.recursionDepth, 64) + } + + func testIncrementThrowsAtLimitPlusOne() { + for _ in 0..<64 { try? proto.incrementRecursionDepth() } + XCTAssertThrowsError(try proto.incrementRecursionDepth()) { error in + guard let err = error as? TProtocolError else { + XCTFail("Expected TProtocolError, got \(error)") + return + } + if case .depthLimit = err.error { } else { + XCTFail("Expected .depthLimit, got \(err.error)") + } + } + } + + func testDepthStaysAtLimitAfterFailedIncrement() { + for _ in 0..<64 { try? proto.incrementRecursionDepth() } + try? proto.incrementRecursionDepth() + XCTAssertEqual(proto.recursionDepth, 64) + } + + func testDecrementRestoresCapacity() { + for _ in 0..<64 { try? proto.incrementRecursionDepth() } + proto.decrementRecursionDepth() + XCTAssertEqual(proto.recursionDepth, 63) + XCTAssertNoThrow(try proto.incrementRecursionDepth()) + } + + static var allTests = [ + ("testInitialDepthIsZero", testInitialDepthIsZero), + ("testIncrementAllowsUpToLimit", testIncrementAllowsUpToLimit), + ("testIncrementThrowsAtLimitPlusOne", testIncrementThrowsAtLimitPlusOne), + ("testDepthStaysAtLimitAfterFailedIncrement", testDepthStaysAtLimitAfterFailedIncrement), + ("testDecrementRestoresCapacity", testDecrementRestoresCapacity), + ] +}