Skip to content

Commit

Permalink
Refactor initialize (deepjavalibrary#675)
Browse files Browse the repository at this point in the history
  • Loading branch information
stu1130 committed Feb 26, 2021
1 parent 280693a commit 9fc3f65
Show file tree
Hide file tree
Showing 20 changed files with 198 additions and 223 deletions.
5 changes: 2 additions & 3 deletions api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,12 @@ public NDList forward(
* @param manager the NDManager to initialize the parameters
* @param dataType the datatype of the parameters
* @param inputShapes the shapes of the inputs to the block
* @return the shapes of the outputs of the block
*/
@Override
public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
beforeInitialize(inputShapes);
encoder.initialize(manager, dataType, inputShapes[0]);
return decoder.initialize(manager, dataType, inputShapes[1]);
decoder.initialize(manager, dataType, inputShapes[1]);
}

/** {@inheritDoc} */
Expand Down
115 changes: 32 additions & 83 deletions api/src/main/java/ai/djl/nn/AbstractBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.function.Function;
import java.util.function.Predicate;

/**
Expand All @@ -44,8 +43,7 @@
* <ul>
* <li>Define a version for serializing parameter and metadata and pass it to the parent
* constructor
* <li>Use {@link AbstractBlock#addParameter(Parameter, Shape)} or {@link
* AbstractBlock#addParameter(Parameter, Function)} to add parameters to your block in the
* <li>Use {@link AbstractBlock#addParameter(Parameter)} to add parameters to your block in the
* constructor if necessary.
* <li>Use {@link AbstractBlock#addChildBlock(String, Block)} to add child blocks if necessary.
* <li>Override {@link AbstractBlock#getOutputShapes(NDManager, Shape[])} to determine the shape
Expand All @@ -62,9 +60,9 @@
* </ul>
*
* <p>If you use {@link AbstractBlock#addParameter(Parameter)} to add parameters, you have to take
* care of parameter initialization yourself. In this case, you need to override {@link
* AbstractBlock#getParameterShape(String, Shape[])} to determine the shape of your parameters. If
* you use the other variants of {@code addParameter} this is done for you.
* care of parameter initialization yourself. In this case, you need to setShape to your parameters
* if you know the shape of Parameter or you can implement prepare to setShape when you see the
* input shape.
*/
// Using LinkedHashMap instead of Map is intentional: we want to make sure that consumers
// of this API know the children and parameters are always iterated over in insertion order.
Expand Down Expand Up @@ -100,14 +98,6 @@ public abstract class AbstractBlock implements Block {
*/
protected LinkedHashMap<String, Parameter> parameters = new LinkedHashMap<>();

/**
* Callbacks to determine the shape of a parameter. Values may be null in which case extending
* classes need to override {@link Block#getParameterShape(String, Shape[])} and implement
* parameter shape resolution manually.
*/
protected LinkedHashMap<String, Function<Shape[], Shape>> parameterShapeCallbacks =
new LinkedHashMap<>();

/**
* Builds an empty block with the given version for parameter serialization.
*
Expand Down Expand Up @@ -153,73 +143,20 @@ protected final <B extends Block> B addChildBlock(String name, B block) {
return block;
}

/**
* Adds a parameter to this block. If parameters are added with this method, subclasses need to
* override {@link Block#getParameterShape(String, Shape[])} and return the shapes of parameters
* themselves.
*
* @param parameter the parameter to add, not null
* @param <P> the specific parameter subclass
* @return the parameter passed as arguments to make it easier to create and assign paramters in
* one line
*/
protected final <P extends Parameter> P addParameter(P parameter) {
return addParameter(parameter, (Function<Shape[], Shape>) null);
}

/**
* Adds a parameter to this block. If parameters are added with this method, intialization of
* the parameter works out of the box
*
* @param parameter the parameter to add, not null
* @param shape the shape of the parameter
* @param <P> the specific parameter subclass
* @return the parameter passed as arguments to make it easier to create and assign paramters in
* one line
*/
protected final <P extends Parameter> P addParameter(P parameter, Shape shape) {
return addParameter(parameter, (inputShapes) -> shape);
}

/**
* Adds a parameter to this block. If parameters are added with this method, intialization of
* the parameter works out of the box
*
* @param parameter the parameter to add, not null
* @param shapeCallback the method to call once the input shape of this block is known to
* determine the shape of the given parameter
* @param <P> the specific parameter subclass
* @return the parameter passed as arguments to make it easier to create and assign parameters
* in one line
*/
protected final <P extends Parameter> P addParameter(
P parameter, Function<Shape[], Shape> shapeCallback) {
protected final <P extends Parameter> P addParameter(P parameter) {
parameters.put(parameter.getName(), parameter);
parameterShapeCallbacks.put(parameter.getName(), shapeCallback);
return parameter;
}

/** {@inheritDoc} */
@Override
public Shape getParameterShape(String name, Shape[] inputShapes) {
Function<Shape[], Shape> callback = parameterShapeCallbacks.get(name);
if (callback == null) {
Parameter parameter = parameters.get(name);
if (parameter == null) {
throw new IllegalArgumentException(
"No parameter named " + name + " found in this block.");
} else {
throw new IllegalStateException(
"No shape initializer for parameter "
+ name
+ "found. "
+ "Either pass an initializer for the shape when adding the "
+ "parameter or override getParameterShape in the subclass.");
}
}
return callback.apply(inputShapes);
}

/** {@inheritDoc} */
@Override
public BlockList getChildren() {
Expand Down Expand Up @@ -271,13 +208,34 @@ public void setInitializer(Initializer initializer, Predicate<Parameter> predica

/** {@inheritDoc} */
@Override
public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
beforeInitialize(inputShapes);
// if parameters are initialized, skip it
if (!isInitialized()) {
// setShape for all params
prepare(inputShapes);
}
for (Parameter parameter : parameters.values()) {
parameter.initialize(manager, dataType, inputShapes);
parameter.initialize(manager, dataType);
}
initializeChildBlocks(manager, dataType, inputShapes);
return getOutputShapes(manager, inputShapes);
}

/**
* Performs any action necessary before initialization. For example, keep the input information
* or verify the layout.
*
* @param inputShapes the expected shapes of the input
*/
protected void beforeInitialize(Shape... inputShapes) {
if (inputNames.isEmpty()) {
// automatically assign input names
inputNames = new ArrayList<>();
for (int i = 0; i < inputShapes.length; ++i) {
inputNames.add("data" + i);
}
}
this.inputShapes = inputShapes;
}

/**
Expand Down Expand Up @@ -320,20 +278,11 @@ public ParameterList getDirectParameters() {
}

/**
* Performs any action necessary before initialization.
* Sets the shape of {@link Parameter}s.
*
* @param inputShapes the expected shapes of the input
* @param inputShapes the shapes of inputs
*/
protected void beforeInitialize(Shape[] inputShapes) {
if (inputNames.isEmpty()) {
// automatically assign input names
inputNames = new ArrayList<>();
for (int i = 0; i < inputShapes.length; ++i) {
inputNames.add("data" + i);
}
}
this.inputShapes = inputShapes;
}
protected void prepare(Shape[] inputShapes) {}

/** {@inheritDoc} */
@Override
Expand Down
14 changes: 1 addition & 13 deletions api/src/main/java/ai/djl/nn/Block.java
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,8 @@ default NDList forward(
* @param manager the NDManager to initialize the parameters
* @param dataType the datatype of the parameters
* @param inputShapes the shapes of the inputs to the block
* @return the shapes of the outputs of the block
*/
Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes);
void initialize(NDManager manager, DataType dataType, Shape... inputShapes);

/**
* Returns a boolean whether the block is initialized.
Expand Down Expand Up @@ -242,17 +241,6 @@ default NDList forward(
*/
ParameterList getParameters();

/**
* Returns the shape of the specified direct parameter of this block given the shapes of the
* input to the block.
*
* @param name the name of the parameter
* @param inputShapes the shapes of the input to the block
* @return the shape of the parameter specified
* @throws IllegalArgumentException if the parameter name specified is invalid
*/
Shape getParameterShape(String name, Shape[] inputShapes);

/**
* Returns the expected output shapes of the block for the specified input shapes.
*
Expand Down
47 changes: 33 additions & 14 deletions api/src/main/java/ai/djl/nn/Parameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Objects;
import java.util.UUID;

/**
Expand All @@ -42,7 +43,7 @@ public class Parameter implements AutoCloseable {

private String id;
private String name;
private Block block;
private Shape shape;
private Type type;
private Initializer initializer;
private NDArray array;
Expand All @@ -52,7 +53,7 @@ public class Parameter implements AutoCloseable {
Parameter(Builder builder) {
this.id = UUID.randomUUID().toString();
this.name = builder.name;
this.block = builder.block;
this.shape = builder.shape;
this.type = builder.type;
this.array = builder.array;
this.requiresGrad = builder.requiresGrad;
Expand Down Expand Up @@ -94,10 +95,26 @@ public Type getType() {
* @param array the {@link NDArray} that contains values of this {@code Parameter}
*/
public void setArray(NDArray array) {
if (shape != null) {
throw new IllegalStateException("array has been set! Use either setArray or setShape");
}
this.array = array;
shape = array.getShape();
array.setName(name);
}

/**
* Sets the shape of this {@code Parameter}.
*
* @param shape the shape of this {@code Parameter}
*/
public void setShape(Shape shape) {
if (array != null) {
throw new IllegalStateException("array has been set! Use either setArray or setShape");
}
this.shape = shape;
}

/**
* Gets the values of this {@code Parameter} as an {@link NDArray}.
*
Expand Down Expand Up @@ -144,11 +161,11 @@ public void setInitializer(Initializer initializer) {
*
* @param manager an NDManager to create the arrays
* @param dataType the datatype of the {@code Parameter}
* @param inputShapes the expected input shapes
*/
public void initialize(NDManager manager, DataType dataType, Shape[] inputShapes) {
public void initialize(NDManager manager, DataType dataType) {
Objects.requireNonNull(initializer, "No initializer has been set");
Objects.requireNonNull(shape, "No parameter shape has been set");
if (!isInitialized()) {
Shape shape = block.getParameterShape(name, inputShapes);
array = initializer.initialize(manager, shape, dataType);
array.setName(name);
}
Expand Down Expand Up @@ -210,6 +227,8 @@ public void load(NDManager manager, DataInputStream dis)
}

array = manager.decode(dis);
// set the shape of the parameter and prepare() can be skipped
shape = array.getShape();
}

/** {@inheritDoc} */
Expand Down Expand Up @@ -264,7 +283,7 @@ public Initializer getInitializer() {
/** A Builder to construct a {@code Parameter}. */
public static final class Builder {
String name;
Block block;
Shape shape;
Type type;
Initializer initializer;
NDArray array;
Expand All @@ -283,24 +302,24 @@ public Builder setName(String name) {
}

/**
* Sets the block of the {@code Parameter}.
* Sets the {@code Type} of the {@code Parameter}.
*
* @param block the block of the {@code Parameter}
* @param type the {@code Type} of the {@code Parameter}
* @return this {@code Parameter}
*/
public Builder setBlock(Block block) {
this.block = block;
public Builder setType(Type type) {
this.type = type;
return this;
}

/**
* Sets the {@code Type} of the {@code Parameter}.
* Sets the shape of the {@code Parameter}.
*
* @param type the {@code Type} of the {@code Parameter}
* @param shape the shape of the {@code Parameter}
* @return this {@code Parameter}
*/
public Builder setType(Type type) {
this.type = type;
public Builder optShape(Shape shape) {
this.shape = shape;
return this;
}

Expand Down
3 changes: 2 additions & 1 deletion api/src/main/java/ai/djl/nn/SequentialBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ protected NDList forwardInternal(
public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {
Shape[] shapes = inputShapes;
for (Block child : getChildren().values()) {
shapes = child.initialize(manager, dataType, shapes);
child.initialize(manager, dataType, shapes);
shapes = child.getOutputShapes(manager, shapes);
}
}

Expand Down
Loading

0 comments on commit 9fc3f65

Please sign in to comment.