Skip to content

Commit

Permalink
[api] Adds NDArray normalize() operator (deepjavalibrary#1924)
Browse files Browse the repository at this point in the history
Change-Id: I877cd924ecada008a53b852092dc224fa1cfa70b
  • Loading branch information
frankfliu authored and patins1 committed Aug 26, 2022
1 parent dac3485 commit 770cd3f
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 0 deletions.
78 changes: 78 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -2717,6 +2717,84 @@ default NDArray mean(int[] axes) {
*/
NDArray mean(int[] axes, boolean keepDims);

/**
* Performs Lp normalization of the array over specified dimension.
*
* <p>Examples
*
* <pre>
* jshell&gt; NDArray array = manager.create(new float[] {1, 2, 3, 4, 5, 6}, new Shape(2, 3));
* jshell&gt; array;
* ND: (2, 2) cpu() float32
* [[1., 2., 3.],
* [4., 5., 6.],
* ]
* jshell&gt; array.normalize();
* ND: (2, 3) cpu() float32
* [[0.2673, 0.5345, 0.8018],
* [0.4558, 0.5698, 0.6838],
* ]
* </pre>
*
* @return the normalized {@code NDArray}
*/
default NDArray normalize() {
return normalize(2, 1, 1e-12);
}

/**
* Performs Lp normalization of the array over specified dimension.
*
* <p>Examples
*
* <pre>
* jshell&gt; NDArray array = manager.create(new float[] {1, 2, 3, 4, 5, 6}, new Shape(2, 3));
* jshell&gt; array;
* ND: (2, 2) cpu() float32
* [[1., 2., 3.],
* [4., 5., 6.],
* ]
* jshell&gt; array.normalize(2, 1);
* ND: (2, 3) cpu() float32
* [[0.2673, 0.5345, 0.8018],
* [0.4558, 0.5698, 0.6838],
* ]
* </pre>
*
* @param exponent the exponent value in the norm formulation
* @param dim the dimension to reduce
* @return the normalized {@code NDArray}
*/
default NDArray normalize(double exponent, long dim) {
return normalize(exponent, dim, 1e-12);
}

/**
* Performs Lp normalization of the array over specified dimension.
*
* <p>Examples
*
* <pre>
* jshell&gt; NDArray array = manager.create(new float[] {1, 2, 3, 4, 5, 6}, new Shape(2, 3));
* jshell&gt; array;
* ND: (2, 2) cpu() float32
* [[1., 2., 3.],
* [4., 5., 6.],
* ]
* jshell&gt; array.normalize(2, 1, 1e-12);
* ND: (2, 3) cpu() float32
* [[0.2673, 0.5345, 0.8018],
* [0.4558, 0.5698, 0.6838],
* ]
* </pre>
*
* @param exponent the exponent value in the norm formulation
* @param dim the dimension to reduce
* @param eps the small value to avoid division by zero
* @return the normalized {@code NDArray}
*/
NDArray normalize(double exponent, long dim, double eps);

/**
* Rotates an array by 90 degrees in the plane specified by axes.
*
Expand Down
6 changes: 6 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,12 @@ public NDArray mean(int[] axes, boolean keepDims) {
return getAlternativeArray().mean(axes, keepDims);
}

/** {@inheritDoc} */
@Override
public NDArray normalize(double p, long dim, double eps) {
return getAlternativeArray().normalize(p, dim, eps);
}

/** {@inheritDoc} */
@Override
public NDArray rotate90(int times, int[] axes) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,12 @@ public NDArray mean(int[] axes, boolean keepDims) {
return manager.invoke("_npi_mean", this, params);
}

/** {@inheritDoc} */
@Override
public NDArray normalize(double p, long dim, double eps) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray rotate90(int times, int[] axes) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,12 @@ public PtNDArray mean(int[] axes, boolean keepDims) {
return JniUtils.mean(this, axes[0], keepDims);
}

/** {@inheritDoc} */
@Override
public PtNDArray normalize(double p, long dim, double eps) {
return JniUtils.normalize(this, p, dim, eps);
}

/** {@inheritDoc} */
@Override
public PtNDArray rotate90(int times, int[] axes) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,12 @@ public static PtNDArray layerNorm(
eps));
}

public static PtNDArray normalize(PtNDArray ndArray, double p, long dim, double eps) {
return new PtNDArray(
ndArray.getManager(),
PyTorchLibrary.LIB.torchNNNormalize(ndArray.getHandle(), p, dim, eps));
}

public static PtNDArray dropout(PtNDArray ndArray, double prob, boolean training) {
return new PtNDArray(
ndArray.getManager(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,8 @@ native long torchNNConvNd(

native long torchNNDropout(long inputHandle, double probability, boolean isTrain);

native long torchNNNormalize(long inputHandle, double p, long dim, double eps);

native long torchNNLayerNorm(
long inputHandle,
long[] normalizedShape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,24 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNLayerNorm(
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNNormalize(
JNIEnv* env, jobject jthis, jlong jinput, jdouble jp, jlong jdim, jdouble jeps) {
API_BEGIN()
#if defined(__ANDROID__)
env->ThrowNew(ENGINE_EXCEPTION_CLASS, "Normalize is not supported on Android.");
return 0;
#else
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jinput);
auto options = torch::nn::functional::NormalizeFuncOptions();
options.p(jp);
options.dim(jdim);
options.eps(jeps);
const auto* result_ptr = new torch::Tensor(torch::nn::functional::normalize(*tensor_ptr, options));
return reinterpret_cast<uintptr_t>(result_ptr);
#endif
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNDropout(
JNIEnv* env, jobject jthis, jlong jinput, jdouble probability, jboolean jtraining) {
API_BEGIN()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,12 @@ public NDArray mean(int[] axes, boolean keepDims) {
}
}

/** {@inheritDoc} */
@Override
public NDArray normalize(double p, long dim, double eps) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray rotate90(int times, int[] axes) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
*/
package ai.djl.integration.tests.ndarray;

import ai.djl.Device;
import ai.djl.engine.Engine;
import ai.djl.engine.EngineException;
import ai.djl.ndarray.LazyNDArray;
Expand Down Expand Up @@ -946,6 +947,21 @@ public void testOneHot() {
}
}

@Test
public void testNormalize() {
try (NDManager manager = NDManager.newBaseManager(Device.cpu(), "PyTorch")) {
float[][] buf = {
{0.2673f, 0.5345f, 0.8018f},
{0.4558f, 0.5698f, 0.6838f}
};
float[] data = {1, 2, 3, 4, 5, 6};
NDArray x = manager.create(data, new Shape(2, 3));
NDArray expected = manager.create(buf);
NDArray ret = x.normalize();
Assertions.assertAlmostEquals(ret, expected);
}
}

@Test
public void testStopGradient() {
try (NDManager manager = NDManager.newBaseManager()) {
Expand Down

0 comments on commit 770cd3f

Please sign in to comment.