Skip to content

Commit

Permalink
add finding BlockFactory feature in model loading
Browse files Browse the repository at this point in the history
  • Loading branch information
Lanking committed Apr 1, 2021
1 parent 10689d7 commit 68b80da
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 1 deletion.
9 changes: 9 additions & 0 deletions api/src/main/java/ai/djl/BaseModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.BlockFactory;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.ParameterStore;
import ai.djl.training.Trainer;
Expand Down Expand Up @@ -214,6 +215,14 @@ protected void setModelDir(Path modelDir) {
this.modelDir = modelDir.toAbsolutePath();
}

protected Block loadFromBlockFactory() {
BlockFactory factory = Utils.findImplementation(modelDir, null);
if (factory == null) {
return null;
}
return factory.newBlock(manager);
}

/** {@inheritDoc} */
@Override
public void save(Path modelPath, String newModelName) throws IOException {
Expand Down
111 changes: 111 additions & 0 deletions api/src/main/java/ai/djl/util/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,38 @@
import ai.djl.ndarray.NDArray;
import ai.djl.nn.Parameter;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Constructor;
import java.net.URL;
import java.net.URLClassLoader;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.Enumeration;
import java.util.List;
import java.util.Objects;
import java.util.Scanner;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** A class containing utility methods. */
public final class Utils {

private static final Logger logger = LoggerFactory.getLogger(Utils.class);

private Utils() {}

/**
Expand Down Expand Up @@ -349,4 +360,104 @@ public static Path getCacheDir() {
}
return Paths.get(cacheDir);
}

/**
* scan classes files from a path to see if there is a matching implementation for a class.
*
* <p>For .class file, this function expects them in classes/your/package/ClassName.class
*
* @param path the path to scan from
* @param className the name of the classes, pass null if name is unknown
* @param <T> the Template T for the output Class
* @return the Class implementation
*/
public static <T> T findImplementation(Path path, String className) {
try {
Path classesDir = path.resolve("classes");
// we only consider .class files and skip .java files
List<Path> jarFiles =
Files.list(path)
.filter(p -> p.toString().endsWith(".jar"))
.collect(Collectors.toList());
List<URL> urls = new ArrayList<>(jarFiles.size() + 1);
urls.add(classesDir.toUri().toURL());
for (Path p : jarFiles) {
urls.add(p.toUri().toURL());
}

ClassLoader parentCl = Thread.currentThread().getContextClassLoader();
ClassLoader cl = new URLClassLoader(urls.toArray(new URL[0]), parentCl);
if (className != null && !className.isEmpty()) {
return initClass(cl, className);
}

T implemented = scanDirectory(cl, classesDir);
if (implemented != null) {
return implemented;
}

for (Path p : jarFiles) {
implemented = scanJarFile(cl, p);
if (implemented != null) {
return implemented;
}
}
} catch (IOException e) {
logger.debug("Failed to find Translator", e);
}
return null;
}

private static <T> T scanDirectory(ClassLoader cl, Path dir) throws IOException {
if (!Files.isDirectory(dir)) {
logger.debug("Directory not exists: {}", dir);
return null;
}
Collection<Path> files =
Files.walk(dir)
.filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".class"))
.collect(Collectors.toList());
for (Path file : files) {
Path p = dir.relativize(file);
String className = p.toString();
className = className.substring(0, className.lastIndexOf('.'));
className = className.replace(File.separatorChar, '.');
T implemented = initClass(cl, className);
if (implemented != null) {
return implemented;
}
}
return null;
}

private static <T> T scanJarFile(ClassLoader cl, Path path) throws IOException {
try (JarFile jarFile = new JarFile(path.toFile())) {
Enumeration<JarEntry> en = jarFile.entries();
while (en.hasMoreElements()) {
JarEntry entry = en.nextElement();
String fileName = entry.getName();
if (fileName.endsWith(".class")) {
fileName = fileName.substring(0, fileName.lastIndexOf('.'));
fileName = fileName.replace('/', '.');
T implemented = initClass(cl, fileName);
if (implemented != null) {
return implemented;
}
}
}
}
return null;
}

@SuppressWarnings("unchecked")
private static <T> T initClass(ClassLoader cl, String className) {
try {
Class<?> clazz = Class.forName(className, true, cl);
Constructor<T> constructor = (Constructor<T>) clazz.getConstructor();
return constructor.newInstance();
} catch (Throwable e) {
logger.trace("Not able to load Object", e);
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,25 @@
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.testing.Assertions;
import ai.djl.training.ParameterStore;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.util.Utils;
import ai.djl.util.ZipUtils;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import org.testng.Assert;
import org.testng.annotations.AfterTest;
import org.testng.annotations.BeforeTest;
import org.testng.annotations.Test;

public class BlockFactoryTest {
Expand Down Expand Up @@ -77,7 +86,59 @@ public void testBlockLoadingSaving()
}
}

static class TestBlockFactory implements BlockFactory {
@Test
public void testBlockFactoryLoadingFromZip()
throws MalformedModelException, ModelNotFoundException, IOException,
TranslateException {
Path savedDir = Paths.get("build/testBlockFactory");
Path zipPath = prepareModel(savedDir);
// load model from here
Criteria<NDList, NDList> criteria =
Criteria.builder()
.setTypes(NDList.class, NDList.class)
.optModelPath(zipPath)
.optModelName("exported")
.build();
try (NDManager manager = NDManager.newBaseManager()) {
try (ZooModel<NDList, NDList> model = ModelZoo.loadModel(criteria);
Predictor<NDList, NDList> pred = model.newPredictor()) {
NDList destOut = pred.predict(new NDList(manager.ones(new Shape(1, 3, 32, 32))));
Assert.assertEquals(destOut.singletonOrThrow().getShape(), new Shape(1, 10));
}
}
}

private Path prepareModel(Path savedDir)
throws IOException, ModelNotFoundException, MalformedModelException {
TestBlockFactory factory = new TestBlockFactory();
Model model = factory.getRemoveLastBlockModel();
try (NDManager manager = NDManager.newBaseManager()) {
Block block = model.getBlock();
block.forward(
new ParameterStore(manager, true),
new NDList(manager.ones(new Shape(1, 3, 32, 32))),
true);
model.save(savedDir, "exported");
}
Path classDir = savedDir.resolve("classes/ai/djl/integration/tests/nn");
Files.createDirectories(classDir);
Files.copy(
Paths.get(
"build/classes/java/main/ai/djl/integration/tests/nn/BlockFactoryTest$TestBlockFactory.class"),
classDir.resolve("BlockFactoryTest$TestBlockFactory.class"));
Path zipPath = Paths.get("build/testBlockFactory.zip");
ZipUtils.zip(savedDir, zipPath, false);
return zipPath;
}

@BeforeTest
@AfterTest
private void cleanUp() {
Utils.deleteQuietly(Paths.get("build/testBlockFactory"));
Utils.deleteQuietly(Paths.get("build/testBlockFactory.zip"));
}

public static class TestBlockFactory implements BlockFactory {

private static final long serialVersionUID = 1234567L;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
}
}

if (block == null) {
block = loadFromBlockFactory();
}

if (block == null) {
// load MxSymbolBlock
Path symbolFile = modelDir.resolve(prefix + "-symbol.json");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
if (prefix == null) {
prefix = modelName;
}

if (block == null) {
block = loadFromBlockFactory();
}

if (block == null) {
Path modelFile = findModelFile(prefix);
if (modelFile == null) {
Expand Down

0 comments on commit 68b80da

Please sign in to comment.