From 770cd3fbdc4e98f67d04da03eb44b861cc016104 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Wed, 17 Aug 2022 15:11:33 -0700 Subject: [PATCH] [api] Adds NDArray normalize() operator (#1924) Change-Id: I877cd924ecada008a53b852092dc224fa1cfa70b --- api/src/main/java/ai/djl/ndarray/NDArray.java | 78 +++++++++++++++++++ .../java/ai/djl/ndarray/NDArrayAdapter.java | 6 ++ .../java/ai/djl/mxnet/engine/MxNDArray.java | 6 ++ .../java/ai/djl/pytorch/engine/PtNDArray.java | 6 ++ .../java/ai/djl/pytorch/jni/JniUtils.java | 6 ++ .../ai/djl/pytorch/jni/PyTorchLibrary.java | 2 + ...ytorch_jni_PyTorchLibrary_nn_functional.cc | 18 +++++ .../ai/djl/tensorflow/engine/TfNDArray.java | 6 ++ .../tests/ndarray/NDArrayOtherOpTest.java | 16 ++++ 9 files changed, 144 insertions(+) diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index b4f202a4ec74..0b5494840cfd 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -2717,6 +2717,84 @@ default NDArray mean(int[] axes) { */ NDArray mean(int[] axes, boolean keepDims); + /** + * Performs Lp normalization of the array over specified dimension. + * + *

Examples + * + *

+     * jshell> NDArray array = manager.create(new float[] {1, 2, 3, 4, 5, 6}, new Shape(2, 3));
+     * jshell> array;
+     * ND: (2, 2) cpu() float32
+     * [[1., 2., 3.],
+     *  [4., 5., 6.],
+     * ]
+     * jshell> array.normalize();
+     * ND: (2, 3) cpu() float32
+     * [[0.2673, 0.5345, 0.8018],
+     *  [0.4558, 0.5698, 0.6838],
+     * ]
+     * 
+ * + * @return the normalized {@code NDArray} + */ + default NDArray normalize() { + return normalize(2, 1, 1e-12); + } + + /** + * Performs Lp normalization of the array over specified dimension. + * + *

Examples + * + *

+     * jshell> NDArray array = manager.create(new float[] {1, 2, 3, 4, 5, 6}, new Shape(2, 3));
+     * jshell> array;
+     * ND: (2, 2) cpu() float32
+     * [[1., 2., 3.],
+     *  [4., 5., 6.],
+     * ]
+     * jshell> array.normalize(2, 1);
+     * ND: (2, 3) cpu() float32
+     * [[0.2673, 0.5345, 0.8018],
+     *  [0.4558, 0.5698, 0.6838],
+     * ]
+     * 
+ * + * @param exponent the exponent value in the norm formulation + * @param dim the dimension to reduce + * @return the normalized {@code NDArray} + */ + default NDArray normalize(double exponent, long dim) { + return normalize(exponent, dim, 1e-12); + } + + /** + * Performs Lp normalization of the array over specified dimension. + * + *

Examples + * + *

+     * jshell> NDArray array = manager.create(new float[] {1, 2, 3, 4, 5, 6}, new Shape(2, 3));
+     * jshell> array;
+     * ND: (2, 2) cpu() float32
+     * [[1., 2., 3.],
+     *  [4., 5., 6.],
+     * ]
+     * jshell> array.normalize(2, 1, 1e-12);
+     * ND: (2, 3) cpu() float32
+     * [[0.2673, 0.5345, 0.8018],
+     *  [0.4558, 0.5698, 0.6838],
+     * ]
+     * 
+ * + * @param exponent the exponent value in the norm formulation + * @param dim the dimension to reduce + * @param eps the small value to avoid division by zero + * @return the normalized {@code NDArray} + */ + NDArray normalize(double exponent, long dim, double eps); + /** * Rotates an array by 90 degrees in the plane specified by axes. * diff --git a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java index 23b788cb39f8..519e1cafc9bf 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java +++ b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java @@ -768,6 +768,12 @@ public NDArray mean(int[] axes, boolean keepDims) { return getAlternativeArray().mean(axes, keepDims); } + /** {@inheritDoc} */ + @Override + public NDArray normalize(double p, long dim, double eps) { + return getAlternativeArray().normalize(p, dim, eps); + } + /** {@inheritDoc} */ @Override public NDArray rotate90(int times, int[] axes) { diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java index 2f75d565f0ee..0974876b2a92 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java @@ -1019,6 +1019,12 @@ public NDArray mean(int[] axes, boolean keepDims) { return manager.invoke("_npi_mean", this, params); } + /** {@inheritDoc} */ + @Override + public NDArray normalize(double p, long dim, double eps) { + throw new UnsupportedOperationException("Not implemented yet."); + } + /** {@inheritDoc} */ @Override public NDArray rotate90(int times, int[] axes) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index 01cb989911e2..562d1d38c6c9 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -932,6 +932,12 @@ public PtNDArray mean(int[] axes, boolean keepDims) { return JniUtils.mean(this, axes[0], keepDims); } + /** {@inheritDoc} */ + @Override + public PtNDArray normalize(double p, long dim, double eps) { + return JniUtils.normalize(this, p, dim, eps); + } + /** {@inheritDoc} */ @Override public PtNDArray rotate90(int times, int[] axes) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index 101eb448e124..69b73518eacc 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -1289,6 +1289,12 @@ public static PtNDArray layerNorm( eps)); } + public static PtNDArray normalize(PtNDArray ndArray, double p, long dim, double eps) { + return new PtNDArray( + ndArray.getManager(), + PyTorchLibrary.LIB.torchNNNormalize(ndArray.getHandle(), p, dim, eps)); + } + public static PtNDArray dropout(PtNDArray ndArray, double prob, boolean training) { return new PtNDArray( ndArray.getManager(), diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java index 10da7843d540..57ba028e44a0 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java @@ -400,6 +400,8 @@ native long torchNNConvNd( native long torchNNDropout(long inputHandle, double probability, boolean isTrain); + native long torchNNNormalize(long inputHandle, double p, long dim, double eps); + native long torchNNLayerNorm( long inputHandle, long[] normalizedShape, diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_nn_functional.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_nn_functional.cc index 503342b0aa3f..3f1bcfd585e3 100644 --- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_nn_functional.cc +++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_nn_functional.cc @@ -168,6 +168,24 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNLayerNorm( API_END_RETURN() } +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNNormalize( + JNIEnv* env, jobject jthis, jlong jinput, jdouble jp, jlong jdim, jdouble jeps) { + API_BEGIN() +#if defined(__ANDROID__) + env->ThrowNew(ENGINE_EXCEPTION_CLASS, "Normalize is not supported on Android."); + return 0; +#else + const auto* tensor_ptr = reinterpret_cast(jinput); + auto options = torch::nn::functional::NormalizeFuncOptions(); + options.p(jp); + options.dim(jdim); + options.eps(jeps); + const auto* result_ptr = new torch::Tensor(torch::nn::functional::normalize(*tensor_ptr, options)); + return reinterpret_cast(result_ptr); +#endif + API_END_RETURN() +} + JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNDropout( JNIEnv* env, jobject jthis, jlong jinput, jdouble probability, jboolean jtraining) { API_BEGIN() diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java index 9afcf7575830..55db30e84034 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java @@ -1034,6 +1034,12 @@ public NDArray mean(int[] axes, boolean keepDims) { } } + /** {@inheritDoc} */ + @Override + public NDArray normalize(double p, long dim, double eps) { + throw new UnsupportedOperationException("Not implemented yet."); + } + /** {@inheritDoc} */ @Override public NDArray rotate90(int times, int[] axes) { diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java index 35606c8482a5..6c7f174a46fe 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java @@ -12,6 +12,7 @@ */ package ai.djl.integration.tests.ndarray; +import ai.djl.Device; import ai.djl.engine.Engine; import ai.djl.engine.EngineException; import ai.djl.ndarray.LazyNDArray; @@ -946,6 +947,21 @@ public void testOneHot() { } } + @Test + public void testNormalize() { + try (NDManager manager = NDManager.newBaseManager(Device.cpu(), "PyTorch")) { + float[][] buf = { + {0.2673f, 0.5345f, 0.8018f}, + {0.4558f, 0.5698f, 0.6838f} + }; + float[] data = {1, 2, 3, 4, 5, 6}; + NDArray x = manager.create(data, new Shape(2, 3)); + NDArray expected = manager.create(buf); + NDArray ret = x.normalize(); + Assertions.assertAlmostEquals(ret, expected); + } + } + @Test public void testStopGradient() { try (NDManager manager = NDManager.newBaseManager()) {