Skip to content

Commit

Permalink
Apply spotless to ml-algorithms package (#1610)
Browse files Browse the repository at this point in the history
Signed-off-by: Tyler Ohlsen <[email protected]>
Signed-off-by: Bhavana Ramaram <[email protected]>
  • Loading branch information
ohltyler authored and dhrubo-os committed Nov 28, 2023
1 parent f823f72 commit dfa7b46
Show file tree
Hide file tree
Showing 88 changed files with 2,641 additions and 1,935 deletions.
10 changes: 10 additions & 0 deletions ml-algorithms/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ plugins {
id 'java'
id 'jacoco'
id "io.freefair.lombok"
id 'com.diffplug.spotless' version '6.18.0'
}

repositories {
Expand Down Expand Up @@ -96,3 +97,12 @@ jacocoTestCoverageVerification {
dependsOn jacocoTestReport
}
check.dependsOn jacocoTestCoverageVerification

spotless {
java {
removeUnusedImports()
importOrder 'java', 'javax', 'org', 'com'

eclipse().configFile rootProject.file('.eclipseformat.xml')
}
}
24 changes: 12 additions & 12 deletions ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,25 @@

package org.opensearch.ml.engine;

import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import java.nio.file.Path;
import java.util.Locale;
import java.util.Map;

import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.Output;
import org.opensearch.ml.engine.encryptor.Encryptor;
import java.nio.file.Path;
import java.util.Locale;
import java.util.Map;

import lombok.Getter;
import lombok.extern.log4j.Log4j2;

/**
* This is the interface to all ml algorithms.
Expand Down Expand Up @@ -92,10 +94,7 @@ public Path getDeployModelRootPath() {
}

public Path getDeployModelChunkPath(String modelId, Integer chunkNumber) {
return mlModelsCachePath.resolve(DEPLOY_MODEL_FOLDER)
.resolve(modelId)
.resolve("chunks")
.resolve(chunkNumber + "");
return mlModelsCachePath.resolve(DEPLOY_MODEL_FOLDER).resolve(modelId).resolve("chunks").resolve(chunkNumber + "");
}

public Path getModelCachePath(String modelId, String modelName, String version) {
Expand Down Expand Up @@ -145,7 +144,8 @@ public MLOutput predict(Input input, MLModel model) {
public MLOutput trainAndPredict(Input input) {
validateMLInput(input);
MLInput mlInput = (MLInput) input;
TrainAndPredictable trainAndPredictable = MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class);
TrainAndPredictable trainAndPredictable = MLEngineClassLoader
.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class);
if (trainAndPredictable == null) {
throw new IllegalArgumentException("Unsupported algorithm: " + mlInput.getAlgorithm());
}
Expand Down Expand Up @@ -180,7 +180,7 @@ private void validateMLInput(Input input) {
throw new IllegalArgumentException("Input data set should not be null");
}
if (inputDataset instanceof DataFrameInputDataset) {
DataFrame dataFrame = ((DataFrameInputDataset)inputDataset).getDataFrame();
DataFrame dataFrame = ((DataFrameInputDataset) inputDataset).getDataFrame();
if (dataFrame == null || dataFrame.size() == 0) {
throw new IllegalArgumentException("Input data frame should not be null or empty");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,6 @@

package org.opensearch.ml.engine;

import org.apache.commons.beanutils.BeanUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.engine.annotation.ConnectorExecutor;
import org.opensearch.ml.engine.annotation.Function;
import org.reflections.Reflections;

import java.lang.reflect.Constructor;
import java.security.AccessController;
import java.security.PrivilegedActionException;
Expand All @@ -22,6 +13,14 @@
import java.util.Map;
import java.util.Set;

import org.apache.commons.beanutils.BeanUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.engine.annotation.ConnectorExecutor;
import org.opensearch.ml.engine.annotation.Function;
import org.reflections.Reflections;

public class MLEngineClassLoader {

Expand Down Expand Up @@ -138,7 +137,7 @@ public static <T, S, I extends Object> S initInstance(T type, I in, Class<?> con
} catch (Exception e) {
Throwable cause = e.getCause();
if (cause instanceof MLException) {
throw (MLException)cause;
throw (MLException) cause;
} else {
logger.error("Failed to init instance for type " + type, e);
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

package org.opensearch.ml.engine;

import org.opensearch.ml.common.MLModel;

import java.util.Map;

import org.opensearch.ml.common.MLModel;

public interface MLExecutable extends Executable {

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,10 @@

package org.opensearch.ml.engine;

import ai.djl.training.util.DownloadUtils;
import ai.djl.training.util.ProgressBar;
import com.google.gson.stream.JsonReader;
import lombok.extern.log4j.Log4j2;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.engine.utils.FileUtils.calculateFileHash;
import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly;
import static org.opensearch.ml.engine.utils.FileUtils.splitFileIntoChunks;

import java.io.File;
import java.io.FileReader;
Expand All @@ -31,10 +25,18 @@
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;

import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.engine.utils.FileUtils.calculateFileHash;
import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly;
import static org.opensearch.ml.engine.utils.FileUtils.splitFileIntoChunks;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;

import com.google.gson.stream.JsonReader;

import ai.djl.training.util.DownloadUtils;
import ai.djl.training.util.ProgressBar;
import lombok.extern.log4j.Log4j2;

@Log4j2
public class ModelHelper {
Expand All @@ -53,7 +55,11 @@ public ModelHelper(MLEngine mlEngine) {
this.mlEngine = mlEngine;
}

public void downloadPrebuiltModelConfig(String taskId, MLRegisterModelInput registerModelInput, ActionListener<MLRegisterModelInput> listener) {
public void downloadPrebuiltModelConfig(
String taskId,
MLRegisterModelInput registerModelInput,
ActionListener<MLRegisterModelInput> listener
) {
String modelName = registerModelInput.getModelName();
String version = registerModelInput.getVersion();
MLModelFormat modelFormat = registerModelInput.getModelFormat();
Expand Down Expand Up @@ -90,7 +96,6 @@ public void downloadPrebuiltModelConfig(String taskId, MLRegisterModelInput regi
.modelNodeIds(modelNodeIds)
.modelGroupId(modelGroupId)
.functionName(FunctionName.from((String) config.get("model_task_type")));

config.entrySet().forEach(entry -> {
switch (entry.getKey().toString()) {
case MLRegisterModelInput.MODEL_FORMAT_FIELD:
Expand All @@ -108,19 +113,24 @@ public void downloadPrebuiltModelConfig(String taskId, MLRegisterModelInput regi
configBuilder.allConfig(configEntry.getValue().toString());
break;
case TextEmbeddingModelConfig.EMBEDDING_DIMENSION_FIELD:
configBuilder.embeddingDimension(((Double)configEntry.getValue()).intValue());
configBuilder.embeddingDimension(((Double) configEntry.getValue()).intValue());
break;
case TextEmbeddingModelConfig.FRAMEWORK_TYPE_FIELD:
configBuilder.frameworkType(TextEmbeddingModelConfig.FrameworkType.from(configEntry.getValue().toString()));
configBuilder
.frameworkType(TextEmbeddingModelConfig.FrameworkType.from(configEntry.getValue().toString()));
break;
case TextEmbeddingModelConfig.POOLING_MODE_FIELD:
configBuilder.poolingMode(TextEmbeddingModelConfig.PoolingMode.from(configEntry.getValue().toString().toUpperCase(Locale.ROOT)));
configBuilder
.poolingMode(
TextEmbeddingModelConfig.PoolingMode
.from(configEntry.getValue().toString().toUpperCase(Locale.ROOT))
);
break;
case TextEmbeddingModelConfig.NORMALIZE_RESULT_FIELD:
configBuilder.normalizeResult(Boolean.parseBoolean(configEntry.getValue().toString()));
break;
case TextEmbeddingModelConfig.MODEL_MAX_LENGTH_FIELD:
configBuilder.modelMaxLength(((Double)configEntry.getValue()).intValue());
configBuilder.modelMaxLength(((Double) configEntry.getValue()).intValue());
break;
default:
break;
Expand Down Expand Up @@ -149,11 +159,13 @@ public boolean isModelAllowed(MLRegisterModelInput registerModelInput, List mode
String modelName = registerModelInput.getModelName();
String version = registerModelInput.getVersion();
MLModelFormat modelFormat = registerModelInput.getModelFormat();
for (Object meta: modelMetaList) {
String name = (String) ((Map<String, Object>)meta).get("name");
List<String> versions = (List) ((Map<String, Object>)meta).get("version");
List<String> formats = (List) ((Map<String, Object>)meta).get("format");
if (name.equals(modelName) && versions.contains(version.toLowerCase(Locale.ROOT)) && formats.contains(modelFormat.toString().toLowerCase(Locale.ROOT))) {
for (Object meta : modelMetaList) {
String name = (String) ((Map<String, Object>) meta).get("name");
List<String> versions = (List) ((Map<String, Object>) meta).get("version");
List<String> formats = (List) ((Map<String, Object>) meta).get("format");
if (name.equals(modelName)
&& versions.contains(version.toLowerCase(Locale.ROOT))
&& formats.contains(modelFormat.toString().toLowerCase(Locale.ROOT))) {
return true;
}
}
Expand Down Expand Up @@ -193,11 +205,20 @@ public List downloadPrebuiltModelMetaList(String taskId, MLRegisterModelInput re
* @param modelContentHash model content hash value
* @param listener action listener
*/
public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String modelName, String version, String url, String modelContentHash, FunctionName functionName, ActionListener<Map<String, Object>> listener) {
public void downloadAndSplit(
MLModelFormat modelFormat,
String taskId,
String modelName,
String version,
String url,
String modelContentHash,
FunctionName functionName,
ActionListener<Map<String, Object>> listener
) {
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
Path registerModelPath = mlEngine.getRegisterModelPath(taskId, modelName, version);
String modelPath = registerModelPath +".zip";
String modelPath = registerModelPath + ".zip";
Path modelPartsPath = registerModelPath.resolve("chunks");
File modelZipFile = new File(modelPath);
log.debug("download model to file {}", modelZipFile.getAbsolutePath());
Expand All @@ -224,7 +245,8 @@ public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String mo
}
}

public void verifyModelZipFile(MLModelFormat modelFormat, String modelZipFilePath, String modelName, FunctionName functionName) throws IOException {
public void verifyModelZipFile(MLModelFormat modelFormat, String modelZipFilePath, String modelName, FunctionName functionName)
throws IOException {
boolean hasPtFile = false;
boolean hasOnnxFile = false;
boolean hasTokenizerFile = false;
Expand All @@ -249,7 +271,13 @@ public void verifyModelZipFile(MLModelFormat modelFormat, String modelZipFilePat
}
}

private static boolean hasModelFile(MLModelFormat modelFormat, MLModelFormat targetModelFormat, String fileExtension, boolean hasModelFile, String fileName) {
private static boolean hasModelFile(
MLModelFormat modelFormat,
MLModelFormat targetModelFormat,
String fileExtension,
boolean hasModelFile,
String fileName
) {
if (fileName.endsWith(fileExtension)) {
if (modelFormat != targetModelFormat) {
throw new IllegalArgumentException("Model format is " + modelFormat + ", but find " + fileExtension + " file");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

package org.opensearch.ml.engine;

import java.util.Map;

import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.engine.encryptor.Encryptor;

import java.util.Map;

/**
* This is machine learning algorithms predict interface.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;


/**
* This is machine learning algorithms train and predict interface.
*/
Expand Down
Loading

0 comments on commit dfa7b46

Please sign in to comment.