Skip to content

Commit

Permalink
Use builder pattern for Parameter (#661)
Browse files Browse the repository at this point in the history
  • Loading branch information
stu1130 committed Feb 25, 2021
1 parent b4a93e4 commit 8e4211e
Show file tree
Hide file tree
Showing 13 changed files with 268 additions and 154 deletions.
211 changes: 147 additions & 64 deletions api/src/main/java/ai/djl/nn/Parameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,61 +43,22 @@ public class Parameter implements AutoCloseable {
private String id;
private String name;
private Block block;
private ParameterType type;
private DataType mandatoryDataType;
private Type type;
private Initializer initializer;
private NDArray array;
private boolean requiresGrad;
private SparseFormat gradientFormat;

/**
* Creates a {@code Parameter} with the given name, and parameter type, and associated with the
* given {@link Block}.
*
* @param name the name of the {@code Parameter}
* @param block the block with which this {@code Parameter} is associated
* @param type the type of this {@code Parameter}
*/
public Parameter(String name, Block block, ParameterType type) {
this(name, block, type, true, SparseFormat.DENSE);
}

/**
* Creates a {@code Parameter} with the given name, and parameter type, and associated with the
* given {@link Block}.
*
* @param name the name of the {@code Parameter}
* @param block the block with which this {@code Parameter} is associated
* @param type the type of this {@code Parameter}
* @param requiresGrad whether this {@code Parameter} needs to compute gradients
*/
public Parameter(String name, Block block, ParameterType type, boolean requiresGrad) {
this(name, block, type, requiresGrad, SparseFormat.DENSE);
}

/**
* Creates a {@code Parameter} with the given name, and parameter type, and associated with the
* given {@link Block}.
*
* @param name the name of the {@code Parameter}
* @param block the block with which this {@code Parameter} is associated
* @param type the type of this {@code Parameter}
* @param requireGrad whether this {@code Parameter} needs to compute gradients
* @param gradientFormat the {@link SparseFormat} of the gradient array
*/
public Parameter(
String name,
Block block,
ParameterType type,
boolean requireGrad,
SparseFormat gradientFormat) {
Parameter(Builder builder) {
this.id = UUID.randomUUID().toString();
this.name = name;
this.block = block;
this.type = type;
this.requiresGrad = requireGrad;
this.initializer = type.getInitializer();
this.gradientFormat = gradientFormat;
this.name = builder.name;
this.block = builder.block;
this.type = builder.type;
this.array = builder.array;
this.requiresGrad = builder.requiresGrad;
this.initializer =
(builder.initializer != null) ? builder.initializer : type.getInitializer();
this.gradientFormat = builder.gradientFormat;
}

/**
Expand All @@ -123,7 +84,7 @@ public String getName() {
*
* @return the type of this {@code Parameter}
*/
public ParameterType getType() {
public Type getType() {
return type;
}

Expand Down Expand Up @@ -158,15 +119,6 @@ public boolean requireGradient() {
return requiresGrad;
}

/**
* Sets the mandatory data type for this {@code Parameter}.
*
* @param mandatoryDataType the mandatory data type for this {@code Parameter}
*/
public void setMandatoryDataType(DataType mandatoryDataType) {
this.mandatoryDataType = mandatoryDataType;
}

/**
* Checks if this {@code Parameter} is initialized.
*
Expand Down Expand Up @@ -201,11 +153,7 @@ public void initialize(NDManager manager, DataType dataType, Shape[] inputShapes
Objects.requireNonNull(initializer, "No initializer has been set");
if (!isInitialized()) {
Shape shape = block.getParameterShape(name, inputShapes);
array =
initializer.initialize(
manager,
shape,
mandatoryDataType == null ? dataType : mandatoryDataType);
array = initializer.initialize(manager, shape, dataType);
array.setName(name);
}

Expand Down Expand Up @@ -276,4 +224,139 @@ public void close() {
array = null;
}
}

/**
* Creates a builder to build a {@code Parameter}.
*
* <p>The methods start with {@code set} are required fields, and {@code opt} for optional
* fields.
*
* @return a new builder
*/
public static Parameter.Builder builder() {
return new Parameter.Builder();
}

/** Enumerates the types of {@link Parameter}. */
public enum Type {
WEIGHT(null),
BIAS(Initializer.ZEROS),
GAMMA(Initializer.ONES),
BETA(Initializer.ZEROS),
RUNNING_MEAN(Initializer.ZEROS),
RUNNING_VAR(Initializer.ONES),
OTHER(null);

private final transient Initializer initializer;

Type(Initializer initializer) {
this.initializer = initializer;
}

/**
* Gets the {@link Initializer} of this {@code ParameterType}.
*
* @return the {@link Initializer} of this {@code ParameterType}
*/
public Initializer getInitializer() {
return initializer;
}
}

/** A Builder to construct a {@code Parameter}. */
public static final class Builder {
String name;
Block block;
Type type;
Initializer initializer;
NDArray array;
boolean requiresGrad = true;
SparseFormat gradientFormat;

/**
* Sets the name of the {@code Parameter}.
*
* @param name the name of the {@code Parameter}
* @return this {@code Parameter}
*/
public Builder setName(String name) {
this.name = name;
return this;
}

/**
* Sets the block of the {@code Parameter}.
*
* @param block the block of the {@code Parameter}
* @return this {@code Parameter}
*/
public Builder setBlock(Block block) {
this.block = block;
return this;
}

/**
* Sets the {@code Type} of the {@code Parameter}.
*
* @param type the {@code Type} of the {@code Parameter}
* @return this {@code Parameter}
*/
public Builder setType(Type type) {
this.type = type;
return this;
}

/**
* Sets the Initializer of the {@code Parameter}.
*
* @param initializer the Initializer of the {@code Parameter}
* @return this {@code Parameter}
*/
public Builder optInitializer(Initializer initializer) {
this.initializer = initializer;
return this;
}

/**
* Sets the array of the {@code Parameter}.
*
* @param array the array of the {@code Parameter}
* @return this {@code Parameter}
*/
public Builder optArray(NDArray array) {
this.array = array;
return this;
}

/**
* Sets if the {@code Parameter} requires gradient.
*
* @param requiresGrad if the {@code Parameter} requires gradient
* @return this {@code Parameter}
*/
public Builder optRequiresGrad(boolean requiresGrad) {
this.requiresGrad = requiresGrad;
return this;
}

/**
* Sets the {@link SparseFormat} of the {@code Parameter}.
*
* @param gradientFormat the {@link SparseFormat} of the {@code Parameter}
* @return this {@code Parameter}
*/
public Builder optGradientFormat(SparseFormat gradientFormat) {
this.gradientFormat = gradientFormat;
return this;
}

/**
* Builds a {@code Parameter} instance.
*
* @return the {@code Parameter} instance
*/
public Parameter build() {
return new Parameter(this);
}
}
}
41 changes: 0 additions & 41 deletions api/src/main/java/ai/djl/nn/ParameterType.java

This file was deleted.

14 changes: 11 additions & 3 deletions api/src/main/java/ai/djl/nn/convolutional/Convolution.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
Expand Down Expand Up @@ -102,13 +101,22 @@ public Convolution(ConvolutionBuilder<?> builder) {

weight =
addParameter(
new Parameter("weight", this, ParameterType.WEIGHT),
Parameter.builder()
.setName("weight")
.setBlock(this)
.setType(Parameter.Type.WEIGHT)
.build(),
(inputShapes) ->
new Shape(filters, inputShapes[0].get(1)).addAll(kernelShape));
if (includeBias) {
bias =
addParameter(
new Parameter("bias", this, ParameterType.BIAS), new Shape(filters));
Parameter.builder()
.setName("bias")
.setBlock(this)
.setType(Parameter.Type.BIAS)
.build(),
new Shape(filters));
}
}

Expand Down
14 changes: 11 additions & 3 deletions api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
Expand Down Expand Up @@ -78,13 +77,22 @@ public Deconvolution(DeconvolutionBuilder<?> builder) {

weight =
addParameter(
new Parameter("weight", this, ParameterType.WEIGHT),
Parameter.builder()
.setName("weight")
.setBlock(this)
.setType(Parameter.Type.WEIGHT)
.build(),
(inputShapes) ->
new Shape(filters, inputShapes[0].get(1)).addAll(kernelShape));
if (includeBias) {
bias =
addParameter(
new Parameter("bias", this, ParameterType.BIAS), new Shape(filters));
Parameter.builder()
.setName("bias")
.setBlock(this)
.setType(Parameter.Type.BIAS)
.build(),
new Shape(filters));
}
}

Expand Down
Loading

0 comments on commit 8e4211e

Please sign in to comment.