Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: [vertexai] make client getters in VertexAI private #10550

Merged
merged 1 commit into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.google.cloud.vertexai;

import com.google.api.core.InternalApi;
import com.google.api.gax.core.CredentialsProvider;
import com.google.api.gax.core.FixedCredentialsProvider;
import com.google.api.gax.core.GaxProperties;
Expand Down Expand Up @@ -220,11 +221,30 @@ public void setApiEndpoint(String apiEndpoint) {
}
}

/**
* Returns the {@link PredictionServiceClient} with GRPC or REST, based on the Transport type. The
* client will be instantiated when the first prediction API call is made.
*
* @return {@link PredictionServiceClient} that send requests to the backing service through
* method calls that map to the API methods.
*/
@InternalApi
public PredictionServiceClient getPredictionServiceClient() throws IOException {
if (this.transport == Transport.GRPC) {
return getPredictionServiceGrpcClient();
} else {
return getPredictionServiceRestClient();
}
}

/**
* Returns the {@link PredictionServiceClient} with GRPC. The client will be instantiated when the
* first prediction API call is made.
*
* @return {@link PredictionServiceClient} that send GRPC requests to the backing service through
* method calls that map to the API methods.
*/
public PredictionServiceClient getPredictionServiceClient() throws IOException {
private PredictionServiceClient getPredictionServiceGrpcClient() throws IOException {
if (predictionServiceClient == null) {
PredictionServiceSettings.Builder settingsBuilder = PredictionServiceSettings.newBuilder();
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
Expand Down Expand Up @@ -257,7 +277,7 @@ public PredictionServiceClient getPredictionServiceClient() throws IOException {
* @return {@link PredictionServiceClient} that send REST requests to the backing service through
* method calls that map to the API methods.
*/
public PredictionServiceClient getPredictionServiceRestClient() throws IOException {
private PredictionServiceClient getPredictionServiceRestClient() throws IOException {
if (predictionServiceRestClient == null) {
PredictionServiceSettings.Builder settingsBuilder =
PredictionServiceSettings.newHttpJsonBuilder();
Expand All @@ -284,14 +304,30 @@ public PredictionServiceClient getPredictionServiceRestClient() throws IOExcepti
return predictionServiceRestClient;
}

/**
* Returns the {@link LlmUtilityServiceClient} with GRPC or REST, based on the Transport type. The
* client will be instantiated when the first API call is made.
*
* @return {@link LlmUtilityServiceClient} that makes calls to the backing service through method
* calls that map to the API methods.
*/
@InternalApi
public LlmUtilityServiceClient getLlmUtilityClient() throws IOException {
if (this.transport == Transport.GRPC) {
return getLlmUtilityGrpcClient();
} else {
return getLlmUtilityRestClient();
}
}

/**
* Returns the {@link LlmUtilityServiceClient} with GRPC. The client will be instantiated when the
* first prediction API call is made.
* first API call is made.
*
* @return {@link LlmUtilityServiceClient} that makes gRPC calls to the backing service through
* method calls that map to the API methods.
*/
public LlmUtilityServiceClient getLlmUtilityClient() throws IOException {
private LlmUtilityServiceClient getLlmUtilityGrpcClient() throws IOException {
if (llmUtilityClient == null) {
LlmUtilityServiceSettings.Builder settingsBuilder = LlmUtilityServiceSettings.newBuilder();
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
Expand Down Expand Up @@ -319,12 +355,12 @@ public LlmUtilityServiceClient getLlmUtilityClient() throws IOException {

/**
* Returns the {@link LlmUtilityServiceClient} with REST. The client will be instantiated when the
* first prediction API call is made.
* first API call is made.
*
* @return {@link LlmUtilityServiceClient} that makes REST requests to the backing service through
* method calls that map to the API methods.
*/
public LlmUtilityServiceClient getLlmUtilityRestClient() throws IOException {
private LlmUtilityServiceClient getLlmUtilityRestClient() throws IOException {
if (llmUtilityRestClient == null) {
LlmUtilityServiceSettings.Builder settingsBuilder =
LlmUtilityServiceSettings.newHttpJsonBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package com.google.cloud.vertexai.generativeai;

import com.google.api.core.BetaApi;
import com.google.cloud.vertexai.Transport;
import com.google.cloud.vertexai.VertexAI;
import com.google.cloud.vertexai.api.Content;
import com.google.cloud.vertexai.api.CountTokensRequest;
Expand Down Expand Up @@ -289,11 +288,7 @@ public CountTokensResponse countTokens(List<Content> contents) throws IOExceptio
@BetaApi
private CountTokensResponse countTokensFromRequest(CountTokensRequest request)
throws IOException {
if (vertexAi.getTransport() == Transport.REST) {
return vertexAi.getLlmUtilityRestClient().countTokens(request);
} else {
return vertexAi.getLlmUtilityClient().countTokens(request);
}
return vertexAi.getLlmUtilityClient().countTokens(request);
}

/**
Expand Down Expand Up @@ -520,11 +515,7 @@ public GenerateContentResponse generateContent(
*/
private GenerateContentResponse generateContent(GenerateContentRequest request)
throws IOException {
if (vertexAi.getTransport() == Transport.REST) {
return vertexAi.getPredictionServiceRestClient().generateContentCallable().call(request);
} else {
return vertexAi.getPredictionServiceClient().generateContentCallable().call(request);
}
return vertexAi.getPredictionServiceClient().generateContentCallable().call(request);
}

/**
Expand Down Expand Up @@ -932,23 +923,13 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
*/
private ResponseStream<GenerateContentResponse> generateContentStream(
GenerateContentRequest request) throws IOException {
if (vertexAi.getTransport() == Transport.REST) {
return new ResponseStream(
new ResponseStreamIteratorWithHistory(
vertexAi
.getPredictionServiceRestClient()
.streamGenerateContentCallable()
.call(request)
.iterator()));
} else {
return new ResponseStream(
new ResponseStreamIteratorWithHistory(
vertexAi
.getPredictionServiceClient()
.streamGenerateContentCallable()
.call(request)
.iterator()));
}
return new ResponseStream(
new ResponseStreamIteratorWithHistory(
vertexAi
.getPredictionServiceClient()
.streamGenerateContentCallable()
.call(request)
.iterator()));
}

/**
Expand Down
Loading