Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add finding BlockFactory feature in model loading #805

Merged
merged 4 commits into from
Apr 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions api/src/main/java/ai/djl/BaseModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.BlockFactory;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.ParameterStore;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.translate.Translator;
import ai.djl.util.ClassLoaderUtils;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import ai.djl.util.Utils;
Expand Down Expand Up @@ -51,6 +53,7 @@ public abstract class BaseModel implements Model {

private static final Logger logger = LoggerFactory.getLogger(BaseModel.class);
private static final int MODEL_VERSION = 1;

protected Path modelDir;
protected Block block;
protected String modelName;
Expand Down Expand Up @@ -214,6 +217,14 @@ protected void setModelDir(Path modelDir) {
this.modelDir = modelDir.toAbsolutePath();
}

protected Block loadFromBlockFactory() {
BlockFactory factory = ClassLoaderUtils.findImplementation(modelDir, null);
if (factory == null) {
return null;
}
return factory.newBlock(manager);
}

/** {@inheritDoc} */
@Override
public void save(Path modelPath, String newModelName) throws IOException {
Expand Down
146 changes: 146 additions & 0 deletions api/src/main/java/ai/djl/util/ClassLoaderUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.util;

import java.io.File;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.net.URL;
import java.net.URLClassLoader;
import java.nio.file.Files;
import java.nio.file.Path;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.Collection;
import java.util.Enumeration;
import java.util.List;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** A utility class that load classes from specific URLs. */
public final class ClassLoaderUtils {

private static final Logger logger = LoggerFactory.getLogger(ClassLoaderUtils.class);

private ClassLoaderUtils() {}

/**
* scan classes files from a path to see if there is a matching implementation for a class.
*
* <p>For .class file, this function expects them in classes/your/package/ClassName.class
*
* @param path the path to scan from
* @param className the name of the classes, pass null if name is unknown
* @param <T> the Template T for the output Class
* @return the Class implementation
*/
public static <T> T findImplementation(Path path, String className) {
try {
Path classesDir = path.resolve("classes");
// we only consider .class files and skip .java files
List<Path> jarFiles =
Files.list(path)
.filter(p -> p.toString().endsWith(".jar"))
.collect(Collectors.toList());
final URL[] urls = new URL[jarFiles.size() + 1];
urls[0] = classesDir.toUri().toURL();
int index = 1;
for (Path p : jarFiles) {
urls[index++] = p.toUri().toURL();
}

ClassLoader cl =
AccessController.doPrivileged(
(PrivilegedAction<ClassLoader>)
() ->
new URLClassLoader(
urls,
Thread.currentThread()
.getContextClassLoader()));
if (className != null && !className.isEmpty()) {
return initClass(cl, className);
}

T implemented = scanDirectory(cl, classesDir);
if (implemented != null) {
return implemented;
}

for (Path p : jarFiles) {
implemented = scanJarFile(cl, p);
if (implemented != null) {
return implemented;
}
}
} catch (IOException e) {
logger.debug("Failed to find Translator", e);
}
return null;
}

private static <T> T scanDirectory(ClassLoader cl, Path dir) throws IOException {
if (!Files.isDirectory(dir)) {
logger.debug("Directory not exists: {}", dir);
return null;
}
Collection<Path> files =
Files.walk(dir)
.filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".class"))
.collect(Collectors.toList());
for (Path file : files) {
Path p = dir.relativize(file);
String className = p.toString();
className = className.substring(0, className.lastIndexOf('.'));
className = className.replace(File.separatorChar, '.');
T implemented = initClass(cl, className);
if (implemented != null) {
return implemented;
}
}
return null;
}

private static <T> T scanJarFile(ClassLoader cl, Path path) throws IOException {
try (JarFile jarFile = new JarFile(path.toFile())) {
Enumeration<JarEntry> en = jarFile.entries();
while (en.hasMoreElements()) {
JarEntry entry = en.nextElement();
String fileName = entry.getName();
if (fileName.endsWith(".class")) {
fileName = fileName.substring(0, fileName.lastIndexOf('.'));
fileName = fileName.replace('/', '.');
T implemented = initClass(cl, fileName);
if (implemented != null) {
return implemented;
}
}
}
}
return null;
}

@SuppressWarnings("unchecked")
private static <T> T initClass(ClassLoader cl, String className) {
try {
Class<?> clazz = Class.forName(className, true, cl);
Constructor<T> constructor = (Constructor<T>) clazz.getConstructor();
return constructor.newInstance();
} catch (Throwable e) {
logger.trace("Not able to load Object", e);
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,29 @@
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.BlockFactory;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.SymbolBlock;
import ai.djl.nn.core.Linear;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.testing.Assertions;
import ai.djl.training.ParameterStore;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.util.Utils;
import ai.djl.util.ZipUtils;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import org.testng.Assert;
import org.testng.annotations.Test;

public class BlockFactoryTest {
Expand Down Expand Up @@ -77,15 +83,62 @@ public void testBlockLoadingSaving()
}
}

static class TestBlockFactory implements BlockFactory {
@Test
public void testBlockFactoryLoadingFromZip()
throws MalformedModelException, ModelNotFoundException, IOException,
TranslateException {
Path savedDir = Paths.get("build/testBlockFactory");
Utils.deleteQuietly(savedDir);
Path zipPath = prepareModel(savedDir);
// load model from here
Criteria<NDList, NDList> criteria =
Criteria.builder()
.setTypes(NDList.class, NDList.class)
.optModelPath(zipPath)
.optModelName("exported")
.build();
try (NDManager manager = NDManager.newBaseManager()) {
try (ZooModel<NDList, NDList> model = ModelZoo.loadModel(criteria);
Predictor<NDList, NDList> pred = model.newPredictor()) {
NDList destOut = pred.predict(new NDList(manager.ones(new Shape(1, 3, 32, 32))));
Assert.assertEquals(destOut.singletonOrThrow().getShape(), new Shape(1, 10));
}
}
}

private Path prepareModel(Path savedDir)
throws IOException, ModelNotFoundException, MalformedModelException {
TestBlockFactory factory = new TestBlockFactory();
Model model = factory.getRemoveLastBlockModel();
try (NDManager manager = NDManager.newBaseManager()) {
Block block = model.getBlock();
block.forward(
new ParameterStore(manager, true),
new NDList(manager.ones(new Shape(1, 3, 32, 32))),
true);
model.save(savedDir, "exported");
}
Path classDir = savedDir.resolve("classes/ai/djl/integration/tests/nn");
Files.createDirectories(classDir);
Files.copy(
Paths.get(
"build/classes/java/main/ai/djl/integration/tests/nn/BlockFactoryTest$TestBlockFactory.class"),
classDir.resolve("BlockFactoryTest$TestBlockFactory.class"));
Path zipPath =
Paths.get("build/testBlockFactory" + Engine.getInstance().getEngineName() + ".zip");
Files.deleteIfExists(zipPath);
ZipUtils.zip(savedDir, zipPath, false);
return zipPath;
}

public static class TestBlockFactory implements BlockFactory {

private static final long serialVersionUID = 1234567L;

@Override
public Block newBlock(NDManager manager) {
SequentialBlock newBlock = new SequentialBlock();
newBlock.add(SymbolBlock.newInstance(manager));
newBlock.add(Blocks.batchFlattenBlock());
newBlock.add(Linear.builder().setUnits(10).build());
return newBlock;
}
Expand All @@ -105,9 +158,7 @@ public Model getRemoveLastBlockModel()
Model model = ModelZoo.loadModel(builder.build());
SequentialBlock newBlock = new SequentialBlock();
SymbolBlock block = (SymbolBlock) model.getBlock();
block.removeLastBlock();
newBlock.add(block);
newBlock.add(Blocks.batchFlattenBlock());
newBlock.add(Linear.builder().setUnits(10).build());
model.setBlock(newBlock);
return model;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
}
}

if (block == null) {
block = loadFromBlockFactory();
}

if (block == null) {
// load MxSymbolBlock
Path symbolFile = modelDir.resolve(prefix + "-symbol.json");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
if (prefix == null) {
prefix = modelName;
}

if (block == null) {
block = loadFromBlockFactory();
}

if (block == null) {
Path modelFile = findModelFile(prefix);
if (modelFile == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,15 @@ public PairList<String, Shape> describeOutput() {
/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(Shape[] inputShapes) {
return new Shape[0];
try (NDManager manager = NDManager.newBaseManager()) {
NDList list = new NDList();
// TODO: Only tested for float32
for (Shape shape : inputShapes) {
list.add(manager.ones(shape));
}
NDList result = forwardInternal(new ParameterStore(manager, false), list, false, null);
return result.stream().map(NDArray::getShape).toArray(Shape[]::new);
}
}

/** {@inheritDoc} */
Expand Down