Skip to content

Commit

Permalink
remove duplicate tokenizer file, fix import problem and add comment f…
Browse files Browse the repository at this point in the history
…or tokenizer model

Signed-off-by: xinyual <[email protected]>
  • Loading branch information
xinyual committed Sep 27, 2023
1 parent dca433a commit 73b3f02
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,20 @@

package org.opensearch.ml.engine.algorithms.sparse_encoding;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.Batchifier;
import ai.djl.translate.ServingTranslator;
import ai.djl.translate.TranslatorContext;
import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.algorithms.SentenceTransformerTranslator;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.file.Path;
import java.util.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

import static org.opensearch.ml.common.CommonValue.ML_MAP_RESPONSE_KEY;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import lombok.extern.log4j.Log4j2;
import org.opensearch.ml.common.FunctionName;
Expand All @@ -36,18 +35,26 @@
import java.lang.reflect.Type;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Map;
import java.util.List;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.common.CommonValue.ML_MAP_RESPONSE_KEY;
import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly;
import static org.opensearch.ml.common.utils.StringUtils.gson;

/**
* Tokenizer model will load from two file: tokenizer file and IDF file.
* For IDF file, it is a predefined weight for each Token. It is calculated like the BM25 IDF for each dataset.
* Decouple idf with model inference will give a more tight upperbound for predicted token weight. And this can help accelerate the WAND algorithm for lucene 9.7.
* IDF introduces global token information, boost search relevance.
* In our pretrained Tokenizer model, we will provide a general IDF from MSMARCO. Customer could recalculate a IDF in their own dataset.
* If without IDF, the weight of each token in the result will be set to 1.0.
* Since we regard tokenizer as a model. Cusotmer needs to keep the consistency between tokenizer/model by themselves.
*/
@Log4j2
@Function(FunctionName.SPARSE_TOKENIZE)
public class SparseTokenizerModel extends DLModel {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package org.opensearch.ml.engine.algorithms.sparse_encoding;

import ai.djl.Model;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
Expand All @@ -27,22 +29,24 @@
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.engine.utils.FileUtils;
import ai.djl.modality.Input;

import java.io.File;
import java.io.IOException;
import ai.djl.modality.Output;

import java.net.URISyntaxException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.UUID;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.mockito.Mockito.*;
import static org.opensearch.ml.engine.algorithms.DLModel.*;
import static org.opensearch.ml.engine.algorithms.DLModel.ML_ENGINE;

public class TextEmbeddingSparseEncodingModelTest {
@Rule
Expand Down Expand Up @@ -98,7 +102,7 @@ public void test_SparseEncoding_Translator_ProcessInput() throws URISyntaxExcept
TranslatorContext translatorContext = mock(TranslatorContext.class);
Model mlModel = mock(Model.class);
when(translatorContext.getModel()).thenReturn(mlModel);
when(mlModel.getModelPath()).thenReturn(Paths.get(getClass().getResource("tokenizer.json").toURI()).getParent());
when(mlModel.getModelPath()).thenReturn(Paths.get(getClass().getResource("../tokenize/tokenizer.json").toURI()).getParent());
sparseEncodingTranslator.prepare(translatorContext);

NDManager manager = mock(NDManager.class);
Expand Down Expand Up @@ -126,7 +130,7 @@ public void test_SparseEncoding_Translator_ProcessOutput() throws URISyntaxExcep
TranslatorContext translatorContext = mock(TranslatorContext.class);
Model mlModel = mock(Model.class);
when(translatorContext.getModel()).thenReturn(mlModel);
when(mlModel.getModelPath()).thenReturn(Paths.get(getClass().getResource("tokenizer.json").toURI()).getParent());
when(mlModel.getModelPath()).thenReturn(Paths.get(getClass().getResource("../tokenize/tokenizer.json").toURI()).getParent());
sparseEncodingTranslator.prepare(translatorContext);

NDArray ndArray = mock(NDArray.class);
Expand Down

This file was deleted.

0 comments on commit 73b3f02

Please sign in to comment.