Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor initialize method #675

Merged
merged 1 commit into from
Feb 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,12 @@ public NDList forward(
* @param manager the NDManager to initialize the parameters
* @param dataType the datatype of the parameters
* @param inputShapes the shapes of the inputs to the block
* @return the shapes of the outputs of the block
*/
@Override
public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
beforeInitialize(inputShapes);
encoder.initialize(manager, dataType, inputShapes[0]);
return decoder.initialize(manager, dataType, inputShapes[1]);
decoder.initialize(manager, dataType, inputShapes[1]);
}

/** {@inheritDoc} */
Expand Down
115 changes: 32 additions & 83 deletions api/src/main/java/ai/djl/nn/AbstractBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.function.Function;
import java.util.function.Predicate;

/**
Expand All @@ -44,8 +43,7 @@
* <ul>
* <li>Define a version for serializing parameter and metadata and pass it to the parent
* constructor
* <li>Use {@link AbstractBlock#addParameter(Parameter, Shape)} or {@link
* AbstractBlock#addParameter(Parameter, Function)} to add parameters to your block in the
* <li>Use {@link AbstractBlock#addParameter(Parameter)} to add parameters to your block in the
* constructor if necessary.
* <li>Use {@link AbstractBlock#addChildBlock(String, Block)} to add child blocks if necessary.
* <li>Override {@link AbstractBlock#getOutputShapes(NDManager, Shape[])} to determine the shape
Expand All @@ -62,9 +60,9 @@
* </ul>
*
* <p>If you use {@link AbstractBlock#addParameter(Parameter)} to add parameters, you have to take
* care of parameter initialization yourself. In this case, you need to override {@link
* AbstractBlock#getParameterShape(String, Shape[])} to determine the shape of your parameters. If
* you use the other variants of {@code addParameter} this is done for you.
* care of parameter initialization yourself. In this case, you need to setShape to your parameters
* if you know the shape of Parameter or you can implement prepare to setShape when you see the
* input shape.
*/
// Using LinkedHashMap instead of Map is intentional: we want to make sure that consumers
// of this API know the children and parameters are always iterated over in insertion order.
Expand Down Expand Up @@ -100,14 +98,6 @@ public abstract class AbstractBlock implements Block {
*/
protected LinkedHashMap<String, Parameter> parameters = new LinkedHashMap<>();

/**
* Callbacks to determine the shape of a parameter. Values may be null in which case extending
* classes need to override {@link Block#getParameterShape(String, Shape[])} and implement
* parameter shape resolution manually.
*/
protected LinkedHashMap<String, Function<Shape[], Shape>> parameterShapeCallbacks =
new LinkedHashMap<>();

/**
* Builds an empty block with the given version for parameter serialization.
*
Expand Down Expand Up @@ -153,73 +143,20 @@ protected final <B extends Block> B addChildBlock(String name, B block) {
return block;
}

/**
* Adds a parameter to this block. If parameters are added with this method, subclasses need to
* override {@link Block#getParameterShape(String, Shape[])} and return the shapes of parameters
* themselves.
*
* @param parameter the parameter to add, not null
* @param <P> the specific parameter subclass
* @return the parameter passed as arguments to make it easier to create and assign paramters in
* one line
*/
protected final <P extends Parameter> P addParameter(P parameter) {
return addParameter(parameter, (Function<Shape[], Shape>) null);
}

/**
* Adds a parameter to this block. If parameters are added with this method, intialization of
* the parameter works out of the box
*
* @param parameter the parameter to add, not null
* @param shape the shape of the parameter
* @param <P> the specific parameter subclass
* @return the parameter passed as arguments to make it easier to create and assign paramters in
* one line
*/
protected final <P extends Parameter> P addParameter(P parameter, Shape shape) {
return addParameter(parameter, (inputShapes) -> shape);
}

/**
* Adds a parameter to this block. If parameters are added with this method, intialization of
* the parameter works out of the box
*
* @param parameter the parameter to add, not null
* @param shapeCallback the method to call once the input shape of this block is known to
* determine the shape of the given parameter
* @param <P> the specific parameter subclass
* @return the parameter passed as arguments to make it easier to create and assign parameters
* in one line
*/
protected final <P extends Parameter> P addParameter(
P parameter, Function<Shape[], Shape> shapeCallback) {
protected final <P extends Parameter> P addParameter(P parameter) {
parameters.put(parameter.getName(), parameter);
parameterShapeCallbacks.put(parameter.getName(), shapeCallback);
return parameter;
}

/** {@inheritDoc} */
@Override
public Shape getParameterShape(String name, Shape[] inputShapes) {
Function<Shape[], Shape> callback = parameterShapeCallbacks.get(name);
if (callback == null) {
Parameter parameter = parameters.get(name);
if (parameter == null) {
throw new IllegalArgumentException(
"No parameter named " + name + " found in this block.");
} else {
throw new IllegalStateException(
"No shape initializer for parameter "
+ name
+ "found. "
+ "Either pass an initializer for the shape when adding the "
+ "parameter or override getParameterShape in the subclass.");
}
}
return callback.apply(inputShapes);
}

/** {@inheritDoc} */
@Override
public BlockList getChildren() {
Expand Down Expand Up @@ -271,13 +208,34 @@ public void setInitializer(Initializer initializer, Predicate<Parameter> predica

/** {@inheritDoc} */
@Override
public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
beforeInitialize(inputShapes);
// if parameters are initialized, skip it
if (!isInitialized()) {
// setShape for all params
prepare(inputShapes);
}
for (Parameter parameter : parameters.values()) {
parameter.initialize(manager, dataType, inputShapes);
parameter.initialize(manager, dataType);
}
initializeChildBlocks(manager, dataType, inputShapes);
return getOutputShapes(manager, inputShapes);
}

/**
* Performs any action necessary before initialization. For example, keep the input information
* or verify the layout.
*
* @param inputShapes the expected shapes of the input
*/
protected void beforeInitialize(Shape... inputShapes) {
if (inputNames.isEmpty()) {
// automatically assign input names
inputNames = new ArrayList<>();
for (int i = 0; i < inputShapes.length; ++i) {
inputNames.add("data" + i);
}
}
this.inputShapes = inputShapes;
}

/**
Expand Down Expand Up @@ -320,20 +278,11 @@ public ParameterList getDirectParameters() {
}

/**
* Performs any action necessary before initialization.
* Sets the shape of {@link Parameter}s.
*
* @param inputShapes the expected shapes of the input
* @param inputShapes the shapes of inputs
*/
protected void beforeInitialize(Shape[] inputShapes) {
if (inputNames.isEmpty()) {
// automatically assign input names
inputNames = new ArrayList<>();
for (int i = 0; i < inputShapes.length; ++i) {
inputNames.add("data" + i);
}
}
this.inputShapes = inputShapes;
}
protected void prepare(Shape[] inputShapes) {}

/** {@inheritDoc} */
@Override
Expand Down
14 changes: 1 addition & 13 deletions api/src/main/java/ai/djl/nn/Block.java
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,8 @@ default NDList forward(
* @param manager the NDManager to initialize the parameters
* @param dataType the datatype of the parameters
* @param inputShapes the shapes of the inputs to the block
* @return the shapes of the outputs of the block
*/
Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes);
void initialize(NDManager manager, DataType dataType, Shape... inputShapes);

/**
* Returns a boolean whether the block is initialized.
Expand Down Expand Up @@ -238,17 +237,6 @@ default NDList forward(
*/
ParameterList getParameters();

/**
* Returns the shape of the specified direct parameter of this block given the shapes of the
* input to the block.
*
* @param name the name of the parameter
* @param inputShapes the shapes of the input to the block
* @return the shape of the parameter specified
* @throws IllegalArgumentException if the parameter name specified is invalid
*/
Shape getParameterShape(String name, Shape[] inputShapes);

/**
* Returns the expected output shapes of the block for the specified input shapes.
*
Expand Down
47 changes: 33 additions & 14 deletions api/src/main/java/ai/djl/nn/Parameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Objects;
import java.util.UUID;

/**
Expand All @@ -39,7 +40,7 @@ public class Parameter implements AutoCloseable {

private String id;
private String name;
private Block block;
private Shape shape;
private Type type;
private Initializer initializer;
private NDArray array;
Expand All @@ -49,7 +50,7 @@ public class Parameter implements AutoCloseable {
Parameter(Builder builder) {
this.id = UUID.randomUUID().toString();
this.name = builder.name;
this.block = builder.block;
this.shape = builder.shape;
this.type = builder.type;
this.array = builder.array;
this.requiresGrad = builder.requiresGrad;
Expand Down Expand Up @@ -91,10 +92,26 @@ public Type getType() {
* @param array the {@link NDArray} that contains values of this {@code Parameter}
*/
public void setArray(NDArray array) {
if (shape != null) {
throw new IllegalStateException("array has been set! Use either setArray or setShape");
}
this.array = array;
shape = array.getShape();
array.setName(name);
}

/**
* Sets the shape of this {@code Parameter}.
*
* @param shape the shape of this {@code Parameter}
*/
public void setShape(Shape shape) {
if (array != null) {
throw new IllegalStateException("array has been set! Use either setArray or setShape");
}
this.shape = shape;
}

/**
* Gets the values of this {@code Parameter} as an {@link NDArray}.
*
Expand Down Expand Up @@ -141,11 +158,11 @@ public void setInitializer(Initializer initializer) {
*
* @param manager an NDManager to create the arrays
* @param dataType the datatype of the {@code Parameter}
* @param inputShapes the expected input shapes
*/
public void initialize(NDManager manager, DataType dataType, Shape[] inputShapes) {
public void initialize(NDManager manager, DataType dataType) {
Objects.requireNonNull(initializer, "No initializer has been set");
Objects.requireNonNull(shape, "No parameter shape has been set");
if (!isInitialized()) {
Shape shape = block.getParameterShape(name, inputShapes);
array = initializer.initialize(manager, shape, dataType);
array.setName(name);
}
Expand Down Expand Up @@ -207,6 +224,8 @@ public void load(NDManager manager, DataInputStream dis)
}

array = manager.decode(dis);
// set the shape of the parameter and prepare() can be skipped
shape = array.getShape();
}

/** {@inheritDoc} */
Expand Down Expand Up @@ -261,7 +280,7 @@ public Initializer getInitializer() {
/** A Builder to construct a {@code Parameter}. */
public static final class Builder {
String name;
Block block;
Shape shape;
Type type;
Initializer initializer;
NDArray array;
Expand All @@ -280,24 +299,24 @@ public Builder setName(String name) {
}

/**
* Sets the block of the {@code Parameter}.
* Sets the {@code Type} of the {@code Parameter}.
*
* @param block the block of the {@code Parameter}
* @param type the {@code Type} of the {@code Parameter}
* @return this {@code Parameter}
*/
public Builder setBlock(Block block) {
this.block = block;
public Builder setType(Type type) {
this.type = type;
return this;
}

/**
* Sets the {@code Type} of the {@code Parameter}.
* Sets the shape of the {@code Parameter}.
*
* @param type the {@code Type} of the {@code Parameter}
* @param shape the shape of the {@code Parameter}
* @return this {@code Parameter}
*/
public Builder setType(Type type) {
this.type = type;
public Builder optShape(Shape shape) {
this.shape = shape;
return this;
}

Expand Down
3 changes: 2 additions & 1 deletion api/src/main/java/ai/djl/nn/SequentialBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ protected NDList forwardInternal(
public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {
Shape[] shapes = inputShapes;
for (Block child : getChildren().values()) {
shapes = child.initialize(manager, dataType, shapes);
child.initialize(manager, dataType, shapes);
shapes = child.getOutputShapes(manager, shapes);
}
}

Expand Down
Loading