Skip to content

Commit

Permalink
cherry pick deepmodeling#338: fix an error in stress by ase interface (
Browse files Browse the repository at this point in the history
…deepmodeling#964)

* rename for cherry-pick

* fix an error in stress by ase interface

(cherry picked from commit a24971f)

* move back

* fix lint error

* fix lint warnings

Co-authored-by: hsulab <[email protected]>
  • Loading branch information
njzjz and hsulab authored Aug 12, 2021
1 parent 80c8260 commit 5b1098d
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions deepmd/calculator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""ASE calculator interface module."""

from typing import TYPE_CHECKING, Dict, List, Optional, Union
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Union

from ase.calculators.calculator import (
Calculator, all_changes, PropertyNotImplementedError
)

from deepmd import DeepPotential
from ase.calculators.calculator import Calculator, all_changes

if TYPE_CHECKING:
from ase import Atoms
Expand Down Expand Up @@ -51,7 +54,7 @@ class DP(Calculator):
"""

name = "DP"
implemented_properties = ["energy", "forces", "stress"]
implemented_properties = ["energy", "forces", "virial", "stress"]

def __init__(
self,
Expand All @@ -72,7 +75,7 @@ def __init__(
def calculate(
self,
atoms: Optional["Atoms"] = None,
properties: List[str] = ["energy", "forces", "stress"],
properties: List[str] = ["energy", "forces", "virial"],
system_changes: List[str] = all_changes,
):
"""Run calculation with deepmd model.
Expand All @@ -98,6 +101,17 @@ def calculate(
symbols = self.atoms.get_chemical_symbols()
atype = [self.type_dict[k] for k in symbols]
e, f, v = self.dp.eval(coords=coord, cells=cell, atom_types=atype)
self.results["energy"] = e[0]
self.results["forces"] = f[0]
self.results["stress"] = v[0]
self.results['energy'] = e[0][0]
self.results['forces'] = f[0]
self.results['virial'] = v[0].reshape(3, 3)

# convert virial into stress for lattice relaxation
if "stress" in properties:
if sum(atoms.get_pbc()) > 0:
# the usual convention (tensile stress is positive)
# stress = -virial / volume
stress = -0.5 * (v[0].copy() + v[0].copy().T) / atoms.get_volume()
# Voigt notation
self.results['stress'] = stress.flat[[0, 4, 8, 5, 2, 1]]
else:
raise PropertyNotImplementedError

0 comments on commit 5b1098d

Please sign in to comment.