From caa576242d3fc3c0321f8c406fd5cfe511ff302d Mon Sep 17 00:00:00 2001 From: Jake Lee Date: Fri, 19 Feb 2021 10:22:20 -0800 Subject: [PATCH] Make XavierInitializer default value & Improve setInitializer (#664) --- .../main/java/ai/djl/nn/AbstractBlock.java | 24 ++++++--- api/src/main/java/ai/djl/nn/Block.java | 14 ++++- api/src/main/java/ai/djl/nn/Parameter.java | 14 +++-- api/src/main/java/ai/djl/nn/core/Prelu.java | 2 +- .../djl/training/DefaultTrainingConfig.java | 51 ++++++++++++++----- .../java/ai/djl/training/TrainingConfig.java | 7 ++- .../initializer/XavierInitializer.java | 2 +- .../basicdataset/AirfoilRandomAccessTest.java | 3 +- .../basicdataset/AmesRandomAccessTest.java | 3 +- .../java/ai/djl/basicdataset/PikachuTest.java | 3 +- .../examples/training/TrainBertOnCode.java | 4 +- .../examples/training/TrainMnistWithLSTM.java | 2 - .../ai/djl/fasttext/FtTrainingConfig.java | 5 +- .../modality/cv/SingleShotDetectionTest.java | 2 - .../modality/nlp/SimpleTextEncoderTest.java | 2 - .../model_zoo/classification/AlexNetTest.java | 4 +- .../classification/GoogLeNetTest.java | 4 +- .../model_zoo/classification/LeNetTest.java | 4 +- .../model_zoo/classification/NiNTest.java | 4 +- .../model_zoo/classification/ResnetTest.java | 4 +- .../classification/SqueezenetTest.java | 2 +- .../model_zoo/classification/VGGTest.java | 5 +- .../NDArrayElementArithmeticOpTest.java | 5 +- .../integration/tests/nn/BlockCoreTest.java | 44 ++++++++++------ .../tests/nn/PoolingOperationsTest.java | 4 +- .../ScaledDotProductAttentionBlockTest.java | 3 +- .../tests/training/ActivationTest.java | 4 +- .../tests/training/BlocksTest.java | 4 +- .../tests/training/DatasetTest.java | 4 +- .../GradientCollectorIntegrationTest.java | 5 +- .../integration/tests/training/ModelTest.java | 2 - .../tests/training/OptimizerTest.java | 17 ++++--- .../java/ai/djl/mxnet/engine/MxModel.java | 11 +++- .../ai/djl/mxnet/engine/MxSymbolBlock.java | 2 + .../MxGradientCollectorIntegrationTest.java | 3 +- .../mxnet/integration/MxSymbolBlockTest.java | 8 +-- .../java/ai/djl/pytorch/engine/PtModel.java | 12 ++++- 37 files changed, 190 insertions(+), 103 deletions(-) diff --git a/api/src/main/java/ai/djl/nn/AbstractBlock.java b/api/src/main/java/ai/djl/nn/AbstractBlock.java index 902f1ef7870..96dc366901e 100644 --- a/api/src/main/java/ai/djl/nn/AbstractBlock.java +++ b/api/src/main/java/ai/djl/nn/AbstractBlock.java @@ -31,6 +31,7 @@ import java.util.List; import java.util.Locale; import java.util.function.Function; +import java.util.function.Predicate; /** * {@code AbstractBlock} is an abstract implementation of {@link Block}. @@ -242,13 +243,9 @@ public PairList describeInput() { /** {@inheritDoc} */ @Override - public void setInitializer(Initializer initializer) { - for (Parameter parameter : parameters.values()) { - parameter.setInitializer(initializer, false); - } - for (Block child : children.values()) { - child.setInitializer(initializer); - } + public void setInitializer(Initializer initializer, Parameter.Type params) { + Predicate predicate = parameter -> parameter.getType().equals(params); + setInitializer(initializer, predicate); } /** {@inheritDoc} */ @@ -258,7 +255,18 @@ public void setInitializer(Initializer initializer, String paramName) { if (parameter == null) { throw new IllegalArgumentException("Could not find parameter " + paramName); } - parameter.setInitializer(initializer, true); + parameter.setInitializer(initializer); + } + + /** {@inheritDoc} */ + @Override + public void setInitializer(Initializer initializer, Predicate predicate) { + List params = getParameters().values(); + for (Parameter param : params) { + if (predicate.test(param)) { + param.setInitializer(initializer); + } + } } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/nn/Block.java b/api/src/main/java/ai/djl/nn/Block.java index 6febeff5f6f..2b817ac9a92 100644 --- a/api/src/main/java/ai/djl/nn/Block.java +++ b/api/src/main/java/ai/djl/nn/Block.java @@ -25,6 +25,7 @@ import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; +import java.util.function.Predicate; /** * A {@code Block} is a composable function that forms a neural network. @@ -158,11 +159,12 @@ default NDList forward( } /** - * Sets an {@link Initializer} to the block. + * Sets an {@link Initializer} to all the parameters that match parameter type in the block. * * @param initializer the initializer to set + * @param type the Parameter Type we want to setInitializer */ - void setInitializer(Initializer initializer); + void setInitializer(Initializer initializer, Parameter.Type type); /** * Sets an {@link Initializer} to the specified direct parameter of the block, overriding the @@ -173,6 +175,14 @@ default NDList forward( */ void setInitializer(Initializer initializer, String paramName); + /** + * Sets an {@link Initializer} to all the parameters that match Predicate in the block. + * + * @param initializer the initializer to be set + * @param predicate predicate function to indicate parameters you want to set + */ + void setInitializer(Initializer initializer, Predicate predicate); + /** * Initializes the parameters of the block. This method must be called before calling `forward`. * diff --git a/api/src/main/java/ai/djl/nn/Parameter.java b/api/src/main/java/ai/djl/nn/Parameter.java index 5b6a13ecb01..b85136f69b5 100644 --- a/api/src/main/java/ai/djl/nn/Parameter.java +++ b/api/src/main/java/ai/djl/nn/Parameter.java @@ -19,10 +19,10 @@ import ai.djl.ndarray.types.Shape; import ai.djl.ndarray.types.SparseFormat; import ai.djl.training.initializer.Initializer; +import ai.djl.training.initializer.XavierInitializer; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; -import java.util.Objects; import java.util.UUID; /** @@ -133,12 +133,9 @@ public boolean isInitialized() { * flag is true, sets the initializer regardless. * * @param initializer the initializer to be set - * @param overwrite if true, set the initializer regardless of whether its already set or not */ - public void setInitializer(Initializer initializer, boolean overwrite) { - if (overwrite || this.initializer == null) { - this.initializer = initializer; - } + public void setInitializer(Initializer initializer) { + this.initializer = initializer; } /** @@ -150,7 +147,6 @@ public void setInitializer(Initializer initializer, boolean overwrite) { * @param inputShapes the expected input shapes */ public void initialize(NDManager manager, DataType dataType, Shape[] inputShapes) { - Objects.requireNonNull(initializer, "No initializer has been set"); if (!isInitialized()) { Shape shape = block.getParameterShape(name, inputShapes); array = initializer.initialize(manager, shape, dataType); @@ -239,7 +235,9 @@ public static Parameter.Builder builder() { /** Enumerates the types of {@link Parameter}. */ public enum Type { - WEIGHT(null), + WEIGHT( + new XavierInitializer( + XavierInitializer.RandomType.GAUSSIAN, XavierInitializer.FactorType.IN, 2)), BIAS(Initializer.ZEROS), GAMMA(Initializer.ONES), BETA(Initializer.ZEROS), diff --git a/api/src/main/java/ai/djl/nn/core/Prelu.java b/api/src/main/java/ai/djl/nn/core/Prelu.java index de125d6d648..01437bc470f 100644 --- a/api/src/main/java/ai/djl/nn/core/Prelu.java +++ b/api/src/main/java/ai/djl/nn/core/Prelu.java @@ -48,7 +48,7 @@ public Prelu() { Parameter.builder() .setName("alpha") .setBlock(this) - .setType(Parameter.Type.OTHER) + .setType(Parameter.Type.WEIGHT) .build(), new Shape()); } diff --git a/api/src/main/java/ai/djl/training/DefaultTrainingConfig.java b/api/src/main/java/ai/djl/training/DefaultTrainingConfig.java index 7905d0c0363..afc42564907 100644 --- a/api/src/main/java/ai/djl/training/DefaultTrainingConfig.java +++ b/api/src/main/java/ai/djl/training/DefaultTrainingConfig.java @@ -13,23 +13,23 @@ package ai.djl.training; import ai.djl.Device; +import ai.djl.nn.Parameter; import ai.djl.training.evaluator.Evaluator; import ai.djl.training.initializer.Initializer; -import ai.djl.training.initializer.XavierInitializer; -import ai.djl.training.initializer.XavierInitializer.FactorType; -import ai.djl.training.initializer.XavierInitializer.RandomType; import ai.djl.training.listener.TrainingListener; import ai.djl.training.loss.Loss; import ai.djl.training.optimizer.Adam; import ai.djl.training.optimizer.Optimizer; +import ai.djl.util.PairList; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.function.Predicate; /** {@code DefaultTrainingConfig} is an implementation of the {@link TrainingConfig} interface. */ public class DefaultTrainingConfig implements TrainingConfig { - private Initializer initializer; + private PairList> initializers = new PairList<>(); private Optimizer optimizer; private Device[] devices; private Loss loss; @@ -39,15 +39,12 @@ public class DefaultTrainingConfig implements TrainingConfig { /** * Creates an instance of {@code DefaultTrainingConfig} with the given {@link Loss}. {@code - * DefaultTrainingConfig} creates a default {@link TrainingConfig} with the {@link - * XavierInitializer} as initialiser, {@link Adam} as optimiser, and the given {@link Loss}. The - * evaluators and listeners are left to the user's discretion. + * DefaultTrainingConfig} creates a default {@link TrainingConfig}, {@link Adam} as optimiser, + * and the given {@link Loss}. The evaluators and listeners are left to the user's discretion. * * @param loss the loss to use for training */ public DefaultTrainingConfig(Loss loss) { - // Defaults to initializer defined in https://arxiv.org/abs/1502.01852 - this.initializer = new XavierInitializer(RandomType.GAUSSIAN, FactorType.IN, 2); optimizer = Adam.builder().build(); this.loss = loss; dataManager = DataManager.DEFAULT_DATA_MANAGER; @@ -60,10 +57,38 @@ public DefaultTrainingConfig(Loss loss) { * href="https://arxiv.org/abs/1502.01852">paper). * * @param initializer the initialer to use for the parameters + * @param type the {@link Parameter.Type} of the parameters * @return this {@code DefaultTrainingConfig} */ - public DefaultTrainingConfig optInitializer(Initializer initializer) { - this.initializer = initializer; + public DefaultTrainingConfig optInitializer(Initializer initializer, Parameter.Type type) { + initializers.add(initializer, parameter -> parameter.getType().equals(type)); + return this; + } + + /** + * Sets the {@link Initializer} to use for the parameters (default from paper). + * + * @param initializer the initialer to use for the parameters + * @param name the name of the parameter + * @return this {@code DefaultTrainingConfig} + */ + public DefaultTrainingConfig optInitializer(Initializer initializer, String name) { + initializers.add(initializer, parameter -> parameter.getName().equals(name)); + return this; + } + + /** + * Sets the {@link Initializer} to use for the parameters (default from paper). + * + * @param initializer the initialer to use for the parameters + * @param predicate the predicate to identify parameter + * @return this {@code DefaultTrainingConfig} + */ + public DefaultTrainingConfig optInitializer( + Initializer initializer, Predicate predicate) { + initializers.add(initializer, predicate); return this; } @@ -133,8 +158,8 @@ public Device[] getDevices() { /** {@inheritDoc} */ @Override - public Initializer getInitializer() { - return initializer; + public PairList> getInitializers() { + return initializers; } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/training/TrainingConfig.java b/api/src/main/java/ai/djl/training/TrainingConfig.java index 03be4d33a60..583176bd01c 100644 --- a/api/src/main/java/ai/djl/training/TrainingConfig.java +++ b/api/src/main/java/ai/djl/training/TrainingConfig.java @@ -13,12 +13,15 @@ package ai.djl.training; import ai.djl.Device; +import ai.djl.nn.Parameter; import ai.djl.training.evaluator.Evaluator; import ai.djl.training.initializer.Initializer; import ai.djl.training.listener.TrainingListener; import ai.djl.training.loss.Loss; import ai.djl.training.optimizer.Optimizer; +import ai.djl.util.PairList; import java.util.List; +import java.util.function.Predicate; /** * An interface that is responsible for holding the configuration required by {@link Trainer}. @@ -64,11 +67,11 @@ public interface TrainingConfig { Device[] getDevices(); /** - * Gets the {@link Initializer} to initialize the parameters of the model. + * Gets a list of {@link Initializer} and Predicate to initialize the parameters of the model. * * @return an {@link Initializer} */ - Initializer getInitializer(); + PairList> getInitializers(); /** * Gets the {@link Optimizer} to use during training. diff --git a/api/src/main/java/ai/djl/training/initializer/XavierInitializer.java b/api/src/main/java/ai/djl/training/initializer/XavierInitializer.java index 64ee6279f71..8cf76c7b8f0 100644 --- a/api/src/main/java/ai/djl/training/initializer/XavierInitializer.java +++ b/api/src/main/java/ai/djl/training/initializer/XavierInitializer.java @@ -80,7 +80,7 @@ public XavierInitializer(RandomType randomType, FactorType factorType, float mag /** Creates a new instance of {@code XavierInitializer}. */ public XavierInitializer() { - this(RandomType.UNIFORM, FactorType.AVG, 3f); + this(RandomType.UNIFORM, FactorType.AVG, 6f); } /** {@inheritDoc} */ diff --git a/basicdataset/src/test/java/ai/djl/basicdataset/AirfoilRandomAccessTest.java b/basicdataset/src/test/java/ai/djl/basicdataset/AirfoilRandomAccessTest.java index 3efdd00af77..dc8028c5c92 100644 --- a/basicdataset/src/test/java/ai/djl/basicdataset/AirfoilRandomAccessTest.java +++ b/basicdataset/src/test/java/ai/djl/basicdataset/AirfoilRandomAccessTest.java @@ -17,6 +17,7 @@ import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.nn.Blocks; +import ai.djl.nn.Parameter; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.Trainer; import ai.djl.training.TrainingConfig; @@ -48,7 +49,7 @@ public class AirfoilRandomAccessTest { public void testAirfoilRemote() throws IOException, TranslateException { TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) - .optInitializer(Initializer.ONES); + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); try (Model model = Model.newInstance("model")) { model.setBlock(Blocks.identityBlock()); diff --git a/basicdataset/src/test/java/ai/djl/basicdataset/AmesRandomAccessTest.java b/basicdataset/src/test/java/ai/djl/basicdataset/AmesRandomAccessTest.java index 20bf7f8ac56..716f04c578f 100644 --- a/basicdataset/src/test/java/ai/djl/basicdataset/AmesRandomAccessTest.java +++ b/basicdataset/src/test/java/ai/djl/basicdataset/AmesRandomAccessTest.java @@ -17,6 +17,7 @@ import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.nn.Blocks; +import ai.djl.nn.Parameter; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.Trainer; import ai.djl.training.TrainingConfig; @@ -48,7 +49,7 @@ public class AmesRandomAccessTest { public void testAmesRandomAccessRemote() throws IOException, TranslateException { TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) - .optInitializer(Initializer.ONES); + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); try (Model model = Model.newInstance("model")) { model.setBlock(Blocks.identityBlock()); diff --git a/basicdataset/src/test/java/ai/djl/basicdataset/PikachuTest.java b/basicdataset/src/test/java/ai/djl/basicdataset/PikachuTest.java index 75bcb3e914a..063ec07fb13 100644 --- a/basicdataset/src/test/java/ai/djl/basicdataset/PikachuTest.java +++ b/basicdataset/src/test/java/ai/djl/basicdataset/PikachuTest.java @@ -16,6 +16,7 @@ import ai.djl.basicdataset.cv.PikachuDetection; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Blocks; +import ai.djl.nn.Parameter; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.Trainer; import ai.djl.training.TrainingConfig; @@ -42,7 +43,7 @@ public void testPikachuRemote() throws IOException, TranslateException { .build(); TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) - .optInitializer(new NormalInitializer(0.01f)); + .optInitializer(new NormalInitializer(0.01f), Parameter.Type.WEIGHT); try (Model model = Model.newInstance("model")) { model.setBlock(Blocks.identityBlock()); try (Trainer trainer = model.newTrainer(config)) { diff --git a/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java b/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java index 53eca2e14b4..7d1cda2adc2 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java @@ -20,6 +20,7 @@ import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Parameter; import ai.djl.nn.transformer.BertBlock; import ai.djl.nn.transformer.BertPretrainingBlock; import ai.djl.nn.transformer.BertPretrainingLoss; @@ -135,7 +136,8 @@ private static Model createBertPretrainingModel(Dictionary dictionary) { model.setBlock( new BertPretrainingBlock( BERT_BUILDER.setTokenDictionarySize(dictionary.tokens.size()))); - model.getBlock().setInitializer(new TruncatedNormalInitializer(0.02f)); + model.getBlock() + .setInitializer(new TruncatedNormalInitializer(0.02f), Parameter.Type.WEIGHT); return model; } diff --git a/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java b/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java index 13ce0f2f742..76078106a22 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java @@ -31,7 +31,6 @@ import ai.djl.training.dataset.Dataset; import ai.djl.training.dataset.RandomAccessDataset; import ai.djl.training.evaluator.Accuracy; -import ai.djl.training.initializer.XavierInitializer; import ai.djl.training.listener.SaveModelTrainingListener; import ai.djl.training.listener.TrainingListener; import ai.djl.training.loss.Loss; @@ -119,7 +118,6 @@ public static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) { return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .addEvaluator(new Accuracy()) - .optInitializer(new XavierInitializer()) .optDevices(Device.getDevices(arguments.getMaxGpus())) .addTrainingListeners(TrainingListener.Defaults.logging(outputDir)) .addTrainingListeners(listener); diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtTrainingConfig.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtTrainingConfig.java index b665106607e..8c8f928e95d 100644 --- a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtTrainingConfig.java +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtTrainingConfig.java @@ -13,6 +13,7 @@ package ai.djl.fasttext; import ai.djl.Device; +import ai.djl.nn.Parameter; import ai.djl.training.DataManager; import ai.djl.training.TrainingConfig; import ai.djl.training.evaluator.Evaluator; @@ -20,9 +21,11 @@ import ai.djl.training.listener.TrainingListener; import ai.djl.training.loss.Loss; import ai.djl.training.optimizer.Optimizer; +import ai.djl.util.PairList; import java.nio.file.Path; import java.util.ArrayList; import java.util.List; +import java.util.function.Predicate; /** An interface that is responsible for holding the configuration required by fastText training. */ public class FtTrainingConfig implements TrainingConfig { @@ -248,7 +251,7 @@ public Device[] getDevices() { /** {@inheritDoc} */ @Override - public Initializer getInitializer() { + public PairList> getInitializers() { return null; } diff --git a/integration/src/main/java/ai/djl/integration/tests/modality/cv/SingleShotDetectionTest.java b/integration/src/main/java/ai/djl/integration/tests/modality/cv/SingleShotDetectionTest.java index f105d46c8cd..2c3ccda8400 100644 --- a/integration/src/main/java/ai/djl/integration/tests/modality/cv/SingleShotDetectionTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/modality/cv/SingleShotDetectionTest.java @@ -20,7 +20,6 @@ import ai.djl.nn.Block; import ai.djl.nn.SequentialBlock; import ai.djl.training.ParameterStore; -import ai.djl.training.initializer.XavierInitializer; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -97,7 +96,6 @@ public void testSingleShotDetectionShape() { .setSizes(sizes) .setBaseNetwork(block) .build(); - ssd.setInitializer(new XavierInitializer()); ssd.initialize(manager, DataType.FLOAT32, new Shape(32, 3, 256, 256)); ParameterStore ps = new ParameterStore(manager, false); NDList output = diff --git a/integration/src/main/java/ai/djl/integration/tests/modality/nlp/SimpleTextEncoderTest.java b/integration/src/main/java/ai/djl/integration/tests/modality/nlp/SimpleTextEncoderTest.java index 2c64703212a..910804a2e03 100644 --- a/integration/src/main/java/ai/djl/integration/tests/modality/nlp/SimpleTextEncoderTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/modality/nlp/SimpleTextEncoderTest.java @@ -23,7 +23,6 @@ import ai.djl.ndarray.types.Shape; import ai.djl.nn.recurrent.LSTM; import ai.djl.training.ParameterStore; -import ai.djl.training.initializer.XavierInitializer; import java.util.Arrays; import org.testng.Assert; import org.testng.annotations.Test; @@ -50,7 +49,6 @@ public void testEncoder() { .optReturnState(true) .build()); try (NDManager manager = NDManager.newBaseManager(TestUtils.getDevices()[0])) { - encoder.setInitializer(new XavierInitializer()); encoder.initialize(manager, DataType.FLOAT32, new Shape(4, 7)); NDList output = encoder.forward( diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/AlexNetTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/AlexNetTest.java index f956bfb46bc..64ce263dc2b 100644 --- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/AlexNetTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/AlexNetTest.java @@ -159,7 +159,7 @@ public void testOutputShapes() { Shape currentShape = x.getShape(); Block alexNet = AlexNet.builder().build(); - alexNet.setInitializer(Initializer.ONES); + alexNet.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); alexNet.initialize(manager, DataType.FLOAT32, currentShape); Map shapeMap = new ConcurrentHashMap<>(); @@ -188,7 +188,7 @@ public void testForwardMethod() { Block alexNet = AlexNet.builder().build(); int batchSize = 1; NDArray x = manager.ones(new Shape(batchSize, 1, 224, 224)); - alexNet.setInitializer(Initializer.ONES); + alexNet.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); alexNet.initialize(manager, DataType.FLOAT32, x.getShape()); NDArray xHat = alexNet.forward(new ParameterStore(manager, true), new NDList(x), false) diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/GoogLeNetTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/GoogLeNetTest.java index a149b7e40e3..5484fe2ff00 100644 --- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/GoogLeNetTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/GoogLeNetTest.java @@ -100,7 +100,7 @@ public void testOutputShapes() { Shape currentShape = x.getShape(); Block googLeNet = GoogLeNet.builder().build(); - googLeNet.setInitializer(Initializer.ONES); + googLeNet.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); googLeNet.initialize(manager, DataType.FLOAT32, currentShape); Map shapeMap = new ConcurrentHashMap<>(); @@ -130,7 +130,7 @@ public void testForwardMethod() { Block googLeNet = GoogLeNet.builder().build(); int batchSize = 1; NDArray x = manager.ones(new Shape(batchSize, 1, 28, 28)); - googLeNet.setInitializer(Initializer.ONES); + googLeNet.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); googLeNet.initialize(manager, DataType.FLOAT32, x.getShape()); NDArray xHat = googLeNet diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/LeNetTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/LeNetTest.java index 39744eeefe6..6c93d49e63c 100644 --- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/LeNetTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/LeNetTest.java @@ -137,7 +137,7 @@ public void testOutputShapes() { Shape currentShape = x.getShape(); Block leNet = LeNet.builder().build(); - leNet.setInitializer(Initializer.ONES); + leNet.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); leNet.initialize(manager, DataType.FLOAT32, currentShape); Map shapeMap = new ConcurrentHashMap<>(); @@ -165,7 +165,7 @@ public void testForwardMethod() { Block leNet = LeNet.builder().build(); int batchSize = 1; NDArray x = manager.ones(new Shape(batchSize, 1, 28, 28)); - leNet.setInitializer(Initializer.ONES); + leNet.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); leNet.initialize(manager, DataType.FLOAT32, x.getShape()); NDArray xHat = leNet.forward(new ParameterStore(manager, true), new NDList(x), true) diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/NiNTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/NiNTest.java index cac77221d69..cc2716573d2 100644 --- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/NiNTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/NiNTest.java @@ -152,7 +152,7 @@ public void testOutputShapes() { Shape currentShape = x.getShape(); Block nin = NiN.builder().build(); - nin.setInitializer(Initializer.ONES); + nin.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); nin.initialize(manager, DataType.FLOAT32, currentShape); Map shapeMap = new ConcurrentHashMap<>(); @@ -180,7 +180,7 @@ public void testForwardMethod() { Block nin = NiN.builder().build(); int batchSize = 1; NDArray x = manager.ones(new Shape(batchSize, 1, 224, 224)); - nin.setInitializer(Initializer.ONES); + nin.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); nin.initialize(manager, DataType.FLOAT32, x.getShape()); NDArray xHat = nin.forward(new ParameterStore(manager, true), new NDList(x), false) diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/ResnetTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/ResnetTest.java index fd50e39943f..dfaf752901f 100644 --- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/ResnetTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/ResnetTest.java @@ -57,7 +57,7 @@ public void testTrain() { TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .optDevices(Device.getDevices(2)) - .optInitializer(Initializer.ONES); + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); Block resNet50 = ResNetV1.builder() @@ -123,7 +123,7 @@ public void testLoadTrain() TrainingConfig config = new DefaultTrainingConfig(Loss.l1Loss()) .optDevices(Device.getDevices(2)) - .optInitializer(Initializer.ONES); + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); try (Trainer trainer = model.newTrainer(config)) { int batchSize = 2; Shape inputShape = new Shape(batchSize, 3, 32, 32); diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/SqueezenetTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/SqueezenetTest.java index 267bd41cf80..114907144aa 100644 --- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/SqueezenetTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/SqueezenetTest.java @@ -40,7 +40,7 @@ public void testTrain() { TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .optDevices(Device.getDevices(2)) - .optInitializer(Initializer.ONES); + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); Block squeezeNet = SqueezeNet.squeezenet(10); try (Model model = Model.newInstance("squeezenet")) { model.setBlock(squeezeNet); diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/VGGTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/VGGTest.java index b8520edd8f9..a0b677bc815 100644 --- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/VGGTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/VGGTest.java @@ -107,7 +107,7 @@ public void testOutputShapes() { Shape currentShape = x.getShape(); Block vgg = VGG.builder().build(); - vgg.setInitializer(Initializer.ONES); + vgg.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); vgg.initialize(manager, DataType.FLOAT32, currentShape); Map shapeMap = new ConcurrentHashMap<>(); @@ -137,8 +137,9 @@ public void testForwardMethod() { Block vgg = VGG.builder().build(); int batchSize = 1; NDArray x = manager.ones(new Shape(batchSize, 1, 224, 224)); - vgg.setInitializer(Initializer.ONES); + vgg.setInitializer(Initializer.ONES, Parameter.Type.WEIGHT); vgg.initialize(manager, DataType.FLOAT32, x.getShape()); + NDArray xHat = vgg.forward(new ParameterStore(manager, true), new NDList(x), false) .singletonOrThrow(); diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayElementArithmeticOpTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayElementArithmeticOpTest.java index 8e97ed38b87..2dbfac8ef2e 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayElementArithmeticOpTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayElementArithmeticOpTest.java @@ -18,6 +18,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Blocks; +import ai.djl.nn.Parameter; import ai.djl.testing.Assertions; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.GradientCollector; @@ -139,7 +140,7 @@ public void testAddScalar() { try (Trainer trainer = model.newTrainer( new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES))) { + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT))) { try (GradientCollector gradCol = trainer.newGradientCollector()) { lhs.attachGradient(); result = NDArrays.add(lhs, 2); @@ -360,7 +361,7 @@ public void testDot() { try (Trainer trainer = model.newTrainer( new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES))) { + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT))) { try (GradientCollector gradCol = trainer.newGradientCollector()) { lhs.attachGradient(); result = NDArrays.dot(lhs, rhs); diff --git a/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java b/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java index 46767e14057..32faa52ffa8 100644 --- a/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java @@ -63,7 +63,8 @@ public class BlockCoreTest { @Test public void testLinear() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); long outSize = 3; Block block = Linear.builder().setUnits(outSize).build(); @@ -125,7 +126,8 @@ public void testLinear() throws IOException, MalformedModelException { @Test public void testLinearWithDefinedLayout() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); long outSize = 3; Block block = Linear.builder().setUnits(outSize).build(); @@ -177,7 +179,8 @@ public void testLinearWithDefinedLayout() throws IOException, MalformedModelExce @Test public void testBatchNorm() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); Block block = BatchNorm.builder().build(); try (Model model = Model.newInstance("model")) { @@ -204,7 +207,8 @@ public void testBatchNorm() throws IOException, MalformedModelException { @Test public void testDropout() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); Block block = Dropout.builder().optRate(.5f).build(); try (Model model = Model.newInstance("model")) { @@ -230,7 +234,8 @@ public void testDropout() throws IOException, MalformedModelException { @Test public void testEmbedding() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); TrainableWordEmbedding block = TrainableWordEmbedding.builder() @@ -264,7 +269,8 @@ public void testEmbedding() throws IOException, MalformedModelException { @Test public void testConv1d() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); Block block = Conv1d.builder().setKernelShape(new Shape(2)).setFilters(1).optBias(false).build(); @@ -296,7 +302,8 @@ public void testConv1d() throws IOException, MalformedModelException { @Test public void testConv1dTranspose() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); Block block = Conv1dTranspose.builder() @@ -330,7 +337,8 @@ public void testConv1dTranspose() throws IOException, MalformedModelException { @Test public void testConv2d() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); Block block = Conv2d.builder().setKernelShape(new Shape(2, 2)).setFilters(1).build(); try (Model model = Model.newInstance("model")) { @@ -361,7 +369,8 @@ public void testConv2d() throws IOException, MalformedModelException { @Test public void testConv2dTranspose() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); Block block = Conv2dTranspose.builder().setKernelShape(new Shape(2, 2)).setFilters(1).build(); @@ -398,7 +407,8 @@ public void testConv2dTranspose() throws IOException, MalformedModelException { @Test public void testConv3d() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); Block block = Conv3d.builder().setKernelShape(new Shape(2, 2, 2)).setFilters(1).build(); try (Model model = Model.newInstance("model")) { @@ -439,7 +449,7 @@ public void testRNNTanh() throws IOException, MalformedModelException { Loss loss = new SoftmaxCrossEntropyLoss("SmCeLoss", 1, -1, false, true); TrainingConfig config = new DefaultTrainingConfig(loss) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optDevices(TestUtils.getDevices()); Block block = RNN.builder() @@ -486,7 +496,7 @@ public void testRNNRelu() throws IOException, MalformedModelException { Loss loss = new SoftmaxCrossEntropyLoss("SmCeLoss", 1, -1, false, true); TrainingConfig config = new DefaultTrainingConfig(loss) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optDevices(TestUtils.getDevices()); Block block = RNN.builder() @@ -536,7 +546,7 @@ public void testLstm() throws IOException, MalformedModelException { Loss loss = new SoftmaxCrossEntropyLoss("SmCeLoss", 1, -1, false, true); TrainingConfig config = new DefaultTrainingConfig(loss) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optDevices(TestUtils.getDevices()); Block block = LSTM.builder() @@ -587,7 +597,7 @@ public void testGRU() throws IOException, MalformedModelException { Loss loss = new SoftmaxCrossEntropyLoss("SmCeLoss", 1, -1, false, true); TrainingConfig config = new DefaultTrainingConfig(loss) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optDevices(TestUtils.getDevices()); GRU block = GRU.builder() @@ -640,7 +650,8 @@ public void testGRU() throws IOException, MalformedModelException { @Test public void testSequentialBlock() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); SequentialBlock block = new SequentialBlock(); block.addSingleton(x -> x.mul(6.5f)); block.add(Linear.builder().setUnits(10).build()); @@ -680,7 +691,8 @@ public void testSequentialBlock() throws IOException, MalformedModelException { @Test public void testParallelBlock() throws IOException, MalformedModelException { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); ParallelBlock block = new ParallelBlock( list -> diff --git a/integration/src/main/java/ai/djl/integration/tests/nn/PoolingOperationsTest.java b/integration/src/main/java/ai/djl/integration/tests/nn/PoolingOperationsTest.java index 6f5f4e1ebc3..44a4e9816dd 100644 --- a/integration/src/main/java/ai/djl/integration/tests/nn/PoolingOperationsTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/nn/PoolingOperationsTest.java @@ -18,6 +18,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.index.NDIndex; import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Parameter; import ai.djl.nn.pooling.Pool; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.Trainer; @@ -29,7 +30,8 @@ public class PoolingOperationsTest { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); @Test public void testMaxPool1d() { diff --git a/integration/src/main/java/ai/djl/integration/tests/nn/ScaledDotProductAttentionBlockTest.java b/integration/src/main/java/ai/djl/integration/tests/nn/ScaledDotProductAttentionBlockTest.java index 9595e88ef15..df8a64f4803 100644 --- a/integration/src/main/java/ai/djl/integration/tests/nn/ScaledDotProductAttentionBlockTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/nn/ScaledDotProductAttentionBlockTest.java @@ -19,6 +19,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Parameter; import ai.djl.nn.transformer.ScaledDotProductAttentionBlock; import ai.djl.training.GradientCollector; import ai.djl.training.ParameterStore; @@ -752,7 +753,7 @@ public void testMaskedAttention() { .optAttentionProbsDropoutProb(0.0f) .build(); - block.setInitializer(new NormalInitializer()); + block.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT); block.getKeyProjection().setInitializer(keyKernelInitializer, "weight"); block.getValueProjection().setInitializer(valueKernelInitializer, "weight"); block.getQueryProjection().setInitializer(queryKernelInitializer, "weight"); diff --git a/integration/src/main/java/ai/djl/integration/tests/training/ActivationTest.java b/integration/src/main/java/ai/djl/integration/tests/training/ActivationTest.java index 443370b10d8..b3e8502a9ca 100644 --- a/integration/src/main/java/ai/djl/integration/tests/training/ActivationTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/training/ActivationTest.java @@ -18,6 +18,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Activation; +import ai.djl.nn.Parameter; import ai.djl.testing.Assertions; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.Trainer; @@ -30,7 +31,8 @@ public class ActivationTest { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); @Test public void testRelu() { diff --git a/integration/src/main/java/ai/djl/integration/tests/training/BlocksTest.java b/integration/src/main/java/ai/djl/integration/tests/training/BlocksTest.java index a86705162ef..630ad57d3e6 100644 --- a/integration/src/main/java/ai/djl/integration/tests/training/BlocksTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/training/BlocksTest.java @@ -18,6 +18,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Blocks; +import ai.djl.nn.Parameter; import ai.djl.testing.Assertions; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.ParameterStore; @@ -30,7 +31,8 @@ public class BlocksTest { TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); @Test public void testFlattenBlock() { diff --git a/integration/src/main/java/ai/djl/integration/tests/training/DatasetTest.java b/integration/src/main/java/ai/djl/integration/tests/training/DatasetTest.java index 4b5dd466c31..8d06c85efd8 100644 --- a/integration/src/main/java/ai/djl/integration/tests/training/DatasetTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/training/DatasetTest.java @@ -20,6 +20,7 @@ import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Blocks; +import ai.djl.nn.Parameter; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.Trainer; import ai.djl.training.TrainingConfig; @@ -49,7 +50,8 @@ public class DatasetTest { private TrainingConfig config = - new DefaultTrainingConfig(Loss.l2Loss()).optInitializer(Initializer.ONES); + new DefaultTrainingConfig(Loss.l2Loss()) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); @Test public void testSequenceSampler() throws IOException, TranslateException { diff --git a/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java b/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java index 154583e59e3..96680659f78 100644 --- a/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java @@ -19,6 +19,7 @@ import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Blocks; +import ai.djl.nn.Parameter; import ai.djl.nn.core.Linear; import ai.djl.testing.Assertions; import ai.djl.training.DefaultTrainingConfig; @@ -49,7 +50,7 @@ public void testAutograd() { try (Trainer trainer = model.newTrainer( new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES))) { + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT))) { try (GradientCollector gradCol = trainer.newGradientCollector()) { NDArray lhs = manager.create(new float[] {6, -9, -12, 15, 0, 4}, new Shape(2, 3)); @@ -87,7 +88,7 @@ public void testTrain() throws IOException, TranslateException { TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) .addTrainingListeners(new EvaluatorTrainingListener()) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optOptimizer(optimizer); try (Model model = Model.newInstance("linear")) { diff --git a/integration/src/main/java/ai/djl/integration/tests/training/ModelTest.java b/integration/src/main/java/ai/djl/integration/tests/training/ModelTest.java index 2e919ca78d5..6023b1b054f 100644 --- a/integration/src/main/java/ai/djl/integration/tests/training/ModelTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/training/ModelTest.java @@ -21,7 +21,6 @@ import ai.djl.nn.convolutional.Conv2d; import ai.djl.nn.norm.BatchNorm; import ai.djl.testing.Assertions; -import ai.djl.training.initializer.XavierInitializer; import java.io.IOException; import java.nio.file.Paths; import org.testng.Assert; @@ -36,7 +35,6 @@ public void testModelSaveAndLoad() throws IOException, MalformedModelException { block.add(BatchNorm.builder().build()); try (Model saveModel = Model.newInstance("saveModel"); Model loadModel = Model.newInstance("loadModel")) { - block.setInitializer(new XavierInitializer()); block.initialize(saveModel.getNDManager(), DataType.FLOAT32, new Shape(1, 3, 32, 32)); ParameterList savedParameters = block.getParameters(); saveModel.setBlock(block); diff --git a/integration/src/main/java/ai/djl/integration/tests/training/OptimizerTest.java b/integration/src/main/java/ai/djl/integration/tests/training/OptimizerTest.java index 8e472c626f2..5657f584805 100644 --- a/integration/src/main/java/ai/djl/integration/tests/training/OptimizerTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/training/OptimizerTest.java @@ -20,6 +20,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Block; +import ai.djl.nn.Parameter; import ai.djl.nn.core.Linear; import ai.djl.testing.Assertions; import ai.djl.training.DefaultTrainingConfig; @@ -46,7 +47,7 @@ public void testSgd() { Device[] devices = Device.getDevices(1); TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optOptimizer(sgd) .optDevices(devices); Block block = Linear.builder().setUnits(CHANNELS).build(); @@ -78,7 +79,7 @@ public void testSgdWithMomentum() { Device[] devices = Device.getDevices(1); TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optOptimizer(optim) .optDevices(devices); Block block = Linear.builder().setUnits(CHANNELS).build(); @@ -118,7 +119,7 @@ public void testNag() { Device[] devices = Device.getDevices(1); TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optOptimizer(optim) .optDevices(devices); Block block = Linear.builder().setUnits(CHANNELS).build(); @@ -147,7 +148,7 @@ public void testAdam() { Device[] devices = Device.getDevices(1); TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optOptimizer(optim) .optDevices(devices); Block block = Linear.builder().setUnits(CHANNELS).build(); @@ -176,7 +177,7 @@ public void testAdagrad() { Device[] devices = Device.getDevices(1); TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optOptimizer(optim) .optDevices(devices); Block block = Linear.builder().setUnits(CHANNELS).build(); @@ -209,7 +210,7 @@ public void testRMSProp() { Device[] devices = Device.getDevices(1); TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optOptimizer(optim) .optDevices(devices); Block block = Linear.builder().setUnits(CHANNELS).build(); @@ -243,7 +244,7 @@ public void testRMSPropAlex() { Device[] devices = Device.getDevices(1); TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optOptimizer(optim) .optDevices(devices); Block block = Linear.builder().setUnits(CHANNELS).build(); @@ -273,7 +274,7 @@ public void testAdadelta() { Device[] devices = Device.getDevices(1); TrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES) + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT) .optOptimizer(optim) .optDevices(devices); diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java index ee8afeba541..54c58bf1a01 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java @@ -24,6 +24,8 @@ import ai.djl.training.Trainer; import ai.djl.training.TrainingConfig; import ai.djl.training.initializer.Initializer; +import ai.djl.util.Pair; +import ai.djl.util.PairList; import java.io.FileNotFoundException; import java.io.IOException; import java.nio.file.Files; @@ -33,6 +35,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Predicate; import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -122,12 +125,16 @@ public void load(Path modelPath, String prefix, Map options) /** {@inheritDoc} */ @Override public Trainer newTrainer(TrainingConfig trainingConfig) { - Initializer initializer = trainingConfig.getInitializer(); + PairList> initializer = trainingConfig.getInitializers(); if (block == null) { throw new IllegalStateException( "You must set a block for the model before creating a new trainer"); } - block.setInitializer(initializer); + for (Pair> pair : initializer) { + if (pair.getKey() != null && pair.getValue() != null) { + block.setInitializer(pair.getKey(), pair.getValue()); + } + } return new Trainer(this, trainingConfig); } diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java index 50254346689..12adaf04b5c 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java @@ -299,6 +299,8 @@ private static Parameter.Type inferType(String name) { return Parameter.Type.RUNNING_MEAN; } else if (name.endsWith("moving_var") || name.endsWith("running_var")) { return Parameter.Type.RUNNING_VAR; + } else if (name.endsWith("weight")) { + return Parameter.Type.WEIGHT; } return Parameter.Type.OTHER; } diff --git a/mxnet/mxnet-engine/src/test/java/ai/djl/mxnet/integration/MxGradientCollectorIntegrationTest.java b/mxnet/mxnet-engine/src/test/java/ai/djl/mxnet/integration/MxGradientCollectorIntegrationTest.java index 0ba9177404f..299136587c7 100644 --- a/mxnet/mxnet-engine/src/test/java/ai/djl/mxnet/integration/MxGradientCollectorIntegrationTest.java +++ b/mxnet/mxnet-engine/src/test/java/ai/djl/mxnet/integration/MxGradientCollectorIntegrationTest.java @@ -19,6 +19,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Blocks; +import ai.djl.nn.Parameter; import ai.djl.testing.Assertions; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.GradientCollector; @@ -38,7 +39,7 @@ public void testMxAutograd() { try (Trainer trainer = model.newTrainer( new DefaultTrainingConfig(Loss.l2Loss()) - .optInitializer(Initializer.ONES))) { + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT))) { try (GradientCollector gradCol = trainer.newGradientCollector()) { NDArray lhs = manager.create(new float[] {6, -9, -12, 15, 0, 4}, new Shape(2, 3)); diff --git a/mxnet/mxnet-model-zoo/src/test/java/ai/djl/mxnet/integration/MxSymbolBlockTest.java b/mxnet/mxnet-model-zoo/src/test/java/ai/djl/mxnet/integration/MxSymbolBlockTest.java index d855bb90ced..74fde4bf89a 100644 --- a/mxnet/mxnet-model-zoo/src/test/java/ai/djl/mxnet/integration/MxSymbolBlockTest.java +++ b/mxnet/mxnet-model-zoo/src/test/java/ai/djl/mxnet/integration/MxSymbolBlockTest.java @@ -25,6 +25,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Block; +import ai.djl.nn.Parameter; import ai.djl.nn.SequentialBlock; import ai.djl.nn.SymbolBlock; import ai.djl.nn.core.Linear; @@ -85,7 +86,7 @@ public void trainWithNewParam() throws IOException, ModelNotFoundException, MalformedModelException { TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) - .optInitializer(Initializer.ONES); + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); try (Model model = MxModelZoo.MLP.loadModel()) { model.getBlock().clear(); try (Trainer trainer = model.newTrainer(config)) { @@ -113,7 +114,7 @@ public void trainWithExistParam() throws IOException, ModelNotFoundException, MalformedModelException { TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) - .optInitializer(Initializer.ONES); + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); try (Model model = MxModelZoo.MLP.loadModel()) { try (Trainer trainer = model.newTrainer(config)) { NDManager manager = trainer.getManager(); @@ -140,7 +141,7 @@ public void trainWithCustomLayer() throws IOException, ModelNotFoundException, MalformedModelException { TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) - .optInitializer(Initializer.ONES); + .optInitializer(Initializer.ONES, Parameter.Type.WEIGHT); try (Model model = MxModelZoo.MLP.loadModel()) { NDManager manager = model.getNDManager(); SymbolBlock mlp = (SymbolBlock) model.getBlock(); @@ -149,7 +150,6 @@ public void trainWithCustomLayer() newMlp.add(mlp); Linear linear = Linear.builder().setUnits(10).build(); - linear.setInitializer(Initializer.ONES); newMlp.add(linear); model.setBlock(newMlp); diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java index 51ec90497ac..3d56947e34d 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java @@ -17,10 +17,13 @@ import ai.djl.MalformedModelException; import ai.djl.Model; import ai.djl.ndarray.types.DataType; +import ai.djl.nn.Parameter; import ai.djl.pytorch.jni.JniUtils; import ai.djl.training.Trainer; import ai.djl.training.TrainingConfig; import ai.djl.training.initializer.Initializer; +import ai.djl.util.Pair; +import ai.djl.util.PairList; import java.io.FileNotFoundException; import java.io.IOException; import java.nio.file.Files; @@ -28,6 +31,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.function.Predicate; import java.util.stream.Collectors; /** @@ -125,12 +129,16 @@ private Path findModelFile(String prefix) { /** {@inheritDoc} */ @Override public Trainer newTrainer(TrainingConfig trainingConfig) { - Initializer initializer = trainingConfig.getInitializer(); + PairList> initializer = trainingConfig.getInitializers(); if (block == null) { throw new IllegalStateException( "You must set a block for the model before creating a new trainer"); } - block.setInitializer(initializer); + for (Pair> pair : initializer) { + if (pair.getKey() != null && pair.getValue() != null) { + block.setInitializer(pair.getKey(), pair.getValue()); + } + } return new Trainer(this, trainingConfig); }