Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use generative-ai-swift tests in Vertex AI #12585

Merged
merged 4 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions .github/workflows/vertexai.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ jobs:
- name: Xcode
run: sudo xcode-select -s /Applications/${{ matrix.xcode }}.app/Contents/Developer
- name: Initialize xcodebuild
run: xcodebuild -list
# TODO: Add unit tests and switch from `spmbuildonly` to `spm`.
- name: Build
run: scripts/third_party/travis/retry.sh scripts/build.sh FirebaseVertexAI ${{ matrix.target }} spmbuildonly
run: scripts/setup_spm_tests.sh
- name: Build and run tests
run: scripts/third_party/travis/retry.sh scripts/build.sh FirebaseVertexAIUnit ${{ matrix.target }} spm
11 changes: 9 additions & 2 deletions FirebaseVertexAI/Tests/Unit/ChatTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
// limitations under the License.

import Foundation
@testable import GoogleGenerativeAI
import XCTest

@testable import FirebaseVertexAI

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, *)
final class ChatTests: XCTestCase {
var urlSession: URLSession!
Expand Down Expand Up @@ -46,7 +47,13 @@ final class ChatTests: XCTestCase {
return (response, fileURL.lines)
}

let model = GenerativeModel(name: "my-model", apiKey: "API_KEY", urlSession: urlSession)
let model = GenerativeModel(
name: "my-model",
apiKey: "API_KEY",
requestOptions: RequestOptions(),
appCheck: nil,
andrewheard marked this conversation as resolved.
Show resolved Hide resolved
urlSession: urlSession
)
let chat = Chat(model: model, history: [])
let input = "Test input"
let stream = chat.sendMessageStream(input)
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"reason": "API_KEY_INVALID",
"domain": "googleapis.com",
"metadata": {
"service": "generativelanguage.googleapis.com"
"service": "staging-firebaseml.sandbox.googleapis.com"
}
},
{
Expand Down

This file was deleted.

This file was deleted.

90 changes: 42 additions & 48 deletions FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.

@testable import GoogleGenerativeAI
import XCTest

@testable import FirebaseVertexAI

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, *)
final class GenerativeModelTests: XCTestCase {
let testPrompt = "What sorts of questions can I ask you?"
Expand All @@ -32,7 +33,13 @@ final class GenerativeModelTests: XCTestCase {
let configuration = URLSessionConfiguration.default
configuration.protocolClasses = [MockURLProtocol.self]
urlSession = try XCTUnwrap(URLSession(configuration: configuration))
model = GenerativeModel(name: "my-model", apiKey: "API_KEY", urlSession: urlSession)
model = GenerativeModel(
name: "my-model",
apiKey: "API_KEY",
requestOptions: RequestOptions(),
appCheck: nil,
urlSession: urlSession
)
}

override func tearDown() {
Expand Down Expand Up @@ -163,6 +170,8 @@ final class GenerativeModelTests: XCTestCase {
// Model name is prefixed with "models/".
name: "models/test-model",
apiKey: "API_KEY",
requestOptions: RequestOptions(),
appCheck: nil,
urlSession: urlSession
)

Expand All @@ -181,10 +190,13 @@ final class GenerativeModelTests: XCTestCase {
do {
_ = try await model.generateContent(testPrompt)
XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
} catch let GenerateContentError.invalidAPIKey(message) {
XCTAssertEqual(message, "API key not valid. Please pass a valid API key.")
} catch let GenerateContentError.internalError(error as RPCError) {
XCTAssertEqual(error.httpResponseCode, 400)
XCTAssertEqual(error.status, .invalidArgument)
XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can Vertex have an invalid API key? Shouldn't there be a message about a GoogleService-Info.plist?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would only happen if the API key in the GoogleService-Info.plist is invalid (e.g., if the key was revoked after the app was shipped). The message/code/status checked here are the ones returned by the backend. I don't think we have specific error cases for revoked/invalid API keys in our other SDKs (everything would be broken in that scenario), which was why I removed the case that was in https://github.com/google/generative-ai-swift/blob/48a0c2f11935a17132492583d4aad51c5a407bcb/Sources/GoogleAI/GenerateContentError.swift#L32-L33 but we could bring it back.

return
} catch {
XCTFail("Should throw GenerateContentError.invalidAPIKey; error thrown: \(error)")
XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
}
}

Expand Down Expand Up @@ -342,24 +354,6 @@ final class GenerativeModelTests: XCTestCase {
}
}

func testGenerateContent_failure_unsupportedUserLocation() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "unary-failure-unsupported-user-location",
withExtension: "json",
statusCode: 400
)

do {
_ = try await model.generateContent(testPrompt)
XCTFail("Should throw GenerateContentError.unsupportedUserLocation; no error thrown.")
} catch GenerateContentError.unsupportedUserLocation {
return
}

XCTFail("Expected an unsupported user location error.")
}

func testGenerateContent_failure_nonHTTPResponse() async throws {
MockURLProtocol.requestHandler = try nonHTTPRequestHandler()

Expand Down Expand Up @@ -468,6 +462,7 @@ final class GenerativeModelTests: XCTestCase {
name: "my-model",
apiKey: "API_KEY",
requestOptions: requestOptions,
appCheck: nil,
urlSession: urlSession
)

Expand All @@ -490,8 +485,10 @@ final class GenerativeModelTests: XCTestCase {
for try await _ in stream {
XCTFail("No content is there, this shouldn't happen.")
}
} catch GenerateContentError.invalidAPIKey {
// invalidAPIKey error is as expected, nothing else to check.
} catch let GenerateContentError.internalError(error as RPCError) {
XCTAssertEqual(error.httpResponseCode, 400)
XCTAssertEqual(error.status, .invalidArgument)
XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
return
}

Expand Down Expand Up @@ -747,26 +744,6 @@ final class GenerativeModelTests: XCTestCase {
XCTFail("Expected an internal decoding error.")
}

func testGenerateContentStream_failure_unsupportedUserLocation() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "unary-failure-unsupported-user-location",
withExtension: "json",
statusCode: 400
)

let stream = model.generateContentStream(testPrompt)
do {
for try await content in stream {
XCTFail("Unexpected content in stream: \(content)")
}
} catch GenerateContentError.unsupportedUserLocation {
return
}

XCTFail("Expected an unsupported user location error.")
}

func testGenerateContentStream_requestOptions_customTimeout() async throws {
let expectedTimeout = 150.0
MockURLProtocol
Expand All @@ -780,6 +757,7 @@ final class GenerativeModelTests: XCTestCase {
name: "my-model",
apiKey: "API_KEY",
requestOptions: requestOptions,
appCheck: nil,
urlSession: urlSession
)

Expand Down Expand Up @@ -837,6 +815,7 @@ final class GenerativeModelTests: XCTestCase {
name: "my-model",
apiKey: "API_KEY",
requestOptions: requestOptions,
appCheck: nil,
urlSession: urlSession
)

Expand All @@ -851,23 +830,38 @@ final class GenerativeModelTests: XCTestCase {
let modelName = "my-model"
let modelResourceName = "models/\(modelName)"

model = GenerativeModel(name: modelName, apiKey: "API_KEY")
model = GenerativeModel(
name: modelName,
apiKey: "API_KEY",
requestOptions: RequestOptions(),
appCheck: nil
)

XCTAssertEqual(model.modelResourceName, modelResourceName)
}

func testModelResourceName_modelsPrefix() async throws {
let modelResourceName = "models/my-model"

model = GenerativeModel(name: modelResourceName, apiKey: "API_KEY")
model = GenerativeModel(
name: modelResourceName,
apiKey: "API_KEY",
requestOptions: RequestOptions(),
appCheck: nil
)

XCTAssertEqual(model.modelResourceName, modelResourceName)
}

func testModelResourceName_tunedModelsPrefix() async throws {
let tunedModelResourceName = "tunedModels/my-model"

model = GenerativeModel(name: tunedModelResourceName, apiKey: "API_KEY")
model = GenerativeModel(
name: tunedModelResourceName,
apiKey: "API_KEY",
requestOptions: RequestOptions(),
appCheck: nil
)

XCTAssertEqual(model.modelResourceName, tunedModelResourceName)
}
Expand Down
2 changes: 1 addition & 1 deletion FirebaseVertexAI/Tests/Unit/PartsRepresentableTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import CoreGraphics
import CoreImage
import GoogleGenerativeAI
import FirebaseVertexAI
import XCTest
#if canImport(UIKit)
import UIKit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import GoogleGenerativeAI
import FirebaseCore
import FirebaseVertexAI
import XCTest
#if canImport(AppKit)
import AppKit // For NSImage extensions.
Expand All @@ -21,8 +22,9 @@ import XCTest
#endif

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
final class GoogleGenerativeAITests: XCTestCase {
final class VertexAIAPITests: XCTestCase {
func codeSamples() async throws {
let app = FirebaseApp.app()
let config = GenerationConfig(temperature: 0.2,
topP: 0.1,
topK: 16,
Expand All @@ -32,16 +34,40 @@ final class GoogleGenerativeAITests: XCTestCase {
let filters = [SafetySetting(harmCategory: .dangerousContent, threshold: .blockOnlyHigh)]

// Permutations without optional arguments.
let _ = GenerativeModel(name: "gemini-1.0-pro", apiKey: "API_KEY")
let _ = GenerativeModel(name: "gemini-1.0-pro", apiKey: "API_KEY", safetySettings: filters)
let _ = GenerativeModel(name: "gemini-1.0-pro", apiKey: "API_KEY", generationConfig: config)

// All arguments passed.
let genAI = GenerativeModel(name: "gemini-1.0-pro",
apiKey: "API_KEY",
generationConfig: config, // Optional
safetySettings: filters // Optional
// TODO: Change `genAI` to `_` when safetySettings and generationConfig are added to public API.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's preventing them from being added?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll just need to check those parameters (and requestOptions) when determining if it's the same instance in VertexAIComponent:

// MARK: - VertexAIProvider conformance
func vertexAI(location: String, modelResourceName: String) -> VertexAI {
os_unfair_lock_lock(&instancesLock)
// Unlock before the function returns.
defer { os_unfair_lock_unlock(&instancesLock) }
if let instance = instances[modelResourceName] {
return instance
}
let newInstance = VertexAI(app: app, location: location, modelResourceName: modelResourceName)
instances[modelResourceName] = newInstance
return newInstance
}

Note: Those parameters are currently all structs so I can't use them with @objc in their current form (might be able to use the [String:Any] trick that Morgan mentioned).

let genAI = VertexAI.generativeModel(modelName: "gemini-1.0-pro", location: "us-central1")
let _ = VertexAI.generativeModel(
app: app!,
modelName: "gemini-1.0-pro",
location: "us-central1"
)

// TODO: Add safetySettings to public API.
// TODO: Add permutation with `app` specified.
// let _ = VertexAI.generativeModel(
// modelName: "gemini-1.0-pro",
// location: "us-central1",
// safetySettings: filters
// )
// TODO: Add generationConfig to public API.
// TODO: Add permutation with `app` specified.
// let _ = VertexAI.generativeModel(
// modelName: "gemini-1.0-pro",
// location: "us-central1",
// generationConfig: config
// )

// All arguments passed.
// TODO: Add safetySettings and generationConfig to public API.
// TODO: Add permutation with `app` specified.
// let genAI = VertexAI.generativeModel(
// modelName: "gemini-1.0-pro",
// location: "us-central1",
// generationConfig: config, // Optional
// safetySettings: filters // Optional
// )

// Full Typed Usage
let pngData = Data() // ....
let contents = [ModelContent(role: "user",
Expand Down
9 changes: 9 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1362,6 +1362,15 @@ let package = Package(
],
path: "FirebaseVertexAI/Sources"
),
.testTarget(
name: "FirebaseVertexAIUnit",
dependencies: ["FirebaseVertexAI"],
path: "FirebaseVertexAI/Tests/Unit",
resources: [
.process("CountTokenResponses"),
.process("GenerateContentResponses"),
]
),
] + firestoreTargets(),
cLanguageStandard: .c99,
cxxLanguageStandard: CXXLanguageStandard.gnucxx14
Expand Down
Loading
Loading