Skip to content

Commit

Permalink
Add javadoc and inline comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
tbolis-at-mulesoft committed Nov 25, 2024
1 parent e957dab commit 1c4dfc3
Showing 1 changed file with 111 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
import java.util.*;
import java.util.stream.Collectors;

/**
* Implementation of Einstein AI's embedding model service that extends DimensionAwareEmbeddingModel.
* This class handles the generation of text embeddings using Salesforce Einstein API.
*/
public class EinsteinEmbeddingModel extends DimensionAwareEmbeddingModel {

private static final Logger LOGGER = LoggerFactory.getLogger(EinsteinEmbeddingModel.class);
Expand All @@ -33,119 +37,161 @@ public class EinsteinEmbeddingModel extends DimensionAwareEmbeddingModel {
private final Integer dimensions;
private final String accessToken;


/**
* Private constructor used by the builder pattern.
* Initializes the embedding model with Salesforce credentials and model configuration.
*
* @param salesforceOrg The Salesforce organization identifier
* @param clientId OAuth client ID for authentication
* @param clientSecret OAuth client secret for authentication
* @param modelName Name of the Einstein embedding model to use
* @param dimensions Number of dimensions for the embeddings
*/
private EinsteinEmbeddingModel(String salesforceOrg, String clientId, String clientSecret, String modelName, Integer dimensions) {

// Default to SFDC text embedding model if none specified
this.modelName = Utils.getOrDefault(modelName, Constants.EMBEDDING_MODEL_NAME_SFDC_TEXT_EMBEDDING_ADA_002);
this.dimensions = dimensions;
this.accessToken = getAccessToken(salesforceOrg, clientId, clientSecret);
}

/**
* Returns the known dimension of the embedding model.
*
* @return The dimension size of the embeddings
*/
protected Integer knownDimension() {
return this.dimensions != null ? this.dimensions : EinsteinEmbeddingModelName.knownDimension(this.modelName());
}

/**
* Returns the name of the current embedding model.
*
* @return The model name
*/
public String modelName() {
return this.modelName;
}

/**
* Generates embeddings for a list of text segments.
*
* @param textSegments List of text segments to embed
* @return Response containing list of embeddings and token usage information
*/
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {

List<String> texts = (List<String>) textSegments.stream().map(TextSegment::text).collect(Collectors.toList());
// Convert TextSegments to plain strings
List<String> texts = textSegments.stream().map(TextSegment::text).collect(Collectors.toList());
return this.embedTexts(texts);
}

/**
* Internal method to process text strings and generate embeddings.
* Handles batching of requests to the Einstein API.
*
* @param texts List of text strings to embed
* @return Response containing embeddings and token usage
*/
private Response<List<Embedding>> embedTexts(List<String> texts) {

List<Embedding> embeddings = new ArrayList<>();

int tokenUsage = 0;

// Loop through each array in batch of 16
// Process texts in batches of 16 (Einstein API limit)
for(int x = 0; x < texts.size(); x += 16) {

// Extract current batch
List<String> batch = texts.subList(x, Math.min(x + 16, texts.size()));

// Generate embeddings for current batch
String response = generateEmbeddings(buildPayload(batch));
JSONObject jsonResponse = new JSONObject(response);

// Accumulate token usage
tokenUsage += jsonResponse.getJSONObject("parameters")
.getJSONObject("usage")
.getInt("total_tokens");

// Extract the 'embeddings' array
// Parse embeddings from response
JSONArray embeddingsArray = jsonResponse.getJSONArray("embeddings");

// Loop through each embedding object in the embeddings array
// Process each embedding in the response
for (int i = 0; i < embeddingsArray.length(); i++) {
// Extract the individual embedding object
JSONObject embeddingObject = embeddingsArray.getJSONObject(i);

// Get the 'embedding' array
JSONArray embeddingArray = embeddingObject.getJSONArray("embedding");

// Convert the 'embedding' JSONArray to a float array
// Convert JSON array to float array
float[] vector = new float[embeddingArray.length()];
for (int y = 0; y < embeddingArray.length(); y++) {
vector[y] = (float) embeddingArray.getDouble(y); // Convert to float
vector[y] = (float) embeddingArray.getDouble(y);
}

// Create an Embedding object and add it to the list
Embedding embedding = Embedding.from(vector);
embeddings.add(embedding);
embeddings.add(Embedding.from(vector));
}
}

return Response.from(embeddings, new TokenUsage(tokenUsage));
}

/**
* Authenticates with Salesforce and obtains an access token.
*
* @param salesforceOrg Salesforce organization identifier
* @param clientId OAuth client ID
* @param clientSecret OAuth client secret
* @return Access token for API calls
* @throws ModuleException if authentication fails
*/
private String getAccessToken(String salesforceOrg, String clientId, String clientSecret) {

String urlString = "https://" + salesforceOrg + ".my.salesforce.com/services/oauth2/token";
String params = "grant_type=client_credentials&client_id=" + clientId + "&client_secret=" + clientSecret;

try {
URL url = new URL(urlString);

HttpURLConnection conn = (HttpURLConnection) url.openConnection();

// Configure connection for OAuth token request
conn.setDoOutput(true);
conn.setRequestMethod("POST");
conn.setRequestProperty("Content-Type", "application/x-www-form-urlencoded");

// Write parameters to request body
try (OutputStream os = conn.getOutputStream()) {
byte[] input = params.getBytes(StandardCharsets.UTF_8);
os.write(input, 0, input.length);
}

int responseCode = conn.getResponseCode();
if (responseCode == HttpURLConnection.HTTP_OK) {
// Read and parse response
try (java.io.BufferedReader br = new java.io.BufferedReader(
new java.io.InputStreamReader(conn.getInputStream(), StandardCharsets.UTF_8))) {
StringBuilder response = new StringBuilder();
String responseLine;
while ((responseLine = br.readLine()) != null) {
response.append(responseLine.trim());
}
// Parse JSON response and extract access_token
JSONObject jsonResponse = new JSONObject(response.toString());
return jsonResponse.getString("access_token");
return new JSONObject(response.toString()).getString("access_token");
}
} else {
throw new ModuleException(
"Error while getting access token for \"EINSTEIN\" embedding model service. Response code: " + responseCode,
MuleVectorsErrorType.AI_SERVICES_FAILURE);
}
} catch (Exception e) {

throw new ModuleException(
"Error while getting access token for \"EINSTEIN\" embedding model service.",
MuleVectorsErrorType.AI_SERVICES_FAILURE,
e);
}
}

/**
* Creates and configures an HTTP connection for Einstein API requests.
*
* @param url The endpoint URL
* @param accessToken OAuth access token
* @return Configured HttpURLConnection
* @throws IOException if connection setup fails
*/
private HttpURLConnection getConnectionObject(URL url, String accessToken) throws IOException {

HttpURLConnection conn = (HttpURLConnection) url.openConnection();
conn.setDoOutput(true);
conn.setRequestMethod("POST");
Expand All @@ -156,28 +202,43 @@ private HttpURLConnection getConnectionObject(URL url, String accessToken) throw
return conn;
}

/**
* Builds JSON payload for single text embedding request.
*
* @param text Text to embed
* @return JSON string payload
*/
private static String buildPayload(String text) {

JSONArray input = new JSONArray();
input.put(text);
JSONObject jsonObject = new JSONObject();
jsonObject.put("input", input);
return jsonObject.toString();
}

/**
* Builds JSON payload for batch text embedding request.
*
* @param texts List of texts to embed
* @return JSON string payload
*/
private static String buildPayload(List<String> texts) {

JSONArray input = new JSONArray(texts);
JSONObject jsonObject = new JSONObject();
jsonObject.put("input", input);
return jsonObject.toString();
}

/**
* Makes the API call to Einstein to generate embeddings.
*
* @param payload JSON payload for the request
* @return JSON response string
* @throws ModuleException if the API call fails
*/
private String generateEmbeddings(String payload) {

try {

// Open connection
// Prepare connection
String urlString = URL_BASE + this.modelName + "/embeddings";
HttpURLConnection connection;
try {
Expand All @@ -190,17 +251,16 @@ private String generateEmbeddings(String payload) {
e);
}

// Write the request body to the OutputStream
// Send request
try (OutputStream os = connection.getOutputStream()) {
byte[] input = payload.getBytes(StandardCharsets.UTF_8);
os.write(input, 0, input.length); // Writing the payload to the connection
os.write(input, 0, input.length);
}

// After writing, check the response code to ensure the request was successful
int responseCode = connection.getResponseCode();

// Only read the response if the request was successful
if (responseCode == HttpURLConnection.HTTP_OK) {
// Read response
try (java.io.BufferedReader br = new java.io.BufferedReader(
new java.io.InputStreamReader(connection.getInputStream(), StandardCharsets.UTF_8))) {
StringBuilder response = new StringBuilder();
Expand All @@ -214,29 +274,31 @@ private String generateEmbeddings(String payload) {
throw new ModuleException(
"Error while generating embeddings with \"EINSTEIN\" embedding model service. Response code: " + responseCode,
MuleVectorsErrorType.AI_SERVICES_FAILURE);

}
} catch (ModuleException e) {

throw e;

} catch (Exception e) {

throw new ModuleException(
"Error while generating embeddings with \"EINSTEIN\" embedding model service.",
MuleVectorsErrorType.AI_SERVICES_FAILURE,
e);
}

}

/**
* Creates a new builder instance for EinsteinEmbeddingModel.
*
* @return A new builder instance
*/
public static EinsteinEmbeddingModel.EinsteinEmbeddingModelBuilder builder() {

return new EinsteinEmbeddingModel.EinsteinEmbeddingModelBuilder();
}

/**
* Builder class for EinsteinEmbeddingModel.
* Implements the Builder pattern for constructing EinsteinEmbeddingModel instances.
*/
public static class EinsteinEmbeddingModelBuilder {

private String salesforceOrg;
private String clientId;
private String clientSecret;
Expand All @@ -245,6 +307,9 @@ public static class EinsteinEmbeddingModelBuilder {
private Tokenizer tokenizer;
private HttpURLConnection connection;

public EinsteinEmbeddingModelBuilder() {
}

public EinsteinEmbeddingModel.EinsteinEmbeddingModelBuilder salesforceOrg(String salesforceOrg) {
this.salesforceOrg = salesforceOrg;
return this;
Expand Down Expand Up @@ -275,10 +340,11 @@ public EinsteinEmbeddingModel.EinsteinEmbeddingModelBuilder dimensions(Integer d
return this;
}

public EinsteinEmbeddingModelBuilder() {

}

/**
* Builds and returns a new EinsteinEmbeddingModel instance.
*
* @return A new EinsteinEmbeddingModel configured with the builder's parameters
*/
public EinsteinEmbeddingModel build() {
return new EinsteinEmbeddingModel(this.salesforceOrg, this.clientId, this.clientSecret, this.modelName, this.dimensions);
}
Expand Down

0 comments on commit 1c4dfc3

Please sign in to comment.