Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 31 additions & 7 deletions Sources/Decoder/CodableCBORDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import Foundation
final public class CodableCBORDecoder {
public var useStringKeys: Bool = false
public var dateStrategy: DateStrategy = .taggedAsEpochTimestamp
public var maximumDepth: Int = .max

struct _Options {
let useStringKeys: Bool
Expand All @@ -29,7 +30,7 @@ final public class CodableCBORDecoder {
}

var options: _Options {
return _Options(useStringKeys: self.useStringKeys, dateStrategy: self.dateStrategy)
return _Options(useStringKeys: self.useStringKeys, dateStrategy: self.dateStrategy, maximumDepth: self.maximumDepth)
}

public init() {}
Expand Down Expand Up @@ -66,6 +67,7 @@ final public class CodableCBORDecoder {
func setOptions(_ newOptions: _Options) {
self.useStringKeys = newOptions.useStringKeys
self.dateStrategy = newOptions.dateStrategy
self.maximumDepth = newOptions.maximumDepth
}
}

Expand All @@ -78,34 +80,55 @@ final class _CBORDecoder {
fileprivate var data: ArraySlice<UInt8>

let options: CodableCBORDecoder._Options
var currentDepth: Int

init(data: ArraySlice<UInt8>, options: CodableCBORDecoder._Options) {
init(data: ArraySlice<UInt8>, options: CodableCBORDecoder._Options, currentDepth: Int = 0) {
self.data = data
self.options = options
self.currentDepth = currentDepth
}
}

extension _CBORDecoder: Decoder {
func container<Key: CodingKey>(keyedBy type: Key.Type) throws -> KeyedDecodingContainer<Key> {
guard self.currentDepth < self.options.maximumDepth else {
let context = DecodingError.Context(
codingPath: self.codingPath,
debugDescription: "Maximum decoding depth of \(self.options.maximumDepth) exceeded"
)
throw DecodingError.dataCorrupted(context)
}

try ensureMap(self.data.first, keyType: Key.self)

let container = KeyedContainer<Key>(data: self.data, codingPath: self.codingPath, userInfo: self.userInfo, options: self.options)
let container = KeyedContainer<Key>(data: self.data, codingPath: self.codingPath, userInfo: self.userInfo, options: self.options, currentDepth: self.currentDepth)
self.container = container

return KeyedDecodingContainer(container)
}

func unkeyedContainer() throws -> UnkeyedDecodingContainer {
guard self.currentDepth < self.options.maximumDepth else {
let context = DecodingError.Context(
codingPath: self.codingPath,
debugDescription: "Maximum decoding depth of \(self.options.maximumDepth) exceeded"
)
throw DecodingError.dataCorrupted(context)
}

try ensureArray(self.data.first)

let container = UnkeyedContainer(data: self.data, codingPath: self.codingPath, userInfo: self.userInfo, options: self.options)
// Check if this is a byte string (0x40-0x5f) being decoded as an array
let isByteString = (self.data.first ?? 0) >= 0x40 && (self.data.first ?? 0) <= 0x5f

let container = UnkeyedContainer(data: self.data, codingPath: self.codingPath, userInfo: self.userInfo, options: self.options, currentDepth: self.currentDepth, isByteString: isByteString)
self.container = container

return container
}

func singleValueContainer() throws -> SingleValueDecodingContainer {
let container = SingleValueContainer(data: self.data, codingPath: self.codingPath, userInfo: self.userInfo, options: self.options)
let container = SingleValueContainer(data: self.data, codingPath: self.codingPath, userInfo: self.userInfo, options: self.options, currentDepth: self.currentDepth)
self.container = container

return container
Expand Down Expand Up @@ -134,8 +157,9 @@ extension _CBORDecoder: Decoder {

func ensureArray(_ initialByte: UInt8?) throws {
switch initialByte {
case .some(0x80...0x9f):
// all good, continue
case .some(0x80...0x9f), .some(0x40...0x5f):
// all good, continue (arrays 0x80-0x9f and byte strings 0x40-0x5f)
// Byte strings can be decoded as arrays of UInt8
return
case nil:
let context = DecodingError.Context(
Expand Down
22 changes: 13 additions & 9 deletions Sources/Decoder/KeyedDecodingContainer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ extension _CBORDecoder {
var codingPath: [CodingKey]
var userInfo: [CodingUserInfoKey: Any]
let options: CodableCBORDecoder._Options
let currentDepth: Int

init(data: ArraySlice<UInt8>, codingPath: [CodingKey], userInfo: [CodingUserInfoKey : Any], options: CodableCBORDecoder._Options) {
init(data: ArraySlice<UInt8>, codingPath: [CodingKey], userInfo: [CodingUserInfoKey : Any], options: CodableCBORDecoder._Options, currentDepth: Int = 0) {
self.codingPath = codingPath
self.userInfo = userInfo
self.data = data
self.index = self.data.startIndex
self.options = options
self.currentDepth = currentDepth
}

func checkCanDecodeValue(forKey key: Key) throws {
Expand All @@ -41,7 +43,7 @@ extension _CBORDecoder {

var nestedContainers: [AnyCodingKey: CBORDecodingContainer] = [:]

let unkeyedContainer = UnkeyedContainer(data: self.data.suffix(from: self.index), codingPath: self.codingPath, userInfo: self.userInfo, options: self.options)
let unkeyedContainer = UnkeyedContainer(data: self.data.suffix(from: self.index), codingPath: self.codingPath, userInfo: self.userInfo, options: self.options, currentDepth: self.currentDepth)
unkeyedContainer.count = count * 2

var iterator = unkeyedContainer.nestedContainers.makeIterator()
Expand Down Expand Up @@ -94,7 +96,7 @@ extension _CBORDecoder {
// each key-value pair in the map.
let nextIndex = self.data.startIndex.advanced(by: 1)
let remainingData = self.data.suffix(from: nextIndex)
count = try? CBORDecoder(input: remainingData.map { $0 }).readPairsUntilBreak().keys.count
count = try? CBORDecoder(input: remainingData.map { $0 }, options: self.options.toCBOROptions()).readPairsUntilBreak().keys.count
default:
let context = DecodingError.Context(
codingPath: self.codingPath,
Expand Down Expand Up @@ -145,9 +147,10 @@ extension _CBORDecoder.KeyedContainer: KeyedDecodingContainerProtocol {
try checkCanDecodeValue(forKey: key)

let container = try self.nestedContainers()[anyCodingKeyForKey(key)]!
let decoder = CodableCBORDecoder()
decoder.setOptions(self.options)
return try decoder.decode(T.self, from: container.data)
let innerDecoder = _CBORDecoder(data: container.data, options: self.options, currentDepth: self.currentDepth + 1)
innerDecoder.codingPath = self.codingPath + [key]
innerDecoder.userInfo = self.userInfo
return try T(from: innerDecoder)
}

func nestedUnkeyedContainer(forKey key: Key) throws -> UnkeyedDecodingContainer {
Expand All @@ -170,17 +173,18 @@ extension _CBORDecoder.KeyedContainer: KeyedDecodingContainerProtocol {
data: anyCodingKeyedContainer.data,
codingPath: anyCodingKeyedContainer.codingPath,
userInfo: anyCodingKeyedContainer.userInfo,
options: anyCodingKeyedContainer.options
options: anyCodingKeyedContainer.options,
currentDepth: anyCodingKeyedContainer.currentDepth
)
return KeyedDecodingContainer(container)
}

func superDecoder() throws -> Decoder {
return _CBORDecoder(data: self.data, options: self.options)
return _CBORDecoder(data: self.data, options: self.options, currentDepth: self.currentDepth + 1)
}

func superDecoder(forKey key: Key) throws -> Decoder {
let decoder = _CBORDecoder(data: self.data, options: self.options)
let decoder = _CBORDecoder(data: self.data, options: self.options, currentDepth: self.currentDepth + 1)
decoder.codingPath = [key]

return decoder
Expand Down
39 changes: 22 additions & 17 deletions Sources/Decoder/SingleValueDecodingContainer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ extension _CBORDecoder {
var data: ArraySlice<UInt8>
var index: Data.Index
let options: CodableCBORDecoder._Options
let currentDepth: Int

init(data: ArraySlice<UInt8>, codingPath: [CodingKey], userInfo: [CodingUserInfoKey : Any], options: CodableCBORDecoder._Options) {
init(data: ArraySlice<UInt8>, codingPath: [CodingKey], userInfo: [CodingUserInfoKey : Any], options: CodableCBORDecoder._Options, currentDepth: Int = 0) {
self.codingPath = codingPath
self.userInfo = userInfo
self.data = data
self.index = self.data.startIndex
self.options = options
self.currentDepth = currentDepth
}

func checkCanDecode<T>(_ type: T.Type, format: UInt8) throws {
Expand All @@ -32,7 +34,7 @@ extension _CBORDecoder {

extension _CBORDecoder.SingleValueContainer: SingleValueDecodingContainer {
func decodeNil() -> Bool {
guard let cbor = try? CBOR.decode(self.data.map { $0 }) else {
guard let cbor = try? CBOR.decode(self.data.map { $0 }, options: self.options.toCBOROptions()) else {
return false
}
switch cbor {
Expand All @@ -42,7 +44,7 @@ extension _CBORDecoder.SingleValueContainer: SingleValueDecodingContainer {
}

func decode(_ type: Bool.Type) throws -> Bool {
guard let cbor = try? CBOR.decode(self.data.map { $0 }) else {
guard let cbor = try? CBOR.decode(self.data.map { $0 }, options: self.options.toCBOROptions()) else {
let context = DecodingError.Context(codingPath: self.codingPath, debugDescription: "Invalid format: \(self.data)")
throw DecodingError.dataCorrupted(context)
}
Expand All @@ -55,7 +57,7 @@ extension _CBORDecoder.SingleValueContainer: SingleValueDecodingContainer {
}

func decode(_ type: String.Type) throws -> String {
guard let cbor = try? CBOR.decode(self.data.map { $0 }) else {
guard let cbor = try? CBOR.decode(self.data.map { $0 }, options: self.options.toCBOROptions()) else {
let context = DecodingError.Context(codingPath: self.codingPath, debugDescription: "Invalid format: \(self.data)")
throw DecodingError.dataCorrupted(context)
}
Expand All @@ -68,22 +70,23 @@ extension _CBORDecoder.SingleValueContainer: SingleValueDecodingContainer {
}

func decode(_ type: Double.Type) throws -> Double {
guard let cbor = try? CBOR.decode(self.data.map { $0 }) else {
guard let cbor = try? CBOR.decode(self.data.map { $0 }, options: self.options.toCBOROptions()) else {
let context = DecodingError.Context(codingPath: self.codingPath, debugDescription: "Invalid format: \(self.data)")
throw DecodingError.dataCorrupted(context)
}
switch cbor {
case .double(let dbl): return dbl
case .float(let flt): return Double(flt)
case .half(let flt): return Double(flt)
case .date(let date): return date.timeIntervalSinceReferenceDate
default:
let context = DecodingError.Context(codingPath: self.codingPath, debugDescription: "Invalid format: \(self.data)")
throw DecodingError.typeMismatch(Double.self, context)
}
}

func decode(_ type: Float.Type) throws -> Float {
guard let cbor = try? CBOR.decode(self.data.map { $0 }) else {
guard let cbor = try? CBOR.decode(self.data.map { $0 }, options: self.options.toCBOROptions()) else {
let context = DecodingError.Context(codingPath: self.codingPath, debugDescription: "Invalid format: \(self.data)")
throw DecodingError.dataCorrupted(context)
}
Expand All @@ -97,7 +100,7 @@ extension _CBORDecoder.SingleValueContainer: SingleValueDecodingContainer {
}

func decode(_ type: Int.Type) throws -> Int {
guard let cbor = try? CBOR.decode(self.data.map { $0 }) else {
guard let cbor = try? CBOR.decode(self.data.map { $0 }, options: self.options.toCBOROptions()) else {
let context = DecodingError.Context(codingPath: self.codingPath, debugDescription: "Invalid format: \(self.data)")
throw DecodingError.dataCorrupted(context)
}
Expand All @@ -111,7 +114,7 @@ extension _CBORDecoder.SingleValueContainer: SingleValueDecodingContainer {
}

func decode(_ type: Int8.Type) throws -> Int8 {
guard let cbor = try? CBOR.decode(self.data.map { $0 }) else {
guard let cbor = try? CBOR.decode(self.data.map { $0 }, options: self.options.toCBOROptions()) else {
let context = DecodingError.Context(codingPath: self.codingPath, debugDescription: "Invalid format: \(self.data)")
throw DecodingError.dataCorrupted(context)
}
Expand All @@ -125,7 +128,7 @@ extension _CBORDecoder.SingleValueContainer: SingleValueDecodingContainer {
}

func decode(_ type: Int16.Type) throws -> Int16 {
guard let cbor = try? CBOR.decode(self.data.map { $0 }) else {
guard let cbor = try? CBOR.decode(self.data.map { $0 }, options: self.options.toCBOROptions()) else {
let context = DecodingError.Context(codingPath: self.codingPath, debugDescription: "Invalid format: \(self.data)")
throw DecodingError.dataCorrupted(context)
}
Expand All @@ -139,7 +142,7 @@ extension _CBORDecoder.SingleValueContainer: SingleValueDecodingContainer {
}

func decode(_ type: Int32.Type) throws -> Int32 {
guard let cbor = try? CBOR.decode(self.data.map { $0 }) else {
guard let cbor = try? CBOR.decode(self.data.map { $0 }, options: self.options.toCBOROptions()) else {
let context = DecodingError.Context(codingPath: self.codingPath, debugDescription: "Invalid format: \(self.data)")
throw DecodingError.dataCorrupted(context)
}
Expand All @@ -153,7 +156,7 @@ extension _CBORDecoder.SingleValueContainer: SingleValueDecodingContainer {
}

func decode(_ type: Int64.Type) throws -> Int64 {
guard let cbor = try? CBOR.decode(self.data.map { $0 }) else {
guard let cbor = try? CBOR.decode(self.data.map { $0 }, options: self.options.toCBOROptions()) else {
let context = DecodingError.Context(codingPath: self.codingPath, debugDescription: "Invalid format: \(self.data)")
throw DecodingError.dataCorrupted(context)
}
Expand All @@ -167,7 +170,7 @@ extension _CBORDecoder.SingleValueContainer: SingleValueDecodingContainer {
}

func decode(_ type: UInt.Type) throws -> UInt {
guard let cbor = try? CBOR.decode(self.data.map { $0 }) else {
guard let cbor = try? CBOR.decode(self.data.map { $0 }, options: self.options.toCBOROptions()) else {
let context = DecodingError.Context(codingPath: self.codingPath, debugDescription: "Invalid format: \(self.data)")
throw DecodingError.dataCorrupted(context)
}
Expand All @@ -180,7 +183,7 @@ extension _CBORDecoder.SingleValueContainer: SingleValueDecodingContainer {
}

func decode(_ type: UInt8.Type) throws -> UInt8 {
guard let cbor = try? CBOR.decode(self.data.map { $0 }) else {
guard let cbor = try? CBOR.decode(self.data.map { $0 }, options: self.options.toCBOROptions()) else {
let context = DecodingError.Context(codingPath: self.codingPath, debugDescription: "Invalid format: \(self.data)")
throw DecodingError.dataCorrupted(context)
}
Expand All @@ -193,7 +196,7 @@ extension _CBORDecoder.SingleValueContainer: SingleValueDecodingContainer {
}

func decode(_ type: UInt16.Type) throws -> UInt16 {
guard let cbor = try? CBOR.decode(self.data.map { $0 }) else {
guard let cbor = try? CBOR.decode(self.data.map { $0 }, options: self.options.toCBOROptions()) else {
let context = DecodingError.Context(codingPath: self.codingPath, debugDescription: "Invalid format: \(self.data)")
throw DecodingError.dataCorrupted(context)
}
Expand All @@ -206,7 +209,7 @@ extension _CBORDecoder.SingleValueContainer: SingleValueDecodingContainer {
}

func decode(_ type: UInt32.Type) throws -> UInt32 {
guard let cbor = try? CBOR.decode(self.data.map { $0 }) else {
guard let cbor = try? CBOR.decode(self.data.map { $0 }, options: self.options.toCBOROptions()) else {
let context = DecodingError.Context(codingPath: self.codingPath, debugDescription: "Invalid format: \(self.data)")
throw DecodingError.dataCorrupted(context)
}
Expand All @@ -219,7 +222,7 @@ extension _CBORDecoder.SingleValueContainer: SingleValueDecodingContainer {
}

func decode(_ type: UInt64.Type) throws -> UInt64 {
guard let cbor = try? CBOR.decode(self.data.map { $0 }) else {
guard let cbor = try? CBOR.decode(self.data.map { $0 }, options: self.options.toCBOROptions()) else {
let context = DecodingError.Context(codingPath: self.codingPath, debugDescription: "Invalid format: \(self.data)")
throw DecodingError.dataCorrupted(context)
}
Expand All @@ -232,7 +235,9 @@ extension _CBORDecoder.SingleValueContainer: SingleValueDecodingContainer {
}

func decode<T: Decodable>(_ type: T.Type) throws -> T {
let decoder = _CBORDecoder(data: self.data, options: self.options)
let decoder = _CBORDecoder(data: self.data, options: self.options, currentDepth: self.currentDepth + 1)
decoder.codingPath = self.codingPath
decoder.userInfo = self.userInfo
let value = try T(from: decoder)
if let nextIndex = decoder.container?.index {
self.index = nextIndex
Expand Down
Loading