Skip to content

Commit

Permalink
[huggingface] Adds Huggingface ModelZoo
Browse files Browse the repository at this point in the history
Change-Id: Ief99e4b5449c95a75582ded9a9e2549e582387c5
  • Loading branch information
frankfliu committed Sep 2, 2022
1 parent ae05f94 commit 4faa4a7
Show file tree
Hide file tree
Showing 8 changed files with 254 additions and 9 deletions.
32 changes: 23 additions & 9 deletions api/src/main/java/ai/djl/Application.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,24 @@ public static Application of(String path) {
return CV.IMAGE_ENHANCEMENT;
case "nlp":
return NLP.ANY;
case "nlp/question_answer":
return NLP.QUESTION_ANSWER;
case "nlp/text_classification":
return NLP.TEXT_CLASSIFICATION;
case "nlp/sentiment_analysis":
return NLP.SENTIMENT_ANALYSIS;
case "nlp/word_embedding":
return NLP.WORD_EMBEDDING;
case "nlp/fill_mask":
return NLP.FILL_MASK;
case "nlp/machine_translation":
return NLP.MACHINE_TRANSLATION;
case "nlp/multiple_choice":
return NLP.MULTIPLE_CHOICE;
case "nlp/question_answer":
return NLP.QUESTION_ANSWER;
case "nlp/sentiment_analysis":
return NLP.SENTIMENT_ANALYSIS;
case "nlp/text_classification":
return NLP.TEXT_CLASSIFICATION;
case "nlp/text_embedding":
return NLP.TEXT_EMBEDDING;
case "nlp/token_classification":
return NLP.TOKEN_CLASSIFICATION;
case "nlp/word_embedding":
return NLP.WORD_EMBEDDING;
case "tabular":
return Tabular.ANY;
case "tabular/linear_regression":
Expand Down Expand Up @@ -210,6 +214,11 @@ public interface NLP {
/** Any NLP application, including those in {@link NLP}. */
Application ANY = new Application("nlp");

/**
* An application that masking some words in a sentence and predicting which words should replace those masks.
*/
Application FILL_MASK = new Application("nlp/fill_mask");

/**
* An application that a reference document and a question about the document and returns
* text answering the question.
Expand All @@ -228,11 +237,16 @@ public interface NLP {
Application TEXT_CLASSIFICATION = new Application("nlp/text_classification");

/**
* An application that classifies text into positive or negative, an specific case of {@link
* An application that classifies text into positive or negative, a specific case of {@link
* #TEXT_CLASSIFICATION}.
*/
Application SENTIMENT_ANALYSIS = new Application("nlp/sentiment_analysis");

/**
* A natural language understanding application that assigns a label to some tokens in a text.
*/
Application TOKEN_CLASSIFICATION = new Application("nlp/token_classification");

/**
* An application that takes a word and returns a feature vector that represents the word.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.huggingface.zoo;

import ai.djl.Application;
import ai.djl.Application.NLP;
import ai.djl.repository.Repository;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.util.JsonUtils;
import ai.djl.util.Utils;

import com.google.gson.reflect.TypeToken;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.Reader;
import java.io.Writer;
import java.lang.reflect.Type;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.time.Duration;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
import java.util.zip.GZIPInputStream;

/** HfModelZoo is a repository that contains HuggingFace models. */
public class HfModelZoo extends ModelZoo {

private static final Logger logger = LoggerFactory.getLogger(HfModelZoo.class);

private static final String REPO = "https://mlrepo.djl.ai/";
private static final Repository REPOSITORY = Repository.newInstance("Huggingface", REPO);
public static final String GROUP_ID = "ai.djl.huggingface.pytorch";

private static final long ONE_DAY = Duration.ofDays(1).toMillis();

HfModelZoo() {
addModels(NLP.FILL_MASK);
addModels(NLP.QUESTION_ANSWER);
addModels(NLP.TEXT_CLASSIFICATION);
addModels(NLP.TEXT_EMBEDDING);
addModels(NLP.TOKEN_CLASSIFICATION);
}

/** {@inheritDoc} */
@Override
public String getGroupId() {
return GROUP_ID;
}

/** {@inheritDoc} */
@Override
public Set<String> getSupportedEngines() {
return Collections.singleton("PyTorch");
}

private void addModels(Application app) {
Map<String, Map<String, Object>> map = listModels(app);
for (Map.Entry<String, Map<String, Object>> entry : map.entrySet()) {
if ("failed".equals(entry.getValue().get("result"))) {
continue;
}
String artifactId = entry.getKey();
addModel(REPOSITORY.model(app, GROUP_ID, artifactId, "0.0.1"));
}
}

private Map<String, Map<String, Object>> listModels(Application app) {
try {
String path = "model/" + app.getPath() + "/ai/djl/huggingface/pytorch/";
Path dir = Utils.getCacheDir().resolve("cache/repo/" + path);
if (Files.notExists(dir)) {
Files.createDirectories(dir);
} else if (!Files.isDirectory(dir)) {
logger.warn("Failed initialize cache directory: " + dir);
return Collections.emptyMap();
}
Type type = new TypeToken<Map<String, Map<String, Object>>>() {}.getType();

Path file = dir.resolve("models.json");
if (Files.exists(file)) {
long lastModified = Files.getLastModifiedTime(file).toMillis();
if (Boolean.getBoolean("offline")
|| System.currentTimeMillis() - lastModified < ONE_DAY) {
try (Reader reader = Files.newBufferedReader(file)) {
return JsonUtils.GSON.fromJson(reader, type);
}
}
}

String url = REPO + path + "models.json.gz";
Path tmp = Files.createTempFile(dir, "models", ".tmp");
try (GZIPInputStream gis = new GZIPInputStream(new URL(url).openStream())) {
String json = Utils.toString(gis);
try (Writer writer = Files.newBufferedWriter(tmp)) {
writer.write(json);
}
Utils.moveQuietly(tmp, file);
return JsonUtils.GSON.fromJson(json, type);
} finally {
Utils.deleteQuietly(tmp);
}
} catch (IOException e) {
logger.warn("Failed load index of models: " + app, e);
}

return Collections.emptyMap();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.huggingface.zoo;

import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooProvider;

/**
* An Huggingface model zoo provider implements the {@link ai.djl.repository.zoo.ZooProvider}
* interface.
*/
public class HfZooProvider implements ZooProvider {

/** {@inheritDoc} */
@Override
public ModelZoo getModelZoo() {
return new HfModelZoo();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
/** Contains the built-in {@link ai.djl.hugging.zoo.HfModelZoo}. */
package ai.djl.huggingface.zoo;
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ai.djl.huggingface.zoo.HfZooProvider
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.huggingface.zoo;

import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.nlp.qa.QAInput;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.testing.TestRequirements;
import ai.djl.translate.TranslateException;

import org.testng.Assert;
import org.testng.annotations.Test;

import java.io.IOException;

public class ModelZooTest {

@Test
public void testModelZoo() throws ModelException, IOException, TranslateException {
TestRequirements.nightly();

String question = "When did BBC Japan start broadcasting?";
String paragraph =
"BBC Japan was a general entertainment Channel. "
+ "Which operated between December 2004 and April 2006. "
+ "It ceased operations after its Japanese distributor folded.";

String url = "djl://ai.djl.huggingface.pytorch/deepset/minilm-uncased-squad2";
Criteria<QAInput, String> criteria =
Criteria.builder().setTypes(QAInput.class, String.class).optModelUrls(url).build();

try (ZooModel<QAInput, String> model = criteria.loadModel();
Predictor<QAInput, String> predictor = model.newPredictor()) {
QAInput input = new QAInput(question, paragraph);
String res = predictor.predict(input);
Assert.assertEquals(res, "december 2004");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.huggingface.zoo;

0 comments on commit 4faa4a7

Please sign in to comment.