From 1c4dfc362711c745519b9fc2a481055d8a20e4af Mon Sep 17 00:00:00 2001 From: Tommaso Bolis Date: Mon, 25 Nov 2024 10:56:39 +0100 Subject: [PATCH] Add javadoc and inline comments. --- .../einstein/EinsteinEmbeddingModel.java | 156 +++++++++++++----- 1 file changed, 111 insertions(+), 45 deletions(-) diff --git a/src/main/java/org/mule/extension/vectors/internal/model/einstein/EinsteinEmbeddingModel.java b/src/main/java/org/mule/extension/vectors/internal/model/einstein/EinsteinEmbeddingModel.java index ed4d765..6379182 100644 --- a/src/main/java/org/mule/extension/vectors/internal/model/einstein/EinsteinEmbeddingModel.java +++ b/src/main/java/org/mule/extension/vectors/internal/model/einstein/EinsteinEmbeddingModel.java @@ -23,6 +23,10 @@ import java.util.*; import java.util.stream.Collectors; +/** + * Implementation of Einstein AI's embedding model service that extends DimensionAwareEmbeddingModel. + * This class handles the generation of text embeddings using Salesforce Einstein API. + */ public class EinsteinEmbeddingModel extends DimensionAwareEmbeddingModel { private static final Logger LOGGER = LoggerFactory.getLogger(EinsteinEmbeddingModel.class); @@ -33,85 +37,122 @@ public class EinsteinEmbeddingModel extends DimensionAwareEmbeddingModel { private final Integer dimensions; private final String accessToken; - + /** + * Private constructor used by the builder pattern. + * Initializes the embedding model with Salesforce credentials and model configuration. + * + * @param salesforceOrg The Salesforce organization identifier + * @param clientId OAuth client ID for authentication + * @param clientSecret OAuth client secret for authentication + * @param modelName Name of the Einstein embedding model to use + * @param dimensions Number of dimensions for the embeddings + */ private EinsteinEmbeddingModel(String salesforceOrg, String clientId, String clientSecret, String modelName, Integer dimensions) { - + // Default to SFDC text embedding model if none specified this.modelName = Utils.getOrDefault(modelName, Constants.EMBEDDING_MODEL_NAME_SFDC_TEXT_EMBEDDING_ADA_002); this.dimensions = dimensions; this.accessToken = getAccessToken(salesforceOrg, clientId, clientSecret); } + /** + * Returns the known dimension of the embedding model. + * + * @return The dimension size of the embeddings + */ protected Integer knownDimension() { return this.dimensions != null ? this.dimensions : EinsteinEmbeddingModelName.knownDimension(this.modelName()); } + /** + * Returns the name of the current embedding model. + * + * @return The model name + */ public String modelName() { return this.modelName; } + /** + * Generates embeddings for a list of text segments. + * + * @param textSegments List of text segments to embed + * @return Response containing list of embeddings and token usage information + */ public Response> embedAll(List textSegments) { - - List texts = (List) textSegments.stream().map(TextSegment::text).collect(Collectors.toList()); + // Convert TextSegments to plain strings + List texts = textSegments.stream().map(TextSegment::text).collect(Collectors.toList()); return this.embedTexts(texts); } + /** + * Internal method to process text strings and generate embeddings. + * Handles batching of requests to the Einstein API. + * + * @param texts List of text strings to embed + * @return Response containing embeddings and token usage + */ private Response> embedTexts(List texts) { - List embeddings = new ArrayList<>(); - int tokenUsage = 0; - // Loop through each array in batch of 16 + // Process texts in batches of 16 (Einstein API limit) for(int x = 0; x < texts.size(); x += 16) { - + // Extract current batch List batch = texts.subList(x, Math.min(x + 16, texts.size())); + // Generate embeddings for current batch String response = generateEmbeddings(buildPayload(batch)); JSONObject jsonResponse = new JSONObject(response); + // Accumulate token usage tokenUsage += jsonResponse.getJSONObject("parameters") .getJSONObject("usage") .getInt("total_tokens"); - // Extract the 'embeddings' array + // Parse embeddings from response JSONArray embeddingsArray = jsonResponse.getJSONArray("embeddings"); - // Loop through each embedding object in the embeddings array + // Process each embedding in the response for (int i = 0; i < embeddingsArray.length(); i++) { - // Extract the individual embedding object JSONObject embeddingObject = embeddingsArray.getJSONObject(i); - - // Get the 'embedding' array JSONArray embeddingArray = embeddingObject.getJSONArray("embedding"); - // Convert the 'embedding' JSONArray to a float array + // Convert JSON array to float array float[] vector = new float[embeddingArray.length()]; for (int y = 0; y < embeddingArray.length(); y++) { - vector[y] = (float) embeddingArray.getDouble(y); // Convert to float + vector[y] = (float) embeddingArray.getDouble(y); } - // Create an Embedding object and add it to the list - Embedding embedding = Embedding.from(vector); - embeddings.add(embedding); + embeddings.add(Embedding.from(vector)); } } return Response.from(embeddings, new TokenUsage(tokenUsage)); } + /** + * Authenticates with Salesforce and obtains an access token. + * + * @param salesforceOrg Salesforce organization identifier + * @param clientId OAuth client ID + * @param clientSecret OAuth client secret + * @return Access token for API calls + * @throws ModuleException if authentication fails + */ private String getAccessToken(String salesforceOrg, String clientId, String clientSecret) { - String urlString = "https://" + salesforceOrg + ".my.salesforce.com/services/oauth2/token"; String params = "grant_type=client_credentials&client_id=" + clientId + "&client_secret=" + clientSecret; try { URL url = new URL(urlString); - HttpURLConnection conn = (HttpURLConnection) url.openConnection(); + + // Configure connection for OAuth token request conn.setDoOutput(true); conn.setRequestMethod("POST"); conn.setRequestProperty("Content-Type", "application/x-www-form-urlencoded"); + // Write parameters to request body try (OutputStream os = conn.getOutputStream()) { byte[] input = params.getBytes(StandardCharsets.UTF_8); os.write(input, 0, input.length); @@ -119,6 +160,7 @@ private String getAccessToken(String salesforceOrg, String clientId, String clie int responseCode = conn.getResponseCode(); if (responseCode == HttpURLConnection.HTTP_OK) { + // Read and parse response try (java.io.BufferedReader br = new java.io.BufferedReader( new java.io.InputStreamReader(conn.getInputStream(), StandardCharsets.UTF_8))) { StringBuilder response = new StringBuilder(); @@ -126,9 +168,7 @@ private String getAccessToken(String salesforceOrg, String clientId, String clie while ((responseLine = br.readLine()) != null) { response.append(responseLine.trim()); } - // Parse JSON response and extract access_token - JSONObject jsonResponse = new JSONObject(response.toString()); - return jsonResponse.getString("access_token"); + return new JSONObject(response.toString()).getString("access_token"); } } else { throw new ModuleException( @@ -136,7 +176,6 @@ private String getAccessToken(String salesforceOrg, String clientId, String clie MuleVectorsErrorType.AI_SERVICES_FAILURE); } } catch (Exception e) { - throw new ModuleException( "Error while getting access token for \"EINSTEIN\" embedding model service.", MuleVectorsErrorType.AI_SERVICES_FAILURE, @@ -144,8 +183,15 @@ private String getAccessToken(String salesforceOrg, String clientId, String clie } } + /** + * Creates and configures an HTTP connection for Einstein API requests. + * + * @param url The endpoint URL + * @param accessToken OAuth access token + * @return Configured HttpURLConnection + * @throws IOException if connection setup fails + */ private HttpURLConnection getConnectionObject(URL url, String accessToken) throws IOException { - HttpURLConnection conn = (HttpURLConnection) url.openConnection(); conn.setDoOutput(true); conn.setRequestMethod("POST"); @@ -156,8 +202,13 @@ private HttpURLConnection getConnectionObject(URL url, String accessToken) throw return conn; } + /** + * Builds JSON payload for single text embedding request. + * + * @param text Text to embed + * @return JSON string payload + */ private static String buildPayload(String text) { - JSONArray input = new JSONArray(); input.put(text); JSONObject jsonObject = new JSONObject(); @@ -165,19 +216,29 @@ private static String buildPayload(String text) { return jsonObject.toString(); } + /** + * Builds JSON payload for batch text embedding request. + * + * @param texts List of texts to embed + * @return JSON string payload + */ private static String buildPayload(List texts) { - JSONArray input = new JSONArray(texts); JSONObject jsonObject = new JSONObject(); jsonObject.put("input", input); return jsonObject.toString(); } + /** + * Makes the API call to Einstein to generate embeddings. + * + * @param payload JSON payload for the request + * @return JSON response string + * @throws ModuleException if the API call fails + */ private String generateEmbeddings(String payload) { - try { - - // Open connection + // Prepare connection String urlString = URL_BASE + this.modelName + "/embeddings"; HttpURLConnection connection; try { @@ -190,17 +251,16 @@ private String generateEmbeddings(String payload) { e); } - // Write the request body to the OutputStream + // Send request try (OutputStream os = connection.getOutputStream()) { byte[] input = payload.getBytes(StandardCharsets.UTF_8); - os.write(input, 0, input.length); // Writing the payload to the connection + os.write(input, 0, input.length); } - // After writing, check the response code to ensure the request was successful int responseCode = connection.getResponseCode(); - // Only read the response if the request was successful if (responseCode == HttpURLConnection.HTTP_OK) { + // Read response try (java.io.BufferedReader br = new java.io.BufferedReader( new java.io.InputStreamReader(connection.getInputStream(), StandardCharsets.UTF_8))) { StringBuilder response = new StringBuilder(); @@ -214,29 +274,31 @@ private String generateEmbeddings(String payload) { throw new ModuleException( "Error while generating embeddings with \"EINSTEIN\" embedding model service. Response code: " + responseCode, MuleVectorsErrorType.AI_SERVICES_FAILURE); - } } catch (ModuleException e) { - throw e; - } catch (Exception e) { - throw new ModuleException( "Error while generating embeddings with \"EINSTEIN\" embedding model service.", MuleVectorsErrorType.AI_SERVICES_FAILURE, e); } - } + /** + * Creates a new builder instance for EinsteinEmbeddingModel. + * + * @return A new builder instance + */ public static EinsteinEmbeddingModel.EinsteinEmbeddingModelBuilder builder() { - return new EinsteinEmbeddingModel.EinsteinEmbeddingModelBuilder(); } + /** + * Builder class for EinsteinEmbeddingModel. + * Implements the Builder pattern for constructing EinsteinEmbeddingModel instances. + */ public static class EinsteinEmbeddingModelBuilder { - private String salesforceOrg; private String clientId; private String clientSecret; @@ -245,6 +307,9 @@ public static class EinsteinEmbeddingModelBuilder { private Tokenizer tokenizer; private HttpURLConnection connection; + public EinsteinEmbeddingModelBuilder() { + } + public EinsteinEmbeddingModel.EinsteinEmbeddingModelBuilder salesforceOrg(String salesforceOrg) { this.salesforceOrg = salesforceOrg; return this; @@ -275,10 +340,11 @@ public EinsteinEmbeddingModel.EinsteinEmbeddingModelBuilder dimensions(Integer d return this; } - public EinsteinEmbeddingModelBuilder() { - - } - + /** + * Builds and returns a new EinsteinEmbeddingModel instance. + * + * @return A new EinsteinEmbeddingModel configured with the builder's parameters + */ public EinsteinEmbeddingModel build() { return new EinsteinEmbeddingModel(this.salesforceOrg, this.clientId, this.clientSecret, this.modelName, this.dimensions); }