-
Notifications
You must be signed in to change notification settings - Fork 62
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Init a sparse model auto tracing workflow. (#394)
* Init a sparse model auto tracing workflow. Signed-off-by: conggguan <[email protected]> * Change the minimum-approvals of sparse model uploader to 2. Add some test case. Remove some redundant lines. Signed-off-by: conggguan <[email protected]> * Fix some test cases. Signed-off-by: conggguan <[email protected]> * Remove the temp test jupyter notebook. Signed-off-by: conggguan <[email protected]> * Change the variable name of inner model, and optimize the license verification. Signed-off-by: conggguan <[email protected]> * Address some comments, and nox format. Signed-off-by: conggguan <[email protected]> * Fix a bug for NeuralSparseModel's init. And remove a redundant save_pretrained. Signed-off-by: conggguan <[email protected]> * [Fix] Deleted some redundant code caused a faiure test case, fixed it. Signed-off-by: conggguan <[email protected]> * [Style]:Run a nox -s format to make format identical. Signed-off-by: conggguan <[email protected]> * [Fix] Simplify the SparseEncodingModel and fix a bug for multiple texts embeddings. Signed-off-by: conggguan <[email protected]> * [Fix] Make register_and_deploy_sparse_encoding_model return proper list but not single map. Signed-off-by: conggguan <[email protected]> * [Fix] Fix a bug for register_and_deploy_sparse_encoding_model, it now generate correct list of embedding of input texts. Signed-off-by: conggguan <[email protected]> * [Fix] Fix sparse encoding mdoel's test_check_required_fields test case. Signed-off-by: conggguan <[email protected]> * [Fix] Renamed a unproper variable name. Signed-off-by: conggguan <[email protected]> * [Refactor] Add some comments and extract some constants to a new file. Signed-off-by: conggguan <[email protected]> * [Refactor] Simplify and reuse some code from model auto tracing. Signed-off-by: conggguan <[email protected]> * [Refactor] Simplify and reuse some code from model auto tracing. Signed-off-by: conggguan <[email protected]> * [Refactor] Add a function comments and merge the sparse model trace workflow and dense. Signed-off-by: conggguan <[email protected]> * [Refactor] Merge the sparse and dense model's ci branch. Signed-off-by: conggguan <[email protected]> * [Refactor] Change for more common API, add a line of comments. Signed-off-by: conggguan <[email protected]> --------- Signed-off-by: conggguan <[email protected]>
- Loading branch information
Showing
16 changed files
with
1,503 additions
and
253 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Validating CODEOWNERS rules …
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
* @dhrubo-os @greaa-aws @ylwu-amzn @b4sjoo @jngz-es @rbhavna | ||
* @dhrubo-os @greaa-aws @ylwu-amzn @b4sjoo @jngz-es @rbhavna |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# The OpenSearch Contributors require contributions made to | ||
# this file be licensed under the Apache-2.0 license or a | ||
# compatible open source license. | ||
# Any modifications Copyright OpenSearch Contributors. See | ||
# GitHub history for details. | ||
import json | ||
import os | ||
from abc import ABC, abstractmethod | ||
from zipfile import ZipFile | ||
|
||
import requests | ||
|
||
from opensearch_py_ml.ml_commons.ml_common_utils import ( | ||
LICENSE_URL, | ||
SPARSE_ENCODING_FUNCTION_NAME, | ||
) | ||
|
||
|
||
class BaseUploadModel(ABC): | ||
""" | ||
A base class for uploading models to OpenSearch pretrained model hub. | ||
""" | ||
|
||
def __init__( | ||
self, model_id: str, folder_path: str = None, overwrite: bool = False | ||
) -> None: | ||
self.model_id = model_id | ||
self.folder_path = folder_path | ||
self.overwrite = overwrite | ||
|
||
@abstractmethod | ||
def save_as_pt(self, *args, **kwargs): | ||
pass | ||
|
||
@abstractmethod | ||
def save_as_onnx(self, *args, **kwargs): | ||
pass | ||
|
||
@abstractmethod | ||
def make_model_config_json( | ||
self, | ||
version_number: str, | ||
model_format: str, | ||
description: str, | ||
) -> str: | ||
pass | ||
|
||
def _fill_null_truncation_field( | ||
self, | ||
save_json_folder_path: str, | ||
max_length: int, | ||
) -> None: | ||
""" | ||
Fill truncation field in tokenizer.json when it is null | ||
:param save_json_folder_path: | ||
path to save model json file, e.g, "home/save_pre_trained_model_json/") | ||
:type save_json_folder_path: string | ||
:param max_length: | ||
maximum sequence length for model | ||
:type max_length: int | ||
:return: no return value expected | ||
:rtype: None | ||
""" | ||
tokenizer_file_path = os.path.join(save_json_folder_path, "tokenizer.json") | ||
with open(tokenizer_file_path) as user_file: | ||
parsed_json = json.load(user_file) | ||
if "truncation" not in parsed_json or parsed_json["truncation"] is None: | ||
parsed_json["truncation"] = { | ||
"direction": "Right", | ||
"max_length": max_length, | ||
"strategy": "LongestFirst", | ||
"stride": 0, | ||
} | ||
with open(tokenizer_file_path, "w") as file: | ||
json.dump(parsed_json, file, indent=2) | ||
|
||
def _add_apache_license_to_model_zip_file(self, model_zip_file_path: str): | ||
""" | ||
Add Apache-2.0 license file to the model zip file at model_zip_file_path | ||
:param model_zip_file_path: | ||
Path to the model zip file | ||
:type model_zip_file_path: string | ||
:return: no return value expected | ||
:rtype: None | ||
""" | ||
r = requests.get(LICENSE_URL) | ||
assert r.status_code == 200, "Failed to add license file to the model zip file" | ||
|
||
with ZipFile(str(model_zip_file_path), "a") as zipObj: | ||
zipObj.writestr("LICENSE", r.content) | ||
|
||
|
||
class SparseModel(BaseUploadModel, ABC): | ||
""" | ||
Class for autotracing the Sparse Encoding model. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model_id: str, | ||
folder_path: str = "./model_files/", | ||
overwrite: bool = False, | ||
): | ||
super().__init__(model_id, folder_path, overwrite) | ||
self.model_id = model_id | ||
self.folder_path = folder_path | ||
self.overwrite = overwrite | ||
self.function_name = SPARSE_ENCODING_FUNCTION_NAME | ||
|
||
def pre_process(self): | ||
pass | ||
|
||
def post_process(self): | ||
pass |
Oops, something went wrong.