Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add App Check tests in GenerativeModelTests #12590

Merged
merged 1 commit into from
Mar 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 129 additions & 2 deletions FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import FirebaseAppCheckInterop
import XCTest

@testable import FirebaseVertexAI
Expand Down Expand Up @@ -178,6 +179,43 @@ final class GenerativeModelTests: XCTestCase {
_ = try await model.generateContent(testPrompt)
}

func testGenerateContent_appCheck_validToken() async throws {
let appCheckToken = "test-valid-token"
model = GenerativeModel(
name: "my-model",
apiKey: "API_KEY",
requestOptions: RequestOptions(),
appCheck: AppCheckInteropFake(token: appCheckToken),
urlSession: urlSession
)
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "unary-success-basic-reply-short",
withExtension: "json",
appCheckToken: appCheckToken
)

_ = try await model.generateContent(testPrompt)
}

func testGenerateContent_appCheck_tokenRefreshError() async throws {
model = GenerativeModel(
name: "my-model",
apiKey: "API_KEY",
requestOptions: RequestOptions(),
appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
urlSession: urlSession
)
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "unary-success-basic-reply-short",
withExtension: "json",
appCheckToken: AppCheckInteropFake.placeholderTokenValue
)

_ = try await model.generateContent(testPrompt)
}

func testGenerateContent_failure_invalidAPIKey() async throws {
let expectedStatusCode = 400
MockURLProtocol
Expand Down Expand Up @@ -654,6 +692,45 @@ final class GenerativeModelTests: XCTestCase {
.contains(where: { $0.startIndex == 899 && $0.endIndex == 1026 && !$0.uri.isEmpty }))
}

func testGenerateContentStream_appCheck_validToken() async throws {
let appCheckToken = "test-valid-token"
model = GenerativeModel(
name: "my-model",
apiKey: "API_KEY",
requestOptions: RequestOptions(),
appCheck: AppCheckInteropFake(token: appCheckToken),
urlSession: urlSession
)
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "streaming-success-basic-reply-short",
withExtension: "txt",
appCheckToken: appCheckToken
)

let stream = model.generateContentStream(testPrompt)
for try await _ in stream {}
}

func testGenerateContentStream_appCheck_tokenRefreshError() async throws {
model = GenerativeModel(
name: "my-model",
apiKey: "API_KEY",
requestOptions: RequestOptions(),
appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
urlSession: urlSession
)
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "streaming-success-basic-reply-short",
withExtension: "txt",
appCheckToken: AppCheckInteropFake.placeholderTokenValue
)

let stream = model.generateContentStream(testPrompt)
for try await _ in stream {}
}

func testGenerateContentStream_errorMidStream() async throws {
MockURLProtocol.requestHandler = try httpRequestHandler(
forResource: "streaming-failure-error-mid-stream",
Expand Down Expand Up @@ -887,8 +964,8 @@ final class GenerativeModelTests: XCTestCase {
private func httpRequestHandler(forResource name: String,
withExtension ext: String,
statusCode: Int = 200,
timeout: TimeInterval = URLRequest
.defaultTimeoutInterval()) throws -> ((URLRequest) throws -> (
timeout: TimeInterval = URLRequest.defaultTimeoutInterval(),
appCheckToken: String? = nil) throws -> ((URLRequest) throws -> (
URLResponse,
AsyncLineSequence<URL.AsyncBytes>?
)) {
Expand All @@ -897,6 +974,7 @@ final class GenerativeModelTests: XCTestCase {
let requestURL = try XCTUnwrap(request.url)
XCTAssertEqual(requestURL.path.occurrenceCount(of: "models/"), 1)
XCTAssertEqual(request.timeoutInterval, timeout)
XCTAssertEqual(request.value(forHTTPHeaderField: "X-Firebase-AppCheck"), appCheckToken)
let response = try XCTUnwrap(HTTPURLResponse(
url: requestURL,
statusCode: statusCode,
Expand All @@ -922,3 +1000,52 @@ private extension URLRequest {
return URLRequest(url: placeholderURL).timeoutInterval
}
}

class AppCheckInteropFake: NSObject, AppCheckInterop {
/// The placeholder token value returned when an error occurs
static let placeholderTokenValue = "placeholder-token"

var token: String
var error: Error?

private init(token: String, error: Error?) {
self.token = token
self.error = error
}

convenience init(token: String) {
self.init(token: token, error: nil)
}

convenience init(error: Error) {
self.init(token: AppCheckInteropFake.placeholderTokenValue, error: error)
}

func getToken(forcingRefresh: Bool) async -> any FIRAppCheckTokenResultInterop {
return AppCheckTokenResultInteropFake(token: token, error: error)
}

func tokenDidChangeNotificationName() -> String {
fatalError("\(#function) not implemented.")
}

func notificationTokenKey() -> String {
fatalError("\(#function) not implemented.")
}

func notificationAppNameKey() -> String {
fatalError("\(#function) not implemented.")
}

private class AppCheckTokenResultInteropFake: NSObject, FIRAppCheckTokenResultInterop {
var token: String
var error: Error?

init(token: String, error: Error?) {
self.token = token
self.error = error
}
}
}

struct AppCheckErrorFake: Error {}
Loading