Skip to content

Commit

Permalink
download PaConvert to get api_alias_mapping.json when build docs (#6580)
Browse files Browse the repository at this point in the history
  • Loading branch information
RedContritio authored Apr 2, 2024
1 parent 41b0fd5 commit b8ac796
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 12 deletions.
1 change: 1 addition & 0 deletions docs/guides/model_convert/convert_from_pytorch/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
docs_mappings.json
api_alias_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import sys

script_path = os.path.abspath(__file__)
script_dir = os.path.dirname(__file__)
sys.path.append(script_dir)
print(script_dir)

# convert from pytorch basedir
CFP_BASEDIR = os.path.dirname(__file__)
sys.path.append(CFP_BASEDIR)
print(CFP_BASEDIR)

from validate_mapping_in_api_difference import (
download_file_by_git,
get_meta_from_diff_file,
process_mapping_index as reference_mapping_item,
)
Expand Down Expand Up @@ -189,12 +192,14 @@ def reference_mapping_item_processer(line, line_idx, state, output, context):


if __name__ == "__main__":
# convert from pytorch basedir
cfp_basedir = os.path.dirname(__file__)
api_alias_source = "paconvert/api_alias_mapping.json"
# api_alias_mapping.json
api_alias_file = download_file_by_git(api_alias_source)

# pytorch_api_mapping_cn
mapping_index_file = os.path.join(cfp_basedir, "pytorch_api_mapping_cn.md")
mapping_index_file = os.path.join(CFP_BASEDIR, "pytorch_api_mapping_cn.md")

api_difference_basedir = os.path.join(cfp_basedir, "api_difference")
api_difference_basedir = os.path.join(CFP_BASEDIR, "api_difference")

mapping_file_pattern = re.compile(r"^torch\.(?P<api_name>.+)\.md$")
# get all diff files (torch.*.md)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import json
import os
import re
import shutil
import subprocess
import sys
import tempfile
import typing
from enum import IntEnum
from typing import TypedDict

CFP_BASEDIR = os.path.dirname(__file__)

mapping_type_set = {
# type 1
"无参数",
Expand Down Expand Up @@ -503,11 +509,58 @@ def process_mapping_index(index_path, item_processer, context={}):
return 0


PACONVERT_REPO_URL = "https://github.com/PaddlePaddle/PaConvert.git"


def download_file_by_git(
source, destination=None, repo_url=PACONVERT_REPO_URL, overwrite=False
):
if destination is None:
destination = os.path.basename(source)

dest_path = os.path.join(CFP_BASEDIR, destination)

if os.path.exists(dest_path) and not overwrite:
print(f"File {dest_path} already exists, skip fetching.")
return

try:
# create temp dir
temp_dir = tempfile.mkdtemp()
# clone repo to temp dir
subprocess.run(
[
"git",
"clone",
"--depth=1",
"--single-branch",
repo_url,
temp_dir,
],
check=True,
)
print(f"Succeeded to clone repo '{repo_url}'.")

# copy file from repo to dest
src_file = os.path.join(temp_dir, source)
shutil.copy(src_file, dest_path)
print(f"Succeeded to copy file {source} from repo to {dest_path}.")
except subprocess.CalledProcessError as e:
print(f"Failed to fetch file by git: {e}")
sys.exit(-1)
finally:
shutil.rmtree(temp_dir)

return dest_path


if __name__ == "__main__":
# convert from pytorch basedir
cfp_basedir = os.path.dirname(__file__)
api_alias_source = "paconvert/api_alias_mapping.json"
# api_alias_mapping.json
api_alias_file = download_file_by_git(api_alias_source)

# pytorch_api_mapping_cn
mapping_index_file = os.path.join(cfp_basedir, "pytorch_api_mapping_cn.md")
mapping_index_file = os.path.join(CFP_BASEDIR, "pytorch_api_mapping_cn.md")

if not os.path.exists(mapping_index_file):
raise Exception(f"Cannot find mapping index file: {mapping_index_file}")
Expand All @@ -521,7 +574,7 @@ def process_mapping_index(index_path, item_processer, context={}):
)
# index_data_dict = {i['torch_api'].replace('\_', '_'): i for i in index_data}

api_difference_basedir = os.path.join(cfp_basedir, "api_difference")
api_difference_basedir = os.path.join(CFP_BASEDIR, "api_difference")

mapping_file_pattern = re.compile(r"^torch\.(?P<api_name>.+)\.md$")
# get all diff files (torch.*.md)
Expand Down Expand Up @@ -551,7 +604,7 @@ def process_mapping_index(index_path, item_processer, context={}):
meta_dict = {m["torch_api"].replace(r"\_", "_"): m for m in metas}

# 该文件用于 PaConvert 的文档对齐工作
api_diff_output_path = os.path.join(cfp_basedir, "docs_mappings.json")
api_diff_output_path = os.path.join(CFP_BASEDIR, "docs_mappings.json")

with open(api_diff_output_path, "w", encoding="utf-8") as f:
json.dump(metas, f, ensure_ascii=False, indent=4)

0 comments on commit b8ac796

Please sign in to comment.