-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fork
google-generative-ai
for Firebase
- Loading branch information
1 parent
7598e89
commit 0c5cd17
Showing
17 changed files
with
2,067 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
// Copyright 2023 Google LLC | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
import Foundation | ||
|
||
/// An object that represents a back-and-forth chat with a model, capturing the history and saving | ||
/// the context in memory between each message sent. | ||
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) | ||
public class Chat { | ||
private let model: GenerativeModel | ||
|
||
/// Initializes a new chat representing a 1:1 conversation between model and user. | ||
init(model: GenerativeModel, history: [ModelContent]) { | ||
self.model = model | ||
self.history = history | ||
} | ||
|
||
/// The previous content from the chat that has been successfully sent and received from the | ||
/// model. This will be provided to the model for each message sent as context for the discussion. | ||
public var history: [ModelContent] | ||
|
||
/// See ``sendMessage(_:)-3ify5``. | ||
public func sendMessage(_ parts: any ThrowingPartsRepresentable...) async throws | ||
-> GenerateContentResponse { | ||
return try await sendMessage([ModelContent(parts: parts)]) | ||
} | ||
|
||
/// Sends a message using the existing history of this chat as context. If successful, the message | ||
/// and response will be added to the history. If unsuccessful, history will remain unchanged. | ||
/// - Parameter content: The new content to send as a single chat message. | ||
/// - Returns: The model's response if no error occurred. | ||
/// - Throws: A ``GenerateContentError`` if an error occurred. | ||
public func sendMessage(_ content: @autoclosure () throws -> [ModelContent]) async throws | ||
-> GenerateContentResponse { | ||
// Ensure that the new content has the role set. | ||
let newContent: [ModelContent] | ||
do { | ||
newContent = try content().map(populateContentRole(_:)) | ||
} catch let underlying { | ||
if let contentError = underlying as? ImageConversionError { | ||
throw GenerateContentError.promptImageContentError(underlying: contentError) | ||
} else { | ||
throw GenerateContentError.internalError(underlying: underlying) | ||
} | ||
} | ||
|
||
// Send the history alongside the new message as context. | ||
let request = history + newContent | ||
let result = try await model.generateContent(request) | ||
guard let reply = result.candidates.first?.content else { | ||
let error = NSError(domain: "com.google.generative-ai", | ||
code: -1, | ||
userInfo: [ | ||
NSLocalizedDescriptionKey: "No candidates with content available.", | ||
]) | ||
throw GenerateContentError.internalError(underlying: error) | ||
} | ||
|
||
// Make sure we inject the role into the content received. | ||
let toAdd = ModelContent(role: "model", parts: reply.parts) | ||
|
||
// Append the request and successful result to history, then return the value. | ||
history.append(contentsOf: newContent) | ||
history.append(toAdd) | ||
return result | ||
} | ||
|
||
/// See ``sendMessageStream(_:)-4abs3``. | ||
@available(macOS 12.0, *) | ||
public func sendMessageStream(_ parts: any ThrowingPartsRepresentable...) | ||
-> AsyncThrowingStream<GenerateContentResponse, Error> { | ||
return try sendMessageStream([ModelContent(parts: parts)]) | ||
} | ||
|
||
/// Sends a message using the existing history of this chat as context. If successful, the message | ||
/// and response will be added to the history. If unsuccessful, history will remain unchanged. | ||
/// - Parameter content: The new content to send as a single chat message. | ||
/// - Returns: A stream containing the model's response or an error if an error occurred. | ||
@available(macOS 12.0, *) | ||
public func sendMessageStream(_ content: @autoclosure () throws -> [ModelContent]) | ||
-> AsyncThrowingStream<GenerateContentResponse, Error> { | ||
let resolvedContent: [ModelContent] | ||
do { | ||
resolvedContent = try content() | ||
} catch let underlying { | ||
return AsyncThrowingStream { continuation in | ||
let error: Error | ||
if let contentError = underlying as? ImageConversionError { | ||
error = GenerateContentError.promptImageContentError(underlying: contentError) | ||
} else { | ||
error = GenerateContentError.internalError(underlying: underlying) | ||
} | ||
continuation.finish(throwing: error) | ||
} | ||
} | ||
|
||
return AsyncThrowingStream { continuation in | ||
Task { | ||
var aggregatedContent: [ModelContent] = [] | ||
|
||
// Ensure that the new content has the role set. | ||
let newContent: [ModelContent] = resolvedContent.map(populateContentRole(_:)) | ||
|
||
// Send the history alongside the new message as context. | ||
let request = history + newContent | ||
let stream = model.generateContentStream(request) | ||
do { | ||
for try await chunk in stream { | ||
// Capture any content that's streaming. This should be populated if there's no error. | ||
if let chunkContent = chunk.candidates.first?.content { | ||
aggregatedContent.append(chunkContent) | ||
} | ||
|
||
// Pass along the chunk. | ||
continuation.yield(chunk) | ||
} | ||
} catch { | ||
// Rethrow the error that the underlying stream threw. Don't add anything to history. | ||
continuation.finish(throwing: error) | ||
return | ||
} | ||
|
||
// Save the request. | ||
history.append(contentsOf: newContent) | ||
|
||
// Aggregate the content to add it to the history before we finish. | ||
let aggregated = aggregatedChunks(aggregatedContent) | ||
history.append(aggregated) | ||
|
||
continuation.finish() | ||
} | ||
} | ||
} | ||
|
||
private func aggregatedChunks(_ chunks: [ModelContent]) -> ModelContent { | ||
var parts: [ModelContent.Part] = [] | ||
var combinedText = "" | ||
for aggregate in chunks { | ||
// Loop through all the parts, aggregating the text and adding the images. | ||
for part in aggregate.parts { | ||
switch part { | ||
case let .text(str): | ||
combinedText += str | ||
|
||
case .data(mimetype: _, _): | ||
// Don't combine it, just add to the content. If there's any text pending, add that as | ||
// a part. | ||
if !combinedText.isEmpty { | ||
parts.append(.text(combinedText)) | ||
combinedText = "" | ||
} | ||
|
||
parts.append(part) | ||
} | ||
} | ||
} | ||
|
||
if !combinedText.isEmpty { | ||
parts.append(.text(combinedText)) | ||
} | ||
|
||
return ModelContent(role: "model", parts: parts) | ||
} | ||
|
||
/// Populates the `role` field with `user` if it doesn't exist. Required in chat sessions. | ||
private func populateContentRole(_ content: ModelContent) -> ModelContent { | ||
if content.role != nil { | ||
return content | ||
} else { | ||
return ModelContent(role: "user", parts: content.parts) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
// Copyright 2023 Google LLC | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
import Foundation | ||
|
||
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) | ||
struct CountTokensRequest { | ||
let model: String | ||
let contents: [ModelContent] | ||
let options: RequestOptions | ||
} | ||
|
||
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) | ||
extension CountTokensRequest: Encodable { | ||
enum CodingKeys: CodingKey { | ||
case contents | ||
} | ||
} | ||
|
||
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) | ||
extension CountTokensRequest: GenerativeAIRequest { | ||
typealias Response = CountTokensResponse | ||
|
||
var url: URL { | ||
URL(string: "\(GenerativeAISwift.baseURL)/\(options.apiVersion)/\(model):countTokens")! | ||
} | ||
} | ||
|
||
/// The model's response to a count tokens request. | ||
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) | ||
public struct CountTokensResponse: Decodable { | ||
/// The total number of tokens in the input given to the model as a prompt. | ||
public let totalTokens: Int | ||
} |
Oops, something went wrong.