Skip to content

Commit

Permalink
Merge pull request #282 from jeongseok-meta/fix_pytorch_cpu_strings
Browse files Browse the repository at this point in the history
Fix pytorch-cpu/gpu build only for single python version
  • Loading branch information
hmaarrfk authored Nov 5, 2024
2 parents 611bc61 + 08042f6 commit b1c36f0
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
14 changes: 10 additions & 4 deletions recipe/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{% set version = "2.5.1" %}
{% set build = 1 %}
{% set build = 2 %}

{% if cuda_compiler_version != "None" %}
{% set build = build + 200 %}
Expand Down Expand Up @@ -32,6 +32,8 @@ source:
# https://github.com/pytorch/pytorch/pull/133235
- patches/0006-Update-sympy-version.patch
- patches/0007-Fix-duplicate-linker-script.patch # [cuda_compiler_version != "None" and aarch64]
# https://github.com/pytorch/pytorch/pull/136034
- patches/0008-Fix-pickler-error.patch
# https://github.com/pytorch/pytorch/pull/137331
- patches/137331.patch

Expand Down Expand Up @@ -338,16 +340,20 @@ outputs:
{% set pytorch_cpu_gpu = "pytorch-gpu" %} # [cuda_compiler_version != "None"]
- name: {{ pytorch_cpu_gpu }}
build:
string: cuda{{ cuda_compiler_version | replace('.', '') }}py{{ CONDA_PY }}h{{ PKG_HASH }}_{{ PKG_BUILDNUM }} # [cuda_compiler_version != "None"]
string: cpu_{{ blas_impl }}_py{{ CONDA_PY }}h{{ PKG_HASH }}_{{ PKG_BUILDNUM }} # [cuda_compiler_version == "None"]
string: cuda{{ cuda_compiler_version | replace('.', '') }}h{{ PKG_HASH }}_{{ PKG_BUILDNUM }} # [megabuild and cuda_compiler_version != "None"]
string: cpu_{{ blas_impl }}_h{{ PKG_HASH }}_{{ PKG_BUILDNUM }} # [megabuild and cuda_compiler_version == "None"]
string: cuda{{ cuda_compiler_version | replace('.', '') }}py{{ CONDA_PY }}h{{ PKG_HASH }}_{{ PKG_BUILDNUM }} # [not megabuild and cuda_compiler_version != "None"]
string: cpu_{{ blas_impl }}_py{{ CONDA_PY }}h{{ PKG_HASH }}_{{ PKG_BUILDNUM }} # [not megabuild and cuda_compiler_version == "None"]
detect_binary_files_with_prefix: false
skip: true # [cuda_compiler_version != "None" and linux64 and blas_impl != "mkl"]
# weigh down cpu implementation and give cuda preference
track_features:
- pytorch-cpu # [cuda_compiler_version == "None"]
requirements:
run:
- {{ pin_subpackage("pytorch", exact=True) }}
- pytorch {{ version }}=cuda*{{ PKG_BUILDNUM }} # [megabuild and cuda_compiler_version != "None"]
- pytorch {{ version }}=cpu_{{ blas_impl }}*{{ PKG_BUILDNUM }} # [megabuild and cuda_compiler_version == "None"]
- {{ pin_subpackage("pytorch", exact=True) }} # [not megabuild]
test:
imports:
- torch
Expand Down
34 changes: 34 additions & 0 deletions recipe/patches/0008-Fix-pickler-error.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
diff --git a/torch/serialization.py b/torch/serialization.py
index d936d31d6f5..d937680c031 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -1005,8 +1005,12 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
pickle_module.dump(sys_info, f, protocol=pickle_protocol)
- pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
- pickler.persistent_id = persistent_id
+
+ class PyTorchLegacyPickler(pickle_module.Pickler):
+ def persistent_id(self, obj):
+ return persistent_id(obj)
+
+ pickler = PyTorchLegacyPickler(f, protocol=pickle_protocol)
pickler.dump(obj)

serialized_storage_keys = sorted(serialized_storages.keys())
@@ -1083,8 +1087,12 @@ def _save(

# Write the pickle data for `obj`
data_buf = io.BytesIO()
- pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
- pickler.persistent_id = persistent_id
+
+ class PyTorchPickler(pickle_module.Pickler): # type: ignore[name-defined]
+ def persistent_id(self, obj):
+ return persistent_id(obj)
+
+ pickler = PyTorchPickler(data_buf, protocol=pickle_protocol)
pickler.dump(obj)
data_value = data_buf.getvalue()
zip_file.write_record("data.pkl", data_value, len(data_value))

0 comments on commit b1c36f0

Please sign in to comment.