Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[huggingface] Adds Huggingface ModelZoo #1984

Merged
merged 2 commits into from
Sep 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions api/src/main/java/ai/djl/Application.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,24 @@ public static Application of(String path) {
return CV.IMAGE_ENHANCEMENT;
case "nlp":
return NLP.ANY;
case "nlp/question_answer":
return NLP.QUESTION_ANSWER;
case "nlp/text_classification":
return NLP.TEXT_CLASSIFICATION;
case "nlp/sentiment_analysis":
return NLP.SENTIMENT_ANALYSIS;
case "nlp/word_embedding":
return NLP.WORD_EMBEDDING;
case "nlp/fill_mask":
return NLP.FILL_MASK;
case "nlp/machine_translation":
return NLP.MACHINE_TRANSLATION;
case "nlp/multiple_choice":
return NLP.MULTIPLE_CHOICE;
case "nlp/question_answer":
return NLP.QUESTION_ANSWER;
case "nlp/sentiment_analysis":
return NLP.SENTIMENT_ANALYSIS;
case "nlp/text_classification":
return NLP.TEXT_CLASSIFICATION;
case "nlp/text_embedding":
return NLP.TEXT_EMBEDDING;
case "nlp/token_classification":
return NLP.TOKEN_CLASSIFICATION;
case "nlp/word_embedding":
return NLP.WORD_EMBEDDING;
case "tabular":
return Tabular.ANY;
case "tabular/linear_regression":
Expand Down Expand Up @@ -210,6 +214,12 @@ public interface NLP {
/** Any NLP application, including those in {@link NLP}. */
Application ANY = new Application("nlp");

/**
* An application that masking some words in a sentence and predicting which words should
* replace those masks.
*/
Application FILL_MASK = new Application("nlp/fill_mask");

/**
* An application that a reference document and a question about the document and returns
* text answering the question.
Expand All @@ -228,11 +238,17 @@ public interface NLP {
Application TEXT_CLASSIFICATION = new Application("nlp/text_classification");

/**
* An application that classifies text into positive or negative, an specific case of {@link
* An application that classifies text into positive or negative, a specific case of {@link
* #TEXT_CLASSIFICATION}.
*/
Application SENTIMENT_ANALYSIS = new Application("nlp/sentiment_analysis");

/**
* A natural language understanding application that assigns a label to some tokens in a
* text.
*/
Application TOKEN_CLASSIFICATION = new Application("nlp/token_classification");

/**
* An application that takes a word and returns a feature vector that represents the word.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.huggingface.zoo;

import ai.djl.Application;
import ai.djl.Application.NLP;
import ai.djl.engine.Engine;
import ai.djl.repository.Repository;
import ai.djl.repository.Version;
import ai.djl.repository.VersionRange;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.util.JsonUtils;
import ai.djl.util.Utils;

import com.google.gson.reflect.TypeToken;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.Reader;
import java.io.Writer;
import java.lang.reflect.Type;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.time.Duration;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
import java.util.zip.GZIPInputStream;

/** HfModelZoo is a repository that contains HuggingFace models. */
public class HfModelZoo extends ModelZoo {

private static final Logger logger = LoggerFactory.getLogger(HfModelZoo.class);

private static final String REPO = "https://mlrepo.djl.ai/";
private static final Repository REPOSITORY = Repository.newInstance("Huggingface", REPO);
private static final String GROUP_ID = "ai.djl.huggingface.pytorch";

private static final long ONE_DAY = Duration.ofDays(1).toMillis();

HfModelZoo() {
Version version = new Version(Engine.class.getPackage().getSpecificationVersion());
addModels(NLP.FILL_MASK, version);
addModels(NLP.QUESTION_ANSWER, version);
addModels(NLP.TEXT_CLASSIFICATION, version);
addModels(NLP.TEXT_EMBEDDING, version);
addModels(NLP.TOKEN_CLASSIFICATION, version);
}

/** {@inheritDoc} */
@Override
public String getGroupId() {
return GROUP_ID;
}

/** {@inheritDoc} */
@Override
public Set<String> getSupportedEngines() {
return Collections.singleton("PyTorch");
}

private void addModels(Application app, Version version) {
Map<String, Map<String, Object>> map = listModels(app);
for (Map.Entry<String, Map<String, Object>> entry : map.entrySet()) {
Map<String, Object> model = entry.getValue();
if ("failed".equals(model.get("result"))) {
continue;
}
String requires = (String) model.get("requires");
if (requires != null) {
// the model requires specific DJL version
VersionRange range = VersionRange.parse(requires);
if (!range.contains(version)) {
continue;
}
}
String artifactId = entry.getKey();
addModel(REPOSITORY.model(app, GROUP_ID, artifactId, "0.0.1"));
}
}

private Map<String, Map<String, Object>> listModels(Application app) {
try {
String path = "model/" + app.getPath() + "/ai/djl/huggingface/pytorch/";
Path dir = Utils.getCacheDir().resolve("cache/repo/" + path);
if (Files.notExists(dir)) {
Files.createDirectories(dir);
} else if (!Files.isDirectory(dir)) {
logger.warn("Failed initialize cache directory: " + dir);
return Collections.emptyMap();
}
Type type = new TypeToken<Map<String, Map<String, Object>>>() {}.getType();

Path file = dir.resolve("models.json");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we may want a models.json for every DJL release version. Otherwise, we may be suggested new models that are not supported for older versions of DJL

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently models.json is per application, assume the same application will use the same Translator. So we don't really need each file per version. The risk would be some model may require enhancement in current Translator. What we can do is add a requires="0.20.0+" field, so we can filter out those future models that is not compatible with old Translator

if (Files.exists(file)) {
long lastModified = Files.getLastModifiedTime(file).toMillis();
if (Boolean.getBoolean("offline")
|| System.currentTimeMillis() - lastModified < ONE_DAY) {
try (Reader reader = Files.newBufferedReader(file)) {
return JsonUtils.GSON.fromJson(reader, type);
}
}
}

String url = REPO + path + "models.json.gz";
Path tmp = Files.createTempFile(dir, "models", ".tmp");
try (GZIPInputStream gis = new GZIPInputStream(new URL(url).openStream())) {
String json = Utils.toString(gis);
try (Writer writer = Files.newBufferedWriter(tmp)) {
writer.write(json);
}
Utils.moveQuietly(tmp, file);
return JsonUtils.GSON.fromJson(json, type);
} finally {
Utils.deleteQuietly(tmp);
}
} catch (IOException e) {
logger.warn("Failed load index of models: " + app, e);
}

return Collections.emptyMap();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.huggingface.zoo;

import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooProvider;

/**
* An Huggingface model zoo provider implements the {@link ai.djl.repository.zoo.ZooProvider}
* interface.
*/
public class HfZooProvider implements ZooProvider {

/** {@inheritDoc} */
@Override
public ModelZoo getModelZoo() {
return new HfModelZoo();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
/** Contains the built-in {@link ai.djl.huggingface.zoo.HfModelZoo}. */
package ai.djl.huggingface.zoo;
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ai.djl.huggingface.zoo.HfZooProvider
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.huggingface.zoo;

import ai.djl.Application.NLP;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.nlp.qa.QAInput;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.testing.TestRequirements;
import ai.djl.translate.TranslateException;
import ai.djl.util.JsonUtils;
import ai.djl.util.Utils;

import org.testng.Assert;
import org.testng.annotations.Test;

import java.io.IOException;
import java.io.Writer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public class ModelZooTest {

@Test
public void testModelZoo() throws ModelException, IOException, TranslateException {
TestRequirements.nightly();

String question = "When did BBC Japan start broadcasting?";
String paragraph =
"BBC Japan was a general entertainment Channel. "
+ "Which operated between December 2004 and April 2006. "
+ "It ceased operations after its Japanese distributor folded.";

String url = "djl://ai.djl.huggingface.pytorch/deepset/minilm-uncased-squad2";
Criteria<QAInput, String> criteria =
Criteria.builder().setTypes(QAInput.class, String.class).optModelUrls(url).build();

try (ZooModel<QAInput, String> model = criteria.loadModel();
Predictor<QAInput, String> predictor = model.newPredictor()) {
QAInput input = new QAInput(question, paragraph);
String res = predictor.predict(input);
Assert.assertEquals(res, "december 2004");
}
}

@Test
public void testFutureVersion() throws IOException {
Map<String, Map<String, Object>> map = new ConcurrentHashMap<>();
Map<String, Object> model = new ConcurrentHashMap<>();
model.put("result", "failed");
map.put("model1", model);

model = new ConcurrentHashMap<>();
model.put("requires", "10.100.0+");
map.put("model2", model);

model = new ConcurrentHashMap<>();
model.put("requires", "0.19.0+");
map.put("model3", model);
map.put("model4", new ConcurrentHashMap<>());

String path = "model/" + NLP.QUESTION_ANSWER.getPath() + "/ai/djl/huggingface/pytorch/";
Path dir = Utils.getCacheDir().resolve("cache/repo/" + path);
Files.createDirectories(dir);
Path file = dir.resolve("models.json");
try (Writer writer = Files.newBufferedWriter(file)) {
writer.write(JsonUtils.GSON_PRETTY.toJson(map));
}
HfModelZoo zoo = new HfModelZoo();

Assert.assertNull(zoo.getModelLoader("model1"));
Assert.assertNull(zoo.getModelLoader("model2"));
Assert.assertNull(zoo.getModelLoader("model3"));
Assert.assertNotNull(zoo.getModelLoader("model4"));

Utils.deleteQuietly(file);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.huggingface.zoo;