From 68b80dad1617514ac511c51988092c4f591b6442 Mon Sep 17 00:00:00 2001 From: Lanking Date: Wed, 31 Mar 2021 15:11:01 -0700 Subject: [PATCH] add finding BlockFactory feature in model loading --- api/src/main/java/ai/djl/BaseModel.java | 9 ++ api/src/main/java/ai/djl/util/Utils.java | 111 ++++++++++++++++++ .../tests/nn/BlockFactoryTest.java | 63 +++++++++- .../java/ai/djl/mxnet/engine/MxModel.java | 4 + .../java/ai/djl/pytorch/engine/PtModel.java | 5 + 5 files changed, 191 insertions(+), 1 deletion(-) diff --git a/api/src/main/java/ai/djl/BaseModel.java b/api/src/main/java/ai/djl/BaseModel.java index d6fa016fc18..fc58ed0c12c 100644 --- a/api/src/main/java/ai/djl/BaseModel.java +++ b/api/src/main/java/ai/djl/BaseModel.java @@ -19,6 +19,7 @@ import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Block; +import ai.djl.nn.BlockFactory; import ai.djl.nn.SymbolBlock; import ai.djl.training.ParameterStore; import ai.djl.training.Trainer; @@ -214,6 +215,14 @@ protected void setModelDir(Path modelDir) { this.modelDir = modelDir.toAbsolutePath(); } + protected Block loadFromBlockFactory() { + BlockFactory factory = Utils.findImplementation(modelDir, null); + if (factory == null) { + return null; + } + return factory.newBlock(manager); + } + /** {@inheritDoc} */ @Override public void save(Path modelPath, String newModelName) throws IOException { diff --git a/api/src/main/java/ai/djl/util/Utils.java b/api/src/main/java/ai/djl/util/Utils.java index 94c203dcff7..d1c943aba57 100644 --- a/api/src/main/java/ai/djl/util/Utils.java +++ b/api/src/main/java/ai/djl/util/Utils.java @@ -15,27 +15,38 @@ import ai.djl.ndarray.NDArray; import ai.djl.nn.Parameter; import java.io.ByteArrayOutputStream; +import java.io.File; import java.io.IOException; import java.io.InputStream; +import java.lang.reflect.Constructor; +import java.net.URL; +import java.net.URLClassLoader; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.nio.file.StandardCopyOption; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.Comparator; +import java.util.Enumeration; import java.util.List; import java.util.Objects; import java.util.Scanner; +import java.util.jar.JarEntry; +import java.util.jar.JarFile; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** A class containing utility methods. */ public final class Utils { + private static final Logger logger = LoggerFactory.getLogger(Utils.class); + private Utils() {} /** @@ -349,4 +360,104 @@ public static Path getCacheDir() { } return Paths.get(cacheDir); } + + /** + * scan classes files from a path to see if there is a matching implementation for a class. + * + *

For .class file, this function expects them in classes/your/package/ClassName.class + * + * @param path the path to scan from + * @param className the name of the classes, pass null if name is unknown + * @param the Template T for the output Class + * @return the Class implementation + */ + public static T findImplementation(Path path, String className) { + try { + Path classesDir = path.resolve("classes"); + // we only consider .class files and skip .java files + List jarFiles = + Files.list(path) + .filter(p -> p.toString().endsWith(".jar")) + .collect(Collectors.toList()); + List urls = new ArrayList<>(jarFiles.size() + 1); + urls.add(classesDir.toUri().toURL()); + for (Path p : jarFiles) { + urls.add(p.toUri().toURL()); + } + + ClassLoader parentCl = Thread.currentThread().getContextClassLoader(); + ClassLoader cl = new URLClassLoader(urls.toArray(new URL[0]), parentCl); + if (className != null && !className.isEmpty()) { + return initClass(cl, className); + } + + T implemented = scanDirectory(cl, classesDir); + if (implemented != null) { + return implemented; + } + + for (Path p : jarFiles) { + implemented = scanJarFile(cl, p); + if (implemented != null) { + return implemented; + } + } + } catch (IOException e) { + logger.debug("Failed to find Translator", e); + } + return null; + } + + private static T scanDirectory(ClassLoader cl, Path dir) throws IOException { + if (!Files.isDirectory(dir)) { + logger.debug("Directory not exists: {}", dir); + return null; + } + Collection files = + Files.walk(dir) + .filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".class")) + .collect(Collectors.toList()); + for (Path file : files) { + Path p = dir.relativize(file); + String className = p.toString(); + className = className.substring(0, className.lastIndexOf('.')); + className = className.replace(File.separatorChar, '.'); + T implemented = initClass(cl, className); + if (implemented != null) { + return implemented; + } + } + return null; + } + + private static T scanJarFile(ClassLoader cl, Path path) throws IOException { + try (JarFile jarFile = new JarFile(path.toFile())) { + Enumeration en = jarFile.entries(); + while (en.hasMoreElements()) { + JarEntry entry = en.nextElement(); + String fileName = entry.getName(); + if (fileName.endsWith(".class")) { + fileName = fileName.substring(0, fileName.lastIndexOf('.')); + fileName = fileName.replace('/', '.'); + T implemented = initClass(cl, fileName); + if (implemented != null) { + return implemented; + } + } + } + } + return null; + } + + @SuppressWarnings("unchecked") + private static T initClass(ClassLoader cl, String className) { + try { + Class clazz = Class.forName(className, true, cl); + Constructor constructor = (Constructor) clazz.getConstructor(); + return constructor.newInstance(); + } catch (Throwable e) { + logger.trace("Not able to load Object", e); + } + return null; + } } diff --git a/integration/src/main/java/ai/djl/integration/tests/nn/BlockFactoryTest.java b/integration/src/main/java/ai/djl/integration/tests/nn/BlockFactoryTest.java index f65e3052be0..00a12bf2f1a 100644 --- a/integration/src/main/java/ai/djl/integration/tests/nn/BlockFactoryTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/nn/BlockFactoryTest.java @@ -31,16 +31,25 @@ import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ModelZoo; +import ai.djl.repository.zoo.ZooModel; import ai.djl.testing.Assertions; import ai.djl.training.ParameterStore; import ai.djl.training.util.ProgressBar; import ai.djl.translate.NoopTranslator; import ai.djl.translate.TranslateException; +import ai.djl.util.Utils; +import ai.djl.util.ZipUtils; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import org.testng.Assert; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeTest; import org.testng.annotations.Test; public class BlockFactoryTest { @@ -77,7 +86,59 @@ public void testBlockLoadingSaving() } } - static class TestBlockFactory implements BlockFactory { + @Test + public void testBlockFactoryLoadingFromZip() + throws MalformedModelException, ModelNotFoundException, IOException, + TranslateException { + Path savedDir = Paths.get("build/testBlockFactory"); + Path zipPath = prepareModel(savedDir); + // load model from here + Criteria criteria = + Criteria.builder() + .setTypes(NDList.class, NDList.class) + .optModelPath(zipPath) + .optModelName("exported") + .build(); + try (NDManager manager = NDManager.newBaseManager()) { + try (ZooModel model = ModelZoo.loadModel(criteria); + Predictor pred = model.newPredictor()) { + NDList destOut = pred.predict(new NDList(manager.ones(new Shape(1, 3, 32, 32)))); + Assert.assertEquals(destOut.singletonOrThrow().getShape(), new Shape(1, 10)); + } + } + } + + private Path prepareModel(Path savedDir) + throws IOException, ModelNotFoundException, MalformedModelException { + TestBlockFactory factory = new TestBlockFactory(); + Model model = factory.getRemoveLastBlockModel(); + try (NDManager manager = NDManager.newBaseManager()) { + Block block = model.getBlock(); + block.forward( + new ParameterStore(manager, true), + new NDList(manager.ones(new Shape(1, 3, 32, 32))), + true); + model.save(savedDir, "exported"); + } + Path classDir = savedDir.resolve("classes/ai/djl/integration/tests/nn"); + Files.createDirectories(classDir); + Files.copy( + Paths.get( + "build/classes/java/main/ai/djl/integration/tests/nn/BlockFactoryTest$TestBlockFactory.class"), + classDir.resolve("BlockFactoryTest$TestBlockFactory.class")); + Path zipPath = Paths.get("build/testBlockFactory.zip"); + ZipUtils.zip(savedDir, zipPath, false); + return zipPath; + } + + @BeforeTest + @AfterTest + private void cleanUp() { + Utils.deleteQuietly(Paths.get("build/testBlockFactory")); + Utils.deleteQuietly(Paths.get("build/testBlockFactory.zip")); + } + + public static class TestBlockFactory implements BlockFactory { private static final long serialVersionUID = 1234567L; diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java index 5c675a9519c..839bf8703aa 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java @@ -100,6 +100,10 @@ public void load(Path modelPath, String prefix, Map options) } } + if (block == null) { + block = loadFromBlockFactory(); + } + if (block == null) { // load MxSymbolBlock Path symbolFile = modelDir.resolve(prefix + "-symbol.json"); diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java index 20ee7430fb2..1d9fcf587c2 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java @@ -64,6 +64,11 @@ public void load(Path modelPath, String prefix, Map options) if (prefix == null) { prefix = modelName; } + + if (block == null) { + block = loadFromBlockFactory(); + } + if (block == null) { Path modelFile = findModelFile(prefix); if (modelFile == null) {