Skip to content

Commit

Permalink
Add App Check tests in GenerativeModelTests
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard committed Mar 19, 2024
1 parent c68e1e3 commit 9bf8216
Showing 1 changed file with 129 additions and 2 deletions.
131 changes: 129 additions & 2 deletions FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import FirebaseAppCheckInterop
import XCTest

@testable import FirebaseVertexAI
Expand Down Expand Up @@ -178,6 +179,43 @@ final class GenerativeModelTests: XCTestCase {
_ = try await model.generateContent(testPrompt)
}

func testGenerateContent_appCheck_validToken() async throws {
let appCheckToken = "test-valid-token"
model = GenerativeModel(
name: "my-model",
apiKey: "API_KEY",
requestOptions: RequestOptions(),
appCheck: AppCheckInteropFake(token: appCheckToken),
urlSession: urlSession
)
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "unary-success-basic-reply-short",
withExtension: "json",
appCheckToken: appCheckToken
)

_ = try await model.generateContent(testPrompt)
}

func testGenerateContent_appCheck_tokenRefreshError() async throws {
model = GenerativeModel(
name: "my-model",
apiKey: "API_KEY",
requestOptions: RequestOptions(),
appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
urlSession: urlSession
)
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "unary-success-basic-reply-short",
withExtension: "json",
appCheckToken: AppCheckInteropFake.placeholderTokenValue
)

_ = try await model.generateContent(testPrompt)
}

func testGenerateContent_failure_invalidAPIKey() async throws {
let expectedStatusCode = 400
MockURLProtocol
Expand Down Expand Up @@ -654,6 +692,45 @@ final class GenerativeModelTests: XCTestCase {
.contains(where: { $0.startIndex == 899 && $0.endIndex == 1026 && !$0.uri.isEmpty }))
}

func testGenerateContentStream_appCheck_validToken() async throws {
let appCheckToken = "test-valid-token"
model = GenerativeModel(
name: "my-model",
apiKey: "API_KEY",
requestOptions: RequestOptions(),
appCheck: AppCheckInteropFake(token: appCheckToken),
urlSession: urlSession
)
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "streaming-success-basic-reply-short",
withExtension: "txt",
appCheckToken: appCheckToken
)

let stream = model.generateContentStream(testPrompt)
for try await _ in stream {}
}

func testGenerateContentStream_appCheck_tokenRefreshError() async throws {
model = GenerativeModel(
name: "my-model",
apiKey: "API_KEY",
requestOptions: RequestOptions(),
appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
urlSession: urlSession
)
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "streaming-success-basic-reply-short",
withExtension: "txt",
appCheckToken: AppCheckInteropFake.placeholderTokenValue
)

let stream = model.generateContentStream(testPrompt)
for try await _ in stream {}
}

func testGenerateContentStream_errorMidStream() async throws {
MockURLProtocol.requestHandler = try httpRequestHandler(
forResource: "streaming-failure-error-mid-stream",
Expand Down Expand Up @@ -887,8 +964,8 @@ final class GenerativeModelTests: XCTestCase {
private func httpRequestHandler(forResource name: String,
withExtension ext: String,
statusCode: Int = 200,
timeout: TimeInterval = URLRequest
.defaultTimeoutInterval()) throws -> ((URLRequest) throws -> (
timeout: TimeInterval = URLRequest.defaultTimeoutInterval(),
appCheckToken: String? = nil) throws -> ((URLRequest) throws -> (
URLResponse,
AsyncLineSequence<URL.AsyncBytes>?
)) {
Expand All @@ -897,6 +974,7 @@ final class GenerativeModelTests: XCTestCase {
let requestURL = try XCTUnwrap(request.url)
XCTAssertEqual(requestURL.path.occurrenceCount(of: "models/"), 1)
XCTAssertEqual(request.timeoutInterval, timeout)
XCTAssertEqual(request.value(forHTTPHeaderField: "X-Firebase-AppCheck"), appCheckToken)
let response = try XCTUnwrap(HTTPURLResponse(
url: requestURL,
statusCode: statusCode,
Expand All @@ -922,3 +1000,52 @@ private extension URLRequest {
return URLRequest(url: placeholderURL).timeoutInterval
}
}

class AppCheckInteropFake: NSObject, AppCheckInterop {
/// The placeholder token value returned when an error occurs
static let placeholderTokenValue = "placeholder-token"

var token: String
var error: Error?

private init(token: String, error: Error?) {
self.token = token
self.error = error
}

convenience init(token: String) {
self.init(token: token, error: nil)
}

convenience init(error: Error) {
self.init(token: AppCheckInteropFake.placeholderTokenValue, error: error)
}

func getToken(forcingRefresh: Bool) async -> any FIRAppCheckTokenResultInterop {
return AppCheckTokenResultInteropFake(token: token, error: error)
}

func tokenDidChangeNotificationName() -> String {
fatalError("\(#function) not implemented.")
}

func notificationTokenKey() -> String {
fatalError("\(#function) not implemented.")
}

func notificationAppNameKey() -> String {
fatalError("\(#function) not implemented.")
}

private class AppCheckTokenResultInteropFake: NSObject, FIRAppCheckTokenResultInterop {
var token: String
var error: Error?

init(token: String, error: Error?) {
self.token = token
self.error = error
}
}
}

struct AppCheckErrorFake: Error {}

0 comments on commit 9bf8216

Please sign in to comment.