From 6148c13f2bf8c78f6d2c3de52da9c62a7c912164 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Mon, 19 Aug 2024 17:33:15 -0400 Subject: [PATCH] [Vertex AI] Make `uri` optional in `Citation` and add `title` field (#13520) --- FirebaseVertexAI/CHANGELOG.md | 4 ++ .../Sources/GenerateContentResponse.swift | 26 +++++++++++-- FirebaseVertexAI/Tests/Unit/ChatTests.swift | 4 +- .../Tests/Unit/GenerativeModelTests.swift | 37 +++++++++++++------ 4 files changed, 54 insertions(+), 17 deletions(-) diff --git a/FirebaseVertexAI/CHANGELOG.md b/FirebaseVertexAI/CHANGELOG.md index 1e15a43755e..6c90c0f6825 100644 --- a/FirebaseVertexAI/CHANGELOG.md +++ b/FirebaseVertexAI/CHANGELOG.md @@ -1,3 +1,7 @@ +# Unreleased +- [fixed] Resolved a decoding error for citations without a `uri` and added + support for decoding `title` fields, which were previously ignored. (#13518) + # 10.29.0 - [feature] Added community support for watchOS. (#13215) diff --git a/FirebaseVertexAI/Sources/GenerateContentResponse.swift b/FirebaseVertexAI/Sources/GenerateContentResponse.swift index 78273830915..631eb228575 100644 --- a/FirebaseVertexAI/Sources/GenerateContentResponse.swift +++ b/FirebaseVertexAI/Sources/GenerateContentResponse.swift @@ -125,8 +125,11 @@ public struct Citation { /// The exclusive end of a sequence in a model response that derives from a cited source. public let endIndex: Int - /// A link to the cited source. - public let uri: String + /// A link to the cited source, if available. + public let uri: String? + + /// The title of the cited source, if available. + public let title: String? /// The license the cited source work is distributed under, if specified. public let license: String? @@ -303,6 +306,7 @@ extension Citation: Decodable { case startIndex case endIndex case uri + case title case license } @@ -310,8 +314,22 @@ extension Citation: Decodable { let container = try decoder.container(keyedBy: CodingKeys.self) startIndex = try container.decodeIfPresent(Int.self, forKey: .startIndex) ?? 0 endIndex = try container.decode(Int.self, forKey: .endIndex) - uri = try container.decode(String.self, forKey: .uri) - license = try container.decodeIfPresent(String.self, forKey: .license) + if let uri = try container.decodeIfPresent(String.self, forKey: .uri), !uri.isEmpty { + self.uri = uri + } else { + uri = nil + } + if let title = try container.decodeIfPresent(String.self, forKey: .title), !title.isEmpty { + self.title = title + } else { + title = nil + } + if let license = try container.decodeIfPresent(String.self, forKey: .license), + !license.isEmpty { + self.license = license + } else { + license = nil + } } } diff --git a/FirebaseVertexAI/Tests/Unit/ChatTests.swift b/FirebaseVertexAI/Tests/Unit/ChatTests.swift index 48aca6786c2..389fcec1c5f 100644 --- a/FirebaseVertexAI/Tests/Unit/ChatTests.swift +++ b/FirebaseVertexAI/Tests/Unit/ChatTests.swift @@ -39,9 +39,9 @@ final class ChatTests: XCTestCase { // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see // https://developer.apple.com/documentation/foundation/urlprotocol for details. - guard #unavailable(watchOS 2) else { + #if os(watchOS) throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.") - } + #endif // os(watchOS) MockURLProtocol.requestHandler = { request in let response = HTTPURLResponse( url: request.url!, diff --git a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift index 10f0f84fedb..9acd27e187f 100644 --- a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift +++ b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift @@ -116,17 +116,20 @@ final class GenerativeModelTests: XCTestCase { XCTAssertEqual(citationSource1.uri, "https://www.example.com/some-citation-1") XCTAssertEqual(citationSource1.startIndex, 0) XCTAssertEqual(citationSource1.endIndex, 128) + XCTAssertNil(citationSource1.title) XCTAssertNil(citationSource1.license) let citationSource2 = try XCTUnwrap(citationMetadata.citationSources[1]) - XCTAssertEqual(citationSource2.uri, "https://www.example.com/some-citation-2") + XCTAssertEqual(citationSource2.title, "some-citation-2") XCTAssertEqual(citationSource2.startIndex, 130) XCTAssertEqual(citationSource2.endIndex, 265) + XCTAssertNil(citationSource2.uri) XCTAssertNil(citationSource2.license) let citationSource3 = try XCTUnwrap(citationMetadata.citationSources[2]) XCTAssertEqual(citationSource3.uri, "https://www.example.com/some-citation-3") XCTAssertEqual(citationSource3.startIndex, 272) XCTAssertEqual(citationSource3.endIndex, 431) XCTAssertEqual(citationSource3.license, "mit") + XCTAssertNil(citationSource3.title) } func testGenerateContent_success_quoteReply() async throws { @@ -951,13 +954,25 @@ final class GenerativeModelTests: XCTestCase { XCTAssertEqual(lastCandidate.finishReason, .stop) XCTAssertEqual(citations.count, 6) XCTAssertTrue(citations - .contains(where: { - $0.startIndex == 574 && $0.endIndex == 705 && !$0.uri.isEmpty && $0.license == "" - })) + .contains { + $0.startIndex == 0 && $0.endIndex == 128 + && $0.uri == "https://www.example.com/some-citation-1" && $0.title == nil + && $0.license == nil + }) XCTAssertTrue(citations - .contains(where: { - $0.startIndex == 899 && $0.endIndex == 1026 && !$0.uri.isEmpty && $0.license == "" - })) + .contains { + $0.startIndex == 130 && $0.endIndex == 265 && $0.uri == nil + && $0.title == "some-citation-2" && $0.license == nil + }) + XCTAssertTrue(citations + .contains { + $0.startIndex == 272 && $0.endIndex == 431 + && $0.uri == "https://www.example.com/some-citation-3" && $0.title == nil + && $0.license == "mit" + }) + XCTAssertFalse(citations.contains { $0.uri?.isEmpty ?? false }) + XCTAssertFalse(citations.contains { $0.title?.isEmpty ?? false }) + XCTAssertFalse(citations.contains { $0.license?.isEmpty ?? false }) } func testGenerateContentStream_appCheck_validToken() async throws { @@ -1283,9 +1298,9 @@ final class GenerativeModelTests: XCTestCase { )) { // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see // https://developer.apple.com/documentation/foundation/urlprotocol for details. - guard #unavailable(watchOS 2) else { + #if os(watchOS) throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.") - } + #endif // os(watchOS) return { request in // This is *not* an HTTPURLResponse let response = URLResponse( @@ -1309,9 +1324,9 @@ final class GenerativeModelTests: XCTestCase { )) { // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see // https://developer.apple.com/documentation/foundation/urlprotocol for details. - guard #unavailable(watchOS 2) else { + #if os(watchOS) throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.") - } + #endif // os(watchOS) let fileURL = try XCTUnwrap(Bundle.module.url(forResource: name, withExtension: ext)) return { request in let requestURL = try XCTUnwrap(request.url)