Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

option to set energy reference #72

Merged
merged 2 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions abics/applications/latgas_abinitio_interface/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ def __init__(self):
self.exe_command = []
self.vac_map = []
self.previous_dir = []
self.energy_ref = 0.0
self.prev_dirs_energy_ref = False

@classmethod
def from_dict(cls, d):
Expand Down Expand Up @@ -226,6 +228,8 @@ def from_dict(cls, d):
params.ignore_species = d.get("ignore_species", None)
params.vac_map = d.get("vac_map", [])
params.previous_dir = d.get("previous_dirs", [])
params.energy_ref = d.get("energy_ref", 0.0)
params.prev_dirs_energy_ref = d.get("prev_dirs_energy_ref", False)
if isinstance(params.previous_dir, str):
params.previous_dir = [params.previous_dir]

Expand Down
8 changes: 6 additions & 2 deletions abics/scripts/activelearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
DFTConfigParams,
)
from abics.applications.latgas_abinitio_interface.base_solver import SolverBase, create_solver
from abics.applications.latgas_abinitio_interface.params import ALParams, DFTParams
from abics.applications.latgas_abinitio_interface.params import ALParams, DFTParams, TrainerParams

from abics.util import exists_on_all_nodes
from pymatgen.core import Structure
Expand All @@ -60,6 +60,10 @@ def main_impl(params_root: MutableMapping):
alparams = ALParams.from_dict(params_root["mlref"]["solver"])
mcparams = DFTParams.from_dict(params_root["sampling"]["solver"])
configparams = DFTConfigParams.from_dict(params_root["config"])
trainerparams = TrainerParams.from_dict(params_root["train"])

energy_ref_empty = trainerparams.energy_ref
mult_now = np.prod(configparams.supercell)

solver: SolverBase = create_solver(alparams.solver, alparams)

Expand Down Expand Up @@ -274,7 +278,7 @@ def main_impl(params_root: MutableMapping):
logger.error("Either train (abics_train) or MC (abics_sampling) first.")
sys.exit(1)
obs = np.load(os.path.join(MCdir, str(myreplica), "obs_save.npy"))
energy_ref = obs[:, 0]
energy_ref = obs[:, 0] + energy_ref_empty * mult_now
ALstep = nextMC_index
ALdir = os.path.join(os.getcwd(), f"AL{ALstep}", str(myreplica))
config = defect_config(configparams)
Expand Down
253 changes: 133 additions & 120 deletions abics/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,38 +41,48 @@


def main_impl(params_root: MutableMapping):
if not os.path.exists("ALloop.progress"):
logger.error("abics_mlref has not run yet.")
sys.exit(1)
trainerparams = TrainerParams.from_dict(params_root["train"])
previous_dirs = trainerparams.previous_dir
prev_dirs_energy_ref = trainerparams.prev_dirs_energy_ref
ALdirs = []
with open("ALloop.progress", "r") as fi:
lines = fi.readlines()
for line in lines:
line = line.strip()
if line.startswith("AL"):
ALdirs.append(line)
last_li = lines[-1].strip()
if not last_li.startswith("AL"):
logger.error("You shouldn't run train now.")
logger.error("Either abics_sampling or abics_mlref first.")
sys.exit(1)
if not os.path.exists("ALloop.progress"):
if len(previous_dirs) == 0:
logger.error("abics_mlref has not run yet and no previous training data provided.")
sys.exit(1)
else:
logger.info("Training on data provided in previous_dirs")
else:
with open("ALloop.progress", "r") as fi:
lines = fi.readlines()
for line in lines:
line = line.strip()
if line.startswith("AL"):
ALdirs.append(line)
last_li = lines[-1].strip()
if not last_li.startswith("AL"):
logger.error("You shouldn't run train now.")
logger.error("Either abics_sampling or abics_mlref first.")
sys.exit(1)

dftparams = DFTParams.from_dict(params_root["sampling"]["solver"])
ensemble = dftparams.ensemble
base_input_dir = dftparams.base_input_dir

trainerparams = TrainerParams.from_dict(params_root["train"])
ignore_species = trainerparams.ignore_species
trainer_commands = trainerparams.exe_command
trainer_type = trainerparams.solver
trainer_input_dirs = trainerparams.base_input_dir
previous_dirs = trainerparams.previous_dir
energy_ref = trainerparams.energy_ref # energy bias per unit cell

configparams = DFTConfigParams.from_dict(params_root["config"])
config = defect_config(configparams)
mult_now = np.prod(configparams.supercell)
unitvol = configparams.lat.volume # volume per unit cell
species = config.structure.symbol_set
dummy_sts = {sp: config.dummy_structure_sp(sp) for sp in species}



if trainer_type not in ["aenet", "allegro", "nequip", "mlip_3"]:
logger.error("Unknown trainer: ", trainer_type)
sys.exit(1)
Expand All @@ -89,119 +99,122 @@ def main_impl(params_root: MutableMapping):
st_fis = ["structure.{}.xsf".format(i) for i in range(num_st)]
for fi in st_fis:
st_tmp = Structure.from_file(fi)
if prev_dirs_energy_ref:
mult = st_tmp.lattice.volume / unitvol
assert abs(mult - round(mult)) < 0.0001, "all input structures must be supercells of unit cell"
else:
mult = 0
st_tmp.remove_species(ignore_species)
structures.append(st_tmp)
with open(fi) as f:
li = f.readline()
e = float(li.split()[4])
e = float(li.split()[4]) - energy_ref * mult
energies.append(e)
os.chdir(rootdir)

logger.info("--Done")

logger.info("-Mapping relaxed structures in AL* to on-lattice model...")

# val_map is a list of list [[sp0, vac0], [sp1, vac1], ...]
# if vac_map:
# vac_map = {specie: vacancy for specie, vacancy in vac_map}
# else:
# vac_map = {}

# we first group species that share sublattices together
G = nx.Graph()
G.add_nodes_from(species)
for sublattice in config.defect_sublattices:
groups = sublattice.groups
sp_list = []
for group in groups:
sp_list.extend(group.species)
for pair in itertools.combinations(sp_list, 2):
G.add_edge(*pair)
sp_groups = nx.connected_components(G)
dummy_sts_share: list[tuple[Structure, list]] = []
for c in nx.connected_components(G):
# merge dummy structures for species that share sublattices
sps = list(c)

coords = np.concatenate([dummy_sts[sp].frac_coords for sp in sps], axis=0)
st_tmp = Structure(
dummy_sts[sps[0]].lattice,
species=["X"] * coords.shape[0],
coords=coords,
)
st_tmp.merge_sites(mode="delete")
dummy_sts_share.append((st_tmp, sps))
if len(ALdirs) > 1:
logger.info(f"-Reading previously mapped structures up to {ALdirs[-2]}")
for dir in ALdirs[:-1]:
rpl = 0
while os.path.isdir(os.path.join(dir, str(rpl))):
os.chdir(os.path.join(dir, str(rpl)))
energies_ref = []
step_ids = []
with open("energy_corr.dat") as fi:
for line in fi:
words = line.split()
energies_ref.append(float(words[1]))
step_ids.append(int(words[2]))
for step_id, energy in zip(step_ids, energies_ref):
if os.path.exists(f"structure.{step_id}_mapped.vasp"):
structures.append(
Structure.from_file(f"structure.{step_id}_mapped.vasp")
)
energies.append(energy)
rpl += 1
os.chdir(rootdir)

logger.info("--Finished reading previously mapped structures")

logger.info(f"-Mapping structures in {ALdirs[-1]}")
dir = ALdirs[-1]
rpl = 0
while os.path.isdir(os.path.join(dir, str(rpl))):
os.chdir(os.path.join(dir, str(rpl)))
energies_ref = []
step_ids = []
with open("energy_corr.dat") as fi:
for line in fi:
words = line.split()
energies_ref.append(float(words[1]))
step_ids.append(int(words[2]))
for step_id, energy in zip(step_ids, energies_ref):
structure: Structure = Structure.from_file(f"structure.{step_id}.vasp")
mapped_sts = []
mapping_success = True
for dummy_st, specs in dummy_sts_share:
# perform sublattice by sublattice mapping
sp_rm = list(filter(lambda s: s not in specs, species))
st_tmp = structure.copy()
st_tmp.remove_species(sp_rm)
num_sp = len(st_tmp)
# map to perfect lattice for this species
st_tmp = map2perflat(dummy_st, st_tmp)
st_tmp.remove_species(["X"])
mapped_sts.append(st_tmp)
if num_sp != len(st_tmp):
logger.info(
f"--mapping failed for structure {step_id} in replica {rpl}"
)
mapping_success = False

for sts in mapped_sts[1:]:
for i in range(len(sts)):
mapped_sts[0].append(sts[i].species_string, sts[i].frac_coords)
if ignore_species:
mapped_sts[0].remove_species(ignore_species)
if mapping_success:
structures.append(mapped_sts[0])
mapped_sts[0].to(
filename=f"structure.{step_id}_mapped.vasp", fmt="POSCAR"
)
energies.append(energy)
rpl += 1
os.chdir(rootdir)
logger.info("--Finished mapping")

if len(ALdirs) > 0:
logger.info("-Mapping relaxed structures in AL* to on-lattice model...")

# val_map is a list of list [[sp0, vac0], [sp1, vac1], ...]
# if vac_map:
# vac_map = {specie: vacancy for specie, vacancy in vac_map}
# else:
# vac_map = {}

# we first group species that share sublattices together
G = nx.Graph()
G.add_nodes_from(species)
for sublattice in config.defect_sublattices:
groups = sublattice.groups
sp_list = []
for group in groups:
sp_list.extend(group.species)
for pair in itertools.combinations(sp_list, 2):
G.add_edge(*pair)
sp_groups = nx.connected_components(G)
dummy_sts_share : list[tuple[Structure, list]] = []
for c in nx.connected_components(G):
# merge dummy structures for species that share sublattices
sps = list(c)

coords = np.concatenate([dummy_sts[sp].frac_coords for sp in sps], axis=0)
st_tmp = Structure(
dummy_sts[sps[0]].lattice,
species=["X"] * coords.shape[0],
coords=coords,
)
st_tmp.merge_sites(mode="delete")
dummy_sts_share.append((st_tmp, sps))
if len(ALdirs) > 1:
logger.info(f"-Reading previously mapped structures up to {ALdirs[-2]}")
for dir in ALdirs[:-1]:
rpl = 0
while os.path.isdir(os.path.join(dir, str(rpl))):
os.chdir(os.path.join(dir, str(rpl)))
energies_ref = []
step_ids = []
with open("energy_corr.dat") as fi:
for line in fi:
words = line.split()
energies_ref.append(float(words[1]))
step_ids.append(int(words[2]))
for step_id, energy in zip(step_ids, energies_ref):
if os.path.exists(f"structure.{step_id}_mapped.vasp"):
st_tmp = Structure.from_file(f"structure.{step_id}_mapped.vasp")
mult = st_tmp.lattice.volume / unitvol
assert mult % 1 < 0.0001, "all input structures must be supercells of unit cell"
structures.append(st_tmp)
energies.append(energy - energy_ref * mult)
rpl += 1
os.chdir(rootdir)

logger.info("--Finished reading previously mapped structures")

logger.info(f"-Mapping structures in {ALdirs[-1]}")
dir = ALdirs[-1]
rpl = 0
while os.path.isdir(os.path.join(dir, str(rpl))):
os.chdir(os.path.join(dir, str(rpl)))
energies_ref = []
step_ids = []
with open("energy_corr.dat") as fi:
for line in fi:
words = line.split()
energies_ref.append(float(words[1]))
step_ids.append(int(words[2]))
for step_id, energy in zip(step_ids, energies_ref):
structure: Structure = Structure.from_file(f"structure.{step_id}.vasp")
mapped_sts = []
mapping_success = True
for dummy_st, specs in dummy_sts_share:
# perform sublattice by sublattice mapping
sp_rm = list(filter(lambda s: s not in specs, species))
st_tmp = structure.copy()
st_tmp.remove_species(sp_rm)
num_sp = len(st_tmp)
# map to perfect lattice for this species
st_tmp = map2perflat(dummy_st, st_tmp)
st_tmp.remove_species(["X"])
mapped_sts.append(st_tmp)
if num_sp != len(st_tmp):
logger.info(f"--mapping failed for structure {step_id} in replica {rpl}")
mapping_success = False

for sts in mapped_sts[1:]:
for i in range(len(sts)):
mapped_sts[0].append(sts[i].species_string, sts[i].frac_coords)
if ignore_species:
mapped_sts[0].remove_species(ignore_species)
if mapping_success:
structures.append(mapped_sts[0])
mapped_sts[0].to(filename=f"structure.{step_id}_mapped.vasp", fmt="POSCAR")
energies.append(energy - energy_ref * mult_now)
rpl += 1
os.chdir(rootdir)
logger.info("--Finished mapping")

generate_input_dirs = []
train_input_dirs = []
predict_input_dirs = []
Expand Down