Skip to content

Commit

Permalink
add finding BlockFactory feature in model loading (#805)
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 authored Apr 11, 2021
1 parent 3c9619a commit ab491ce
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 6 deletions.
11 changes: 11 additions & 0 deletions api/src/main/java/ai/djl/BaseModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.BlockFactory;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.ParameterStore;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.translate.Translator;
import ai.djl.util.ClassLoaderUtils;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import ai.djl.util.Utils;
Expand Down Expand Up @@ -51,6 +53,7 @@ public abstract class BaseModel implements Model {

private static final Logger logger = LoggerFactory.getLogger(BaseModel.class);
private static final int MODEL_VERSION = 1;

protected Path modelDir;
protected Block block;
protected String modelName;
Expand Down Expand Up @@ -214,6 +217,14 @@ protected void setModelDir(Path modelDir) {
this.modelDir = modelDir.toAbsolutePath();
}

protected Block loadFromBlockFactory() {
BlockFactory factory = ClassLoaderUtils.findImplementation(modelDir, null);
if (factory == null) {
return null;
}
return factory.newBlock(manager);
}

/** {@inheritDoc} */
@Override
public void save(Path modelPath, String newModelName) throws IOException {
Expand Down
146 changes: 146 additions & 0 deletions api/src/main/java/ai/djl/util/ClassLoaderUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.util;

import java.io.File;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.net.URL;
import java.net.URLClassLoader;
import java.nio.file.Files;
import java.nio.file.Path;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.Collection;
import java.util.Enumeration;
import java.util.List;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** A utility class that load classes from specific URLs. */
public final class ClassLoaderUtils {

private static final Logger logger = LoggerFactory.getLogger(ClassLoaderUtils.class);

private ClassLoaderUtils() {}

/**
* scan classes files from a path to see if there is a matching implementation for a class.
*
* <p>For .class file, this function expects them in classes/your/package/ClassName.class
*
* @param path the path to scan from
* @param className the name of the classes, pass null if name is unknown
* @param <T> the Template T for the output Class
* @return the Class implementation
*/
public static <T> T findImplementation(Path path, String className) {
try {
Path classesDir = path.resolve("classes");
// we only consider .class files and skip .java files
List<Path> jarFiles =
Files.list(path)
.filter(p -> p.toString().endsWith(".jar"))
.collect(Collectors.toList());
final URL[] urls = new URL[jarFiles.size() + 1];
urls[0] = classesDir.toUri().toURL();
int index = 1;
for (Path p : jarFiles) {
urls[index++] = p.toUri().toURL();
}

ClassLoader cl =
AccessController.doPrivileged(
(PrivilegedAction<ClassLoader>)
() ->
new URLClassLoader(
urls,
Thread.currentThread()
.getContextClassLoader()));
if (className != null && !className.isEmpty()) {
return initClass(cl, className);
}

T implemented = scanDirectory(cl, classesDir);
if (implemented != null) {
return implemented;
}

for (Path p : jarFiles) {
implemented = scanJarFile(cl, p);
if (implemented != null) {
return implemented;
}
}
} catch (IOException e) {
logger.debug("Failed to find Translator", e);
}
return null;
}

private static <T> T scanDirectory(ClassLoader cl, Path dir) throws IOException {
if (!Files.isDirectory(dir)) {
logger.debug("Directory not exists: {}", dir);
return null;
}
Collection<Path> files =
Files.walk(dir)
.filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".class"))
.collect(Collectors.toList());
for (Path file : files) {
Path p = dir.relativize(file);
String className = p.toString();
className = className.substring(0, className.lastIndexOf('.'));
className = className.replace(File.separatorChar, '.');
T implemented = initClass(cl, className);
if (implemented != null) {
return implemented;
}
}
return null;
}

private static <T> T scanJarFile(ClassLoader cl, Path path) throws IOException {
try (JarFile jarFile = new JarFile(path.toFile())) {
Enumeration<JarEntry> en = jarFile.entries();
while (en.hasMoreElements()) {
JarEntry entry = en.nextElement();
String fileName = entry.getName();
if (fileName.endsWith(".class")) {
fileName = fileName.substring(0, fileName.lastIndexOf('.'));
fileName = fileName.replace('/', '.');
T implemented = initClass(cl, fileName);
if (implemented != null) {
return implemented;
}
}
}
}
return null;
}

@SuppressWarnings("unchecked")
private static <T> T initClass(ClassLoader cl, String className) {
try {
Class<?> clazz = Class.forName(className, true, cl);
Constructor<T> constructor = (Constructor<T>) clazz.getConstructor();
return constructor.newInstance();
} catch (Throwable e) {
logger.trace("Not able to load Object", e);
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,29 @@
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.BlockFactory;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.SymbolBlock;
import ai.djl.nn.core.Linear;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.testing.Assertions;
import ai.djl.training.ParameterStore;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.util.Utils;
import ai.djl.util.ZipUtils;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import org.testng.Assert;
import org.testng.annotations.Test;

public class BlockFactoryTest {
Expand Down Expand Up @@ -77,15 +83,62 @@ public void testBlockLoadingSaving()
}
}

static class TestBlockFactory implements BlockFactory {
@Test
public void testBlockFactoryLoadingFromZip()
throws MalformedModelException, ModelNotFoundException, IOException,
TranslateException {
Path savedDir = Paths.get("build/testBlockFactory");
Utils.deleteQuietly(savedDir);
Path zipPath = prepareModel(savedDir);
// load model from here
Criteria<NDList, NDList> criteria =
Criteria.builder()
.setTypes(NDList.class, NDList.class)
.optModelPath(zipPath)
.optModelName("exported")
.build();
try (NDManager manager = NDManager.newBaseManager()) {
try (ZooModel<NDList, NDList> model = ModelZoo.loadModel(criteria);
Predictor<NDList, NDList> pred = model.newPredictor()) {
NDList destOut = pred.predict(new NDList(manager.ones(new Shape(1, 3, 32, 32))));
Assert.assertEquals(destOut.singletonOrThrow().getShape(), new Shape(1, 10));
}
}
}

private Path prepareModel(Path savedDir)
throws IOException, ModelNotFoundException, MalformedModelException {
TestBlockFactory factory = new TestBlockFactory();
Model model = factory.getRemoveLastBlockModel();
try (NDManager manager = NDManager.newBaseManager()) {
Block block = model.getBlock();
block.forward(
new ParameterStore(manager, true),
new NDList(manager.ones(new Shape(1, 3, 32, 32))),
true);
model.save(savedDir, "exported");
}
Path classDir = savedDir.resolve("classes/ai/djl/integration/tests/nn");
Files.createDirectories(classDir);
Files.copy(
Paths.get(
"build/classes/java/main/ai/djl/integration/tests/nn/BlockFactoryTest$TestBlockFactory.class"),
classDir.resolve("BlockFactoryTest$TestBlockFactory.class"));
Path zipPath =
Paths.get("build/testBlockFactory" + Engine.getInstance().getEngineName() + ".zip");
Files.deleteIfExists(zipPath);
ZipUtils.zip(savedDir, zipPath, false);
return zipPath;
}

public static class TestBlockFactory implements BlockFactory {

private static final long serialVersionUID = 1234567L;

@Override
public Block newBlock(NDManager manager) {
SequentialBlock newBlock = new SequentialBlock();
newBlock.add(SymbolBlock.newInstance(manager));
newBlock.add(Blocks.batchFlattenBlock());
newBlock.add(Linear.builder().setUnits(10).build());
return newBlock;
}
Expand All @@ -105,9 +158,7 @@ public Model getRemoveLastBlockModel()
Model model = ModelZoo.loadModel(builder.build());
SequentialBlock newBlock = new SequentialBlock();
SymbolBlock block = (SymbolBlock) model.getBlock();
block.removeLastBlock();
newBlock.add(block);
newBlock.add(Blocks.batchFlattenBlock());
newBlock.add(Linear.builder().setUnits(10).build());
model.setBlock(newBlock);
return model;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
}
}

if (block == null) {
block = loadFromBlockFactory();
}

if (block == null) {
// load MxSymbolBlock
Path symbolFile = modelDir.resolve(prefix + "-symbol.json");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
if (prefix == null) {
prefix = modelName;
}

if (block == null) {
block = loadFromBlockFactory();
}

if (block == null) {
Path modelFile = findModelFile(prefix);
if (modelFile == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,15 @@ public PairList<String, Shape> describeOutput() {
/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(Shape[] inputShapes) {
return new Shape[0];
try (NDManager manager = NDManager.newBaseManager()) {
NDList list = new NDList();
// TODO: Only tested for float32
for (Shape shape : inputShapes) {
list.add(manager.ones(shape));
}
NDList result = forwardInternal(new ParameterStore(manager, false), list, false, null);
return result.stream().map(NDArray::getShape).toArray(Shape[]::new);
}
}

/** {@inheritDoc} */
Expand Down

0 comments on commit ab491ce

Please sign in to comment.