Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up NewtonNet recipes some more #926

Merged
merged 7 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/quacc/recipes/lj/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,10 @@ def freq_job(

igt = ideal_gas(atoms, vibrations.get_frequencies(), energy=energy)
vib_summary["thermo"] = summarize_thermo(
igt, temperature=temperature, pressure=pressure
igt,
temperature=temperature,
pressure=pressure,
additional_fields={"name": "ASE Thermo Analysis"},
)

return vib_summary
32 changes: 22 additions & 10 deletions src/quacc/recipes/newtonnet/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
if TYPE_CHECKING:
from ase import Atoms

from quacc.schemas.ase import FreqSchema, OptSchema, RunSchema
from quacc.schemas.ase import OptSchema, RunSchema, ThermoSchema, VibSchema

class FreqSchema(RunSchema):
vib: VibSchema
thermo: ThermoSchema


@job
Expand Down Expand Up @@ -199,8 +203,7 @@ def freq_job(
Returns
-------
FreqSchema
Dictionary of results specified in [quacc.schemas.ase.summarize_vib_run][]
and [quacc.schemas.ase.summarize_thermo][]
Dictionary of results
"""
atoms = fetch_atoms(atoms)
calc_swaps = calc_swaps or {}
Expand All @@ -214,20 +217,29 @@ def freq_job(
ml_calculator = NewtonNet(**flags)
atoms.calc = ml_calculator
final_atoms = run_calc(atoms)
energy = final_atoms.get_potential_energy()
hessian = final_atoms.calc.results["hessian"]

summary = summarize_run(
final_atoms,
input_atoms=atoms,
additional_fields={"name": "NewtonNet Hessian"},
)
energy = summary["results"]["energy"]
hessian = summary["results"]["hessian"]

vib = VibrationsData(final_atoms, hessian)
vib_summary = summarize_vib_run(
vib, additional_fields={"name": "NewtonNet Frequency"}
summary["vib"] = summarize_vib_run(
vib, additional_fields={"name": "ASE Vibrations Analysis"}
)

igt = ideal_gas(final_atoms, vib.get_frequencies(), energy=energy)
vib_summary["thermo"] = summarize_thermo(
igt, temperature=temperature, pressure=pressure
summary["thermo"] = summarize_thermo(
igt,
temperature=temperature,
pressure=pressure,
additional_fields={"name": "ASE Thermo Analysis"},
)

return vib_summary
return summary


def _add_stdev_and_hess(summary: dict) -> dict:
Expand Down
16 changes: 8 additions & 8 deletions src/quacc/recipes/newtonnet/ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@
from quacc.schemas.ase import FreqSchema, OptSchema

class TSSchema(OptSchema):
freq: FreqSchema | None
freq_job: FreqSchema | None

class IRCSchema(OptSchema):
freq: FreqSchema | None
freq_job: FreqSchema | None

class QuasiIRCSchema(OptSchema):
irc: IRCSchema
freq: FreqSchema | None
irc_job: IRCSchema
freq_job: FreqSchema | None


@job
Expand Down Expand Up @@ -142,7 +142,7 @@ def ts_job(
freq_summary = (
freq_job.__wrapped__(opt_ts_summary, **freq_job_kwargs) if run_freq else None
)
opt_ts_summary["freq"] = freq_summary
opt_ts_summary["freq_job"] = freq_summary

return opt_ts_summary

Expand Down Expand Up @@ -259,7 +259,7 @@ def irc_job(
freq_summary = (
freq_job.__wrapped__(opt_irc_summary, **freq_job_kwargs) if run_freq else None
)
opt_irc_summary["freq"] = freq_summary
opt_irc_summary["freq_job"] = freq_summary

return opt_irc_summary

Expand Down Expand Up @@ -328,8 +328,8 @@ def quasi_irc_job(
freq_summary = (
freq_job.__wrapped__(relax_summary, **freq_job_kwargs) if run_freq else None
)
relax_summary["freq"] = freq_summary
relax_summary["irc"] = irc_summary
relax_summary["freq_job"] = freq_summary
relax_summary["irc_job"] = irc_summary

return relax_summary

Expand Down
5 changes: 4 additions & 1 deletion src/quacc/recipes/tblite/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,10 @@ def freq_job(

igt = ideal_gas(atoms, vibrations.get_frequencies(), energy=energy)
vib_summary["thermo"] = summarize_thermo(
igt, temperature=temperature, pressure=pressure
igt,
temperature=temperature,
pressure=pressure,
additional_fields={"name": "ASE Thermo Analysis"},
)

return vib_summary
89 changes: 50 additions & 39 deletions tests/recipes/newtonnet/test_newtonnet_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,15 @@ def test_freq_job(tmpdir):
atoms = molecule("H2O")
output = freq_job(atoms)
assert output["atoms"] == molecule("H2O")
assert len(output["results"]["vib_freqs_raw"]) == 9
assert len(output["results"]["vib_freqs"]) == 3
assert output["results"]["vib_freqs_raw"][-1] == pytest.approx(4090.37777396351)
assert output["results"]["vib_freqs"][0] == pytest.approx(1814.0941260498644)
assert output["results"]["vib_freqs"][-1] == pytest.approx(4090.37777396351)
assert output["results"]["n_imag"] == 0
assert output["results"]["imag_vib_freqs"] == []
assert len(output["vib"]["results"]["vib_freqs_raw"]) == 9
assert len(output["vib"]["results"]["vib_freqs"]) == 3
assert output["vib"]["results"]["vib_freqs_raw"][-1] == pytest.approx(
4090.37777396351
)
assert output["vib"]["results"]["vib_freqs"][0] == pytest.approx(1814.0941260498644)
assert output["vib"]["results"]["vib_freqs"][-1] == pytest.approx(4090.37777396351)
assert output["vib"]["results"]["n_imag"] == 0
assert output["vib"]["results"]["imag_vib_freqs"] == []

assert output["thermo"]["atoms"] == atoms
assert output["thermo"]["symmetry"]["point_group"] == "C2v"
Expand All @@ -103,18 +105,24 @@ def test_freq_job(tmpdir):
atoms = molecule("CH3")
output = freq_job(atoms, temperature=1000, pressure=20)
assert output["atoms"] == molecule("CH3")
assert len(output["results"]["vib_freqs_raw"]) == 12
assert len(output["results"]["vib_freqs"]) == 6
assert output["results"]["vib_energies_raw"][0] == pytest.approx(
assert len(output["vib"]["results"]["vib_freqs_raw"]) == 12
assert len(output["vib"]["results"]["vib_freqs"]) == 6
assert output["vib"]["results"]["vib_energies_raw"][0] == pytest.approx(
-0.09441402482739979
)
assert output["results"]["vib_energies_raw"][-1] == pytest.approx(
assert output["vib"]["results"]["vib_energies_raw"][-1] == pytest.approx(
0.3925829460532815
)
assert output["results"]["vib_energies"][0] == pytest.approx(-0.09441402482739979)
assert output["results"]["vib_energies"][-1] == pytest.approx(0.3925829460532815)
assert output["results"]["n_imag"] == 1
assert output["results"]["imag_vib_freqs"] == pytest.approx([-761.5004719152678])
assert output["vib"]["results"]["vib_energies"][0] == pytest.approx(
-0.09441402482739979
)
assert output["vib"]["results"]["vib_energies"][-1] == pytest.approx(
0.3925829460532815
)
assert output["vib"]["results"]["n_imag"] == 1
assert output["vib"]["results"]["imag_vib_freqs"] == pytest.approx(
[-761.5004719152678]
)
assert output["thermo"]["atoms"] == molecule("CH3")


Expand All @@ -134,10 +142,10 @@ def test_ts_job_with_default_args(tmpdir):
# Perform assertions on the result
assert isinstance(output, dict)

assert "freq" in output
assert "thermo" in output["freq"]
assert "freq_job" in output
assert "thermo" in output["freq_job"]
assert output["results"]["energy"] == pytest.approx(-6.796914263061945)
assert output["freq"]["results"]["imag_vib_freqs"][0] == pytest.approx(
assert output["freq_job"]["vib"]["results"]["imag_vib_freqs"][0] == pytest.approx(
-2426.7398321816004
)

Expand All @@ -163,10 +171,10 @@ def test_ts_job_with_custom_hessian(tmpdir):
assert isinstance(output, dict)

assert output["results"]["energy"] == pytest.approx(-8.855604432470276)
assert output["freq"]["results"]["vib_energies"][0] == pytest.approx(
assert output["freq_job"]["vib"]["results"]["vib_energies"][0] == pytest.approx(
0.2256022513686731
)
assert "thermo" in output["freq"]
assert "thermo" in output["freq_job"]


@pytest.mark.skipif(
Expand All @@ -185,11 +193,14 @@ def test_ts_job_with_custom_optimizer(tmpdir):
# Perform assertions on the result
assert isinstance(output, dict)

assert "thermo" in output["freq"]
assert "thermo" in output["freq_job"]
assert output["results"]["energy"] == pytest.approx(-9.51735515322368)
assert output["freq"]["results"]["vib_energies"][0] == pytest.approx(
assert output["freq_job"]["vib"]["results"]["vib_energies"][0] == pytest.approx(
0.22679888726664774
)
assert output["freq_job"]["thermo"]["results"]["energy"] == pytest.approx(
-9.51735515322368
)


@pytest.mark.skipif(
Expand Down Expand Up @@ -223,7 +234,7 @@ def test_irc_job_with_default_args(tmpdir):
assert isinstance(output, dict)

assert output["results"]["energy"] == pytest.approx(-9.517354091813969)
assert output["freq"]["thermo"]["results"]["energy"] == pytest.approx(
assert output["freq_job"]["thermo"]["results"]["energy"] == pytest.approx(
-9.517354091813969
)

Expand All @@ -245,7 +256,7 @@ def test_irc_job_with_custom_fmax(tmpdir):
assert isinstance(output, dict)

assert output["results"]["energy"] == pytest.approx(-9.517354091813969)
assert output["freq"]["thermo"]["results"]["energy"] == pytest.approx(
assert output["freq_job"]["thermo"]["results"]["energy"] == pytest.approx(
-9.517354091813969
)

Expand All @@ -267,7 +278,7 @@ def test_irc_job_with_custom_max_steps(tmpdir):
assert isinstance(output, dict)

assert output["results"]["energy"] == pytest.approx(-9.517354091813969)
assert output["freq"]["thermo"]["results"]["energy"] == pytest.approx(
assert output["freq_job"]["thermo"]["results"]["energy"] == pytest.approx(
-9.517354091813969
)

Expand All @@ -292,7 +303,7 @@ def test_irc_job_with_custom_temperature_and_pressure(tmpdir):
assert isinstance(output, dict)

assert output["results"]["energy"] == pytest.approx(-9.517354091813969)
assert output["freq"]["thermo"]["results"]["energy"] == pytest.approx(
assert output["freq_job"]["thermo"]["results"]["energy"] == pytest.approx(
-9.517354091813969
)

Expand All @@ -314,7 +325,7 @@ def test_irc_job_with_custom_opt_swaps(tmpdir):
assert isinstance(output, dict)

assert output["results"]["energy"] == pytest.approx(-9.517354965639784)
assert output["freq"]["thermo"]["results"]["energy"] == pytest.approx(
assert output["freq_job"]["thermo"]["results"]["energy"] == pytest.approx(
-9.517354965639784
)

Expand All @@ -334,9 +345,9 @@ def test_quasi_irc_job_with_default_args(tmpdir):
# Perform assertions on the result
assert isinstance(output, dict)

assert output["irc"]["results"]["energy"] == pytest.approx(-9.517354091813969)
assert output["irc_job"]["results"]["energy"] == pytest.approx(-9.517354091813969)
assert output["results"]["energy"] == pytest.approx(-9.517354091813969)
assert output["freq"]["thermo"]["results"]["energy"] == pytest.approx(
assert output["freq_job"]["thermo"]["results"]["energy"] == pytest.approx(
-9.517354091813969
)

Expand All @@ -356,12 +367,12 @@ def test_quasi_irc_job_with_custom_direction(tmpdir):

# Perform assertions on the result
assert isinstance(output, dict)
assert "irc" in output
assert "irc_job" in output

assert output["irc"]["results"]["energy"] == pytest.approx(-9.517354965639784)
assert output["irc"]["results"]["energy"] == pytest.approx(-9.517354965639784)
assert output["irc_job"]["results"]["energy"] == pytest.approx(-9.517354965639784)
assert output["irc_job"]["results"]["energy"] == pytest.approx(-9.517354965639784)
assert output["results"]["energy"] == pytest.approx(-9.517354965639784)
assert output["freq"]["thermo"]["results"]["energy"] == pytest.approx(
assert output["freq_job"]["thermo"]["results"]["energy"] == pytest.approx(
-9.517354965639784
)

Expand All @@ -384,11 +395,11 @@ def test_quasi_irc_job_with_custom_temperature_and_pressure(tmpdir):

# Perform assertions on the result
assert isinstance(output, dict)
assert "irc" in output
assert "irc_job" in output

assert output["irc"]["results"]["energy"] == pytest.approx(-9.517354091813969)
assert output["irc_job"]["results"]["energy"] == pytest.approx(-9.517354091813969)
assert output["results"]["energy"] == pytest.approx(-9.517354091813969)
assert output["freq"]["thermo"]["results"]["energy"] == pytest.approx(
assert output["freq_job"]["thermo"]["results"]["energy"] == pytest.approx(
-9.517354091813969
)

Expand All @@ -410,10 +421,10 @@ def test_quasi_irc_job_with_custom_irc_swaps(tmpdir):

# Perform assertions on the result
assert isinstance(output, dict)
assert "irc" in output
assert "irc_job" in output

assert output["irc"]["results"]["energy"] == pytest.approx(-9.517354965639784)
assert output["irc_job"]["results"]["energy"] == pytest.approx(-9.517354965639784)
assert output["results"]["energy"] == pytest.approx(-9.517354965639784)
assert output["freq"]["thermo"]["results"]["energy"] == pytest.approx(
assert output["freq_job"]["thermo"]["results"]["energy"] == pytest.approx(
-9.517354965639784
)