Skip to content

Commit

Permalink
Merge 701fbff into 75160c7
Browse files Browse the repository at this point in the history
  • Loading branch information
denghuilu authored Aug 31, 2021
2 parents 75160c7 + 701fbff commit 4bb2a10
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 3 deletions.
1 change: 1 addition & 0 deletions deepmd/entrypoints/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def compress(
10 * step,
int(frequency),
]
jdata["training"]["save_ckpt"] = "model-compression/model.ckpt"
jdata = normalize(jdata)

# check the descriptor info of the input file
Expand Down
2 changes: 1 addition & 1 deletion deepmd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def parse_args(args: Optional[List[str]] = None):
"-c",
"--checkpoint-folder",
type=str,
default=".",
default="model-compression",
help="path to checkpoint folder",
)
parser_compress.add_argument(
Expand Down
13 changes: 11 additions & 2 deletions source/tests/test_model_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
default_places = 10

def _file_delete(file) :
if os.path.exists(file):
if os.path.isdir(file):
os.rmdir(file)
elif os.path.isfile(file):
os.remove(file)

def _subprocess_run(command):
Expand Down Expand Up @@ -318,10 +320,17 @@ def tearDownClass(self):
_file_delete("out.json")
_file_delete("compress.json")
_file_delete("checkpoint")
_file_delete("lcurve.out")
_file_delete("model.ckpt.meta")
_file_delete("model.ckpt.index")
_file_delete("model.ckpt.data-00000-of-00001")
_file_delete("model.ckpt-100.meta")
_file_delete("model.ckpt-100.index")
_file_delete("model.ckpt-100.data-00000-of-00001")
_file_delete("model-compression/checkpoint")
_file_delete("model-compression/model.ckpt.meta")
_file_delete("model-compression/model.ckpt.index")
_file_delete("model-compression/model.ckpt.data-00000-of-00001")
_file_delete("model-compression")

def test_attrs(self):
self.assertEqual(self.dp_original.get_ntypes(), 2)
Expand Down

0 comments on commit 4bb2a10

Please sign in to comment.