Skip to content

Commit

Permalink
Ensure GradientCollector can clear gradients (deepjavalibrary#2101)
Browse files Browse the repository at this point in the history
It also adds a new function GradientCollector.zeroGradients().

As part of this, it adds a new feature to get the arrays managed by an NDManager
or an NDResource. This is so all arrays can be found and have their gradients
cleared when the PtGradientCollector is started.

It also means that the System NDManagers are modified to now begin tracking
resources. As most things created under them are BaseNDManagers, it shouldn't be
a big issue. Then, operations that don't make sense under the SystemNDManagers
such as tempAttach or close have been changed to throw an exception rather than
work silently.
  • Loading branch information
zachgk authored Oct 25, 2022
1 parent 0835d0d commit 0972264
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 13 deletions.
43 changes: 32 additions & 11 deletions api/src/main/java/ai/djl/ndarray/BaseNDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@
import java.nio.ShortBuffer;
import java.nio.charset.Charset;
import java.nio.file.Path;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/** {@code BaseNDManager} is the default implementation of {@link NDManager}. */
public abstract class BaseNDManager implements NDManager {
Expand Down Expand Up @@ -298,6 +301,30 @@ public Device getDevice() {
return device;
}

/** {@inheritDoc} */
@Override
public List<NDArray> getManagedArrays() {
return Stream.concat(
// Main resources
resources.values().stream()
.flatMap(
r -> {
if (r instanceof NDResource) {
return ((NDResource) r)
.getResourceNDArrays().stream();
} else if (r instanceof NDManager) {
return ((NDManager) r).getManagedArrays().stream();
} else {
return Stream.empty();
}
}),

// Temp resouces
tempResources.values().stream()
.flatMap(tr -> tr.resource.getResourceNDArrays().stream()))
.collect(Collectors.toList());
}

/** {@inheritDoc} */
@Override
public String toString() {
Expand All @@ -315,9 +342,6 @@ public String toString() {
/** {@inheritDoc} */
@Override
public synchronized void attachInternal(String resourceId, AutoCloseable resource) {
if (this instanceof SystemNDManager) {
return;
}
if (capped.get()) {
throw new IllegalStateException("NDManager is capped for addition of resources.");
}
Expand All @@ -327,9 +351,6 @@ public synchronized void attachInternal(String resourceId, AutoCloseable resourc
/** {@inheritDoc} */
@Override
public synchronized void attachUncappedInternal(String resourceId, AutoCloseable resource) {
if (this instanceof SystemNDManager) {
return;
}
if (closed.get()) {
throw new IllegalStateException("NDManager has been closed already.");
}
Expand All @@ -356,7 +377,8 @@ public synchronized void attachUncappedInternal(String resourceId, AutoCloseable
public void tempAttachInternal(
NDManager originalManager, String resourceId, NDResource resource) {
if (this instanceof SystemNDManager) {
return;
throw new IllegalStateException(
"System manager cannot be temp attached because it can't be closed..");
}
if (closed.get()) {
throw new IllegalStateException("NDManager has been closed already.");
Expand All @@ -367,9 +389,6 @@ public void tempAttachInternal(
/** {@inheritDoc} */
@Override
public synchronized void detachInternal(String resourceId) {
if (this instanceof SystemNDManager) {
return;
}
if (closed.get()) {
// This may happen in the middle of BaseNDManager.close()
return;
Expand Down Expand Up @@ -400,7 +419,9 @@ public NDList invoke(String operation, NDList src, PairList<String, ?> params) {
@Override
public void close() {
if (this instanceof SystemNDManager) {
return;
throw new IllegalStateException(
"The SystemNDManager can not be closed. It is global and lives for the duration"
+ " of the process");
}
if (!closed.getAndSet(true)) {
for (AutoCloseable closeable : resources.values()) {
Expand Down
8 changes: 8 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.function.Function;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
Expand Down Expand Up @@ -4567,6 +4569,12 @@ default NDArray countNonzero(int axis) {
*/
NDArray erfinv();

/** {@inheritDoc} */
@Override
default List<NDArray> getResourceNDArrays() {
return Collections.singletonList(this);
}

/**
* Returns an internal representative of Native {@code NDArray}.
*
Expand Down
7 changes: 7 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDList.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;
Expand Down Expand Up @@ -269,6 +270,12 @@ public NDManager getManager() {
return head().getManager();
}

/** {@inheritDoc} */
@Override
public List<NDArray> getResourceNDArrays() {
return this;
}

/** {@inheritDoc} */
@Override
public void attach(NDManager manager) {
Expand Down
8 changes: 8 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.List;

/**
* NDArray managers are used to create <I>NDArrays</I> (n-dimensional array on native engine).
Expand Down Expand Up @@ -1497,6 +1498,13 @@ default NDArray truncatedNormal(
*/
Device getDevice();

/**
* Returns all {@link NDArray}s managed by this manager (including recursively).
*
* @return all {@link NDArray}s managed by this manager (including recursively)
*/
List<NDArray> getManagedArrays();

/**
* Attaches a resource to this {@code NDManager}.
*
Expand Down
9 changes: 9 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDResource.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
*/
package ai.djl.ndarray;

import java.util.List;

/** An object which is managed by an {@link NDManager} and tracks the manager it is attached to. */
public interface NDResource extends AutoCloseable {

Expand All @@ -22,6 +24,13 @@ public interface NDResource extends AutoCloseable {
*/
NDManager getManager();

/**
* Returns the {@link NDArray} or {@link NDArray}s contained within this resource.
*
* @return the {@link NDArray} or {@link NDArray}s contained within this resource
*/
List<NDArray> getResourceNDArrays();

/**
* Attaches this {@link NDResource} to the specified {@link NDManager}.
*
Expand Down
3 changes: 3 additions & 0 deletions api/src/main/java/ai/djl/training/GradientCollector.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ public interface GradientCollector extends AutoCloseable {
*/
void backward(NDArray target);

/** Sets all the gradients within the engine to zero. */
void zeroGradients();

/** {@inheritDoc} */
@Override
void close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.nio.file.Path;
import java.util.Collections;
import java.util.List;

/** An {@link NDManager} that does nothing, for use in extensions and hybrid engines. */
public final class PassthroughNDManager implements NDManager {
Expand Down Expand Up @@ -241,6 +243,12 @@ public Device getDevice() {
return Device.cpu();
}

/** {@inheritDoc} */
@Override
public List<NDArray> getManagedArrays() {
return Collections.emptyList();
}

/** {@inheritDoc} */
@Override
public void attachInternal(String resourceId, AutoCloseable resource) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import ai.djl.training.GradientCollector;

/** {@code MxGradientCollector} is the MXNet implementation of {@link GradientCollector}. */
public class MxGradientCollector implements GradientCollector {
public final class MxGradientCollector implements GradientCollector {

/**
* Constructs an {@code MxGradientCollector} and enables training data collection for
Expand Down Expand Up @@ -116,4 +116,15 @@ public void backward(NDArray array) {
private void backward(NDArray array, boolean retainGraph) {
JnaUtils.autogradBackward(new NDList(array), retainGraph ? 1 : 0);
}

/** {@inheritDoc} */
@Override
public void zeroGradients() {
NDManager systemManager = MxNDManager.getSystemManager();
for (NDArray array : systemManager.getManagedArrays()) {
if (array.hasGradient()) {
array.getGradient().subi(array.getGradient());
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,20 @@
package ai.djl.pytorch.engine;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.pytorch.jni.JniUtils;
import ai.djl.training.GradientCollector;

/** {@code PtGradientCollector} is the PyTorch implementation of {@link GradientCollector}. */
public class PtGradientCollector implements GradientCollector {
public final class PtGradientCollector implements GradientCollector {

private boolean gradModel;

/** Constructs a new {@code PtGradientCollector} instance. */
public PtGradientCollector() {
gradModel = JniUtils.isGradMode();
JniUtils.setGradMode(true);
zeroGradients();
}

/** {@inheritDoc} */
Expand Down Expand Up @@ -54,6 +56,17 @@ private void backward(NDArray target, NDArray grad, boolean keepGraph, boolean c
JniUtils.backward((PtNDArray) target, (PtNDArray) grad, keepGraph, createGraph);
}

/** {@inheritDoc} */
@Override
public void zeroGradients() {
NDManager systemManager = PtNDManager.getSystemManager();
for (NDArray array : systemManager.getManagedArrays()) {
if (array.hasGradient()) {
array.getGradient().subi(array.getGradient());
}
}
}

/** {@inheritDoc} */
@Override
public void close() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import ai.djl.Model;
import ai.djl.basicmodelzoo.basic.Mlp;
import ai.djl.engine.Engine;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
Expand Down Expand Up @@ -80,6 +81,46 @@ public void testAutograd() {
}
}

@Test
public void testZeroGradients() {
try (NDManager manager = NDManager.newBaseManager()) {
NDArray a = manager.create(0.0f);
a.setRequiresGradient(true);

try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {
NDArray b = a.mul(2);

// Gradients are initially zero
Assert.assertEquals(a.getGradient().getFloat(), 0.0f);

// Gradients are updated by backwards
gc.backward(b);
Assert.assertEquals(a.getGradient().getFloat(), 2.0f);

// Gradients are cleared by zeroGradients
gc.zeroGradients();
Assert.assertEquals(a.getGradient().getFloat(), 0.0f);
}
}
}

/** Tests that the gradients do not accumulate when closing the gradient collector. */
@Test
public void testClearGradients() {
try (NDManager manager = NDManager.newBaseManager()) {
NDArray a = manager.create(0.0f);
a.setRequiresGradient(true);

for (int i = 0; i < 3; i++) {
try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {
NDArray b = a.mul(2);
gc.backward(b);
}
Assert.assertEquals(a.getGradient().getFloat(), 2.0f);
}
}
}

@Test
public void testFreezeParameters() {
try (Model model = Model.newInstance("model")) {
Expand Down

0 comments on commit 0972264

Please sign in to comment.