diff --git a/serving/src/main/java/ai/djl/serving/ModelServer.java b/serving/src/main/java/ai/djl/serving/ModelServer.java index 712625af3..39d4a130c 100644 --- a/serving/src/main/java/ai/djl/serving/ModelServer.java +++ b/serving/src/main/java/ai/djl/serving/ModelServer.java @@ -12,7 +12,10 @@ */ package ai.djl.serving; +import ai.djl.repository.Artifact; import ai.djl.repository.FilenameUtils; +import ai.djl.repository.MRL; +import ai.djl.repository.Repository; import ai.djl.serving.models.ModelManager; import ai.djl.serving.models.WorkflowInfo; import ai.djl.serving.plugins.FolderScanPluginManager; @@ -31,7 +34,9 @@ import io.netty.handler.ssl.SslContext; import io.netty.util.internal.logging.InternalLoggerFactory; import io.netty.util.internal.logging.Slf4JLoggerFactory; +import java.io.File; import java.io.IOException; +import java.io.InputStream; import java.net.MalformedURLException; import java.nio.file.Files; import java.nio.file.Path; @@ -39,6 +44,8 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Objects; +import java.util.Properties; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -279,25 +286,8 @@ private void initModelStore() throws IOException { // Check folders to see if they can be models as well urls = Files.list(modelStore) - .filter( - p -> { - logger.info("Found file in model_store: {}", p); - try { - return !Files.isHidden(p) && Files.isDirectory(p) - || FilenameUtils.isArchiveFile(p.toString()); - } catch (IOException e) { - logger.warn("Failed to access file: " + p, e); - return false; - } - }) - .map( - p -> { - try { - return p.toUri().toURL().toString(); - } catch (MalformedURLException e) { - throw new AssertionError("Invalid path: " + p, e); - } - }) + .map(this::mapModelUrl) + .filter(Objects::nonNull) .collect(Collectors.toList()); } else { String[] modelsUrls = loadModels.split("[, ]+"); @@ -377,6 +367,99 @@ private void initModelStore() throws IOException { } } + String mapModelUrl(Path path) { + try { + logger.info("Found file in model_store: {}", path); + if (Files.isHidden(path) + || (!Files.isDirectory(path) + && !FilenameUtils.isArchiveFile(path.toString()))) { + return null; + } + + File[] files = path.toFile().listFiles(); + if (files != null && files.length == 1 && files[0].isDirectory()) { + // handle archive file contains folder name case + path = files[0].toPath().toAbsolutePath(); + } + + String url = path.toUri().toURL().toString(); + String modelName = ModelInfo.inferModelNameFromUrl(url); + String engine; + if (Files.isDirectory(path)) { + engine = inferEngine(path); + } else { + try { + Repository repository = Repository.newInstance("modelStore", url); + List mrls = repository.getResources(); + Artifact artifact = mrls.get(0).getDefaultArtifact(); + repository.prepare(artifact); + Path modelDir = repository.getResourceDirectory(artifact); + engine = inferEngine(modelDir); + } catch (IOException e) { + logger.warn("Failed to extract model: " + path, e); + return null; + } + } + if (engine == null) { + return null; + } + return modelName + "::" + engine + ":*=" + url; + } catch (MalformedURLException e) { + throw new AssertionError("Invalid path: " + path, e); + } catch (IOException e) { + logger.warn("Failed to access file: " + path, e); + return null; + } + } + + private String inferEngine(Path modelDir) { + Path file = modelDir.resolve("serving.properties"); + if (Files.isRegularFile(file)) { + Properties prop = new Properties(); + try (InputStream is = Files.newInputStream(file)) { + prop.load(is); + String engine = prop.getProperty("engine"); + if (engine != null) { + return engine; + } + } catch (IOException e) { + logger.warn("Failed read serving.properties file", e); + } + } + + String dirName = modelDir.toFile().getName(); + if (Files.isDirectory(modelDir.resolve("MAR-INF")) + || Files.isRegularFile(modelDir.resolve("model.py")) + || Files.isRegularFile(modelDir.resolve(dirName + ".py"))) { + // MMS/TorchServe + return "Python"; + } else if (Files.isRegularFile(modelDir.resolve(dirName + ".pt"))) { + return "PyTorch"; + } else if (Files.isRegularFile(modelDir.resolve("saved_model.pb"))) { + return "TensorFlow"; + } else if (Files.isRegularFile(modelDir.resolve(dirName + "-symbol.json"))) { + return "MXNet"; + } else if (Files.isRegularFile(modelDir.resolve(dirName + ".onnx"))) { + return "OnnxRuntime"; + } else if (Files.isRegularFile(modelDir.resolve(dirName + ".trt")) + || Files.isRegularFile(modelDir.resolve(dirName + ".uff"))) { + return "TensorRT"; + } else if (Files.isRegularFile(modelDir.resolve(dirName + ".tflite"))) { + return "TFLite"; + } else if (Files.isRegularFile(modelDir.resolve("model")) + || Files.isRegularFile(modelDir.resolve("__model__")) + || Files.isRegularFile(modelDir.resolve("inference.pdmodel"))) { + return "PaddlePaddle"; + } else if (Files.isRegularFile(modelDir.resolve(dirName + ".json"))) { + return "XGBoost"; + } else if (Files.isRegularFile(modelDir.resolve(dirName + ".dylib")) + || Files.isRegularFile(modelDir.resolve(dirName + ".so"))) { + return "DLR"; + } + logger.warn("Failed to detect engine of the model: " + modelDir); + return null; + } + private static void printHelp(String msg, Options options) { HelpFormatter formatter = new HelpFormatter(); formatter.setLeftPadding(1); diff --git a/serving/src/test/java/ai/djl/serving/ModelServerTest.java b/serving/src/test/java/ai/djl/serving/ModelServerTest.java index 963b45ede..a4884922c 100644 --- a/serving/src/test/java/ai/djl/serving/ModelServerTest.java +++ b/serving/src/test/java/ai/djl/serving/ModelServerTest.java @@ -21,6 +21,7 @@ import ai.djl.serving.util.Connector; import ai.djl.util.JsonUtils; import ai.djl.util.Utils; +import ai.djl.util.ZipUtils; import ai.djl.util.cuda.CudaUtils; import com.google.gson.reflect.TypeToken; import io.netty.bootstrap.Bootstrap; @@ -55,6 +56,7 @@ import io.netty.handler.timeout.ReadTimeoutHandler; import io.netty.util.internal.logging.InternalLoggerFactory; import io.netty.util.internal.logging.Slf4JLoggerFactory; +import java.io.BufferedWriter; import java.io.IOException; import java.io.InputStream; import java.io.UnsupportedEncodingException; @@ -62,6 +64,9 @@ import java.net.URL; import java.net.URLEncoder; import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; import java.security.GeneralSecurityException; import java.util.List; import java.util.concurrent.CountDownLatch; @@ -110,6 +115,9 @@ public void beforeSuite() try (InputStream is = url.openStream()) { testImage = Utils.toByteArray(is); } + Path modelStore = Paths.get("build/models"); + Utils.deleteQuietly(modelStore); + Files.createDirectories(modelStore); String[] args = {"-f", "src/test/resources/config.properties"}; Arguments arguments = ConfigManagerTest.parseArguments(args); @@ -129,6 +137,83 @@ public void afterSuite() { server.stop(); } + @Test + public void testModelStore() throws IOException { + Path modelStore = Paths.get("build/models"); + Path modelDir = modelStore.resolve("test_model"); + Files.createDirectories(modelDir); + Path notModel = modelStore.resolve("non-model"); + Files.createFile(notModel); + + String url = server.mapModelUrl(notModel); // not a model dir + Assert.assertNull(url); + + url = server.mapModelUrl(modelDir); // empty folder + Assert.assertNull(url); + + String expected = modelDir.toUri().toURL().toString(); + + Path dlr = modelDir.resolve("test_model.so"); + Files.createFile(dlr); + url = server.mapModelUrl(modelDir); + Assert.assertEquals(url, "test_model::DLR:*=" + expected); + + Path xgb = modelDir.resolve("test_model.json"); + Files.createFile(xgb); + url = server.mapModelUrl(modelDir); + Assert.assertEquals(url, "test_model::XGBoost:*=" + expected); + + Path paddle = modelDir.resolve("__model__"); + Files.createFile(paddle); + url = server.mapModelUrl(modelDir); + Assert.assertEquals(url, "test_model::PaddlePaddle:*=" + expected); + + Path tflite = modelDir.resolve("test_model.tflite"); + Files.createFile(tflite); + url = server.mapModelUrl(modelDir); + Assert.assertEquals(url, "test_model::TFLite:*=" + expected); + + Path tensorRt = modelDir.resolve("test_model.uff"); + Files.createFile(tensorRt); + url = server.mapModelUrl(modelDir); + Assert.assertEquals(url, "test_model::TensorRT:*=" + expected); + + Path onnx = modelDir.resolve("test_model.onnx"); + Files.createFile(onnx); + url = server.mapModelUrl(modelDir); + Assert.assertEquals(url, "test_model::OnnxRuntime:*=" + expected); + + Path mxnet = modelDir.resolve("test_model-symbol.json"); + Files.createFile(mxnet); + url = server.mapModelUrl(modelDir); + Assert.assertEquals(url, "test_model::MXNet:*=" + expected); + + Path tensorflow = modelDir.resolve("saved_model.pb"); + Files.createFile(tensorflow); + url = server.mapModelUrl(modelDir); + Assert.assertEquals(url, "test_model::TensorFlow:*=" + expected); + + Path pytorch = modelDir.resolve("test_model.pt"); + Files.createFile(pytorch); + url = server.mapModelUrl(modelDir); + Assert.assertEquals(url, "test_model::PyTorch:*=" + expected); + + Path prop = modelDir.resolve("serving.properties"); + try (BufferedWriter writer = Files.newBufferedWriter(prop)) { + writer.write("engine=MyEngine"); + } + url = server.mapModelUrl(modelDir); + Assert.assertEquals(url, "test_model::MyEngine:*=" + expected); + + Path mar = modelStore.resolve("torchServe.mar"); + Path torchServe = modelStore.resolve("torchServe"); + Files.createDirectories(torchServe.resolve("MAR-INF")); + ZipUtils.zip(torchServe, mar, false); + + url = server.mapModelUrl(mar); + Assert.assertEquals(url, "torchServe::Python:*=" + mar.toUri().toURL()); + } + @Test public void test() throws InterruptedException, HttpPostRequestEncoder.ErrorDataEncoderException, diff --git a/serving/src/test/resources/config.properties b/serving/src/test/resources/config.properties index 9151e2f22..14d64bfc6 100644 --- a/serving/src/test/resources/config.properties +++ b/serving/src/test/resources/config.properties @@ -2,7 +2,7 @@ inference_address=https://127.0.0.1:8443 management_address=https://127.0.0.1:8443 # management_address=unix:/tmp/management.sock -# model_store=models +model_store=build/models load_models=https://resources.djl.ai/test-models/mlp.tar.gz,[mlp:v1:MXNet:*]=https://resources.djl.ai/test-models/mlp.tar.gz # model_url_pattern=.* # number_of_netty_threads=0