Skip to content

Commit

Permalink
feat: [vertexai] add fluent API in ChatSession (#10597)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 617901539

Co-authored-by: Jaycee Li <[email protected]>
  • Loading branch information
copybara-service[bot] and jaycee-li authored Mar 21, 2024
1 parent a8aa591 commit 5c3d93e
Show file tree
Hide file tree
Showing 5 changed files with 417 additions and 370 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@

package com.google.cloud.vertexai;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;

import com.google.api.core.InternalApi;
import com.google.api.gax.core.CredentialsProvider;
import com.google.api.gax.core.FixedCredentialsProvider;
Expand All @@ -31,10 +28,8 @@
import com.google.cloud.vertexai.api.LlmUtilityServiceSettings;
import com.google.cloud.vertexai.api.PredictionServiceClient;
import com.google.cloud.vertexai.api.PredictionServiceSettings;
import com.google.common.base.Strings;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Level;
import java.util.logging.Logger;

Expand All @@ -61,8 +56,9 @@ public class VertexAI implements AutoCloseable {
private Transport transport = Transport.GRPC;
// The clients will be instantiated lazily
private PredictionServiceClient predictionServiceClient = null;
private PredictionServiceClient predictionServiceRestClient = null;
private LlmUtilityServiceClient llmUtilityClient = null;
private final ReentrantLock lock = new ReentrantLock();
private LlmUtilityServiceClient llmUtilityRestClient = null;

/**
* Construct a VertexAI instance.
Expand Down Expand Up @@ -197,35 +193,32 @@ public Credentials getCredentials() throws IOException {

/** Sets the value for {@link #getTransport()}. */
public void setTransport(Transport transport) {
checkNotNull(transport, "Transport can't be null.");
if (this.transport == transport) {
return;
}

this.transport = transport;
resetClients();
}

/** Sets the value for {@link #getApiEndpoint()}. */
public void setApiEndpoint(String apiEndpoint) {
checkArgument(!Strings.isNullOrEmpty(apiEndpoint), "Api endpoint can't be null or empty.");
if (this.apiEndpoint == apiEndpoint) {
return;
}
this.apiEndpoint = apiEndpoint;
resetClients();
}

private void resetClients() {
if (this.predictionServiceClient != null) {
this.predictionServiceClient.close();
this.predictionServiceClient = null;
}

if (this.predictionServiceRestClient != null) {
this.predictionServiceRestClient.close();
this.predictionServiceRestClient = null;
}

if (this.llmUtilityClient != null) {
this.llmUtilityClient.close();
this.llmUtilityClient = null;
}

if (this.llmUtilityRestClient != null) {
this.llmUtilityRestClient.close();
this.llmUtilityRestClient = null;
}
}

/**
Expand All @@ -237,47 +230,78 @@ private void resetClients() {
*/
@InternalApi
public PredictionServiceClient getPredictionServiceClient() throws IOException {
if (predictionServiceClient != null) {
return predictionServiceClient;
if (this.transport == Transport.GRPC) {
return getPredictionServiceGrpcClient();
} else {
return getPredictionServiceRestClient();
}
lock.lock();
try {
if (predictionServiceClient == null) {
PredictionServiceSettings settings = getPredictionServiceSettings();
// Disable the warning message logged in getApplicationDefault
Logger defaultCredentialsProviderLogger =
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
predictionServiceClient = PredictionServiceClient.create(settings);
defaultCredentialsProviderLogger.setLevel(previousLevel);
}

/**
* Returns the {@link PredictionServiceClient} with GRPC. The client will be instantiated when the
* first prediction API call is made.
*
* @return {@link PredictionServiceClient} that send GRPC requests to the backing service through
* method calls that map to the API methods.
*/
private PredictionServiceClient getPredictionServiceGrpcClient() throws IOException {
if (predictionServiceClient == null) {
PredictionServiceSettings.Builder settingsBuilder = PredictionServiceSettings.newBuilder();
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
if (this.credentialsProvider != null) {
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
}
return predictionServiceClient;
} finally {
lock.unlock();
HeaderProvider headerProvider =
FixedHeaderProvider.create(
"user-agent",
String.format(
"%s/%s",
Constants.USER_AGENT_HEADER,
GaxProperties.getLibraryVersion(PredictionServiceSettings.class)));
settingsBuilder.setHeaderProvider(headerProvider);
// Disable the warning message logged in getApplicationDefault
Logger defaultCredentialsProviderLogger =
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
predictionServiceClient = PredictionServiceClient.create(settingsBuilder.build());
defaultCredentialsProviderLogger.setLevel(previousLevel);
}
return predictionServiceClient;
}

private PredictionServiceSettings getPredictionServiceSettings() throws IOException {
PredictionServiceSettings.Builder builder;
if (transport == Transport.REST) {
builder = PredictionServiceSettings.newHttpJsonBuilder();
} else {
builder = PredictionServiceSettings.newBuilder();
}
builder.setEndpoint(String.format("%s:443", this.apiEndpoint));
if (this.credentialsProvider != null) {
builder.setCredentialsProvider(this.credentialsProvider);
/**
* Returns the {@link PredictionServiceClient} with REST. The client will be instantiated when the
* first prediction API call is made.
*
* @return {@link PredictionServiceClient} that send REST requests to the backing service through
* method calls that map to the API methods.
*/
private PredictionServiceClient getPredictionServiceRestClient() throws IOException {
if (predictionServiceRestClient == null) {
PredictionServiceSettings.Builder settingsBuilder =
PredictionServiceSettings.newHttpJsonBuilder();
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
if (this.credentialsProvider != null) {
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
}
HeaderProvider headerProvider =
FixedHeaderProvider.create(
"user-agent",
String.format(
"%s/%s",
Constants.USER_AGENT_HEADER,
GaxProperties.getLibraryVersion(PredictionServiceSettings.class)));
settingsBuilder.setHeaderProvider(headerProvider);
// Disable the warning message logged in getApplicationDefault
Logger defaultCredentialsProviderLogger =
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
predictionServiceRestClient = PredictionServiceClient.create(settingsBuilder.build());
defaultCredentialsProviderLogger.setLevel(previousLevel);
}
HeaderProvider headerProvider =
FixedHeaderProvider.create(
"user-agent",
String.format(
"%s/%s",
Constants.USER_AGENT_HEADER,
GaxProperties.getLibraryVersion(PredictionServiceSettings.class)));
builder.setHeaderProvider(headerProvider);
return builder.build();
return predictionServiceRestClient;
}

/**
Expand All @@ -289,47 +313,78 @@ private PredictionServiceSettings getPredictionServiceSettings() throws IOExcept
*/
@InternalApi
public LlmUtilityServiceClient getLlmUtilityClient() throws IOException {
if (llmUtilityClient != null) {
return llmUtilityClient;
if (this.transport == Transport.GRPC) {
return getLlmUtilityGrpcClient();
} else {
return getLlmUtilityRestClient();
}
lock.lock();
try {
if (llmUtilityClient == null) {
LlmUtilityServiceSettings settings = getLlmUtilityServiceClientSettings();
// Disable the warning message logged in getApplicationDefault
Logger defaultCredentialsProviderLogger =
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
llmUtilityClient = LlmUtilityServiceClient.create(settings);
defaultCredentialsProviderLogger.setLevel(previousLevel);
}

/**
* Returns the {@link LlmUtilityServiceClient} with GRPC. The client will be instantiated when the
* first API call is made.
*
* @return {@link LlmUtilityServiceClient} that makes gRPC calls to the backing service through
* method calls that map to the API methods.
*/
private LlmUtilityServiceClient getLlmUtilityGrpcClient() throws IOException {
if (llmUtilityClient == null) {
LlmUtilityServiceSettings.Builder settingsBuilder = LlmUtilityServiceSettings.newBuilder();
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
if (this.credentialsProvider != null) {
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
}
return llmUtilityClient;
} finally {
lock.unlock();
HeaderProvider headerProvider =
FixedHeaderProvider.create(
"user-agent",
String.format(
"%s/%s",
Constants.USER_AGENT_HEADER,
GaxProperties.getLibraryVersion(LlmUtilityServiceSettings.class)));
settingsBuilder.setHeaderProvider(headerProvider);
// Disable the warning message logged in getApplicationDefault
Logger defaultCredentialsProviderLogger =
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
llmUtilityClient = LlmUtilityServiceClient.create(settingsBuilder.build());
defaultCredentialsProviderLogger.setLevel(previousLevel);
}
return llmUtilityClient;
}

private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings() throws IOException {
LlmUtilityServiceSettings.Builder settingsBuilder;
if (transport == Transport.REST) {
settingsBuilder = LlmUtilityServiceSettings.newHttpJsonBuilder();
} else {
settingsBuilder = LlmUtilityServiceSettings.newBuilder();
}
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
if (this.credentialsProvider != null) {
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
/**
* Returns the {@link LlmUtilityServiceClient} with REST. The client will be instantiated when the
* first API call is made.
*
* @return {@link LlmUtilityServiceClient} that makes REST requests to the backing service through
* method calls that map to the API methods.
*/
private LlmUtilityServiceClient getLlmUtilityRestClient() throws IOException {
if (llmUtilityRestClient == null) {
LlmUtilityServiceSettings.Builder settingsBuilder =
LlmUtilityServiceSettings.newHttpJsonBuilder();
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
if (this.credentialsProvider != null) {
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
}
HeaderProvider headerProvider =
FixedHeaderProvider.create(
"user-agent",
String.format(
"%s/%s",
Constants.USER_AGENT_HEADER,
GaxProperties.getLibraryVersion(LlmUtilityServiceSettings.class)));
settingsBuilder.setHeaderProvider(headerProvider);
// Disable the warning message logged in getApplicationDefault
Logger defaultCredentialsProviderLogger =
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
llmUtilityRestClient = LlmUtilityServiceClient.create(settingsBuilder.build());
defaultCredentialsProviderLogger.setLevel(previousLevel);
}
HeaderProvider headerProvider =
FixedHeaderProvider.create(
"user-agent",
String.format(
"%s/%s",
Constants.USER_AGENT_HEADER,
GaxProperties.getLibraryVersion(LlmUtilityServiceSettings.class)));
settingsBuilder.setHeaderProvider(headerProvider);
return settingsBuilder.build();
return llmUtilityRestClient;
}

/** Closes the VertexAI instance together with all its instantiated clients. */
Expand All @@ -338,8 +393,14 @@ public void close() {
if (predictionServiceClient != null) {
predictionServiceClient.close();
}
if (predictionServiceRestClient != null) {
predictionServiceRestClient.close();
}
if (llmUtilityClient != null) {
llmUtilityClient.close();
}
if (llmUtilityRestClient != null) {
llmUtilityRestClient.close();
}
}
}
Loading

0 comments on commit 5c3d93e

Please sign in to comment.