Skip to content

Commit

Permalink
[Vertex AI] Make uri optional in Citation and add title field (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored Aug 19, 2024
1 parent 22099aa commit 6148c13
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 17 deletions.
4 changes: 4 additions & 0 deletions FirebaseVertexAI/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Unreleased
- [fixed] Resolved a decoding error for citations without a `uri` and added
support for decoding `title` fields, which were previously ignored. (#13518)

# 10.29.0
- [feature] Added community support for watchOS. (#13215)

Expand Down
26 changes: 22 additions & 4 deletions FirebaseVertexAI/Sources/GenerateContentResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,11 @@ public struct Citation {
/// The exclusive end of a sequence in a model response that derives from a cited source.
public let endIndex: Int

/// A link to the cited source.
public let uri: String
/// A link to the cited source, if available.
public let uri: String?

/// The title of the cited source, if available.
public let title: String?

/// The license the cited source work is distributed under, if specified.
public let license: String?
Expand Down Expand Up @@ -303,15 +306,30 @@ extension Citation: Decodable {
case startIndex
case endIndex
case uri
case title
case license
}

public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
startIndex = try container.decodeIfPresent(Int.self, forKey: .startIndex) ?? 0
endIndex = try container.decode(Int.self, forKey: .endIndex)
uri = try container.decode(String.self, forKey: .uri)
license = try container.decodeIfPresent(String.self, forKey: .license)
if let uri = try container.decodeIfPresent(String.self, forKey: .uri), !uri.isEmpty {
self.uri = uri
} else {
uri = nil
}
if let title = try container.decodeIfPresent(String.self, forKey: .title), !title.isEmpty {
self.title = title
} else {
title = nil
}
if let license = try container.decodeIfPresent(String.self, forKey: .license),
!license.isEmpty {
self.license = license
} else {
license = nil
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions FirebaseVertexAI/Tests/Unit/ChatTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ final class ChatTests: XCTestCase {

// Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
// https://developer.apple.com/documentation/foundation/urlprotocol for details.
guard #unavailable(watchOS 2) else {
#if os(watchOS)
throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
}
#endif // os(watchOS)
MockURLProtocol.requestHandler = { request in
let response = HTTPURLResponse(
url: request.url!,
Expand Down
37 changes: 26 additions & 11 deletions FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -116,17 +116,20 @@ final class GenerativeModelTests: XCTestCase {
XCTAssertEqual(citationSource1.uri, "https://www.example.com/some-citation-1")
XCTAssertEqual(citationSource1.startIndex, 0)
XCTAssertEqual(citationSource1.endIndex, 128)
XCTAssertNil(citationSource1.title)
XCTAssertNil(citationSource1.license)
let citationSource2 = try XCTUnwrap(citationMetadata.citationSources[1])
XCTAssertEqual(citationSource2.uri, "https://www.example.com/some-citation-2")
XCTAssertEqual(citationSource2.title, "some-citation-2")
XCTAssertEqual(citationSource2.startIndex, 130)
XCTAssertEqual(citationSource2.endIndex, 265)
XCTAssertNil(citationSource2.uri)
XCTAssertNil(citationSource2.license)
let citationSource3 = try XCTUnwrap(citationMetadata.citationSources[2])
XCTAssertEqual(citationSource3.uri, "https://www.example.com/some-citation-3")
XCTAssertEqual(citationSource3.startIndex, 272)
XCTAssertEqual(citationSource3.endIndex, 431)
XCTAssertEqual(citationSource3.license, "mit")
XCTAssertNil(citationSource3.title)
}

func testGenerateContent_success_quoteReply() async throws {
Expand Down Expand Up @@ -951,13 +954,25 @@ final class GenerativeModelTests: XCTestCase {
XCTAssertEqual(lastCandidate.finishReason, .stop)
XCTAssertEqual(citations.count, 6)
XCTAssertTrue(citations
.contains(where: {
$0.startIndex == 574 && $0.endIndex == 705 && !$0.uri.isEmpty && $0.license == ""
}))
.contains {
$0.startIndex == 0 && $0.endIndex == 128
&& $0.uri == "https://www.example.com/some-citation-1" && $0.title == nil
&& $0.license == nil
})
XCTAssertTrue(citations
.contains(where: {
$0.startIndex == 899 && $0.endIndex == 1026 && !$0.uri.isEmpty && $0.license == ""
}))
.contains {
$0.startIndex == 130 && $0.endIndex == 265 && $0.uri == nil
&& $0.title == "some-citation-2" && $0.license == nil
})
XCTAssertTrue(citations
.contains {
$0.startIndex == 272 && $0.endIndex == 431
&& $0.uri == "https://www.example.com/some-citation-3" && $0.title == nil
&& $0.license == "mit"
})
XCTAssertFalse(citations.contains { $0.uri?.isEmpty ?? false })
XCTAssertFalse(citations.contains { $0.title?.isEmpty ?? false })
XCTAssertFalse(citations.contains { $0.license?.isEmpty ?? false })
}

func testGenerateContentStream_appCheck_validToken() async throws {
Expand Down Expand Up @@ -1283,9 +1298,9 @@ final class GenerativeModelTests: XCTestCase {
)) {
// Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
// https://developer.apple.com/documentation/foundation/urlprotocol for details.
guard #unavailable(watchOS 2) else {
#if os(watchOS)
throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
}
#endif // os(watchOS)
return { request in
// This is *not* an HTTPURLResponse
let response = URLResponse(
Expand All @@ -1309,9 +1324,9 @@ final class GenerativeModelTests: XCTestCase {
)) {
// Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see
// https://developer.apple.com/documentation/foundation/urlprotocol for details.
guard #unavailable(watchOS 2) else {
#if os(watchOS)
throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
}
#endif // os(watchOS)
let fileURL = try XCTUnwrap(Bundle.module.url(forResource: name, withExtension: ext))
return { request in
let requestURL = try XCTUnwrap(request.url)
Expand Down

0 comments on commit 6148c13

Please sign in to comment.