-
Notifications
You must be signed in to change notification settings - Fork 205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
共有ライブラリのロードでアーキテクチャを考慮する機能を追加 #327
共有ライブラリのロードでアーキテクチャを考慮する機能を追加 #327
Conversation
Pull Request Test Coverage Report for Build 1852237076
💛 - Coveralls |
@takana-v さん、 @PickledChair さん、もしよければレビューよろしくお願いします・・・! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
少し気になったところがあったのでコメントしました。ご確認いただければ幸いです。
また、動作自体は問題なさそうですが、if 分岐が多くなりメンテナンスが大変そうなコードになってきている気がしています(特に load_core
関数が肥大化してきています)。共通の処理を繰り返し書いているように見える箇所(x86_64
や aarch64
などのアーキテクチャの同じような判定が Windows と Linux の双方にある)もあるので、まとめられるところはまとめてしまった方が良い気もしました。
ただし、そのためにはコアの名前・OS・アーキテクチャ・CPU/GPUといった情報を一つのデータ構造にまとめる必要がありそうです。試しにコードに起こしてみたものを以下に示します(うまく動作するかどうかはテストしきれていません……)。この PR 内でこのような形にしてしまうかどうかについては、 @.HyodaKazuaki さんや他のレビュワーの方々のご意見をお聞きしたいと思います(この PR では既存のコードの拡張のみにすることになった場合、多分、後で私が改めて PR します)。
コアの情報をデータ構造にまとめたコードの例
import platform
from ctypes import CDLL
from pathlib import Path
from typing import Optional
CORE_INFOS = [
# Windows
{ "name": "core.dll", "platform": "Windows", "is_gpu_core": True, "arch": "x86_64", "core_type": "libtorch" },
{ "name": "core_cpu.dll", "platform": "Windows", "is_gpu_core": False, "arch": "x86_64", "core_type": "libtorch" },
{ "name": "core_gpu_x64_nvidia.dll", "platform": "Windows", "is_gpu_core": True, "arch": "x86_64", "core_type": "onnxruntime" },
{ "name": "core_cpu_x64.dll", "platform": "Windows", "is_gpu_core": False, "arch": "x86_64", "core_type": "onnxruntime" },
{ "name": "core_cpu_x86.dll", "platform": "Windows", "is_gpu_core": False, "arch": "x86", "core_type": "onnxruntime" },
{ "name": "core_cpu_arm.dll", "platform": "Windows", "is_gpu_core": False, "arch": "armv7l", "core_type": "onnxruntime" },
{ "name": "core_cpu_arm64.dll", "platform": "Windows", "is_gpu_core": False, "arch": "aarch64", "core_type": "onnxruntime" },
# Linux
{ "name": "libcore.so", "platform": "Linux", "is_gpu_core": True, "arch": "x86_64", "core_type": "libtorch" },
{ "name": "libcore_cpu.so", "platform": "Linux", "is_gpu_core": False, "arch": "x86_64", "core_type": "libtorch" },
{ "name": "libcore_gpu_x64_nvidia.so", "platform": "Linux", "is_gpu_core": True, "arch": "x86_64", "core_type": "onnxruntime" },
{ "name": "libcore_cpu_x64.so", "platform": "Linux", "is_gpu_core": False, "arch": "x86_64", "core_type": "onnxruntime" },
{ "name": "libcore_cpu_armhf.so", "platform": "Linux", "is_gpu_core": False, "arch": "armv7l", "core_type": "onnxruntime" },
{ "name": "libcore_cpu_arm64.so", "platform": "Linux", "is_gpu_core": False, "arch": "aarch64", "core_type": "onnxruntime" },
# macOS
{ "name": "libcore_cpu_universal2.dylib", "platform": "Darwin", "is_gpu_core": False, "arch": "universal", "core_type": "onnxruntime" },
]
def get_arch_name() -> Optional[str]:
"""
platform.machine() が特定のアーキテクチャ上で複数パターンの文字列を返し得るので、
一意な文字列に変換する
サポート外のアーキテクチャである場合、None を返す
"""
machine = platform.machine()
if machine == "x86_64" or machine == "x64" or machine == "AMD64":
return "x86_64"
elif machine == "i386" or machine == "x86":
return "x86"
elif machine in ["armv7l", "aarch64"]:
return machine
else:
return None
def get_core_name(arch_name: str, platform_name: str, model_type: str, is_gpu_core: bool) -> Optional[str]:
if platform_name == "Darwin":
if (not is_gpu_core) and (arch_name == "x86_64" or arch_name == "aarch64"):
arch_name = "universal"
else:
return None
for core_info in CORE_INFOS:
if (
core_info["platform"] == platform_name
and core_info["arch"] == arch_name
and core_info["core_type"] == model_type
and core_info["is_gpu_core"] == is_gpu_core
):
return core_info["name"]
return None
def get_suitable_core_name(model_type: str, is_gpu_core: bool) -> Optional[str]:
arch_name = get_arch_name()
if arch_name is None:
return None
platform_name = platform.system()
return get_core_name(arch_name, platform_name, model_type, is_gpu_core)
def check_core_type(core_dir: Path) -> Optional[str]:
libtorch_core_names = [
get_suitable_core_name("libtorch", is_gpu_core=True),
get_suitable_core_name("libtorch", is_gpu_core=False),
]
onnxruntime_core_names = [
get_suitable_core_name("onnxruntime", is_gpu_core=True),
get_suitable_core_name("onnxruntime", is_gpu_core=False),
]
if any([(core_dir / name).is_file() for name in libtorch_core_names if name]):
return "libtorch"
elif any([(core_dir / name).is_file() for name in onnxruntime_core_names if name]):
return "onnxruntime"
else:
return None
def load_core(core_dir: Path, use_gpu: bool) -> CDLL:
model_type = check_core_type(core_dir)
if model_type is None:
raise RuntimeError("コアが見つかりません")
if use_gpu or model_type == "onnxruntime":
core_name = get_suitable_core_name(model_type, is_gpu_core=True)
if core_name:
try:
return CDLL(str((core_dir / core_name).resolve(strict=True)))
except OSError:
pass
core_name = get_suitable_core_name(model_type, is_gpu_core=False)
if core_name:
try:
return CDLL(str((core_dir / core_name).resolve(strict=True)))
except OSError:
if model_type == "libtorch":
core_name = get_suitable_core_name(model_type, is_gpu_core=True)
if core_name:
return CDLL(str((core_dir / core_name).resolve(strict=True)))
else:
raise RuntimeError(f"このコンピュータのアーキテクチャ {platform.machine()} で利用可能なコアがありません")
raise RuntimeError("コアの読み込みに失敗しました")
if __name__ == "__main__":
load_core(Path.cwd(), use_gpu=False)
print("use cpu")
load_core(Path.cwd(), use_gpu=True)
print("use gpu")
コード拝見しました。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@HyodaKazuaki ご意見(とご配慮)ありがとうございます。今回の PR のついでにリファクタリングが行えるかもしれないと思い提案しましたが、自分の案だと変更量が多いため、別 PR とした方が良さそうだと思えてきました。他にご意見がなければこの PR はこのままマージしようと思います!
マージします。PR ありがとうございました! |
内容
共有ライブラリのロードの際に、アーキテクチャを考慮することができるように変更しました。
関連 Issue
close #324
その他
Pythonの
platform
ライブラリを利用しています。WindowsとLinuxで判定を行います。
Windows arm向けのコアファイル
core_cpu_arm.dll
とLinux arm向けのコアファイルlibcore_cpu_armhf.so
で命名ルールが異なっているようです。