From 9fc3f65e3e1f4ef6c07dec91445515da41966319 Mon Sep 17 00:00:00 2001 From: Jake Lee Date: Wed, 24 Feb 2021 13:13:39 -0800 Subject: [PATCH] Refactor initialize (#675) --- .../ai/djl/modality/nlp/EncoderDecoder.java | 5 +- .../main/java/ai/djl/nn/AbstractBlock.java | 115 +++++------------- api/src/main/java/ai/djl/nn/Block.java | 14 +-- api/src/main/java/ai/djl/nn/Parameter.java | 47 ++++--- .../main/java/ai/djl/nn/SequentialBlock.java | 3 +- .../ai/djl/nn/convolutional/Convolution.java | 26 ++-- .../djl/nn/convolutional/Deconvolution.java | 26 ++-- .../main/java/ai/djl/nn/core/Embedding.java | 17 +-- api/src/main/java/ai/djl/nn/core/Linear.java | 25 ++-- api/src/main/java/ai/djl/nn/core/Prelu.java | 5 +- .../main/java/ai/djl/nn/norm/BatchNorm.java | 30 ++--- .../ai/djl/nn/recurrent/RecurrentBlock.java | 51 ++++---- .../java/ai/djl/nn/transformer/BertBlock.java | 11 +- .../BertMaskedLanguageModelBlock.java | 5 +- .../nn/transformer/BertPretrainingBlock.java | 3 +- .../ai/djl/nn/transformer/IdEmbedding.java | 5 +- .../PointwiseFeedForwardBlock.java | 3 +- .../ssd/SingleShotDetection.java | 21 ++-- .../ai/djl/mxnet/engine/MxSymbolBlock.java | 5 +- .../djl/tensorflow/engine/TfSymbolBlock.java | 4 +- 20 files changed, 198 insertions(+), 223 deletions(-) diff --git a/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java b/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java index afc0958ace9d..0b1f0d868ba9 100644 --- a/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java +++ b/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java @@ -97,13 +97,12 @@ public NDList forward( * @param manager the NDManager to initialize the parameters * @param dataType the datatype of the parameters * @param inputShapes the shapes of the inputs to the block - * @return the shapes of the outputs of the block */ @Override - public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) { + public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) { beforeInitialize(inputShapes); encoder.initialize(manager, dataType, inputShapes[0]); - return decoder.initialize(manager, dataType, inputShapes[1]); + decoder.initialize(manager, dataType, inputShapes[1]); } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/nn/AbstractBlock.java b/api/src/main/java/ai/djl/nn/AbstractBlock.java index 96dc366901e4..83a063e45f18 100644 --- a/api/src/main/java/ai/djl/nn/AbstractBlock.java +++ b/api/src/main/java/ai/djl/nn/AbstractBlock.java @@ -30,7 +30,6 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Locale; -import java.util.function.Function; import java.util.function.Predicate; /** @@ -44,8 +43,7 @@ * * *

If you use {@link AbstractBlock#addParameter(Parameter)} to add parameters, you have to take - * care of parameter initialization yourself. In this case, you need to override {@link - * AbstractBlock#getParameterShape(String, Shape[])} to determine the shape of your parameters. If - * you use the other variants of {@code addParameter} this is done for you. + * care of parameter initialization yourself. In this case, you need to setShape to your parameters + * if you know the shape of Parameter or you can implement prepare to setShape when you see the + * input shape. */ // Using LinkedHashMap instead of Map is intentional: we want to make sure that consumers // of this API know the children and parameters are always iterated over in insertion order. @@ -100,14 +98,6 @@ public abstract class AbstractBlock implements Block { */ protected LinkedHashMap parameters = new LinkedHashMap<>(); - /** - * Callbacks to determine the shape of a parameter. Values may be null in which case extending - * classes need to override {@link Block#getParameterShape(String, Shape[])} and implement - * parameter shape resolution manually. - */ - protected LinkedHashMap> parameterShapeCallbacks = - new LinkedHashMap<>(); - /** * Builds an empty block with the given version for parameter serialization. * @@ -153,73 +143,20 @@ protected final B addChildBlock(String name, B block) { return block; } - /** - * Adds a parameter to this block. If parameters are added with this method, subclasses need to - * override {@link Block#getParameterShape(String, Shape[])} and return the shapes of parameters - * themselves. - * - * @param parameter the parameter to add, not null - * @param

the specific parameter subclass - * @return the parameter passed as arguments to make it easier to create and assign paramters in - * one line - */ - protected final

P addParameter(P parameter) { - return addParameter(parameter, (Function) null); - } - /** * Adds a parameter to this block. If parameters are added with this method, intialization of * the parameter works out of the box * - * @param parameter the parameter to add, not null - * @param shape the shape of the parameter * @param

the specific parameter subclass - * @return the parameter passed as arguments to make it easier to create and assign paramters in - * one line - */ - protected final

P addParameter(P parameter, Shape shape) { - return addParameter(parameter, (inputShapes) -> shape); - } - - /** - * Adds a parameter to this block. If parameters are added with this method, intialization of - * the parameter works out of the box - * * @param parameter the parameter to add, not null - * @param shapeCallback the method to call once the input shape of this block is known to - * determine the shape of the given parameter - * @param

the specific parameter subclass * @return the parameter passed as arguments to make it easier to create and assign parameters * in one line */ - protected final

P addParameter( - P parameter, Function shapeCallback) { + protected final

P addParameter(P parameter) { parameters.put(parameter.getName(), parameter); - parameterShapeCallbacks.put(parameter.getName(), shapeCallback); return parameter; } - /** {@inheritDoc} */ - @Override - public Shape getParameterShape(String name, Shape[] inputShapes) { - Function callback = parameterShapeCallbacks.get(name); - if (callback == null) { - Parameter parameter = parameters.get(name); - if (parameter == null) { - throw new IllegalArgumentException( - "No parameter named " + name + " found in this block."); - } else { - throw new IllegalStateException( - "No shape initializer for parameter " - + name - + "found. " - + "Either pass an initializer for the shape when adding the " - + "parameter or override getParameterShape in the subclass."); - } - } - return callback.apply(inputShapes); - } - /** {@inheritDoc} */ @Override public BlockList getChildren() { @@ -271,13 +208,34 @@ public void setInitializer(Initializer initializer, Predicate predica /** {@inheritDoc} */ @Override - public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) { + public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) { beforeInitialize(inputShapes); + // if parameters are initialized, skip it + if (!isInitialized()) { + // setShape for all params + prepare(inputShapes); + } for (Parameter parameter : parameters.values()) { - parameter.initialize(manager, dataType, inputShapes); + parameter.initialize(manager, dataType); } initializeChildBlocks(manager, dataType, inputShapes); - return getOutputShapes(manager, inputShapes); + } + + /** + * Performs any action necessary before initialization. For example, keep the input information + * or verify the layout. + * + * @param inputShapes the expected shapes of the input + */ + protected void beforeInitialize(Shape... inputShapes) { + if (inputNames.isEmpty()) { + // automatically assign input names + inputNames = new ArrayList<>(); + for (int i = 0; i < inputShapes.length; ++i) { + inputNames.add("data" + i); + } + } + this.inputShapes = inputShapes; } /** @@ -320,20 +278,11 @@ public ParameterList getDirectParameters() { } /** - * Performs any action necessary before initialization. + * Sets the shape of {@link Parameter}s. * - * @param inputShapes the expected shapes of the input + * @param inputShapes the shapes of inputs */ - protected void beforeInitialize(Shape[] inputShapes) { - if (inputNames.isEmpty()) { - // automatically assign input names - inputNames = new ArrayList<>(); - for (int i = 0; i < inputShapes.length; ++i) { - inputNames.add("data" + i); - } - } - this.inputShapes = inputShapes; - } + protected void prepare(Shape[] inputShapes) {} /** {@inheritDoc} */ @Override diff --git a/api/src/main/java/ai/djl/nn/Block.java b/api/src/main/java/ai/djl/nn/Block.java index 2b817ac9a92a..4ae6650c9dec 100644 --- a/api/src/main/java/ai/djl/nn/Block.java +++ b/api/src/main/java/ai/djl/nn/Block.java @@ -189,9 +189,8 @@ default NDList forward( * @param manager the NDManager to initialize the parameters * @param dataType the datatype of the parameters * @param inputShapes the shapes of the inputs to the block - * @return the shapes of the outputs of the block */ - Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes); + void initialize(NDManager manager, DataType dataType, Shape... inputShapes); /** * Returns a boolean whether the block is initialized. @@ -242,17 +241,6 @@ default NDList forward( */ ParameterList getParameters(); - /** - * Returns the shape of the specified direct parameter of this block given the shapes of the - * input to the block. - * - * @param name the name of the parameter - * @param inputShapes the shapes of the input to the block - * @return the shape of the parameter specified - * @throws IllegalArgumentException if the parameter name specified is invalid - */ - Shape getParameterShape(String name, Shape[] inputShapes); - /** * Returns the expected output shapes of the block for the specified input shapes. * diff --git a/api/src/main/java/ai/djl/nn/Parameter.java b/api/src/main/java/ai/djl/nn/Parameter.java index b85136f69b55..369111057c7d 100644 --- a/api/src/main/java/ai/djl/nn/Parameter.java +++ b/api/src/main/java/ai/djl/nn/Parameter.java @@ -23,6 +23,7 @@ import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; +import java.util.Objects; import java.util.UUID; /** @@ -42,7 +43,7 @@ public class Parameter implements AutoCloseable { private String id; private String name; - private Block block; + private Shape shape; private Type type; private Initializer initializer; private NDArray array; @@ -52,7 +53,7 @@ public class Parameter implements AutoCloseable { Parameter(Builder builder) { this.id = UUID.randomUUID().toString(); this.name = builder.name; - this.block = builder.block; + this.shape = builder.shape; this.type = builder.type; this.array = builder.array; this.requiresGrad = builder.requiresGrad; @@ -94,10 +95,26 @@ public Type getType() { * @param array the {@link NDArray} that contains values of this {@code Parameter} */ public void setArray(NDArray array) { + if (shape != null) { + throw new IllegalStateException("array has been set! Use either setArray or setShape"); + } this.array = array; + shape = array.getShape(); array.setName(name); } + /** + * Sets the shape of this {@code Parameter}. + * + * @param shape the shape of this {@code Parameter} + */ + public void setShape(Shape shape) { + if (array != null) { + throw new IllegalStateException("array has been set! Use either setArray or setShape"); + } + this.shape = shape; + } + /** * Gets the values of this {@code Parameter} as an {@link NDArray}. * @@ -144,11 +161,11 @@ public void setInitializer(Initializer initializer) { * * @param manager an NDManager to create the arrays * @param dataType the datatype of the {@code Parameter} - * @param inputShapes the expected input shapes */ - public void initialize(NDManager manager, DataType dataType, Shape[] inputShapes) { + public void initialize(NDManager manager, DataType dataType) { + Objects.requireNonNull(initializer, "No initializer has been set"); + Objects.requireNonNull(shape, "No parameter shape has been set"); if (!isInitialized()) { - Shape shape = block.getParameterShape(name, inputShapes); array = initializer.initialize(manager, shape, dataType); array.setName(name); } @@ -210,6 +227,8 @@ public void load(NDManager manager, DataInputStream dis) } array = manager.decode(dis); + // set the shape of the parameter and prepare() can be skipped + shape = array.getShape(); } /** {@inheritDoc} */ @@ -264,7 +283,7 @@ public Initializer getInitializer() { /** A Builder to construct a {@code Parameter}. */ public static final class Builder { String name; - Block block; + Shape shape; Type type; Initializer initializer; NDArray array; @@ -283,24 +302,24 @@ public Builder setName(String name) { } /** - * Sets the block of the {@code Parameter}. + * Sets the {@code Type} of the {@code Parameter}. * - * @param block the block of the {@code Parameter} + * @param type the {@code Type} of the {@code Parameter} * @return this {@code Parameter} */ - public Builder setBlock(Block block) { - this.block = block; + public Builder setType(Type type) { + this.type = type; return this; } /** - * Sets the {@code Type} of the {@code Parameter}. + * Sets the shape of the {@code Parameter}. * - * @param type the {@code Type} of the {@code Parameter} + * @param shape the shape of the {@code Parameter} * @return this {@code Parameter} */ - public Builder setType(Type type) { - this.type = type; + public Builder optShape(Shape shape) { + this.shape = shape; return this; } diff --git a/api/src/main/java/ai/djl/nn/SequentialBlock.java b/api/src/main/java/ai/djl/nn/SequentialBlock.java index f9578130d32e..19b7486e978b 100644 --- a/api/src/main/java/ai/djl/nn/SequentialBlock.java +++ b/api/src/main/java/ai/djl/nn/SequentialBlock.java @@ -140,7 +140,8 @@ protected NDList forwardInternal( public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) { Shape[] shapes = inputShapes; for (Block child : getChildren().values()) { - shapes = child.initialize(manager, dataType, shapes); + child.initialize(manager, dataType, shapes); + shapes = child.getOutputShapes(manager, shapes); } } diff --git a/api/src/main/java/ai/djl/nn/convolutional/Convolution.java b/api/src/main/java/ai/djl/nn/convolutional/Convolution.java index c2fa87f0a41f..347923110f12 100644 --- a/api/src/main/java/ai/djl/nn/convolutional/Convolution.java +++ b/api/src/main/java/ai/djl/nn/convolutional/Convolution.java @@ -103,20 +103,15 @@ public Convolution(ConvolutionBuilder builder) { addParameter( Parameter.builder() .setName("weight") - .setBlock(this) .setType(Parameter.Type.WEIGHT) - .build(), - (inputShapes) -> - new Shape(filters, inputShapes[0].get(1)).addAll(kernelShape)); + .build()); if (includeBias) { bias = addParameter( Parameter.builder() .setName("bias") - .setBlock(this) .setType(Parameter.Type.BIAS) - .build(), - new Shape(filters)); + .build()); } } @@ -157,10 +152,19 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - protected void beforeInitialize(Shape[] inputs) { - super.beforeInitialize(inputs); - Shape inputShape = inputs[0]; - Block.validateLayout(getExpectedLayout(), inputShape.getLayout()); + protected void beforeInitialize(Shape... inputShapes) { + super.beforeInitialize(inputShapes); + Block.validateLayout(getExpectedLayout(), inputShapes[0].getLayout()); + } + + /** {@inheritDoc} */ + @Override + protected void prepare(Shape[] inputs) { + long inputChannel = inputs[0].get(1); + weight.setShape(new Shape(filters, inputChannel / groups).addAll(kernelShape)); + if (bias != null) { + bias.setShape(new Shape(filters)); + } } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java b/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java index 30d348cebc9e..c8ad86454e19 100644 --- a/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java +++ b/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java @@ -79,20 +79,15 @@ public Deconvolution(DeconvolutionBuilder builder) { addParameter( Parameter.builder() .setName("weight") - .setBlock(this) .setType(Parameter.Type.WEIGHT) - .build(), - (inputShapes) -> - new Shape(filters, inputShapes[0].get(1)).addAll(kernelShape)); + .build()); if (includeBias) { bias = addParameter( Parameter.builder() .setName("bias") - .setBlock(this) .setType(Parameter.Type.BIAS) - .build(), - new Shape(filters)); + .build()); } } @@ -134,10 +129,19 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - protected void beforeInitialize(Shape[] inputs) { - super.beforeInitialize(inputs); - Shape inputShape = inputs[0]; - Block.validateLayout(getExpectedLayout(), inputShape.getLayout()); + protected void beforeInitialize(Shape... inputShapes) { + super.beforeInitialize(inputShapes); + Block.validateLayout(getExpectedLayout(), inputShapes[0].getLayout()); + } + + /** {@inheritDoc} */ + @Override + protected void prepare(Shape[] inputs) { + long inputChannel = inputs[0].get(1); + weight.setShape(new Shape(filters, inputChannel / groups).addAll(kernelShape)); + if (bias != null) { + bias.setShape(new Shape(filters)); + } } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/nn/core/Embedding.java b/api/src/main/java/ai/djl/nn/core/Embedding.java index 52ec3e2f306a..39a4bbf52929 100644 --- a/api/src/main/java/ai/djl/nn/core/Embedding.java +++ b/api/src/main/java/ai/djl/nn/core/Embedding.java @@ -57,12 +57,9 @@ protected Embedding(BaseBuilder baseBuilder) { addParameter( Parameter.builder() .setName("embedding") - .setBlock(this) .setType(Parameter.Type.WEIGHT) - .optRequiresGrad(true) .optGradientFormat(sparseFormat) - .build(), - (inputShapes) -> new Shape(numEmbeddings, embeddingSize)); + .build()); if (baseBuilder.fallthrough != null && baseBuilder.defaultItem != null) { throw new IllegalArgumentException( "You can not specify both a fallthrough and a defaultItem"); @@ -100,16 +97,20 @@ public Embedding(NDArray embedding, SparseFormat format) { addParameter( Parameter.builder() .setName("embedding") - .setBlock(this) .setType(Parameter.Type.WEIGHT) - .optRequiresGrad(true) .optGradientFormat(sparseFormat) - .build(), - (inputShapes) -> new Shape(numEmbeddings, embeddingSize)); + .build()); this.embedding.setArray(embedding); inputShapes = new Shape[] {new Shape(-1)}; } + /** {@inheritDoc} */ + @Override + public void prepare(Shape[] inputShapes) { + // numItems will be adjusted by embedding array or fallthroughEmbedding + embedding.setShape(new Shape(numEmbeddings, embeddingSize)); + } + /** {@inheritDoc} */ @Override public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { diff --git a/api/src/main/java/ai/djl/nn/core/Linear.java b/api/src/main/java/ai/djl/nn/core/Linear.java index 9041b6249847..7054e64aadef 100644 --- a/api/src/main/java/ai/djl/nn/core/Linear.java +++ b/api/src/main/java/ai/djl/nn/core/Linear.java @@ -58,25 +58,19 @@ public class Linear extends AbstractBlock { Linear(Builder builder) { super(VERSION); units = builder.units; - // "inputFeatures" is only known after "beforeInitialize" is called, hence we need - // a callback, even if we do not used the callback parameter weight = addParameter( Parameter.builder() .setName("weight") - .setBlock(this) .setType(Parameter.Type.WEIGHT) - .build(), - inputShapes -> new Shape(units, inputFeatures)); + .build()); if (builder.bias) { bias = addParameter( Parameter.builder() .setName("bias") - .setBlock(this) .setType(Parameter.Type.BIAS) - .build(), - new Shape(units)); + .build()); } } @@ -97,7 +91,7 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) { - return new Shape[] {inputShape.addAll(new Shape(units))}; + return new Shape[] {inputs[0].slice(0, inputs[0].dimension() - 1).add(units)}; } /** {@inheritDoc} */ @@ -109,13 +103,24 @@ public PairList describeInput() { /** {@inheritDoc} */ @Override - public void beforeInitialize(Shape[] inputShapes) { + protected void beforeInitialize(Shape... inputShapes) { super.beforeInitialize(inputShapes); + Preconditions.checkArgument(inputShapes.length == 1, "Linear block only support 1 input"); Shape input = inputShapes[0]; inputFeatures = input.get(input.dimension() - 1); inputShape = input.slice(0, input.dimension() - 1); } + /** {@inheritDoc} */ + @Override + public void prepare(Shape[] inputShapes) { + Shape input = inputShapes[0]; + weight.setShape(new Shape(units, input.get(input.dimension() - 1))); + if (bias != null) { + bias.setShape(new Shape(units)); + } + } + /** {@inheritDoc} */ @Override protected void saveMetadata(DataOutputStream os) throws IOException { diff --git a/api/src/main/java/ai/djl/nn/core/Prelu.java b/api/src/main/java/ai/djl/nn/core/Prelu.java index 01437bc470f3..77f380d328e9 100644 --- a/api/src/main/java/ai/djl/nn/core/Prelu.java +++ b/api/src/main/java/ai/djl/nn/core/Prelu.java @@ -47,10 +47,9 @@ public Prelu() { addParameter( Parameter.builder() .setName("alpha") - .setBlock(this) .setType(Parameter.Type.WEIGHT) - .build(), - new Shape()); + .optShape(new Shape()) + .build()); } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/nn/norm/BatchNorm.java b/api/src/main/java/ai/djl/nn/norm/BatchNorm.java index 54d0a6da13b5..a83ef09a2e3f 100644 --- a/api/src/main/java/ai/djl/nn/norm/BatchNorm.java +++ b/api/src/main/java/ai/djl/nn/norm/BatchNorm.java @@ -85,46 +85,37 @@ public class BatchNorm extends AbstractBlock { momentum = builder.momentum; center = builder.center; scale = builder.scale; - // When creating parameters we use a callback as "inChannels" is set before initialization, - // it is not known yet. + // make gamma trainable if scale gamma = addParameter( Parameter.builder() .setName("gamma") - .setBlock(this) .setType(Parameter.Type.GAMMA) .optRequiresGrad(scale) - .build(), - (inputShapes) -> new Shape(inChannels)); + .build()); // make beta trainable if center beta = addParameter( Parameter.builder() .setName("beta") - .setBlock(this) .setType(Parameter.Type.BETA) .optRequiresGrad(center) - .build(), - (inputShapes) -> new Shape(inChannels)); + .build()); runningMean = addParameter( Parameter.builder() .setName("runningMean") - .setBlock(this) .setType(Parameter.Type.RUNNING_MEAN) .optRequiresGrad(false) - .build(), - (inputShapes) -> new Shape(inChannels)); + .build()); runningVar = addParameter( Parameter.builder() .setName("runningVar") - .setBlock(this) .setType(Parameter.Type.RUNNING_VAR) .optRequiresGrad(false) - .build(), - (inputShapes) -> new Shape(inChannels)); + .build()); } /** {@inheritDoc} */ @@ -160,11 +151,20 @@ public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { /** {@inheritDoc} */ @Override - public void beforeInitialize(Shape[] inputShapes) { + protected void beforeInitialize(Shape... inputShapes) { super.beforeInitialize(inputShapes); inChannels = inputShapes[0].size(axis); } + /** {@inheritDoc} */ + @Override + public void prepare(Shape[] inputShapes) { + gamma.setShape(new Shape(inChannels)); + beta.setShape(new Shape(inChannels)); + runningMean.setShape(new Shape(inChannels)); + runningVar.setShape(new Shape(inChannels)); + } + /** {@inheritDoc} */ @Override protected void saveMetadata(DataOutputStream os) throws IOException { diff --git a/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java b/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java index 620e40831aeb..1ba9664822de 100644 --- a/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java +++ b/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java @@ -19,6 +19,8 @@ import ai.djl.nn.AbstractBlock; import ai.djl.nn.Block; import ai.djl.nn.Parameter; +import ai.djl.nn.ParameterList; +import ai.djl.util.Pair; import java.io.DataInputStream; import java.io.IOException; @@ -80,11 +82,7 @@ public RecurrentBlock(BaseBuilder builder) { String name = direction + '_' + i + '_' + gateString + '_' + parameterType.name(); addParameter( - Parameter.builder() - .setName(name) - .setBlock(this) - .setType(parameterType) - .build()); + Parameter.builder().setName(name).setType(parameterType).build()); } } } @@ -113,31 +111,34 @@ public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) { /** {@inheritDoc} */ @Override - public void beforeInitialize(Shape[] inputs) { - super.beforeInitialize(inputs); - Shape inputShape = inputs[0]; - Block.validateLayout(EXPECTED_LAYOUT, inputShape.getLayout()); + protected void beforeInitialize(Shape... inputShapes) { + super.beforeInitialize(inputShapes); + Block.validateLayout(EXPECTED_LAYOUT, inputShapes[0].getLayout()); } /** {@inheritDoc} */ @Override - public Shape getParameterShape(String name, Shape[] inputShapes) { - int layer = Integer.parseInt(name.split("_")[1]); - Shape shape = inputShapes[0]; - long inputs = shape.get(2); - if (layer > 0) { - inputs = stateSize * getNumDirections(); - } - if (name.contains("BIAS")) { - return new Shape(gates * stateSize); - } - if (name.contains("i2h")) { - return new Shape(gates * stateSize, inputs); - } - if (name.contains("h2h")) { - return new Shape(gates * stateSize, stateSize); + public void prepare(Shape[] inputs) { + Shape inputShape = inputs[0]; + ParameterList parameters = getDirectParameters(); + for (Pair pair : parameters) { + String name = pair.getKey(); + Parameter parameter = pair.getValue(); + int layer = Integer.parseInt(name.split("_")[1]); + long inputSize = inputShape.get(2); + if (layer > 0) { + inputSize = stateSize * getNumDirections(); + } + if (name.contains("BIAS")) { + parameter.setShape(new Shape(gates * stateSize)); + } else if (name.contains("i2h")) { + parameter.setShape(new Shape(gates * stateSize, inputSize)); + } else if (name.contains("h2h")) { + parameter.setShape(new Shape(gates * stateSize, stateSize)); + } else { + throw new IllegalArgumentException("Invalid parameter name"); + } } - throw new IllegalArgumentException("Invalid parameter name"); } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/nn/transformer/BertBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertBlock.java index ce8112539251..a0629e33166b 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertBlock.java @@ -78,10 +78,10 @@ private BertBlock(Builder builder) { addParameter( Parameter.builder() .setName(PARAM_POSITION_EMBEDDING) - .setBlock(this) .setType(Parameter.Type.WEIGHT) - .build(), - new Shape(builder.maxSequenceLength, builder.embeddingSize)); + .optShape( + new Shape(builder.maxSequenceLength, builder.embeddingSize)) + .build()); // embedding for the input types this.typeEmbedding = addChildBlock( @@ -167,11 +167,12 @@ public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) { /** {@inheritDoc} */ @Override public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) { - beforeInitialize(inputShapes); + super.beforeInitialize(inputShapes); inputNames = Arrays.asList("tokenIds", "typeIds", "masks"); Shape[] tokenShape = {inputShapes[0]}; Shape[] typeShape = {inputShapes[1]}; - Shape[] embeddingOutput = this.tokenEmbedding.initialize(manager, dataType, tokenShape); + this.tokenEmbedding.initialize(manager, dataType, tokenShape); + Shape[] embeddingOutput = getOutputShapes(manager, tokenShape); this.typeEmbedding.initialize(manager, dataType, typeShape); this.embeddingNorm.initialize(manager, dataType, embeddingOutput); this.embeddingDropout.initialize(manager, dataType, embeddingOutput); diff --git a/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java index 76de486090ee..8142a1ae2d6b 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertMaskedLanguageModelBlock.java @@ -60,10 +60,9 @@ public BertMaskedLanguageModelBlock( addParameter( Parameter.builder() .setName("dictionaryBias") - .setBlock(this) .setType(Parameter.Type.BIAS) - .build(), - new Shape(bertBlock.getTokenDictionarySize())); + .optShape(new Shape(bertBlock.getTokenDictionarySize())) + .build()); this.hiddenActivation = hiddenActivation; } diff --git a/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java index 2478811d0375..5372b7fd8ebc 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertPretrainingBlock.java @@ -51,7 +51,8 @@ public BertPretrainingBlock(final BertBlock.Builder builder) { @Override public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) { inputNames = Arrays.asList("tokenIds", "typeIds", "sequenceMasks", "maskedIndices"); - Shape[] bertOutputShapes = bertBlock.initialize(manager, dataType, inputShapes); + bertBlock.initialize(manager, dataType, inputShapes); + Shape[] bertOutputShapes = getOutputShapes(manager, inputShapes); Shape embeddedSequence = bertOutputShapes[0]; Shape pooledOutput = bertOutputShapes[1]; Shape maskedIndices = inputShapes[2]; diff --git a/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java b/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java index e917f3e71635..890a76e008bf 100644 --- a/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java +++ b/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java @@ -48,10 +48,9 @@ private IdEmbedding(Builder builder) { addParameter( Parameter.builder() .setName(EMBEDDING_PARAM_NAME) - .setBlock(this) .setType(Parameter.Type.WEIGHT) - .build(), - new Shape(dictionarySize, embeddingSize)); + .optShape(new Shape(dictionarySize, embeddingSize)) + .build()); } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java b/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java index 10d4aa20b401..f4588e3ccaa0 100644 --- a/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/PointwiseFeedForwardBlock.java @@ -82,7 +82,8 @@ public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... } Shape lastShape = inputShape; for (Block child : children.values()) { - lastShape = child.initialize(manager, dataType, lastShape)[0]; + child.initialize(manager, dataType, lastShape); + lastShape = getOutputShapes(manager, new Shape[] {lastShape})[0]; } outputShape = lastShape; } diff --git a/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SingleShotDetection.java b/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SingleShotDetection.java index 4d9486c70cdb..966b96a390cf 100644 --- a/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SingleShotDetection.java +++ b/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/ssd/SingleShotDetection.java @@ -177,15 +177,22 @@ private Shape concatShape(Shape shape, Shape concat, int axis) { /** {@inheritDoc} */ @Override - public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) { + public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) { beforeInitialize(inputShapes); - Shape[] shapes = inputShapes; - for (int i = 0; i < features.size(); i++) { - shapes = features.get(i).initialize(manager, dataType, shapes); - classPredictionBlocks.get(i).initialize(manager, dataType, shapes); - anchorPredictionBlocks.get(i).initialize(manager, dataType, shapes); + Shape outputShape; + try (NDManager tempManager = manager.newSubManager()) { + ParameterStore store = new ParameterStore(tempManager, false); + NDArray array = tempManager.create(inputShapes[0]); + for (int i = 0; i < features.size(); i++) { + features.get(i).initialize(manager, dataType, array.getShape()); + // the getOutputShapes is wrong + // so manually call forward here + array = features.get(i).forward(store, new NDList(array), false).singletonOrThrow(); + outputShape = array.getShape(); + classPredictionBlocks.get(i).initialize(manager, dataType, outputShape); + anchorPredictionBlocks.get(i).initialize(manager, dataType, outputShape); + } } - return getOutputShapes(manager, inputShapes); } /** {@inheritDoc} */ diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java index 04cf8373c7bb..33689024c20f 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java @@ -231,9 +231,7 @@ public void removeLastBlock() { } } - /** {@inheritDoc} */ - @Override - public Shape getParameterShape(String name, Shape[] inputShapes) { + private Shape getParameterShape(String name, Shape[] inputShapes) { if (paramShapes == null) { PairList pairs = new PairList<>(); for (int i = 0; i < inputNames.size(); i++) { @@ -318,7 +316,6 @@ private void initBlock() { mxNetParams.add( Parameter.builder() .setName(name) - .setBlock(this) .setType(type) .optRequiresGrad(requireGrad) .build()); diff --git a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java index 43f667d2dc2a..5e4b8e0fdc08 100644 --- a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java +++ b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java @@ -131,8 +131,8 @@ protected NDList forwardInternal( /** {@inheritDoc} */ @Override - public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) { - return new Shape[0]; + public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) { + throw new IllegalStateException("TfSymbolBlock can't be initialized"); } /** {@inheritDoc} */