diff --git a/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java b/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java
index 8866b9de4cd..3cb5ef09bb3 100644
--- a/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java
+++ b/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java
@@ -103,13 +103,12 @@ public NDList forward(
* @param manager the NDManager to initialize the parameters
* @param dataType the datatype of the parameters
* @param inputShapes the shapes of the inputs to the block
- * @return the shapes of the outputs of the block
*/
@Override
- public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
+ public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
beforeInitialize(inputShapes);
encoder.initialize(manager, dataType, inputShapes[0]);
- return decoder.initialize(manager, dataType, inputShapes[1]);
+ decoder.initialize(manager, dataType, inputShapes[1]);
}
/** {@inheritDoc} */
diff --git a/api/src/main/java/ai/djl/nn/AbstractBlock.java b/api/src/main/java/ai/djl/nn/AbstractBlock.java
index 96dc366901e..83a063e45f1 100644
--- a/api/src/main/java/ai/djl/nn/AbstractBlock.java
+++ b/api/src/main/java/ai/djl/nn/AbstractBlock.java
@@ -30,7 +30,6 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
-import java.util.function.Function;
import java.util.function.Predicate;
/**
@@ -44,8 +43,7 @@
*
* - Define a version for serializing parameter and metadata and pass it to the parent
* constructor
- *
- Use {@link AbstractBlock#addParameter(Parameter, Shape)} or {@link
- * AbstractBlock#addParameter(Parameter, Function)} to add parameters to your block in the
+ *
- Use {@link AbstractBlock#addParameter(Parameter)} to add parameters to your block in the
* constructor if necessary.
*
- Use {@link AbstractBlock#addChildBlock(String, Block)} to add child blocks if necessary.
*
- Override {@link AbstractBlock#getOutputShapes(NDManager, Shape[])} to determine the shape
@@ -62,9 +60,9 @@
*
*
* If you use {@link AbstractBlock#addParameter(Parameter)} to add parameters, you have to take
- * care of parameter initialization yourself. In this case, you need to override {@link
- * AbstractBlock#getParameterShape(String, Shape[])} to determine the shape of your parameters. If
- * you use the other variants of {@code addParameter} this is done for you.
+ * care of parameter initialization yourself. In this case, you need to setShape to your parameters
+ * if you know the shape of Parameter or you can implement prepare to setShape when you see the
+ * input shape.
*/
// Using LinkedHashMap instead of Map is intentional: we want to make sure that consumers
// of this API know the children and parameters are always iterated over in insertion order.
@@ -100,14 +98,6 @@ public abstract class AbstractBlock implements Block {
*/
protected LinkedHashMap parameters = new LinkedHashMap<>();
- /**
- * Callbacks to determine the shape of a parameter. Values may be null in which case extending
- * classes need to override {@link Block#getParameterShape(String, Shape[])} and implement
- * parameter shape resolution manually.
- */
- protected LinkedHashMap> parameterShapeCallbacks =
- new LinkedHashMap<>();
-
/**
* Builds an empty block with the given version for parameter serialization.
*
@@ -153,73 +143,20 @@ protected final B addChildBlock(String name, B block) {
return block;
}
- /**
- * Adds a parameter to this block. If parameters are added with this method, subclasses need to
- * override {@link Block#getParameterShape(String, Shape[])} and return the shapes of parameters
- * themselves.
- *
- * @param parameter the parameter to add, not null
- * @param the specific parameter subclass
- * @return the parameter passed as arguments to make it easier to create and assign paramters in
- * one line
- */
- protected final
P addParameter(P parameter) {
- return addParameter(parameter, (Function) null);
- }
-
/**
* Adds a parameter to this block. If parameters are added with this method, intialization of
* the parameter works out of the box
*
- * @param parameter the parameter to add, not null
- * @param shape the shape of the parameter
* @param the specific parameter subclass
- * @return the parameter passed as arguments to make it easier to create and assign paramters in
- * one line
- */
- protected final
P addParameter(P parameter, Shape shape) {
- return addParameter(parameter, (inputShapes) -> shape);
- }
-
- /**
- * Adds a parameter to this block. If parameters are added with this method, intialization of
- * the parameter works out of the box
- *
* @param parameter the parameter to add, not null
- * @param shapeCallback the method to call once the input shape of this block is known to
- * determine the shape of the given parameter
- * @param
the specific parameter subclass
* @return the parameter passed as arguments to make it easier to create and assign parameters
* in one line
*/
- protected final
P addParameter(
- P parameter, Function shapeCallback) {
+ protected final P addParameter(P parameter) {
parameters.put(parameter.getName(), parameter);
- parameterShapeCallbacks.put(parameter.getName(), shapeCallback);
return parameter;
}
- /** {@inheritDoc} */
- @Override
- public Shape getParameterShape(String name, Shape[] inputShapes) {
- Function callback = parameterShapeCallbacks.get(name);
- if (callback == null) {
- Parameter parameter = parameters.get(name);
- if (parameter == null) {
- throw new IllegalArgumentException(
- "No parameter named " + name + " found in this block.");
- } else {
- throw new IllegalStateException(
- "No shape initializer for parameter "
- + name
- + "found. "
- + "Either pass an initializer for the shape when adding the "
- + "parameter or override getParameterShape in the subclass.");
- }
- }
- return callback.apply(inputShapes);
- }
-
/** {@inheritDoc} */
@Override
public BlockList getChildren() {
@@ -271,13 +208,34 @@ public void setInitializer(Initializer initializer, Predicate predica
/** {@inheritDoc} */
@Override
- public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
+ public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
beforeInitialize(inputShapes);
+ // if parameters are initialized, skip it
+ if (!isInitialized()) {
+ // setShape for all params
+ prepare(inputShapes);
+ }
for (Parameter parameter : parameters.values()) {
- parameter.initialize(manager, dataType, inputShapes);
+ parameter.initialize(manager, dataType);
}
initializeChildBlocks(manager, dataType, inputShapes);
- return getOutputShapes(manager, inputShapes);
+ }
+
+ /**
+ * Performs any action necessary before initialization. For example, keep the input information
+ * or verify the layout.
+ *
+ * @param inputShapes the expected shapes of the input
+ */
+ protected void beforeInitialize(Shape... inputShapes) {
+ if (inputNames.isEmpty()) {
+ // automatically assign input names
+ inputNames = new ArrayList<>();
+ for (int i = 0; i < inputShapes.length; ++i) {
+ inputNames.add("data" + i);
+ }
+ }
+ this.inputShapes = inputShapes;
}
/**
@@ -320,20 +278,11 @@ public ParameterList getDirectParameters() {
}
/**
- * Performs any action necessary before initialization.
+ * Sets the shape of {@link Parameter}s.
*
- * @param inputShapes the expected shapes of the input
+ * @param inputShapes the shapes of inputs
*/
- protected void beforeInitialize(Shape[] inputShapes) {
- if (inputNames.isEmpty()) {
- // automatically assign input names
- inputNames = new ArrayList<>();
- for (int i = 0; i < inputShapes.length; ++i) {
- inputNames.add("data" + i);
- }
- }
- this.inputShapes = inputShapes;
- }
+ protected void prepare(Shape[] inputShapes) {}
/** {@inheritDoc} */
@Override
diff --git a/api/src/main/java/ai/djl/nn/Block.java b/api/src/main/java/ai/djl/nn/Block.java
index b001db2722b..a4f8ac351cd 100644
--- a/api/src/main/java/ai/djl/nn/Block.java
+++ b/api/src/main/java/ai/djl/nn/Block.java
@@ -185,9 +185,8 @@ default NDList forward(
* @param manager the NDManager to initialize the parameters
* @param dataType the datatype of the parameters
* @param inputShapes the shapes of the inputs to the block
- * @return the shapes of the outputs of the block
*/
- Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes);
+ void initialize(NDManager manager, DataType dataType, Shape... inputShapes);
/**
* Returns a boolean whether the block is initialized.
@@ -238,17 +237,6 @@ default NDList forward(
*/
ParameterList getParameters();
- /**
- * Returns the shape of the specified direct parameter of this block given the shapes of the
- * input to the block.
- *
- * @param name the name of the parameter
- * @param inputShapes the shapes of the input to the block
- * @return the shape of the parameter specified
- * @throws IllegalArgumentException if the parameter name specified is invalid
- */
- Shape getParameterShape(String name, Shape[] inputShapes);
-
/**
* Returns the expected output shapes of the block for the specified input shapes.
*
diff --git a/api/src/main/java/ai/djl/nn/Parameter.java b/api/src/main/java/ai/djl/nn/Parameter.java
index 8f2293a8128..3d6270f3f89 100644
--- a/api/src/main/java/ai/djl/nn/Parameter.java
+++ b/api/src/main/java/ai/djl/nn/Parameter.java
@@ -23,6 +23,7 @@
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
+import java.util.Objects;
import java.util.UUID;
/**
@@ -39,7 +40,7 @@ public class Parameter implements AutoCloseable {
private String id;
private String name;
- private Block block;
+ private Shape shape;
private Type type;
private Initializer initializer;
private NDArray array;
@@ -49,7 +50,7 @@ public class Parameter implements AutoCloseable {
Parameter(Builder builder) {
this.id = UUID.randomUUID().toString();
this.name = builder.name;
- this.block = builder.block;
+ this.shape = builder.shape;
this.type = builder.type;
this.array = builder.array;
this.requiresGrad = builder.requiresGrad;
@@ -91,10 +92,26 @@ public Type getType() {
* @param array the {@link NDArray} that contains values of this {@code Parameter}
*/
public void setArray(NDArray array) {
+ if (shape != null) {
+ throw new IllegalStateException("array has been set! Use either setArray or setShape");
+ }
this.array = array;
+ shape = array.getShape();
array.setName(name);
}
+ /**
+ * Sets the shape of this {@code Parameter}.
+ *
+ * @param shape the shape of this {@code Parameter}
+ */
+ public void setShape(Shape shape) {
+ if (array != null) {
+ throw new IllegalStateException("array has been set! Use either setArray or setShape");
+ }
+ this.shape = shape;
+ }
+
/**
* Gets the values of this {@code Parameter} as an {@link NDArray}.
*
@@ -141,11 +158,11 @@ public void setInitializer(Initializer initializer) {
*
* @param manager an NDManager to create the arrays
* @param dataType the datatype of the {@code Parameter}
- * @param inputShapes the expected input shapes
*/
- public void initialize(NDManager manager, DataType dataType, Shape[] inputShapes) {
+ public void initialize(NDManager manager, DataType dataType) {
+ Objects.requireNonNull(initializer, "No initializer has been set");
+ Objects.requireNonNull(shape, "No parameter shape has been set");
if (!isInitialized()) {
- Shape shape = block.getParameterShape(name, inputShapes);
array = initializer.initialize(manager, shape, dataType);
array.setName(name);
}
@@ -207,6 +224,8 @@ public void load(NDManager manager, DataInputStream dis)
}
array = manager.decode(dis);
+ // set the shape of the parameter and prepare() can be skipped
+ shape = array.getShape();
}
/** {@inheritDoc} */
@@ -261,7 +280,7 @@ public Initializer getInitializer() {
/** A Builder to construct a {@code Parameter}. */
public static final class Builder {
String name;
- Block block;
+ Shape shape;
Type type;
Initializer initializer;
NDArray array;
@@ -280,24 +299,24 @@ public Builder setName(String name) {
}
/**
- * Sets the block of the {@code Parameter}.
+ * Sets the {@code Type} of the {@code Parameter}.
*
- * @param block the block of the {@code Parameter}
+ * @param type the {@code Type} of the {@code Parameter}
* @return this {@code Parameter}
*/
- public Builder setBlock(Block block) {
- this.block = block;
+ public Builder setType(Type type) {
+ this.type = type;
return this;
}
/**
- * Sets the {@code Type} of the {@code Parameter}.
+ * Sets the shape of the {@code Parameter}.
*
- * @param type the {@code Type} of the {@code Parameter}
+ * @param shape the shape of the {@code Parameter}
* @return this {@code Parameter}
*/
- public Builder setType(Type type) {
- this.type = type;
+ public Builder optShape(Shape shape) {
+ this.shape = shape;
return this;
}
diff --git a/api/src/main/java/ai/djl/nn/SequentialBlock.java b/api/src/main/java/ai/djl/nn/SequentialBlock.java
index f9578130d32..19b7486e978 100644
--- a/api/src/main/java/ai/djl/nn/SequentialBlock.java
+++ b/api/src/main/java/ai/djl/nn/SequentialBlock.java
@@ -140,7 +140,8 @@ protected NDList forwardInternal(
public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {
Shape[] shapes = inputShapes;
for (Block child : getChildren().values()) {
- shapes = child.initialize(manager, dataType, shapes);
+ child.initialize(manager, dataType, shapes);
+ shapes = child.getOutputShapes(manager, shapes);
}
}
diff --git a/api/src/main/java/ai/djl/nn/convolutional/Convolution.java b/api/src/main/java/ai/djl/nn/convolutional/Convolution.java
index 48bae1ccbe6..fb065673623 100644
--- a/api/src/main/java/ai/djl/nn/convolutional/Convolution.java
+++ b/api/src/main/java/ai/djl/nn/convolutional/Convolution.java
@@ -100,20 +100,15 @@ public Convolution(ConvolutionBuilder> builder) {
addParameter(
Parameter.builder()
.setName("weight")
- .setBlock(this)
.setType(Parameter.Type.WEIGHT)
- .build(),
- (inputShapes) ->
- new Shape(filters, inputShapes[0].get(1)).addAll(kernelShape));
+ .build());
if (includeBias) {
bias =
addParameter(
Parameter.builder()
.setName("bias")
- .setBlock(this)
.setType(Parameter.Type.BIAS)
- .build(),
- new Shape(filters));
+ .build());
}
}
@@ -154,10 +149,19 @@ protected NDList forwardInternal(
/** {@inheritDoc} */
@Override
- protected void beforeInitialize(Shape[] inputs) {
- super.beforeInitialize(inputs);
- Shape inputShape = inputs[0];
- Block.validateLayout(getExpectedLayout(), inputShape.getLayout());
+ protected void beforeInitialize(Shape... inputShapes) {
+ super.beforeInitialize(inputShapes);
+ Block.validateLayout(getExpectedLayout(), inputShapes[0].getLayout());
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ protected void prepare(Shape[] inputs) {
+ long inputChannel = inputs[0].get(1);
+ weight.setShape(new Shape(filters, inputChannel / groups).addAll(kernelShape));
+ if (bias != null) {
+ bias.setShape(new Shape(filters));
+ }
}
/** {@inheritDoc} */
diff --git a/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java b/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java
index 30d348cebc9..c8ad86454e1 100644
--- a/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java
+++ b/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java
@@ -79,20 +79,15 @@ public Deconvolution(DeconvolutionBuilder> builder) {
addParameter(
Parameter.builder()
.setName("weight")
- .setBlock(this)
.setType(Parameter.Type.WEIGHT)
- .build(),
- (inputShapes) ->
- new Shape(filters, inputShapes[0].get(1)).addAll(kernelShape));
+ .build());
if (includeBias) {
bias =
addParameter(
Parameter.builder()
.setName("bias")
- .setBlock(this)
.setType(Parameter.Type.BIAS)
- .build(),
- new Shape(filters));
+ .build());
}
}
@@ -134,10 +129,19 @@ protected NDList forwardInternal(
/** {@inheritDoc} */
@Override
- protected void beforeInitialize(Shape[] inputs) {
- super.beforeInitialize(inputs);
- Shape inputShape = inputs[0];
- Block.validateLayout(getExpectedLayout(), inputShape.getLayout());
+ protected void beforeInitialize(Shape... inputShapes) {
+ super.beforeInitialize(inputShapes);
+ Block.validateLayout(getExpectedLayout(), inputShapes[0].getLayout());
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ protected void prepare(Shape[] inputs) {
+ long inputChannel = inputs[0].get(1);
+ weight.setShape(new Shape(filters, inputChannel / groups).addAll(kernelShape));
+ if (bias != null) {
+ bias.setShape(new Shape(filters));
+ }
}
/** {@inheritDoc} */
diff --git a/api/src/main/java/ai/djl/nn/core/Embedding.java b/api/src/main/java/ai/djl/nn/core/Embedding.java
index c38682ffd4b..fd48f707cb3 100644
--- a/api/src/main/java/ai/djl/nn/core/Embedding.java
+++ b/api/src/main/java/ai/djl/nn/core/Embedding.java
@@ -61,13 +61,10 @@ protected Embedding(BaseBuilder baseBuilder) {
addParameter(
Parameter.builder()
.setName("embedding")
- .setBlock(this)
.setType(Parameter.Type.WEIGHT)
- .optRequiresGrad(true)
.optGradientFormat(
sparseGrad ? SparseFormat.ROW_SPARSE : SparseFormat.DENSE)
- .build(),
- (inputShapes) -> new Shape(numItems, embeddingSize));
+ .build());
if (baseBuilder.fallthrough != null && baseBuilder.defaultItem != null) {
throw new IllegalArgumentException(
"You can not specify both a fallthrough and a defaultItem");
@@ -106,18 +103,22 @@ public Embedding(NDArray embedding, boolean sparseGrad) {
addParameter(
Parameter.builder()
.setName("embedding")
- .setBlock(this)
.setType(Parameter.Type.WEIGHT)
- .optRequiresGrad(true)
.optGradientFormat(
sparseGrad ? SparseFormat.ROW_SPARSE : SparseFormat.DENSE)
- .build(),
- (inputShapes) -> new Shape(numItems, embeddingSize));
+ .build());
this.embedding.setArray(embedding);
numItems = Math.toIntExact(embedding.getShape().size(0));
inputShapes = new Shape[] {new Shape(-1)};
}
+ /** {@inheritDoc} */
+ @Override
+ public void prepare(Shape[] inputShapes) {
+ // numItems will be adjusted by embedding array or fallthroughEmbedding
+ embedding.setShape(new Shape(numItems, embeddingSize));
+ }
+
/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
diff --git a/api/src/main/java/ai/djl/nn/core/Linear.java b/api/src/main/java/ai/djl/nn/core/Linear.java
index 9041b624984..7054e64aade 100644
--- a/api/src/main/java/ai/djl/nn/core/Linear.java
+++ b/api/src/main/java/ai/djl/nn/core/Linear.java
@@ -58,25 +58,19 @@ public class Linear extends AbstractBlock {
Linear(Builder builder) {
super(VERSION);
units = builder.units;
- // "inputFeatures" is only known after "beforeInitialize" is called, hence we need
- // a callback, even if we do not used the callback parameter
weight =
addParameter(
Parameter.builder()
.setName("weight")
- .setBlock(this)
.setType(Parameter.Type.WEIGHT)
- .build(),
- inputShapes -> new Shape(units, inputFeatures));
+ .build());
if (builder.bias) {
bias =
addParameter(
Parameter.builder()
.setName("bias")
- .setBlock(this)
.setType(Parameter.Type.BIAS)
- .build(),
- new Shape(units));
+ .build());
}
}
@@ -97,7 +91,7 @@ protected NDList forwardInternal(
/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
- return new Shape[] {inputShape.addAll(new Shape(units))};
+ return new Shape[] {inputs[0].slice(0, inputs[0].dimension() - 1).add(units)};
}
/** {@inheritDoc} */
@@ -109,13 +103,24 @@ public PairList describeInput() {
/** {@inheritDoc} */
@Override
- public void beforeInitialize(Shape[] inputShapes) {
+ protected void beforeInitialize(Shape... inputShapes) {
super.beforeInitialize(inputShapes);
+ Preconditions.checkArgument(inputShapes.length == 1, "Linear block only support 1 input");
Shape input = inputShapes[0];
inputFeatures = input.get(input.dimension() - 1);
inputShape = input.slice(0, input.dimension() - 1);
}
+ /** {@inheritDoc} */
+ @Override
+ public void prepare(Shape[] inputShapes) {
+ Shape input = inputShapes[0];
+ weight.setShape(new Shape(units, input.get(input.dimension() - 1)));
+ if (bias != null) {
+ bias.setShape(new Shape(units));
+ }
+ }
+
/** {@inheritDoc} */
@Override
protected void saveMetadata(DataOutputStream os) throws IOException {
diff --git a/api/src/main/java/ai/djl/nn/core/Prelu.java b/api/src/main/java/ai/djl/nn/core/Prelu.java
index 01437bc470f..77f380d328e 100644
--- a/api/src/main/java/ai/djl/nn/core/Prelu.java
+++ b/api/src/main/java/ai/djl/nn/core/Prelu.java
@@ -47,10 +47,9 @@ public Prelu() {
addParameter(
Parameter.builder()
.setName("alpha")
- .setBlock(this)
.setType(Parameter.Type.WEIGHT)
- .build(),
- new Shape());
+ .optShape(new Shape())
+ .build());
}
/** {@inheritDoc} */
diff --git a/api/src/main/java/ai/djl/nn/norm/BatchNorm.java b/api/src/main/java/ai/djl/nn/norm/BatchNorm.java
index f8324bb209a..7e41e533fae 100644
--- a/api/src/main/java/ai/djl/nn/norm/BatchNorm.java
+++ b/api/src/main/java/ai/djl/nn/norm/BatchNorm.java
@@ -82,46 +82,37 @@ public class BatchNorm extends AbstractBlock {
momentum = builder.momentum;
center = builder.center;
scale = builder.scale;
- // When creating parameters we use a callback as "inChannels" is set before initialization,
- // it is not known yet.
+
// make gamma trainable if scale
gamma =
addParameter(
Parameter.builder()
.setName("gamma")
- .setBlock(this)
.setType(Parameter.Type.GAMMA)
.optRequiresGrad(scale)
- .build(),
- (inputShapes) -> new Shape(inChannels));
+ .build());
// make beta trainable if center
beta =
addParameter(
Parameter.builder()
.setName("beta")
- .setBlock(this)
.setType(Parameter.Type.BETA)
.optRequiresGrad(center)
- .build(),
- (inputShapes) -> new Shape(inChannels));
+ .build());
runningMean =
addParameter(
Parameter.builder()
.setName("runningMean")
- .setBlock(this)
.setType(Parameter.Type.RUNNING_MEAN)
.optRequiresGrad(false)
- .build(),
- (inputShapes) -> new Shape(inChannels));
+ .build());
runningVar =
addParameter(
Parameter.builder()
.setName("runningVar")
- .setBlock(this)
.setType(Parameter.Type.RUNNING_VAR)
.optRequiresGrad(false)
- .build(),
- (inputShapes) -> new Shape(inChannels));
+ .build());
}
/** {@inheritDoc} */
@@ -157,11 +148,20 @@ public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
/** {@inheritDoc} */
@Override
- public void beforeInitialize(Shape[] inputShapes) {
+ protected void beforeInitialize(Shape... inputShapes) {
super.beforeInitialize(inputShapes);
inChannels = inputShapes[0].size(axis);
}
+ /** {@inheritDoc} */
+ @Override
+ public void prepare(Shape[] inputShapes) {
+ gamma.setShape(new Shape(inChannels));
+ beta.setShape(new Shape(inChannels));
+ runningMean.setShape(new Shape(inChannels));
+ runningVar.setShape(new Shape(inChannels));
+ }
+
/** {@inheritDoc} */
@Override
protected void saveMetadata(DataOutputStream os) throws IOException {
diff --git a/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java b/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java
index b474c9708e0..75c338a45fc 100644
--- a/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java
+++ b/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java
@@ -19,6 +19,8 @@
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
+import ai.djl.nn.ParameterList;
+import ai.djl.util.Pair;
import java.io.DataInputStream;
import java.io.IOException;
@@ -80,11 +82,7 @@ public RecurrentBlock(BaseBuilder> builder) {
String name =
direction + '_' + i + '_' + gateString + '_' + parameterType.name();
addParameter(
- Parameter.builder()
- .setName(name)
- .setBlock(this)
- .setType(parameterType)
- .build());
+ Parameter.builder().setName(name).setType(parameterType).build());
}
}
}
@@ -113,31 +111,34 @@ public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
/** {@inheritDoc} */
@Override
- public void beforeInitialize(Shape[] inputs) {
- super.beforeInitialize(inputs);
- Shape inputShape = inputs[0];
- Block.validateLayout(EXPECTED_LAYOUT, inputShape.getLayout());
+ protected void beforeInitialize(Shape... inputShapes) {
+ super.beforeInitialize(inputShapes);
+ Block.validateLayout(EXPECTED_LAYOUT, inputShapes[0].getLayout());
}
/** {@inheritDoc} */
@Override
- public Shape getParameterShape(String name, Shape[] inputShapes) {
- int layer = Integer.parseInt(name.split("_")[1]);
- Shape shape = inputShapes[0];
- long inputs = shape.get(2);
- if (layer > 0) {
- inputs = stateSize * getNumDirections();
- }
- if (name.contains("BIAS")) {
- return new Shape(gates * stateSize);
- }
- if (name.contains("i2h")) {
- return new Shape(gates * stateSize, inputs);
- }
- if (name.contains("h2h")) {
- return new Shape(gates * stateSize, stateSize);
+ public void prepare(Shape[] inputs) {
+ Shape inputShape = inputs[0];
+ ParameterList parameters = getDirectParameters();
+ for (Pair pair : parameters) {
+ String name = pair.getKey();
+ Parameter parameter = pair.getValue();
+ int layer = Integer.parseInt(name.split("_")[1]);
+ long inputSize = inputShape.get(2);
+ if (layer > 0) {
+ inputSize = stateSize * getNumDirections();
+ }
+ if (name.contains("BIAS")) {
+ parameter.setShape(new Shape(gates * stateSize));
+ } else if (name.contains("i2h")) {
+ parameter.setShape(new Shape(gates * stateSize, inputSize));
+ } else if (name.contains("h2h")) {
+ parameter.setShape(new Shape(gates * stateSize, stateSize));
+ } else {
+ throw new IllegalArgumentException("Invalid parameter name");
+ }
}
- throw new IllegalArgumentException("Invalid parameter name");
}
/** {@inheritDoc} */
diff --git a/api/src/main/java/ai/djl/nn/transformer/BertBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertBlock.java
index 4deabda9964..b6b8fe1c8bc 100644
--- a/api/src/main/java/ai/djl/nn/transformer/BertBlock.java
+++ b/api/src/main/java/ai/djl/nn/transformer/BertBlock.java
@@ -78,10 +78,10 @@ private BertBlock(Builder builder) {
addParameter(
Parameter.builder()
.setName(PARAM_POSITION_EMBEDDING)
- .setBlock(this)
.setType(Parameter.Type.WEIGHT)
- .build(),
- new Shape(builder.maxSequenceLength, builder.embeddingSize));
+ .optShape(
+ new Shape(builder.maxSequenceLength, builder.embeddingSize))
+ .build());
// embedding for the input types
this.typeEmbedding =
addChildBlock(
@@ -165,11 +165,12 @@ public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
/** {@inheritDoc} */
@Override
public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {
- beforeInitialize(inputShapes);
+ super.beforeInitialize(inputShapes);
inputNames = Arrays.asList("tokenIds", "typeIds", "masks");
Shape[] tokenShape = {inputShapes[0]};
Shape[] typeShape = {inputShapes[1]};
- Shape[] embeddingOutput = this.tokenEmbedding.initialize(manager, dataType, tokenShape);
+ this.tokenEmbedding.initialize(manager, dataType, tokenShape);
+ Shape[] embeddingOutput = getOutputShapes(manager, tokenShape);
this.typeEmbedding.initialize(manager, dataType, typeShape);
this.embeddingNorm.initialize(manager, dataType, embeddingOutput);
this.embeddingDropout.initialize(manager, dataType, embeddingOutput);
diff --git a/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java
index 7d7c97e5288..38583ffb5ba 100644
--- a/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java
+++ b/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java
@@ -60,10 +60,9 @@ public BertMaskedLanguageModelBlock(
addParameter(
Parameter.builder()
.setName("dictionaryBias")
- .setBlock(this)
.setType(Parameter.Type.BIAS)
- .build(),
- new Shape(bertBlock.getTokenDictionarySize()));
+ .optShape(new Shape(bertBlock.getTokenDictionarySize()))
+ .build());
this.hiddenActivation = hiddenActivation;
}
diff --git a/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java
index f31814a651f..9afd98ca07e 100644
--- a/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java
+++ b/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java
@@ -52,7 +52,8 @@ public BertPretrainingBlock(final BertBlock.Builder builder) {
public void initializeChildBlocks(
final NDManager manager, final DataType dataType, final Shape... inputShapes) {
inputNames = Arrays.asList("tokenIds", "typeIds", "sequenceMasks", "maskedIndices");
- Shape[] bertOutputShapes = bertBlock.initialize(manager, dataType, inputShapes);
+ bertBlock.initialize(manager, dataType, inputShapes);
+ Shape[] bertOutputShapes = getOutputShapes(manager, inputShapes);
Shape embeddedSequence = bertOutputShapes[0];
Shape pooledOutput = bertOutputShapes[1];
Shape maskedIndices = inputShapes[2];
diff --git a/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java b/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java
index b01c359bdaa..a8ea57182cd 100644
--- a/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java
+++ b/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java
@@ -48,10 +48,9 @@ private IdEmbedding(Builder builder) {
addParameter(
Parameter.builder()
.setName(EMBEDDING_PARAM_NAME)
- .setBlock(this)
.setType(Parameter.Type.WEIGHT)
- .build(),
- new Shape(dictionarySize, embeddingSize));
+ .optShape(new Shape(dictionarySize, embeddingSize))
+ .build());
}
/** {@inheritDoc} */
diff --git a/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java b/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java
index 7e2a2403b8e..562bf319d9c 100644
--- a/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java
+++ b/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java
@@ -82,7 +82,8 @@ public void initializeChildBlocks(NDManager manager, DataType dataType, Shape...
}
Shape lastShape = inputShape;
for (final Block child : children.values()) {
- lastShape = child.initialize(manager, dataType, lastShape)[0];
+ child.initialize(manager, dataType, lastShape);
+ lastShape = getOutputShapes(manager, new Shape[] {lastShape})[0];
}
outputShape = lastShape;
}
diff --git a/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SingleShotDetection.java b/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SingleShotDetection.java
index 4d9486c70cd..966b96a390c 100644
--- a/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SingleShotDetection.java
+++ b/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SingleShotDetection.java
@@ -177,15 +177,22 @@ private Shape concatShape(Shape shape, Shape concat, int axis) {
/** {@inheritDoc} */
@Override
- public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
+ public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
beforeInitialize(inputShapes);
- Shape[] shapes = inputShapes;
- for (int i = 0; i < features.size(); i++) {
- shapes = features.get(i).initialize(manager, dataType, shapes);
- classPredictionBlocks.get(i).initialize(manager, dataType, shapes);
- anchorPredictionBlocks.get(i).initialize(manager, dataType, shapes);
+ Shape outputShape;
+ try (NDManager tempManager = manager.newSubManager()) {
+ ParameterStore store = new ParameterStore(tempManager, false);
+ NDArray array = tempManager.create(inputShapes[0]);
+ for (int i = 0; i < features.size(); i++) {
+ features.get(i).initialize(manager, dataType, array.getShape());
+ // the getOutputShapes is wrong
+ // so manually call forward here
+ array = features.get(i).forward(store, new NDList(array), false).singletonOrThrow();
+ outputShape = array.getShape();
+ classPredictionBlocks.get(i).initialize(manager, dataType, outputShape);
+ anchorPredictionBlocks.get(i).initialize(manager, dataType, outputShape);
+ }
}
- return getOutputShapes(manager, inputShapes);
}
/** {@inheritDoc} */
diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java
index 12adaf04b5c..2ff9a71f9c8 100644
--- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java
+++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java
@@ -83,7 +83,6 @@ public MxSymbolBlock(NDManager manager, Symbol symbol) {
mxNetParams.add(
Parameter.builder()
.setName(name)
- .setBlock(this)
.setType(type)
.optRequiresGrad(requireGrad)
.build());
@@ -237,9 +236,7 @@ public void removeLastBlock() {
}
}
- /** {@inheritDoc} */
- @Override
- public Shape getParameterShape(String name, Shape[] inputShapes) {
+ private Shape getParameterShape(String name, Shape[] inputShapes) {
if (paramShapes == null) {
PairList pairs = new PairList<>();
for (int i = 0; i < inputNames.size(); i++) {
diff --git a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java
index 43f667d2dc2..5e4b8e0fdc0 100644
--- a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java
+++ b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java
@@ -131,8 +131,8 @@ protected NDList forwardInternal(
/** {@inheritDoc} */
@Override
- public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
- return new Shape[0];
+ public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
+ throw new IllegalStateException("TfSymbolBlock can't be initialized");
}
/** {@inheritDoc} */