Skip to content

Commit

Permalink
Remove erroneous random forest application
Browse files Browse the repository at this point in the history
The application was changed to the more accurate softmax_regression (matching
the terminology from the D2L book).

Change-Id: I1f69f005bbe38b125f2709c2988d06c14eebb765
  • Loading branch information
zachgk committed Mar 9, 2021
1 parent a6a2232 commit f8d3f2a
Show file tree
Hide file tree
Showing 8 changed files with 10 additions and 18 deletions.
11 changes: 1 addition & 10 deletions api/src/main/java/ai/djl/Application.java
Original file line number Diff line number Diff line change
Expand Up @@ -268,15 +268,6 @@ public interface Tabular {
* @see <a href="https://d2l.djl.ai/chapter_linear-networks/softmax-regression.html">The D2L
* chapter introducing this application</a>
*/
Application SOFTMAX_REGRESSION = new Application("tabular/linear_regression");

/**
* This is erroneous because random forest is a technique (not deep learning), not an
* application.
*
* <p>The actual application is likely to be in {@link Tabular}, especially {@link
* #SOFTMAX_REGRESSION}.
*/
Application RANDOM_FOREST = new Application("tabular/random_forest");
Application SOFTMAX_REGRESSION = new Application("tabular/softmax_regression");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@
"metadata": {},
"outputs": [],
"source": [
"String modelUrl = \"https://mlrepo.djl.ai/model/tabular/random_forest/ai/djl/onnxruntime/iris_flowers/0.0.1/iris_flowers.zip\";\n",
"String modelUrl = \"https://mlrepo.djl.ai/model/tabular/softmax_regression/ai/djl/onnxruntime/iris_flowers/0.0.1/iris_flowers.zip\";\n",
"Criteria<IrisFlower, Classifications> criteria = Criteria.builder()\n",
" .setTypes(IrisFlower.class, Classifications.class)\n",
" .optModelUrls(modelUrl)\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
package ai.djl.onnxruntime.zoo;

import ai.djl.onnxruntime.engine.OrtEngine;
import ai.djl.onnxruntime.zoo.tabular.randomforest.IrisClassificationModelLoader;
import ai.djl.onnxruntime.zoo.tabular.softmax_regression.IrisClassificationModelLoader;
import ai.djl.repository.Repository;
import ai.djl.repository.zoo.ModelZoo;
import java.util.Collections;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.onnxruntime.zoo.tabular.randomforest;
package ai.djl.onnxruntime.zoo.tabular.softmax_regression;

import ai.djl.Application;
import ai.djl.Application.Tabular;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.modality.Classifications;
Expand All @@ -39,7 +40,7 @@
/** Model loader for onnx iris_flowers models. */
public class IrisClassificationModelLoader extends BaseModelLoader {

private static final Application APPLICATION = Application.Tabular.RANDOM_FOREST;
private static final Application APPLICATION = Tabular.SOFTMAX_REGRESSION;
private static final String GROUP_ID = OrtModelZoo.GROUP_ID;
private static final String ARTIFACT_ID = "iris_flowers";
private static final String VERSION = "0.0.1";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.onnxruntime.zoo.tabular.randomforest;
package ai.djl.onnxruntime.zoo.tabular.softmax_regression;

/** A class holds the iris flower features. */
public class IrisFlower {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
/**
* Contains classes for the classification models in the {@link ai.djl.onnxruntime.zoo.OrtModelZoo}.
*/
package ai.djl.onnxruntime.zoo.tabular.randomforest;
package ai.djl.onnxruntime.zoo.tabular.softmax_regression;
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.onnxruntime.zoo.tabular.randomforest.IrisFlower;
import ai.djl.onnxruntime.zoo.tabular.softmax_regression.IrisFlower;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"metadataVersion": "0.2",
"resourceType": "model",
"application": "tabular/random_forest",
"application": "tabular/softmax_regression",
"groupId": "ai.djl.onnxruntime",
"artifactId": "iris_flowers",
"name": "iris_flowers",
Expand Down

0 comments on commit f8d3f2a

Please sign in to comment.