From ff3519a72cf49818cf2ecfe855d2d5e59d3e9fd3 Mon Sep 17 00:00:00 2001 From: Gwynne Raskind Date: Tue, 25 Oct 2022 21:20:31 -0500 Subject: [PATCH] Revise MySQLDataEncoder implementation to better handle various Codable conformances (no more corrupted data when superEncoder() is used) --- Sources/MySQLKit/MySQLDataEncoder.swift | 192 +++++++----------------- Tests/MySQLKitTests/MySQLKitTests.swift | 53 +++++++ 2 files changed, 108 insertions(+), 137 deletions(-) diff --git a/Sources/MySQLKit/MySQLDataEncoder.swift b/Sources/MySQLKit/MySQLDataEncoder.swift index eea67aa0..86381b92 100644 --- a/Sources/MySQLKit/MySQLDataEncoder.swift +++ b/Sources/MySQLKit/MySQLDataEncoder.swift @@ -11,13 +11,26 @@ public struct MySQLDataEncoder { if let custom = value as? MySQLDataConvertible, let data = custom.mysqlData { return data } else { - let encoder = _Encoder() - try value.encode(to: encoder) - if let data = encoder.data { - return data - } else { + let encoder = _Encoder(parent: self) + do { + try value.encode(to: encoder) + if let value = encoder.value { + return value + } else { + throw _Encoder.NonScalarValueSentinel() + } + } catch is _Encoder.NonScalarValueSentinel { var buffer = ByteBufferAllocator().buffer(capacity: 0) - try buffer.writeBytes(self.json.encode(_Wrapper(value))) +#if swift(<5.7) + struct _Wrapper: Encodable { + let encodable: Encodable + init(_ encodable: Encodable) { self.encodable = encodable } + func encode(to encoder: Encoder) throws { try self.encodable.encode(to: encoder) } + } + try buffer.writeBytes(self.json.encode(_Wrapper(value))) // Swift <5.7 will complain that "Encodable does not conform to Encodable" without the wrapper +#else + try buffer.writeBytes(self.json.encode(value)) +#endif return MySQLData( type: .string, format: .text, @@ -27,145 +40,50 @@ public struct MySQLDataEncoder { } } } - + private final class _Encoder: Encoder { - var codingPath: [CodingKey] { - return [] - } - - var userInfo: [CodingUserInfoKey : Any] { - return [:] - } - var data: MySQLData? - init() { + struct NonScalarValueSentinel: Error {} - } - - func container(keyedBy type: Key.Type) -> KeyedEncodingContainer where Key : CodingKey { - return .init(_KeyedValueEncoder(self)) - } - - func unkeyedContainer() -> UnkeyedEncodingContainer { - _UnkeyedEncoder(self) - } + var userInfo: [CodingUserInfoKey : Any] { [:] }; var codingPath: [CodingKey] { [] } + var parent: MySQLDataEncoder, value: MySQLData? + init(parent: MySQLDataEncoder) { self.parent = parent } + func container(keyedBy: K.Type) -> KeyedEncodingContainer { .init(_FailingKeyedContainer()) } + func unkeyedContainer() -> UnkeyedEncodingContainer { _TaintedEncoder() } func singleValueContainer() -> SingleValueEncodingContainer { - _SingleValueEncoder(self) - } - } - - struct DoJSON: Error {} - - private struct _UnkeyedEncoder: UnkeyedEncodingContainer { - var codingPath: [CodingKey] { - self.encoder.codingPath - } - var count: Int { - 0 - } - - let encoder: _Encoder - init(_ encoder: _Encoder) { - self.encoder = encoder - } - - - mutating func encodeNil() throws { } - - mutating func encode(_ value: T) throws - where T : Encodable - { } - - mutating func nestedContainer( - keyedBy keyType: NestedKey.Type - ) -> KeyedEncodingContainer - where NestedKey : CodingKey - { - self.encoder.container(keyedBy: NestedKey.self) - } - - mutating func nestedUnkeyedContainer() -> UnkeyedEncodingContainer { - self.encoder.unkeyedContainer() - } - - mutating func superEncoder() -> Encoder { - self.encoder + precondition(self.value == nil, "Requested multiple containers from the same encoder.") + return _SingleValueContainer(encoder: self) } - } - - private struct _KeyedValueEncoder: KeyedEncodingContainerProtocol where Key: CodingKey { - var codingPath: [CodingKey] { - return self.encoder.codingPath - } - - let encoder: _Encoder - init(_ encoder: _Encoder) { - self.encoder = encoder - } - - mutating func encodeNil(forKey key: Key) throws { } - - mutating func encode(_ value: T, forKey key: Key) throws - where T : Encodable - { } - mutating func nestedContainer( - keyedBy keyType: NestedKey.Type, - forKey key: Key - ) -> KeyedEncodingContainer where NestedKey : CodingKey { - self.encoder.container(keyedBy: NestedKey.self) + struct _SingleValueContainer: SingleValueEncodingContainer { + let encoder: _Encoder; var codingPath: [CodingKey] { self.encoder.codingPath } + func encodeNil() throws { self.encoder.value = .null } + func encode(_ value: T) throws { self.encoder.value = try self.encoder.parent.encode(value) } } - mutating func nestedUnkeyedContainer(forKey key: Key) -> UnkeyedEncodingContainer { - self.encoder.unkeyedContainer() + /// This pair of types is only necessary because we can't directly throw an error from various Encoder and + /// encoding container methods. We define duplicate types rather than the old implementation's use of a + /// no-action keyed container because it can save a significant amount of time otherwise spent uselessly calling + /// nested methods in some cases. + struct _TaintedEncoder: Encoder, UnkeyedEncodingContainer, SingleValueEncodingContainer { + var userInfo: [CodingUserInfoKey : Any] { [:] }; var codingPath: [CodingKey] { [] }; var count: Int { 0 } + func container(keyedBy: K.Type) -> KeyedEncodingContainer { .init(_FailingKeyedContainer()) } + func nestedContainer(keyedBy: K.Type) -> KeyedEncodingContainer { .init(_FailingKeyedContainer()) } + func unkeyedContainer() -> UnkeyedEncodingContainer { self } + func nestedUnkeyedContainer() -> UnkeyedEncodingContainer { self } + func singleValueContainer() -> SingleValueEncodingContainer { self } + func superEncoder() -> Encoder { self } + func encodeNil() throws { throw NonScalarValueSentinel() } + func encode(_: T) throws { throw NonScalarValueSentinel() } + } + struct _FailingKeyedContainer: KeyedEncodingContainerProtocol { + var codingPath: [CodingKey] { [] } + func encodeNil(forKey: K) throws { throw NonScalarValueSentinel() } + func encode(_: T, forKey: K) throws { throw NonScalarValueSentinel() } + func nestedContainer(keyedBy: NK.Type, forKey: K) -> KeyedEncodingContainer { .init(_FailingKeyedContainer()) } + func nestedUnkeyedContainer(forKey: K) -> UnkeyedEncodingContainer { _TaintedEncoder() } + func superEncoder() -> Encoder { _TaintedEncoder() } + func superEncoder(forKey: K) -> Encoder { _TaintedEncoder() } } - - mutating func superEncoder() -> Encoder { - self.encoder - } - - mutating func superEncoder(forKey key: Key) -> Encoder { - self.encoder - } - } - - - private struct _SingleValueEncoder: SingleValueEncodingContainer { - var codingPath: [CodingKey] { - return self.encoder.codingPath - } - - let encoder: _Encoder - init(_ encoder: _Encoder) { - self.encoder = encoder - } - - mutating func encodeNil() throws { - self.encoder.data = MySQLData.null - } - - mutating func encode(_ value: T) throws where T : Encodable { - if let convertible = value as? MySQLDataConvertible { - guard let data = convertible.mysqlData else { - throw EncodingError.invalidValue(convertible, EncodingError.Context( - codingPath: self.codingPath, - debugDescription: "Could not encode \(T.self) to MySQL data: \(value)" - )) - } - self.encoder.data = data - } else { - try value.encode(to: self.encoder) - } - } - } -} - -struct _Wrapper: Encodable { - let encodable: Encodable - init(_ encodable: Encodable) { - self.encodable = encodable - } - func encode(to encoder: Encoder) throws { - try self.encodable.encode(to: encoder) } } diff --git a/Tests/MySQLKitTests/MySQLKitTests.swift b/Tests/MySQLKitTests/MySQLKitTests.swift index 59217670..213dec7f 100644 --- a/Tests/MySQLKitTests/MySQLKitTests.swift +++ b/Tests/MySQLKitTests/MySQLKitTests.swift @@ -48,6 +48,59 @@ class MySQLKitTests: XCTestCase { XCTAssertEqual(rows, [foo]) } + /// Tests dealing with encoding of values whose `encode(to:)` implementation calls one of the `superEncoder()` + /// methods (most notably the implementation of `Codable` for Fluent's `Fields`, which we can't directly test + /// at this layer). + func testValuesThatUseSuperEncoder() throws { + struct UnusualType: Codable { + var prop1: String, prop2: [Bool], prop3: [[Bool]] + + // This is intentionally contrived - Fluent's implementation does Codable this roundabout way as a + // workaround for the interaction of property wrappers with optional properties; it serves no purpose + // here other than to demonstrate that the encoder supports it. + private enum CodingKeys: String, CodingKey { case prop1, prop2, prop3 } + init(prop1: String, prop2: [Bool], prop3: [[Bool]]) { (self.prop1, self.prop2, self.prop3) = (prop1, prop2, prop3) } + init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + self.prop1 = try .init(from: container.superDecoder(forKey: .prop1)) + var acontainer = try container.nestedUnkeyedContainer(forKey: .prop2), ongoing: [Bool] = [] + while !acontainer.isAtEnd { ongoing.append(try Bool.init(from: acontainer.superDecoder())) } + self.prop2 = ongoing + var bcontainer = try container.nestedUnkeyedContainer(forKey: .prop3), bongoing: [[Bool]] = [] + while !bcontainer.isAtEnd { + var ccontainer = try bcontainer.nestedUnkeyedContainer(), congoing: [Bool] = [] + while !ccontainer.isAtEnd { congoing.append(try Bool.init(from: ccontainer.superDecoder())) } + bongoing.append(congoing) + } + self.prop3 = bongoing + } + func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try self.prop1.encode(to: container.superEncoder(forKey: .prop1)) + var acontainer = container.nestedUnkeyedContainer(forKey: .prop2) + for val in self.prop2 { try val.encode(to: acontainer.superEncoder()) } + var bcontainer = container.nestedUnkeyedContainer(forKey: .prop3) + for arr in self.prop3 { + var ccontainer = bcontainer.nestedUnkeyedContainer() + for val in arr { try val.encode(to: ccontainer.superEncoder()) } + } + } + } + + let instance = UnusualType(prop1: "hello", prop2: [true, false, false, true], prop3: [[true, true], [false], [true], []]) + let encoded1 = try MySQLDataEncoder().encode(instance) + let encoded2 = try MySQLDataEncoder().encode([instance, instance]) + + XCTAssertEqual(encoded1.type, .string) + XCTAssertEqual(encoded2.type, .string) + + let decoded1 = try MySQLDataDecoder().decode(UnusualType.self, from: encoded1) + let decoded2 = try MySQLDataDecoder().decode([UnusualType].self, from: encoded2) + + XCTAssertEqual(decoded1.prop3, instance.prop3) + XCTAssertEqual(decoded2.count, 2) + } + var sql: SQLDatabase { self.mysql.sql() }