Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions compiler/cpp/src/thrift/generate/t_swift_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1044,6 +1044,8 @@ void t_swift_generator::generate_swift_union_reader(ostream& out, t_struct* tstr
indent(out) << "public static func read(from proto: TProtocol) throws -> "
<< tstruct->get_name();
block_open(out);
indent(out) << "try proto.incrementRecursionDepth()" << '\n';
indent(out) << "defer { proto.decrementRecursionDepth() }" << '\n';
indent(out) << "_ = try proto.readStructBegin()" << '\n';

indent(out) << "var ret: " << tstruct->get_name() << "?";
Expand Down Expand Up @@ -1139,6 +1141,8 @@ void t_swift_generator::generate_swift_struct_reader(ostream& out,
<< tstruct->get_name();

block_open(out);
indent(out) << "try proto.incrementRecursionDepth()" << '\n';
indent(out) << "defer { proto.decrementRecursionDepth() }" << '\n';
indent(out) << "_ = try proto.readStructBegin()" << '\n';

const vector<t_field*>& fields = tstruct->get_members();
Expand Down
3 changes: 2 additions & 1 deletion lib/swift/Sources/TBinaryProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ public struct TBinaryProtocolVersion {

public class TBinaryProtocol: TProtocol {
public var messageSizeLimit: UInt32 = 0

public var recursionDepth: Int = 0

public var transport: TTransport

// class level properties for setting global config (useful for server in lieu of Factory design)
Expand Down
1 change: 1 addition & 0 deletions lib/swift/Sources/TCompactProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public class TCompactProtocol: TProtocol {
static let maxVarintBytes = 10 // ceil(64/7); matches protobuf wire format

public var transport: TTransport
public var recursionDepth: Int = 0

var lastField: [UInt8] = []
var lastFieldId: UInt8 = 0
Expand Down
1 change: 1 addition & 0 deletions lib/swift/Sources/TJSONProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public class TJSONProtocol: TProtocol {
static let Version: Int = 1

public var transport: TTransport
public var recursionDepth: Int = 0

// Temporary buffer used by several methods
private var tempBuffer: [UInt8] = [0,0,0,0]
Expand Down
15 changes: 14 additions & 1 deletion lib/swift/Sources/TProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ public enum TType: Int32 {
case uuid = 16
}

public protocol TProtocol {
public protocol TProtocol: AnyObject {
var transport: TTransport { get set }
var recursionDepth: Int { get set }
init(on transport: TTransport)
// Reading Methods

Expand Down Expand Up @@ -132,6 +133,18 @@ public extension TProtocol {
try writeMessageEnd()
}

func incrementRecursionDepth() throws {
recursionDepth += 1
if recursionDepth > 64 {
recursionDepth -= 1
throw TProtocolError(error: .depthLimit, message: "Maximum recursion depth exceeded")
}
}

func decrementRecursionDepth() {
recursionDepth -= 1
}

func skip(type: TType) throws {
try skip(type: type, depth: 0)
}
Expand Down
5 changes: 5 additions & 0 deletions lib/swift/Sources/TProtocolDecorator.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ class TProtocolDecorator: TProtocol {
private let proto: TProtocol
var transport: TTransport

var recursionDepth: Int {
get { proto.recursionDepth }
set { proto.recursionDepth = newValue }
}

init(proto: TProtocol) {
self.proto = proto
self.transport = proto.transport
Expand Down
2 changes: 2 additions & 0 deletions lib/swift/Sources/TStruct.swift
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ public extension TStruct {
static var thriftType: TType { return .struct }

func write(to proto: TProtocol) throws {
try proto.incrementRecursionDepth()
defer { proto.decrementRecursionDepth() }
// Write struct name first
try proto.writeStructBegin(name: Self.structName)

Expand Down
5 changes: 5 additions & 0 deletions lib/swift/Sources/TWrappedProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ open class TWrappedProtocol<Protocol: TProtocol> : TProtocol {
}
}

public var recursionDepth: Int {
get { concreteProtocol.recursionDepth }
set { concreteProtocol.recursionDepth = newValue }
}

public required init(on transport: TTransport) {
self.concreteProtocol = Protocol(on: transport)
}
Expand Down
1 change: 1 addition & 0 deletions lib/swift/Tests/LinuxMain.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ XCTMain([
testCase(ThriftTests.allTests),
testCase(TBinaryProtocolTests.allTests),
testCase(TCompactProtocolTests.allTests),
testCase(TRecursionDepthTests.allTests),
])
80 changes: 80 additions & 0 deletions lib/swift/Tests/ThriftTests/TRecursionDepthTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

import XCTest
@testable import Thrift

class TRecursionDepthTests: XCTestCase {

var transport: TMemoryBufferTransport!
var proto: TBinaryProtocol!

override func setUp() {
super.setUp()
transport = TMemoryBufferTransport(flushHandler: { $0.reset(readBuffer: $1) })
proto = TBinaryProtocol(on: transport)
}

func testInitialDepthIsZero() {
XCTAssertEqual(proto.recursionDepth, 0)
}

func testIncrementAllowsUpToLimit() {
XCTAssertNoThrow(try {
for _ in 0..<64 {
try self.proto.incrementRecursionDepth()
}
}())
XCTAssertEqual(proto.recursionDepth, 64)
}

func testIncrementThrowsAtLimitPlusOne() {
for _ in 0..<64 { try? proto.incrementRecursionDepth() }
XCTAssertThrowsError(try proto.incrementRecursionDepth()) { error in
guard let err = error as? TProtocolError else {
XCTFail("Expected TProtocolError, got \(error)")
return
}
if case .depthLimit = err.error { } else {
XCTFail("Expected .depthLimit, got \(err.error)")
}
}
}

func testDepthStaysAtLimitAfterFailedIncrement() {
for _ in 0..<64 { try? proto.incrementRecursionDepth() }
try? proto.incrementRecursionDepth()
XCTAssertEqual(proto.recursionDepth, 64)
}

func testDecrementRestoresCapacity() {
for _ in 0..<64 { try? proto.incrementRecursionDepth() }
proto.decrementRecursionDepth()
XCTAssertEqual(proto.recursionDepth, 63)
XCTAssertNoThrow(try proto.incrementRecursionDepth())
}

static var allTests = [
("testInitialDepthIsZero", testInitialDepthIsZero),
("testIncrementAllowsUpToLimit", testIncrementAllowsUpToLimit),
("testIncrementThrowsAtLimitPlusOne", testIncrementThrowsAtLimitPlusOne),
("testDepthStaysAtLimitAfterFailedIncrement", testDepthStaysAtLimitAfterFailedIncrement),
("testDecrementRestoresCapacity", testDecrementRestoresCapacity),
]
}
Loading