Skip to content

Commit

Permalink
[api] Adds IdentityBlockFactory for demo/test purpose (#1854)
Browse files Browse the repository at this point in the history
Change-Id: I51eaedafec628d326af66f61ffc1f7c65a7c6b52
  • Loading branch information
frankfliu authored Aug 2, 2022
1 parent f1ebbe8 commit 564d710
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 3 deletions.
30 changes: 30 additions & 0 deletions api/src/main/java/ai/djl/nn/IdentityBlockFactory.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.nn;

import ai.djl.Model;

import java.nio.file.Path;
import java.util.Map;

/** A {@link BlockFactory} class that creates IdentityBlock. */
public class IdentityBlockFactory implements BlockFactory {

private static final long serialVersionUID = 1L;

/** {@inheritDoc} */
@Override
public Block newBlock(Model model, Path modelPath, Map<String, ?> arguments) {
return Blocks.identityBlock();
}
}
34 changes: 34 additions & 0 deletions api/src/test/java/ai/djl/nn/BlockFactoryTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.nn;

import ai.djl.Model;

import org.testng.Assert;
import org.testng.annotations.Test;

import java.nio.file.Path;
import java.nio.file.Paths;

public class BlockFactoryTest {

@Test
public void testIdentityBlockFactory() {
IdentityBlockFactory factory = new IdentityBlockFactory();
try (Model model = Model.newInstance("identity")) {
Path path = Paths.get("build");
Block block = factory.newBlock(model, path, null);
Assert.assertEquals(((LambdaBlock) block).getName(), "identity");
}
}
}
15 changes: 15 additions & 0 deletions api/src/test/java/ai/djl/nn/package-info.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/

/** Contains tests for {@link ai.djl.nn}. */
package ai.djl.nn;
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import ai.djl.translate.ArgumentsUtil;

import java.nio.file.Path;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

Expand All @@ -34,9 +35,16 @@ public Block newBlock(Model model, Path modelPath, Map<String, ?> arguments) {
int height = ArgumentsUtil.intValue(arguments, "height", 28);
int output = ArgumentsUtil.intValue(arguments, "output", 10);
int input = width * height;
int[] hidden =
((List<Double>) arguments.get("hidden"))
.stream().mapToInt(Double::intValue).toArray();
Object hiddenValue = arguments.get("hidden");
int[] hidden;
if (hiddenValue == null) {
hidden = new int[] {128, 64};
} else if (hiddenValue instanceof List) {
hidden = ((List<Double>) hiddenValue).stream().mapToInt(Double::intValue).toArray();
} else {
String[] v = ((String) hiddenValue).split(",");
hidden = Arrays.stream(v).mapToInt(Integer::parseInt).toArray();
}

return new Mlp(input, output, hidden);
}
Expand Down

0 comments on commit 564d710

Please sign in to comment.