diff --git a/Sources/_OpenAPIGeneratorCore/Translator/Multipart/MultipartContentInspector.swift b/Sources/_OpenAPIGeneratorCore/Translator/Multipart/MultipartContentInspector.swift index 1b4bdb2e..804fd3aa 100644 --- a/Sources/_OpenAPIGeneratorCore/Translator/Multipart/MultipartContentInspector.swift +++ b/Sources/_OpenAPIGeneratorCore/Translator/Multipart/MultipartContentInspector.swift @@ -257,31 +257,52 @@ extension FileTranslator { default: return .infer(.primitive) } } - let repetitionKind: MultipartPartInfo.RepetitionKind - let candidateSource: MultipartPartInfo.ContentTypeSource - switch try schema.dereferenced(in: components) { - case .null, .not: return nil - case .boolean, .number, .integer: - repetitionKind = .single - candidateSource = .infer(.primitive) - case .string(_, let context): - repetitionKind = .single - candidateSource = try inferStringContent(context) - case .object, .all, .one, .any, .fragment: - repetitionKind = .single - candidateSource = .infer(.complex) - case .array(_, let context): - repetitionKind = .array - if let items = context.items { - switch items { - case .null, .not: return nil - case .boolean, .number, .integer: candidateSource = .infer(.primitive) - case .string(_, let context): candidateSource = try inferStringContent(context) - case .object, .all, .one, .any, .fragment, .array: candidateSource = .infer(.complex) - } - } else { + func inferAllOfAnyOfOneOf(_ schemas: [DereferencedJSONSchema]) throws -> MultipartPartInfo.ContentTypeSource? { + // If all schemas are primitive, the allOf/anyOf/oneOf is also primitive. + // These cannot be binary, so only primitive vs complex. + for schema in schemas { + guard let (_, kind) = try inferSchema(schema) else { return nil } + guard case .infer(.primitive) = kind else { return kind } + } + return .infer(.primitive) + } + func inferSchema(_ schema: DereferencedJSONSchema) throws -> ( + MultipartPartInfo.RepetitionKind, MultipartPartInfo.ContentTypeSource + )? { + let repetitionKind: MultipartPartInfo.RepetitionKind + let candidateSource: MultipartPartInfo.ContentTypeSource + switch schema { + case .null, .not: return nil + case .boolean, .number, .integer: + repetitionKind = .single + candidateSource = .infer(.primitive) + case .string(_, let context): + repetitionKind = .single + candidateSource = try inferStringContent(context) + case .object, .fragment: + repetitionKind = .single candidateSource = .infer(.complex) + case .all(of: let schemas, _), .one(of: let schemas, _), .any(of: let schemas, _): + repetitionKind = .single + guard let value = try inferAllOfAnyOfOneOf(schemas) else { return nil } + candidateSource = value + case .array(_, let context): + repetitionKind = .array + if let items = context.items { + switch items { + case .null, .not: return nil + case .boolean, .number, .integer: candidateSource = .infer(.primitive) + case .string(_, let context): candidateSource = try inferStringContent(context) + case .object, .all, .one, .any, .fragment, .array: candidateSource = .infer(.complex) + } + } else { + candidateSource = .infer(.complex) + } } + return (repetitionKind, candidateSource) + } + guard let (repetitionKind, candidateSource) = try inferSchema(schema.dereferenced(in: components)) else { + return nil } let finalContentTypeSource: MultipartPartInfo.ContentTypeSource if let encoding, let contentType = encoding.contentType { @@ -301,9 +322,23 @@ extension FileTranslator { let resolvedSchema: JSONSchema if isOptional { resolvedSchema = baseSchema.optionalSchemaObject() } else { resolvedSchema = baseSchema } return (info, resolvedSchema) + } else if repetitionKind == .array { + let isOptional = try typeMatcher.isOptional(schema, components: components) + guard case .array(_, let context) = schema.value else { + preconditionFailure("Array repetition should always use an array schema.") + } + let elementSchema: JSONSchema = context.items ?? .fragment + let resolvedSchema: JSONSchema + if isOptional { + resolvedSchema = elementSchema.optionalSchemaObject() + } else { + resolvedSchema = elementSchema + } + return (info, resolvedSchema) } return (info, schema) } + /// Parses the names of component schemas used by multipart request and response bodies. /// /// The result is used to inform how a schema is generated. diff --git a/Tests/OpenAPIGeneratorCoreTests/Translator/Multipart/Test_MultipartContentInspector.swift b/Tests/OpenAPIGeneratorCoreTests/Translator/Multipart/Test_MultipartContentInspector.swift new file mode 100644 index 00000000..c388a550 --- /dev/null +++ b/Tests/OpenAPIGeneratorCoreTests/Translator/Multipart/Test_MultipartContentInspector.swift @@ -0,0 +1,76 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftOpenAPIGenerator open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftOpenAPIGenerator project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftOpenAPIGenerator project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +import XCTest +import OpenAPIKit +@testable import _OpenAPIGeneratorCore + +class Test_MultipartContentInspector: Test_Core { + func testSerializationStrategy() throws { + let translator = makeTypesTranslator() + func _test( + schemaIn: JSONSchema, + encoding: OpenAPI.Content.Encoding? = nil, + source: MultipartPartInfo.ContentTypeSource, + repetition: MultipartPartInfo.RepetitionKind, + schemaOut: JSONSchema, + file: StaticString = #file, + line: UInt = #line + ) throws { + let (info, actualSchemaOut) = try XCTUnwrap( + translator.parseMultipartPartInfo(schema: schemaIn, encoding: encoding, foundIn: "") + ) + XCTAssertEqual(info.repetition, repetition, file: file, line: line) + XCTAssertEqual(info.contentTypeSource, source, file: file, line: line) + XCTAssertEqual(actualSchemaOut, schemaOut, file: file, line: line) + } + try _test(schemaIn: .object, source: .infer(.complex), repetition: .single, schemaOut: .object) + try _test(schemaIn: .array(items: .object), source: .infer(.complex), repetition: .array, schemaOut: .object) + try _test( + schemaIn: .string, + source: .infer(.primitive), + repetition: .single, + schemaOut: .string(contentEncoding: .binary) + ) + try _test( + schemaIn: .integer, + source: .infer(.primitive), + repetition: .single, + schemaOut: .string(contentEncoding: .binary) + ) + try _test( + schemaIn: .boolean, + source: .infer(.primitive), + repetition: .single, + schemaOut: .string(contentEncoding: .binary) + ) + try _test( + schemaIn: .string(allowedValues: ["foo"]), + source: .infer(.primitive), + repetition: .single, + schemaOut: .string(contentEncoding: .binary) + ) + try _test( + schemaIn: .array(items: .string), + source: .infer(.primitive), + repetition: .array, + schemaOut: .string(contentEncoding: .binary) + ) + try _test( + schemaIn: .any(of: .string, .string(allowedValues: ["foo"])), + source: .infer(.primitive), + repetition: .single, + schemaOut: .string(contentEncoding: .binary) + ) + } +}