diff --git a/api/src/main/java/ai/djl/training/loss/ElasticNetWeightDecay.java b/api/src/main/java/ai/djl/training/loss/ElasticNetWeightDecay.java new file mode 100644 index 00000000000..4a20f16c132 --- /dev/null +++ b/api/src/main/java/ai/djl/training/loss/ElasticNetWeightDecay.java @@ -0,0 +1,101 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.training.loss; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; + +/** + * {@code ElasticWeightDecay} calculates L1+L2 penalty of a set of parameters. Used for + * regularization. + * + *
L loss is defined as \(L = \lambda_1 \sum_i \vert W_i\vert + \lambda_2 \sum_i {W_i}^2\). + */ +public class ElasticNetWeightDecay extends Loss { + + private float lambda1; + private float lambda2; + private NDList parameters; + + /** + * Calculates Elastic Net weight decay for regularization. + * + * @param parameters holds the model weights that will be penalized + */ + public ElasticNetWeightDecay(NDList parameters) { + this("ElasticNetWeightDecay", parameters); + } + + /** + * Calculates Elastic Net weight decay for regularization. + * + * @param name the name of the penalty + * @param parameters holds the model weights that will be penalized + */ + public ElasticNetWeightDecay(String name, NDList parameters) { + this(name, parameters, 1); + } + + /** + * Calculates Elastic Net weight decay for regularization. + * + * @param name the name of the penalty + * @param parameters holds the model weights that will be penalized + * @param lambda the weight to apply to the penalty value, default 1 (both L1 and L2) + */ + public ElasticNetWeightDecay(String name, NDList parameters, float lambda) { + super(name); + this.lambda1 = lambda; + this.lambda2 = lambda; + this.parameters = parameters; + } + + /** + * Calculates Elastic Net weight decay for regularization. + * + * @param name the name of the penalty + * @param parameters holds the model weights that will be penalized + * @param lambda1 the weight to apply to the L1 penalty value, default 1 + * @param lambda2 the weight to apply to the L2 penalty value, default 1 + */ + public ElasticNetWeightDecay(String name, NDList parameters, float lambda1, float lambda2) { + super(name); + this.lambda1 = lambda1; + this.lambda2 = lambda2; + this.parameters = parameters; + } + + private NDArray l1(NDArray w) { + return ((w.abs()).sum()); + } + + private NDArray l2(NDArray w) { + return ((w.square()).sum()); + } + + /** {@inheritDoc} */ + @Override + public NDArray evaluate(NDList label, NDList prediction) { + + NDManager manager = parameters.getManager(); + NDArray sum1 = manager.create(0.0f); + NDArray sum2 = manager.create(0.0f); + for (NDArray wi : parameters) { + sum1.addi(l1(wi)); + sum2.addi(l2(wi)); + } + return sum1.muli(lambda1).addi(sum2.muli(lambda2)); + } +} diff --git a/api/src/main/java/ai/djl/training/loss/L1WeightDecay.java b/api/src/main/java/ai/djl/training/loss/L1WeightDecay.java new file mode 100644 index 00000000000..1851303b3b9 --- /dev/null +++ b/api/src/main/java/ai/djl/training/loss/L1WeightDecay.java @@ -0,0 +1,77 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.training.loss; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; + +/** + * {@code L1WeightDecay} calculates L1 penalty of a set of parameters. Used for regularization. + * + *
L1 loss is defined as \(L1 = \lambda \sum_i \vert W_i\vert\). + */ +public class L1WeightDecay extends Loss { + + private float lambda; + private NDList parameters; + + /** + * Calculates L1 weight decay for regularization. + * + * @param parameters holds the model weights that will be penalized + */ + public L1WeightDecay(NDList parameters) { + this("L1WeightDecay", parameters); + } + + /** + * Calculates L1 weight decay for regularization. + * + * @param name the name of the penalty + * @param parameters holds the model weights that will be penalized + */ + public L1WeightDecay(String name, NDList parameters) { + this(name, parameters, 1); + } + + /** + * Calculates L1 weight decay for regularization. + * + * @param name the name of the penalty + * @param parameters holds the model weights that will be penalized + * @param lambda the weight to apply to the penalty value, default 1 + */ + public L1WeightDecay(String name, NDList parameters, float lambda) { + super(name); + this.lambda = lambda; + this.parameters = parameters; + } + + private NDArray l1(NDArray w) { + return ((w.abs()).sum()); + } + + /** {@inheritDoc} */ + @Override + public NDArray evaluate(NDList label, NDList prediction) { + + NDManager manager = parameters.getManager(); + NDArray sum = manager.create(0.0f); + for (NDArray wi : parameters) { + sum.addi(l1(wi)); + } + return sum.muli(lambda); + } +} diff --git a/api/src/main/java/ai/djl/training/loss/L2WeightDecay.java b/api/src/main/java/ai/djl/training/loss/L2WeightDecay.java new file mode 100644 index 00000000000..496d70b32bb --- /dev/null +++ b/api/src/main/java/ai/djl/training/loss/L2WeightDecay.java @@ -0,0 +1,77 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.training.loss; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; + +/** + * {@code L2WeightDecay} calculates L2 penalty of a set of parameters. Used for regularization. + * + *
L2 loss is defined by \(L2 = \lambda \sum_i {W_i}^2\). + */ +public class L2WeightDecay extends Loss { + + private float lambda; + private NDList parameters; + + /** + * Calculates L2 weight decay for regularization. + * + * @param parameters holds the model weights that will be penalized + */ + public L2WeightDecay(NDList parameters) { + this("L2WeightDecay", parameters); + } + + /** + * Calculates L2 weight decay for regularization. + * + * @param name the name of the penalty + * @param parameters holds the model weights that will be penalized + */ + public L2WeightDecay(String name, NDList parameters) { + this(name, parameters, 1); + } + + /** + * Calculates L2 weight decay for regularization. + * + * @param name the name of the penalty + * @param parameters holds the model weights that will be penalized + * @param lambda the weight to apply to the penalty value, default 1 + */ + public L2WeightDecay(String name, NDList parameters, float lambda) { + super(name); + this.lambda = lambda; + this.parameters = parameters; + } + + private NDArray l2(NDArray w) { + return ((w.square()).sum()); + } + + /** {@inheritDoc} */ + @Override + public NDArray evaluate(NDList label, NDList prediction) { + + NDManager manager = parameters.getManager(); + NDArray sum = manager.create(0.0f); + for (NDArray wi : parameters) { + sum.addi(l2(wi)); + } + return sum.muli(lambda); + } +} diff --git a/api/src/main/java/ai/djl/training/loss/Loss.java b/api/src/main/java/ai/djl/training/loss/Loss.java index 6185e98ede4..e0079f0a5d2 100644 --- a/api/src/main/java/ai/djl/training/loss/Loss.java +++ b/api/src/main/java/ai/djl/training/loss/Loss.java @@ -239,6 +239,120 @@ public static HingeLoss hingeLoss(String name, int margin, float weight) { return new HingeLoss(name, margin, weight); } + /** + * Returns a new instance of {@link L1WeightDecay} with default weight and name. + * + * @param parameters holds the model weights that will be penalized + * @return a new instance of {@link L1WeightDecay} + */ + public static L1WeightDecay l1WeightedDecay(NDList parameters) { + return new L1WeightDecay(parameters); + } + + /** + * Returns a new instance of {@link L1WeightDecay} with default weight. + * + * @param name the name of the weight decay + * @param parameters holds the model weights that will be penalized + * @return a new instance of {@link L1WeightDecay} + */ + public static L1WeightDecay l1WeightedDecay(String name, NDList parameters) { + return new L1WeightDecay(name, parameters); + } + + /** + * Returns a new instance of {@link L1WeightDecay}. + * + * @param name the name of the weight decay + * @param weight the weight to apply on weight decay value, default 1 + * @param parameters holds the model weights that will be penalized + * @return a new instance of {@link L1WeightDecay} + */ + public static L1WeightDecay l1WeightedDecay(String name, float weight, NDList parameters) { + return new L1WeightDecay(name, parameters, weight); + } + + /** + * Returns a new instance of {@link L2WeightDecay} with default weight and name. + * + * @param parameters holds the model weights that will be penalized + * @return a new instance of {@link L2WeightDecay} + */ + public static L2WeightDecay l2WeightedDecay(NDList parameters) { + return new L2WeightDecay(parameters); + } + + /** + * Returns a new instance of {@link L2WeightDecay} with default weight. + * + * @param name the name of the weight decay + * @param parameters holds the model weights that will be penalized + * @return a new instance of {@link L2WeightDecay} + */ + public static L2WeightDecay l2WeightedDecay(String name, NDList parameters) { + return new L2WeightDecay(name, parameters); + } + + /** + * Returns a new instance of {@link L2WeightDecay}. + * + * @param name the name of the weight decay + * @param weight the weight to apply on weight decay value, default 1 + * @param parameters holds the model weights that will be penalized + * @return a new instance of {@link L2WeightDecay} + */ + public static L2WeightDecay l2WeightedDecay(String name, float weight, NDList parameters) { + return new L2WeightDecay(name, parameters, weight); + } + + /** + * Returns a new instance of {@link ElasticNetWeightDecay} with default weight and name. + * + * @param parameters holds the model weights that will be penalized + * @return a new instance of {@link ElasticNetWeightDecay} + */ + public static ElasticNetWeightDecay elasticNetWeightedDecay(NDList parameters) { + return new ElasticNetWeightDecay(parameters); + } + + /** + * Returns a new instance of {@link ElasticNetWeightDecay} with default weight. + * + * @param name the name of the weight decay + * @param parameters holds the model weights that will be penalized + * @return a new instance of {@link ElasticNetWeightDecay} + */ + public static ElasticNetWeightDecay elasticNetWeightedDecay(String name, NDList parameters) { + return new ElasticNetWeightDecay(name, parameters); + } + + /** + * Returns a new instance of {@link ElasticNetWeightDecay}. + * + * @param name the name of the weight decay + * @param weight the weight to apply on weight decay values, default 1 + * @param parameters holds the model weights that will be penalized + * @return a new instance of {@link ElasticNetWeightDecay} + */ + public static ElasticNetWeightDecay elasticNetWeightedDecay( + String name, float weight, NDList parameters) { + return new ElasticNetWeightDecay(name, parameters, weight); + } + + /** + * Returns a new instance of {@link ElasticNetWeightDecay}. + * + * @param name the name of the weight decay + * @param weight1 the weight to apply on weight decay L1 value, default 1 + * @param weight2 the weight to apply on weight decay L2 value, default 1 + * @param parameters holds the model weights that will be penalized + * @return a new instance of {@link ElasticNetWeightDecay} + */ + public static ElasticNetWeightDecay elasticNetWeightedDecay( + String name, float weight1, float weight2, NDList parameters) { + return new ElasticNetWeightDecay(name, parameters, weight1, weight2); + } + /** {@inheritDoc} */ @Override public void addAccumulator(String key) { diff --git a/integration/src/main/java/ai/djl/integration/tests/training/WeightDecayTest.java b/integration/src/main/java/ai/djl/integration/tests/training/WeightDecayTest.java new file mode 100644 index 00000000000..1be78244e0c --- /dev/null +++ b/integration/src/main/java/ai/djl/integration/tests/training/WeightDecayTest.java @@ -0,0 +1,75 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.integration.tests.training; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.training.loss.ElasticNetWeightDecay; +import ai.djl.training.loss.L1WeightDecay; +import ai.djl.training.loss.L2WeightDecay; +import ai.djl.training.loss.Loss; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class WeightDecayTest { + + @Test + public void l1DecayTest() { + try (NDManager manager = NDManager.newBaseManager()) { + NDArray parameters1 = manager.create(new float[] {-1, -2, 3, 4, 5}); // 15 + NDArray parameters2 = manager.create(new float[] {-1, -1, -1, -1, -1}); // 5 + // Not used + NDArray pred = manager.create(new float[] {}); + NDArray label = manager.create(new float[] {}); + // r = 2*(15 + 5) = 40 + L1WeightDecay decay = + Loss.l1WeightedDecay("", 2.0f, new NDList(parameters1, parameters2)); + Assert.assertEquals( + decay.evaluate(new NDList(label), new NDList(pred)).getFloat(), 40.0f); + } + } + + @Test + public void l2DecayTest() { + try (NDManager manager = NDManager.newBaseManager()) { + NDArray parameters1 = manager.create(new float[] {-1, -2, 3, 4, 5}); // 55 + NDArray parameters2 = manager.create(new float[] {-1, -1, -1, -1, -1}); // 5 + // Not used + NDArray pred = manager.create(new float[] {}); + NDArray label = manager.create(new float[] {}); + // r = 2*(55 + 5) = 120 + L2WeightDecay decay = + Loss.l2WeightedDecay("", 2.0f, new NDList(parameters1, parameters2)); + Assert.assertEquals( + decay.evaluate(new NDList(label), new NDList(pred)).getFloat(), 120.0f); + } + } + + @Test + public void elasticNetDecayTest() { + try (NDManager manager = NDManager.newBaseManager()) { + NDArray parameters1 = manager.create(new float[] {-1, -2, 3, 4, 5}); + NDArray parameters2 = manager.create(new float[] {-1, -1, -1, -1, -1}); + // Not used + NDArray pred = manager.create(new float[] {}); + NDArray label = manager.create(new float[] {}); + // r = L1 + L2 = 2*20 + 1*60 = 100 + ElasticNetWeightDecay decay = + Loss.elasticNetWeightedDecay( + "", 2.0f, 1.0f, new NDList(parameters1, parameters2)); + Assert.assertEquals( + decay.evaluate(new NDList(label), new NDList(pred)).getFloat(), 100.0f); + } + } +}