Skip to content

Commit

Permalink
update CPP inference test(deeppot_sea)
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Dec 27, 2024
1 parent 2e6cae4 commit a8145fa
Show file tree
Hide file tree
Showing 5 changed files with 641 additions and 5 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ repos:
- id: trailing-whitespace
exclude: "^.+\\.pbtxt$"
- id: end-of-file-fixer
exclude: "^.+\\.pbtxt$"
exclude: "^.+\\.pbtxt$|deeppot_sea\\.json$"
- id: check-yaml
- id: check-json
- id: check-added-large-files
Expand Down Expand Up @@ -63,7 +63,7 @@ repos:
rev: v19.1.4
hooks:
- id: clang-format
exclude: ^(source/3rdparty|source/lib/src/gpu/cudart/.+\.inc|.+\.ipynb$)
exclude: ^(source/3rdparty|source/lib/src/gpu/cudart/.+\.inc|.+\.ipynb$|.+\.json$)
# markdown, yaml, CSS, javascript
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v4.0.0-alpha.8
Expand Down
18 changes: 15 additions & 3 deletions deepmd/pd/utils/serialization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json

import paddle

Expand Down Expand Up @@ -38,7 +37,7 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
raise ValueError("Paddle backend only supports converting .json file")
model = BaseModel.deserialize(data["model"])
# JIT will happy in this way...
model.model_def_script = json.dumps(data["model_def_script"])
# model.model_def_script = json.dumps(data["model_def_script"])
if "min_nbor_dist" in data.get("@variables", {}):
model.min_nbor_dist = float(data["@variables"]["min_nbor_dist"])
# model = paddle.jit.to_static(model)
Expand All @@ -49,7 +48,20 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
"FLAGS_enable_pir_api": 1,
}
)
from paddle.static import (
InputSpec,
)

jit_model = paddle.jit.to_static(
model.forward_lower,
full_graph=True,
input_spec=[
InputSpec([-1, -1, 3], dtype="float64", name="coord"),
InputSpec([-1, -1], dtype="int32", name="atype"),
InputSpec([-1, -1, -1], dtype="int32", name="nlist"),
],
)
paddle.jit.save(
model,
jit_model,
model_file.split(".json")[0],
)
Loading

0 comments on commit a8145fa

Please sign in to comment.