diff --git a/README.md b/README.md index 4928aa3ee..1aa4fef36 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,11 @@ To give a specific validation set, use the argument `--valid_file`. To set a lar To control the model's size, you need to change `--hidden_irreps`. For most applications, the recommended default model size is `--hidden_irreps='256x0e'` (meaning 256 invariant messages) or `--hidden_irreps='128x0e + 128x1o'`. If the model is not accurate enough, you can include higher order features, e.g., `128x0e + 128x1o + 128x2e`, or increase the number of channels to `256`. It is also possible to specify the model using the `--num_channels=128` and `--max_L=1`keys. -It is usually preferred to add the isolated atoms to the training set, rather than reading in their energies through the command line like in the example above. To label them in the training set, set `config_type=IsolatedAtom` in their info fields. If you prefer not to use or do not know the energies of the isolated atoms, you can use the option `--E0s="average"` which estimates the atomic energies using least squares regression. Note that using fitted E0s corresponds to fitting the deviations of the atomic energies from the average, rather than fitting the atomization energy (which is the case when using isolated-atom E0s), and this will most likely result in less stable potentials for molecular dynamics applications. +It is usually preferred to add the isolated atoms to the training set, rather than reading in their energies through the command line like in the example above. To label them in the training set, set `config_type=IsolatedAtom` in their info fields. + +When training a model from scratch, if you prefer not to use or do not know the energies of the isolated atoms, you can use the option `--E0s="average"` which estimates the atomic energies using least squares regression. Note that using fitted E0s corresponds to fitting the deviations of the atomic energies from the average, rather than fitting the atomization energy (which is the case when using isolated-atom E0s), and this will most likely result in less stable potentials for molecular dynamics applications. + +When finetuning foundation models, you can use `--E0s="estimated"`, which estimates the atomic reference energies by solving a linear system that optimally corrects the foundation model's predictions on the training data. This approach computes E0 corrections by first running the foundation model on all training configurations, computing the prediction errors (reference energies minus predicted energies), and then solving a least-squares system to find optimal E0 corrections for each element. This is preferable in general over the 'average' option. If the keyword `--stage_two` (previously called swa) is enabled, the energy weight of the loss is increased for the last ~20% of the training epochs (from `--start_stage_two` epochs). This setting usually helps lower the energy errors. diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index e730a021f..282e7658a 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -36,7 +36,7 @@ from mace.data import KeySpecification, update_keyspec_from_kwargs from mace.tools import torch_geometric from mace.tools.distributed_tools import init_distributed -from mace.tools.lora_tools import inject_LoRAs +from mace.tools.lora_tools import inject_LoRAs, merge_lora_weights from mace.tools.model_script_utils import configure_model from mace.tools.multihead_tools import ( HeadConfig, @@ -434,7 +434,7 @@ def run(args) -> None: for head_config in head_configs: if head_config.atomic_energies_dict is None or len(head_config.atomic_energies_dict) == 0: assert head_config.E0s is not None, "Atomic energies must be provided" - if all(check_path_ase_read(f) for f in head_config.train_file) and head_config.E0s.lower() != "foundation": + if all(check_path_ase_read(f) for f in head_config.train_file) and head_config.E0s.lower() not in ["foundation", "estimated"]: atomic_energies_dict[head_config.head_name] = get_atomic_energies( head_config.E0s, head_config.collections.train, head_config.z_table ) @@ -455,6 +455,32 @@ def run(args) -> None: ].item() for z in z_table.zs } + elif head_config.E0s.lower() == "estimated": + assert args.foundation_model is not None, "Foundation model must be provided for E0s estimation" + assert all(check_path_ase_read(f) for f in head_config.train_file), "E0s estimation requires training data in .xyz format" + logging.info("Estimating E0s from foundation model predictions on training data") + z_table_foundation = AtomicNumberTable( + [int(z) for z in model_foundation.atomic_numbers] + ) + foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies + if foundation_atomic_energies.ndim > 1: + foundation_atomic_energies = foundation_atomic_energies.squeeze() + if foundation_atomic_energies.ndim == 2: + foundation_atomic_energies = foundation_atomic_energies[0] + logging.info("Foundation model has multiple heads, using the first head for E0 estimation.") + foundation_e0s = { + z: foundation_atomic_energies[ + z_table_foundation.z_to_index(z) + ].item() + for z in z_table_foundation.zs + } + atomic_energies_dict[head_config.head_name] = data.estimate_e0s_from_foundation( + foundation_model=model_foundation, + foundation_e0s=foundation_e0s, + collections_train=head_config.collections.train, + z_table=head_config.z_table, + device=device, + ) else: atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table) else: @@ -1009,6 +1035,9 @@ def run(args) -> None: model_path = Path(args.checkpoints_dir) / (tag + ".model") logging.info(f"Saving model to {model_path}") model_to_save = deepcopy(model) + if args.lora: + logging.info("Merging LoRA weights into base model") + merge_lora_weights(model_to_save) if args.enable_cueq and not args.only_cueq: logging.info("RUNING CUEQ TO E3NN") model_to_save = run_cueq_to_e3nn(deepcopy(model), device=device) diff --git a/mace/data/__init__.py b/mace/data/__init__.py index 8629cf521..c37439bcd 100644 --- a/mace/data/__init__.py +++ b/mace/data/__init__.py @@ -9,6 +9,7 @@ compute_average_E0s, config_from_atoms, config_from_atoms_list, + estimate_e0s_from_foundation, load_from_xyz, random_train_valid_split, save_AtomicData_to_HDF5, @@ -29,6 +30,7 @@ "config_from_atoms_list", "AtomicData", "compute_average_E0s", + "estimate_e0s_from_foundation", "save_dataset_as_HDF5", "HDF5Dataset", "dataset_from_sharded_hdf5", diff --git a/mace/data/utils.py b/mace/data/utils.py index 049d42cdd..fab52a965 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -388,6 +388,169 @@ def compute_average_E0s( return atomic_energies_dict +def estimate_e0s_from_foundation( + foundation_model, + foundation_e0s: Dict[int, float], + collections_train: Configurations, + z_table: AtomicNumberTable, + device: str = "cpu", +) -> Dict[int, float]: + """ + Estimate atomic reference energies (E0s) by solving a linear system + that optimally corrects foundation model predictions on training data. + + This function computes E0 corrections by: + 1. Running the foundation model on all training configurations + 2. Computing prediction errors (reference - predicted) + 3. Solving a least-squares system to find optimal E0 corrections + + Args: + foundation_model: The foundation MACE model + foundation_e0s: Dictionary mapping element atomic numbers to original E0 values + collections_train: List of training configurations + z_table: Atomic number table for the training dataset + device: Device to run predictions on (default: "cpu") + + Returns: + Dictionary with estimated E0 values for each element + """ + import torch + + # Filter configs with valid energy + valid_configs = [] + for config in collections_train: + if "energy" in config.properties and config.properties["energy"] is not None: + valid_configs.append(config) + + if not valid_configs: + logging.warning("No configurations with energy found for E0 estimation. Using foundation E0s.") + return foundation_e0s.copy() + + elements = z_table.zs + n_configs = len(valid_configs) + n_elements = len(elements) + + # A matrix: each row contains atom counts for each element + # b vector: each entry is the prediction error for a configuration + A = np.zeros((n_configs, n_elements)) + b = np.zeros(n_configs) + + logging.info(f"Estimating E0s using foundation model on {n_configs} configurations with {n_elements} elements") + + # Set model to eval mode + foundation_model.eval() + foundation_model = foundation_model.to(device) + + # Get r_max as a float + r_max = foundation_model.r_max + if hasattr(r_max, 'item'): + r_max = r_max.item() + elif isinstance(r_max, torch.Tensor): + r_max = float(r_max) + + with torch.no_grad(): + for i, config in enumerate(valid_configs): + # Convert to AtomicData for model prediction + # Import here to avoid circular dependency + from mace.data import AtomicData + from mace.tools import torch_geometric + + atomic_data = AtomicData.from_config( + config, + z_table=AtomicNumberTable([int(z) for z in foundation_model.atomic_numbers]), + cutoff=r_max, + ) + + # Create a proper batch using DataLoader + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data], + batch_size=1, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)).to(device) + + # Get model prediction (only energy, no forces/stress to avoid gradient computation) + output = foundation_model( + batch.to_dict(), + training=False, + compute_force=False, + compute_virials=False, + compute_stress=False, + ) + predicted_energy = output["energy"] + + # Handle different tensor shapes (batched or unbatched) + if predicted_energy.dim() == 0: + predicted_energy = predicted_energy.item() + else: + predicted_energy = predicted_energy.item() if predicted_energy.numel() == 1 else predicted_energy[0].item() + + # Get reference energy + ref_energy = config.properties["energy"] + + # Compute error + error = ref_energy - predicted_energy + b[i] = error + + # Store atom counts for each element + for j, element in enumerate(elements): + A[i, j] = np.sum(config.atomic_numbers == element) + + # Solve least squares system: A @ corrections = b + try: + corrections, residuals, rank, s = np.linalg.lstsq(A, b, rcond=None) + + logging.info("=" * 80) + logging.info("E0 ESTIMATION FROM FOUNDATION MODEL") + logging.info("=" * 80) + logging.info(f"Rank of system: {rank}/{n_elements}") + logging.info(f"Residuals: {residuals}") + + # Compute new E0s + new_e0s = {} + for i, element in enumerate(elements): + correction = corrections[i] + foundation_e0 = foundation_e0s.get(element, 0.0) + new_e0s[element] = foundation_e0 + correction + logging.info( + f"Element {element}: foundation E0 = {foundation_e0:.6f} eV, " + f"correction = {correction:.6f} eV, new E0 = {new_e0s[element]:.6f} eV" + ) + + # Compute statistics + mse_before = np.mean(b**2) + b_after = b - A @ corrections + mse_after = np.mean(b_after**2) + rmse_before = np.sqrt(mse_before) + rmse_after = np.sqrt(mse_after) + mae_before = np.mean(np.abs(b)) + mae_after = np.mean(np.abs(b_after)) + + logging.info("=" * 80) + logging.info("FIT STATISTICS") + logging.info("=" * 80) + logging.info(f"RMSE before E0 correction: {rmse_before:.6f} eV") + logging.info(f"RMSE after E0 correction: {rmse_after:.6f} eV") + logging.info(f"MAE before E0 correction: {mae_before:.6f} eV") + logging.info(f"MAE after E0 correction: {mae_after:.6f} eV") + + if rank < n_elements: + logging.warning( + f"System is rank deficient (rank {rank}/{n_elements}). " + "Some elements may not be sufficiently represented in the dataset." + ) + + logging.info("=" * 80) + + return new_e0s + + except np.linalg.LinAlgError as e: + logging.error(f"Error solving linear system for E0 estimation: {e}") + logging.warning("Falling back to foundation model E0s") + return foundation_e0s.copy() + + def save_dataset_as_HDF5(dataset: List, out_name: str) -> None: with h5py.File(out_name, "w") as f: for i, data in enumerate(dataset): diff --git a/mace/tools/lora_tools.py b/mace/tools/lora_tools.py index 39d40a153..b6ec0f486 100644 --- a/mace/tools/lora_tools.py +++ b/mace/tools/lora_tools.py @@ -1,4 +1,5 @@ import torch +import torch.nn.functional as F from e3nn import o3 from e3nn.nn._fc import _Layer as E3NNFCLayer from torch import nn @@ -23,7 +24,11 @@ def build_lora_irreps( class LoRAO3Linear(nn.Module): - """LoRA for equivariant o3.Linear-like layers (preserves O(3) equivariance).""" + """LoRA for equivariant o3.Linear-like layers (preserves O(3) equivariance). + + Uses fused weight computation: W_merged = W_base + scaling * (W_A @ W_B) + with automatic caching during inference (when grad is disabled). + """ def __init__(self, base_linear: o3.Linear, rank: int = 4, alpha: float = 1.0): super().__init__() @@ -32,6 +37,7 @@ def __init__(self, base_linear: o3.Linear, rank: int = 4, alpha: float = 1.0): self.irreps_out = self.base.irreps_out self.scaling = float(alpha) / float(rank) self.lora_irreps = build_lora_irreps(self.irreps_in, self.irreps_out, rank) + # Use the same class as base to avoid layout mismatches if possible layer_type = type(self.base) self.lora_A = layer_type( @@ -40,11 +46,18 @@ def __init__(self, base_linear: o3.Linear, rank: int = 4, alpha: float = 1.0): self.lora_B = layer_type( self.lora_irreps, self.irreps_out, internal_weights=True, biases=False ) + # Match dtype/device to base base_param = next(self.base.parameters()) self.lora_A.to(dtype=base_param.dtype, device=base_param.device) self.lora_B.to(dtype=base_param.dtype, device=base_param.device) + # Cache for merged weight (used during inference) + self._cached_merged_weight: torch.Tensor | None = None + + # Build instruction mapping for weight composition + self._build_instruction_mapping() + with torch.no_grad(): for p in self.lora_B.parameters(): p.zero_() @@ -52,42 +65,152 @@ def __init__(self, base_linear: o3.Linear, rank: int = 4, alpha: float = 1.0): if p.dim() >= 2: p.normal_(mean=0.0, std=1e-3) + def _build_instruction_mapping(self) -> None: + """Build lookup tables for matching instructions between base, A, and B.""" + # lora_A: maps i_in -> (instruction_idx, i_out, path_weight) + self._A_by_i_in = {} + for idx, instr in enumerate(self.lora_A.instructions): + self._A_by_i_in[instr.i_in] = (idx, instr.i_out, instr.path_weight) + + # lora_B: maps (i_in, i_out) -> (instruction_idx, path_weight) + self._B_by_in_out = {} + for idx, instr in enumerate(self.lora_B.instructions): + self._B_by_in_out[(instr.i_in, instr.i_out)] = (idx, instr.path_weight) + + @staticmethod + def _extract_weight_blocks(linear: o3.Linear) -> dict[int, torch.Tensor]: + """Extract weight blocks indexed by instruction.""" + blocks = {} + offset = 0 + for idx, instr in enumerate(linear.instructions): + size = instr.path_shape[0] * instr.path_shape[1] + block = linear.weight[offset : offset + size].reshape(instr.path_shape) + blocks[idx] = block + offset += size + return blocks + + def compute_merged_weight(self) -> torch.Tensor: + """Compute W_base + scaling * composed(W_A, W_B) in weight space.""" + base_blocks = self._extract_weight_blocks(self.base) + A_blocks = self._extract_weight_blocks(self.lora_A) + B_blocks = self._extract_weight_blocks(self.lora_B) + + merged_blocks = [] + for base_idx, base_instr in enumerate(self.base.instructions): + i_in_base = base_instr.i_in + i_out_base = base_instr.i_out + pw_base = base_instr.path_weight + + # Find corresponding lora_A instruction + if i_in_base not in self._A_by_i_in: + merged_blocks.append(base_blocks[base_idx]) + continue + + A_idx, i_mid, pw_A = self._A_by_i_in[i_in_base] + + # Find corresponding lora_B instruction + B_key = (i_mid, i_out_base) + if B_key not in self._B_by_in_out: + merged_blocks.append(base_blocks[base_idx]) + continue + + B_idx, pw_B = self._B_by_in_out[B_key] + + # Compose: W_delta = (pw_A * pw_B / pw_base) * (W_A @ W_B) + ratio = (pw_A * pw_B) / pw_base + delta = A_blocks[A_idx] @ B_blocks[B_idx] + merged = base_blocks[base_idx] + self.scaling * ratio * delta + merged_blocks.append(merged) + + return torch.cat([b.flatten() for b in merged_blocks]) + def forward(self, x: torch.Tensor) -> torch.Tensor: - base = self.base(x) - delta = self.lora_B(self.lora_A(x)) - return base + self.scaling * delta + if torch.is_grad_enabled(): + # Training: use activation-space computation for correct gradient flow + self._cached_merged_weight = None + return self.base(x) + self.scaling * self.lora_B(self.lora_A(x)) + + # Inference: use fused weight-space computation with caching + if self._cached_merged_weight is None: + self._cached_merged_weight = self.compute_merged_weight() + + original_weight = self.base.weight.data + self.base.weight.data = self._cached_merged_weight + try: + return self.base(x) + finally: + self.base.weight.data = original_weight + + def merge_into_base(self) -> o3.Linear: + """Permanently merge LoRA weights into base and return the base layer.""" + with torch.no_grad(): + self.base.weight.copy_(self.compute_merged_weight()) + return self.base class LoRADenseLinear(nn.Module): - """LoRA for torch.nn.Linear""" + """LoRA for torch.nn.Linear. + + Uses fused weight computation: W_merged = W_base + scaling * (W_B @ W_A) + with automatic caching during inference (when grad is disabled). + """ def __init__(self, base_linear: nn.Linear, rank: int = 4, alpha: float = 1.0): super().__init__() self.base = base_linear - in_f = self.base.in_features - out_f = self.base.out_features + self.in_features = base_linear.in_features + self.out_features = base_linear.out_features self.scaling = float(alpha) / float(rank) - self.lora_A = nn.Linear(in_f, rank, bias=False) - self.lora_B = nn.Linear(rank, out_f, bias=False) - # match dtype/device to base + # LoRA matrices: W_delta = W_B @ W_A + # W_A: (rank, in_features), W_B: (out_features, rank) + self.lora_A = nn.Linear(self.in_features, rank, bias=False) + self.lora_B = nn.Linear(rank, self.out_features, bias=False) + + # Match dtype/device to base base_param = next(self.base.parameters()) self.lora_A.to(dtype=base_param.dtype, device=base_param.device) self.lora_B.to(dtype=base_param.dtype, device=base_param.device) + # Cache for weight delta (used during inference) + self._cached_delta: torch.Tensor | None = None + with torch.no_grad(): nn.init.zeros_(self.lora_B.weight) nn.init.normal_(self.lora_A.weight, mean=0.0, std=1e-3) + def compute_delta(self) -> torch.Tensor: + """Compute the LoRA weight delta: W_B @ W_A.""" + return self.lora_B.weight @ self.lora_A.weight + def forward(self, x: torch.Tensor) -> torch.Tensor: - base = self.base(x) - delta = self.lora_B(self.lora_A(x)) - return base + self.scaling * delta + if torch.is_grad_enabled(): + # Training: compute fresh delta (gradients flow through B @ A) + self._cached_delta = None + delta = self.compute_delta() + else: + # Inference: use cached delta + if self._cached_delta is None: + self._cached_delta = self.compute_delta() + delta = self._cached_delta + + merged_weight = self.base.weight + self.scaling * delta + return F.linear(x, merged_weight, self.base.bias) + + def merge_into_base(self) -> nn.Linear: + """Permanently merge LoRA weights into base and return the base layer.""" + with torch.no_grad(): + self.base.weight.add_(self.scaling * self.compute_delta()) + return self.base class LoRAFCLayer(nn.Module): """LoRA for e3nn.nn._fc._Layer used by FullyConnectedNet (scalar MLP). - Adds a low-rank delta on the weight matrix. + + Uses fused weight computation: W_merged = W_base + scaling * (A @ B) + with automatic caching during inference (when grad is disabled). + + Note: e3nn uses (in, out) weight layout, so delta = A @ B (not B @ A). """ def __init__(self, base_layer: nn.Module, rank: int = 4, alpha: float = 1.0): @@ -100,40 +223,49 @@ def __init__(self, base_layer: nn.Module, rank: int = 4, alpha: float = 1.0): in_f, out_f = int(w.shape[0]), int(w.shape[1]) self.scaling = float(alpha) / float(rank) - # Use explicit parameters to match e3nn layout [in, out] - self.lora_A = nn.Parameter( - torch.empty(in_f, rank, device=w.device, dtype=w.dtype) - ) - self.lora_B = nn.Parameter( - torch.empty(rank, out_f, device=w.device, dtype=w.dtype) - ) + # LoRA matrices: delta = A @ B (e3nn layout: in_f x out_f) + self.lora_A = nn.Parameter(torch.empty(in_f, rank, device=w.device, dtype=w.dtype)) + self.lora_B = nn.Parameter(torch.empty(rank, out_f, device=w.device, dtype=w.dtype)) + + # Cache for weight delta (used during inference) + self._cached_delta: torch.Tensor | None = None with torch.no_grad(): - torch.nn.init.normal_(self.lora_A, mean=0.0, std=1e-3) - torch.nn.init.zeros_(self.lora_B) + nn.init.normal_(self.lora_A, mean=0.0, std=1e-3) + nn.init.zeros_(self.lora_B) + + def compute_delta(self) -> torch.Tensor: + """Compute the LoRA weight delta: A @ B.""" + return self.lora_A @ self.lora_B def forward(self, x: torch.Tensor) -> torch.Tensor: - # Replicate e3nn _Layer normalization - W = self.base.weight # type: ignore[attr-defined] - h_in = getattr(self.base, "h_in") - var_in = getattr(self.base, "var_in") - var_out = getattr(self.base, "var_out") - act = getattr(self.base, "act", None) - - delta = self.lora_A @ self.lora_B - W_sum = W + self.scaling * delta - - if act is not None: - denom = (h_in * var_in) ** 0.5 - w = W_sum / denom - x = x @ w - x = act(x) - x = x * (var_out**0.5) + if torch.is_grad_enabled(): + # Training: compute fresh delta (gradients flow through A @ B) + self._cached_delta = None + delta = self.compute_delta() else: - denom = (h_in * var_in / var_out) ** 0.5 - w = W_sum / denom - x = x @ w - return x + # Inference: use cached delta + if self._cached_delta is None: + self._cached_delta = self.compute_delta() + delta = self._cached_delta + + merged_weight = self.base.weight + self.scaling * delta + + # Temporarily patch weight for forward (dict manipulation preserves gradient flow) + w_orig = self.base.weight + del self.base._parameters["weight"] + self.base.weight = merged_weight + try: + return self.base(x) + finally: + self.base.weight = w_orig + self.base._parameters["weight"] = w_orig + + def merge_into_base(self) -> nn.Module: + """Permanently merge LoRA weights into base and return the base layer.""" + with torch.no_grad(): + self.base.weight.add_(self.scaling * self.compute_delta()) + return self.base def inject_lora( @@ -144,10 +276,7 @@ def inject_lora( wrap_dense: bool = True, _is_root: bool = True, ) -> None: - """ - Recursively replace eligible linears with LoRA-wrapped versions. - """ - + """Recursively replace eligible linears with LoRA-wrapped versions.""" for child_name, child in list(module.named_children()): # Skip already wrapped if isinstance(child, (LoRAO3Linear, LoRADenseLinear, LoRAFCLayer)): @@ -158,28 +287,64 @@ def inject_lora( wrapped = LoRAO3Linear(child, rank=rank, alpha=alpha) except ValueError: # If no shared irreps, skip continue - module._modules[child_name] = wrapped # pylint: disable=protected-access + setattr(module, child_name, wrapped) # Dense nn.Linear if wrap_dense and isinstance(child, nn.Linear): wrapped = LoRADenseLinear(child, rank=rank, alpha=alpha) - module._modules[child_name] = wrapped # pylint: disable=protected-access + setattr(module, child_name, wrapped) continue # e3nn FullyConnectedNet internal layer if wrap_dense and isinstance(child, E3NNFCLayer): wrapped = LoRAFCLayer(child, rank=rank, alpha=alpha) - module._modules[child_name] = wrapped # pylint: disable=protected-access + setattr(module, child_name, wrapped) continue # Recurse inject_lora(child, rank, alpha, wrap_equivariant, wrap_dense, _is_root=False) if _is_root: for name, p in module.named_parameters(): - if ("lora_A" in name) or ("lora_B" in name): - p.requires_grad = True - else: - p.requires_grad = False + p.requires_grad = ("lora_A" in name) or ("lora_B" in name) def inject_LoRAs(model: nn.Module, rank: int = 4, alpha: int = 1): inject_lora(model, rank=rank, alpha=alpha, wrap_equivariant=True, wrap_dense=True) return model + + +def merge_lora_weights(model: nn.Module, inplace: bool = True) -> nn.Module: + """ + Merge LoRA weights into base weights and replace LoRA wrappers with merged base modules. + + This eliminates the inference overhead from LoRA by folding the low-rank + adaptations directly into the original weight matrices. After merging: + - LoRADenseLinear -> nn.Linear (with merged weights) + - LoRAFCLayer -> e3nn _Layer (with merged weights) + - LoRAO3Linear -> o3.Linear (with merged weights) + + Args: + model: Model containing LoRA layers to merge. + inplace: If True, modifies the model in place. If False, works on a deep copy. + + Returns: + Model with LoRA weights merged into base layers. All parameters will have + requires_grad=True after merging. + """ + if not inplace: + import copy + + model = copy.deepcopy(model) + + def merge_recursive(module: nn.Module) -> None: + for name, child in list(module.named_children()): + if isinstance(child, (LoRADenseLinear, LoRAFCLayer, LoRAO3Linear)): + setattr(module, name, child.merge_into_base()) + else: + merge_recursive(child) + + merge_recursive(model) + + # Re-enable gradients for all parameters + for param in model.parameters(): + param.requires_grad = True + + return model diff --git a/tests/test_lora.py b/tests/test_lora.py index ce90b7154..42213224e 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -11,7 +11,7 @@ from mace import data, modules, tools from mace.data import Configuration from mace.tools import torch_geometric -from mace.tools.lora_tools import inject_lora +from mace.tools.lora_tools import inject_lora, merge_lora_weights def _random_config() -> Configuration: @@ -228,3 +228,151 @@ def test_lora_symmetry_equivariance(build_lora_model, random_configs) -> None: rtol=1e-5, atol=1e-5, ) + + +def test_lora_merge_preserves_outputs(build_lora_model, random_configs) -> None: + """Test that merging LoRA weights produces identical outputs.""" + model, table = build_lora_model(rank=2, alpha=0.5, randomize=True) + model.eval() + + # Get outputs before merging + configs = list(random_configs) + energy_before, forces_before = _forward_energy_forces(model, configs, table) + + # Merge LoRA weights + merge_lora_weights(model) + model.eval() + + # Get outputs after merging + energy_after, forces_after = _forward_energy_forces(model, configs, table) + + # Outputs should be identical (within numerical precision) + assert torch.allclose(energy_before, energy_after, rtol=1e-5, atol=1e-6), ( + f"Energy mismatch after merge: {energy_before} vs {energy_after}" + ) + assert torch.allclose(forces_before, forces_after, rtol=1e-5, atol=1e-6), ( + f"Forces mismatch after merge: max diff = {(forces_before - forces_after).abs().max()}" + ) + + +def test_lora_merge_removes_wrappers(build_lora_model) -> None: + """Test that merging removes LoRA wrapper modules.""" + from mace.tools.lora_tools import LoRADenseLinear, LoRAFCLayer, LoRAO3Linear + + model, _ = build_lora_model(rank=2, alpha=0.5, randomize=True) + + # Count LoRA wrappers before merge + def count_lora_wrappers(module): + count = 0 + for child in module.modules(): + if isinstance(child, (LoRADenseLinear, LoRAFCLayer, LoRAO3Linear)): + count += 1 + return count + + wrappers_before = count_lora_wrappers(model) + assert wrappers_before > 0, "Model should have LoRA wrappers before merge" + + # Merge + merge_lora_weights(model) + + # Count LoRA wrappers after merge + wrappers_after = count_lora_wrappers(model) + assert wrappers_after == 0, f"Model still has {wrappers_after} LoRA wrappers after merge" + + +def test_lora_merge_enables_gradients(build_lora_model) -> None: + """Test that merging re-enables gradients for all parameters.""" + model, _ = build_lora_model(rank=2, alpha=0.5, randomize=True) + + # Before merge, only LoRA params have gradients + non_lora_grads_before = [ + name + for name, p in model.named_parameters() + if "lora_" not in name and p.requires_grad + ] + assert not non_lora_grads_before, "Non-LoRA params should be frozen before merge" + + # Merge + merge_lora_weights(model) + + # After merge, all params should have gradients + frozen_after = [name for name, p in model.named_parameters() if not p.requires_grad] + assert not frozen_after, f"Some parameters frozen after merge: {frozen_after}" + + +def test_lora_merge_preserves_equivariance(build_lora_model, random_configs) -> None: + """Test that merged model preserves rotational equivariance.""" + model, table = build_lora_model(rank=2, alpha=0.5, randomize=True) + + # Merge LoRA weights + merge_lora_weights(model) + model.eval() + + base_cfg = random_configs[0] + energy, forces = _forward_energy_forces(model, [base_cfg], table) + energy_val = energy.item() + forces_val = forces.squeeze(0).detach().numpy() + + # Test rotation equivariance after merge + R = _rotation_matrix() + rotated_cfg = _rotate_config(base_cfg, R) + energy_rot, forces_rot = _forward_energy_forces(model, [rotated_cfg], table) + + assert np.allclose(energy_rot.item(), energy_val, rtol=1e-6, atol=1e-6), ( + "Energy not invariant under rotation after merge" + ) + assert np.allclose( + forces_val @ R.T, forces_rot.squeeze(0).detach().numpy(), rtol=1e-5, atol=1e-5 + ), "Forces not equivariant under rotation after merge" + + +def test_lora_evaluate_preserves_frozen_state(build_lora_model, random_configs) -> None: + """Test that evaluate() preserves requires_grad states for LoRA models. + """ + from mace.tools import evaluate + from mace.modules.loss import WeightedEnergyForcesLoss + + model, table = build_lora_model(rank=2, alpha=0.5, randomize=True) + + # Record which parameters should be trainable (only LoRA params) + lora_params_before = { + name: p.requires_grad for name, p in model.named_parameters() + } + trainable_before = [name for name, grad in lora_params_before.items() if grad] + frozen_before = [name for name, grad in lora_params_before.items() if not grad] + + # Verify initial state: only LoRA params are trainable + assert all("lora_" in name for name in trainable_before), ( + "Only LoRA params should be trainable initially" + ) + assert len(frozen_before) > 0, "Some base params should be frozen" + + # Create a minimal data loader for evaluation + configs = list(random_configs) + dataset = [_atomic_data_from_config(cfg, table) for cfg in configs] + loader = torch_geometric.dataloader.DataLoader( + dataset=dataset, + batch_size=len(dataset), + shuffle=False, + drop_last=False, + ) + + # Run evaluate + loss_fn = WeightedEnergyForcesLoss(energy_weight=1.0, forces_weight=1.0) + output_args = {"forces": True, "virials": False, "stress": False} + evaluate(model, loss_fn, loader, output_args, device=torch.device("cpu")) + + # Check that requires_grad states are preserved + lora_params_after = { + name: p.requires_grad for name, p in model.named_parameters() + } + + for name in trainable_before: + assert lora_params_after[name], ( + f"LoRA param {name} should still be trainable after evaluate()" + ) + + for name in frozen_before: + assert not lora_params_after[name], ( + f"Base param {name} should still be frozen after evaluate()" + )