diff --git a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift index ac1980d93d4..13c625036cc 100644 --- a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift +++ b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +import FirebaseAppCheckInterop import XCTest @testable import FirebaseVertexAI @@ -178,6 +179,43 @@ final class GenerativeModelTests: XCTestCase { _ = try await model.generateContent(testPrompt) } + func testGenerateContent_appCheck_validToken() async throws { + let appCheckToken = "test-valid-token" + model = GenerativeModel( + name: "my-model", + apiKey: "API_KEY", + requestOptions: RequestOptions(), + appCheck: AppCheckInteropFake(token: appCheckToken), + urlSession: urlSession + ) + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "unary-success-basic-reply-short", + withExtension: "json", + appCheckToken: appCheckToken + ) + + _ = try await model.generateContent(testPrompt) + } + + func testGenerateContent_appCheck_tokenRefreshError() async throws { + model = GenerativeModel( + name: "my-model", + apiKey: "API_KEY", + requestOptions: RequestOptions(), + appCheck: AppCheckInteropFake(error: AppCheckErrorFake()), + urlSession: urlSession + ) + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "unary-success-basic-reply-short", + withExtension: "json", + appCheckToken: AppCheckInteropFake.placeholderTokenValue + ) + + _ = try await model.generateContent(testPrompt) + } + func testGenerateContent_failure_invalidAPIKey() async throws { let expectedStatusCode = 400 MockURLProtocol @@ -654,6 +692,45 @@ final class GenerativeModelTests: XCTestCase { .contains(where: { $0.startIndex == 899 && $0.endIndex == 1026 && !$0.uri.isEmpty })) } + func testGenerateContentStream_appCheck_validToken() async throws { + let appCheckToken = "test-valid-token" + model = GenerativeModel( + name: "my-model", + apiKey: "API_KEY", + requestOptions: RequestOptions(), + appCheck: AppCheckInteropFake(token: appCheckToken), + urlSession: urlSession + ) + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "streaming-success-basic-reply-short", + withExtension: "txt", + appCheckToken: appCheckToken + ) + + let stream = model.generateContentStream(testPrompt) + for try await _ in stream {} + } + + func testGenerateContentStream_appCheck_tokenRefreshError() async throws { + model = GenerativeModel( + name: "my-model", + apiKey: "API_KEY", + requestOptions: RequestOptions(), + appCheck: AppCheckInteropFake(error: AppCheckErrorFake()), + urlSession: urlSession + ) + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "streaming-success-basic-reply-short", + withExtension: "txt", + appCheckToken: AppCheckInteropFake.placeholderTokenValue + ) + + let stream = model.generateContentStream(testPrompt) + for try await _ in stream {} + } + func testGenerateContentStream_errorMidStream() async throws { MockURLProtocol.requestHandler = try httpRequestHandler( forResource: "streaming-failure-error-mid-stream", @@ -887,8 +964,8 @@ final class GenerativeModelTests: XCTestCase { private func httpRequestHandler(forResource name: String, withExtension ext: String, statusCode: Int = 200, - timeout: TimeInterval = URLRequest - .defaultTimeoutInterval()) throws -> ((URLRequest) throws -> ( + timeout: TimeInterval = URLRequest.defaultTimeoutInterval(), + appCheckToken: String? = nil) throws -> ((URLRequest) throws -> ( URLResponse, AsyncLineSequence? )) { @@ -897,6 +974,7 @@ final class GenerativeModelTests: XCTestCase { let requestURL = try XCTUnwrap(request.url) XCTAssertEqual(requestURL.path.occurrenceCount(of: "models/"), 1) XCTAssertEqual(request.timeoutInterval, timeout) + XCTAssertEqual(request.value(forHTTPHeaderField: "X-Firebase-AppCheck"), appCheckToken) let response = try XCTUnwrap(HTTPURLResponse( url: requestURL, statusCode: statusCode, @@ -922,3 +1000,52 @@ private extension URLRequest { return URLRequest(url: placeholderURL).timeoutInterval } } + +class AppCheckInteropFake: NSObject, AppCheckInterop { + /// The placeholder token value returned when an error occurs + static let placeholderTokenValue = "placeholder-token" + + var token: String + var error: Error? + + private init(token: String, error: Error?) { + self.token = token + self.error = error + } + + convenience init(token: String) { + self.init(token: token, error: nil) + } + + convenience init(error: Error) { + self.init(token: AppCheckInteropFake.placeholderTokenValue, error: error) + } + + func getToken(forcingRefresh: Bool) async -> any FIRAppCheckTokenResultInterop { + return AppCheckTokenResultInteropFake(token: token, error: error) + } + + func tokenDidChangeNotificationName() -> String { + fatalError("\(#function) not implemented.") + } + + func notificationTokenKey() -> String { + fatalError("\(#function) not implemented.") + } + + func notificationAppNameKey() -> String { + fatalError("\(#function) not implemented.") + } + + private class AppCheckTokenResultInteropFake: NSObject, FIRAppCheckTokenResultInterop { + var token: String + var error: Error? + + init(token: String, error: Error?) { + self.token = token + self.error = error + } + } +} + +struct AppCheckErrorFake: Error {}