Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,11 @@ To give a specific validation set, use the argument `--valid_file`. To set a lar

To control the model's size, you need to change `--hidden_irreps`. For most applications, the recommended default model size is `--hidden_irreps='256x0e'` (meaning 256 invariant messages) or `--hidden_irreps='128x0e + 128x1o'`. If the model is not accurate enough, you can include higher order features, e.g., `128x0e + 128x1o + 128x2e`, or increase the number of channels to `256`. It is also possible to specify the model using the `--num_channels=128` and `--max_L=1`keys.

It is usually preferred to add the isolated atoms to the training set, rather than reading in their energies through the command line like in the example above. To label them in the training set, set `config_type=IsolatedAtom` in their info fields. If you prefer not to use or do not know the energies of the isolated atoms, you can use the option `--E0s="average"` which estimates the atomic energies using least squares regression. Note that using fitted E0s corresponds to fitting the deviations of the atomic energies from the average, rather than fitting the atomization energy (which is the case when using isolated-atom E0s), and this will most likely result in less stable potentials for molecular dynamics applications.
It is usually preferred to add the isolated atoms to the training set, rather than reading in their energies through the command line like in the example above. To label them in the training set, set `config_type=IsolatedAtom` in their info fields.

When training a model from scratch, if you prefer not to use or do not know the energies of the isolated atoms, you can use the option `--E0s="average"` which estimates the atomic energies using least squares regression. Note that using fitted E0s corresponds to fitting the deviations of the atomic energies from the average, rather than fitting the atomization energy (which is the case when using isolated-atom E0s), and this will most likely result in less stable potentials for molecular dynamics applications.

When finetuning foundation models, you can use `--E0s="estimated"`, which estimates the atomic reference energies by solving a linear system that optimally corrects the foundation model's predictions on the training data. This approach computes E0 corrections by first running the foundation model on all training configurations, computing the prediction errors (reference energies minus predicted energies), and then solving a least-squares system to find optimal E0 corrections for each element. This is preferable in general over the 'average' option.

If the keyword `--stage_two` (previously called swa) is enabled, the energy weight of the loss is increased for the last ~20% of the training epochs (from `--start_stage_two` epochs). This setting usually helps lower the energy errors.

Expand Down
33 changes: 31 additions & 2 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from mace.data import KeySpecification, update_keyspec_from_kwargs
from mace.tools import torch_geometric
from mace.tools.distributed_tools import init_distributed
from mace.tools.lora_tools import inject_LoRAs
from mace.tools.lora_tools import inject_LoRAs, merge_lora_weights
from mace.tools.model_script_utils import configure_model
from mace.tools.multihead_tools import (
HeadConfig,
Expand Down Expand Up @@ -434,7 +434,7 @@ def run(args) -> None:
for head_config in head_configs:
if head_config.atomic_energies_dict is None or len(head_config.atomic_energies_dict) == 0:
assert head_config.E0s is not None, "Atomic energies must be provided"
if all(check_path_ase_read(f) for f in head_config.train_file) and head_config.E0s.lower() != "foundation":
if all(check_path_ase_read(f) for f in head_config.train_file) and head_config.E0s.lower() not in ["foundation", "estimated"]:
atomic_energies_dict[head_config.head_name] = get_atomic_energies(
head_config.E0s, head_config.collections.train, head_config.z_table
)
Expand All @@ -455,6 +455,32 @@ def run(args) -> None:
].item()
for z in z_table.zs
}
elif head_config.E0s.lower() == "estimated":
assert args.foundation_model is not None, "Foundation model must be provided for E0s estimation"
assert all(check_path_ase_read(f) for f in head_config.train_file), "E0s estimation requires training data in .xyz format"
logging.info("Estimating E0s from foundation model predictions on training data")
z_table_foundation = AtomicNumberTable(
[int(z) for z in model_foundation.atomic_numbers]
)
foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies
if foundation_atomic_energies.ndim > 1:
foundation_atomic_energies = foundation_atomic_energies.squeeze()
if foundation_atomic_energies.ndim == 2:
foundation_atomic_energies = foundation_atomic_energies[0]
logging.info("Foundation model has multiple heads, using the first head for E0 estimation.")
foundation_e0s = {
z: foundation_atomic_energies[
z_table_foundation.z_to_index(z)
].item()
for z in z_table_foundation.zs
}
atomic_energies_dict[head_config.head_name] = data.estimate_e0s_from_foundation(
foundation_model=model_foundation,
foundation_e0s=foundation_e0s,
collections_train=head_config.collections.train,
z_table=head_config.z_table,
device=device,
)
else:
atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table)
else:
Expand Down Expand Up @@ -1009,6 +1035,9 @@ def run(args) -> None:
model_path = Path(args.checkpoints_dir) / (tag + ".model")
logging.info(f"Saving model to {model_path}")
model_to_save = deepcopy(model)
if args.lora:
logging.info("Merging LoRA weights into base model")
merge_lora_weights(model_to_save)
if args.enable_cueq and not args.only_cueq:
logging.info("RUNING CUEQ TO E3NN")
model_to_save = run_cueq_to_e3nn(deepcopy(model), device=device)
Expand Down
2 changes: 2 additions & 0 deletions mace/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
compute_average_E0s,
config_from_atoms,
config_from_atoms_list,
estimate_e0s_from_foundation,
load_from_xyz,
random_train_valid_split,
save_AtomicData_to_HDF5,
Expand All @@ -29,6 +30,7 @@
"config_from_atoms_list",
"AtomicData",
"compute_average_E0s",
"estimate_e0s_from_foundation",
"save_dataset_as_HDF5",
"HDF5Dataset",
"dataset_from_sharded_hdf5",
Expand Down
163 changes: 163 additions & 0 deletions mace/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,169 @@ def compute_average_E0s(
return atomic_energies_dict


def estimate_e0s_from_foundation(
foundation_model,
foundation_e0s: Dict[int, float],
collections_train: Configurations,
z_table: AtomicNumberTable,
device: str = "cpu",
) -> Dict[int, float]:
"""
Estimate atomic reference energies (E0s) by solving a linear system
that optimally corrects foundation model predictions on training data.

This function computes E0 corrections by:
1. Running the foundation model on all training configurations
2. Computing prediction errors (reference - predicted)
3. Solving a least-squares system to find optimal E0 corrections

Args:
foundation_model: The foundation MACE model
foundation_e0s: Dictionary mapping element atomic numbers to original E0 values
collections_train: List of training configurations
z_table: Atomic number table for the training dataset
device: Device to run predictions on (default: "cpu")

Returns:
Dictionary with estimated E0 values for each element
"""
import torch

# Filter configs with valid energy
valid_configs = []
for config in collections_train:
if "energy" in config.properties and config.properties["energy"] is not None:
valid_configs.append(config)

if not valid_configs:
logging.warning("No configurations with energy found for E0 estimation. Using foundation E0s.")
return foundation_e0s.copy()

elements = z_table.zs
n_configs = len(valid_configs)
n_elements = len(elements)

# A matrix: each row contains atom counts for each element
# b vector: each entry is the prediction error for a configuration
A = np.zeros((n_configs, n_elements))
b = np.zeros(n_configs)

logging.info(f"Estimating E0s using foundation model on {n_configs} configurations with {n_elements} elements")

# Set model to eval mode
foundation_model.eval()
foundation_model = foundation_model.to(device)

# Get r_max as a float
r_max = foundation_model.r_max
if hasattr(r_max, 'item'):
r_max = r_max.item()
elif isinstance(r_max, torch.Tensor):
r_max = float(r_max)

with torch.no_grad():
for i, config in enumerate(valid_configs):
# Convert to AtomicData for model prediction
# Import here to avoid circular dependency
from mace.data import AtomicData
from mace.tools import torch_geometric

atomic_data = AtomicData.from_config(
config,
z_table=AtomicNumberTable([int(z) for z in foundation_model.atomic_numbers]),
cutoff=r_max,
)

# Create a proper batch using DataLoader
data_loader = torch_geometric.dataloader.DataLoader(
dataset=[atomic_data],
batch_size=1,
shuffle=False,
drop_last=False,
)
batch = next(iter(data_loader)).to(device)

# Get model prediction (only energy, no forces/stress to avoid gradient computation)
output = foundation_model(
batch.to_dict(),
training=False,
compute_force=False,
compute_virials=False,
compute_stress=False,
)
predicted_energy = output["energy"]

# Handle different tensor shapes (batched or unbatched)
if predicted_energy.dim() == 0:
predicted_energy = predicted_energy.item()
else:
predicted_energy = predicted_energy.item() if predicted_energy.numel() == 1 else predicted_energy[0].item()

# Get reference energy
ref_energy = config.properties["energy"]

# Compute error
error = ref_energy - predicted_energy
b[i] = error

# Store atom counts for each element
for j, element in enumerate(elements):
A[i, j] = np.sum(config.atomic_numbers == element)

# Solve least squares system: A @ corrections = b
try:
corrections, residuals, rank, s = np.linalg.lstsq(A, b, rcond=None)

logging.info("=" * 80)
logging.info("E0 ESTIMATION FROM FOUNDATION MODEL")
logging.info("=" * 80)
logging.info(f"Rank of system: {rank}/{n_elements}")
logging.info(f"Residuals: {residuals}")

# Compute new E0s
new_e0s = {}
for i, element in enumerate(elements):
correction = corrections[i]
foundation_e0 = foundation_e0s.get(element, 0.0)
new_e0s[element] = foundation_e0 + correction
logging.info(
f"Element {element}: foundation E0 = {foundation_e0:.6f} eV, "
f"correction = {correction:.6f} eV, new E0 = {new_e0s[element]:.6f} eV"
)

# Compute statistics
mse_before = np.mean(b**2)
b_after = b - A @ corrections
mse_after = np.mean(b_after**2)
rmse_before = np.sqrt(mse_before)
rmse_after = np.sqrt(mse_after)
mae_before = np.mean(np.abs(b))
mae_after = np.mean(np.abs(b_after))

logging.info("=" * 80)
logging.info("FIT STATISTICS")
logging.info("=" * 80)
logging.info(f"RMSE before E0 correction: {rmse_before:.6f} eV")
logging.info(f"RMSE after E0 correction: {rmse_after:.6f} eV")
logging.info(f"MAE before E0 correction: {mae_before:.6f} eV")
logging.info(f"MAE after E0 correction: {mae_after:.6f} eV")

if rank < n_elements:
logging.warning(
f"System is rank deficient (rank {rank}/{n_elements}). "
"Some elements may not be sufficiently represented in the dataset."
)

logging.info("=" * 80)

return new_e0s

except np.linalg.LinAlgError as e:
logging.error(f"Error solving linear system for E0 estimation: {e}")
logging.warning("Falling back to foundation model E0s")
return foundation_e0s.copy()


def save_dataset_as_HDF5(dataset: List, out_name: str) -> None:
with h5py.File(out_name, "w") as f:
for i, data in enumerate(dataset):
Expand Down
Loading
Loading