Skip to content

Conversation

@JIABI
Copy link

@JIABI JIABI commented Oct 29, 2025

Enable Optional MIL Pooling in MACE (Multiple Instance Learning Extension)

This PR introduces Multiple Instance Learning (MIL) pooling as an optional graph-level readout module for MACE, using the Conjunctive Pooling operator from:

Maronna et al., Inherently Interpretable Time Series Classification via Multiple Instance Learning, 2024.


🚀 Motivation

MACE predicts atomic/total energies by summing atomic contributions.
For crystalline or periodic systems, this works well.
But for finite / defect-rich / amorphous / multi-region systems, some chemical effects:

  • ✅ are not purely local
  • ✅ depend on global geometry
  • ✅ and need graph-level attention

MIL pooling complements MACE by learning:

  • global corrections to extensive predictions
  • salient atom regions that dominate energy variation
  • structure-specific signatures (e.g., local order changes)

Thus, MIL pooling acts as a residual graph-level regressor supporting MACE’s atomic baseline.


🧩 Method

MIL is added after the final interaction block.

 node features
        │
        ▼
 Equivariant Product Basis
        │
        ├──► MACE Atomic Energy Readout
        │          │
        │          └──► E_MACE
        │
        └──► (Optional) MIL Branch
                   │
                   ▼
          Conjunctive MIL Pooling
                   │
                   └──► E_MIL (graph residual)

Residual formulation:

$$ E_{\text{Total}} = E_{\text{MACE}} + \gamma \cdot E_{\text{MIL}} $$

Where:

Component Purpose
LayerNorm stabilize magnitude
Learnable gate ( \gamma ) keep correction bounded
Mean-centering prevent systematic bias

✅ Performance Summary (Ag Dataset)

Model Training F MAE (eV/Å) ↓ Training E MAE (eV/atom) ↓ Validation F MAE (eV/Å) ↓ Validation E MAE (eV/atom) ↓ Training Time
MILT_MACE (Ours) 0.013 0.161 0.018 0.027 1h28m
Allegro using KAN (B-splines) 0.014 0.021 0.014 0.029 22h45m
Allegro using KAN (Gaussian) 0.014 0.026 0.014 0.032 4h56m
Allegro using MLPs 0.016 0.026 0.016 0.035 4h51m
MACE baseline 0.029 0.185 0.021 0.071 1h20m

→ Large gain on global energy
→ Better force modeling via improved latent structure


🛠️ Usage

# Enable MIL pooling (recommended hyperparameters)
--use_mil_pooling True \
--mil_d_attn 8 \
--mil_dropout 0.1 \
--mil_gamma_cap 0.2

Disable at any time for perfect compatibility:

--use_mil_pooling False
Hyper-parameter Recommended
mil_d_attn 4–16
mil_dropout 0.0–0.1
mil_gamma_cap 0.1–0.3

When to Use MACE + MIL Pooling

The MIL branch is most beneficial when target properties:

  • depend on global structure variations rather than purely local chemistry
  • involve defect-driven or topology-driven energy contributions
  • are influenced by long-range order/disorder changes

Recommended tasks include:

  • Nanoclusters & finite systems
    Non-periodic boundaries and shape-dependent stability

  • Surfaces & interfaces
    Chemical activity governed by exposed sites and coordination gradients

  • Amorphous / glassy / heterogeneous phases
    Global disorder and medium-range correlations affect energy

  • Point & extended defects
    Vacancies, interstitials, dislocations, grain boundaries

  • Catalysis & multi-component systems
    Active site dominance and composition-dependent reactivity

In summary, MACE+MIL excels when energy differences arise from structural heterogeneity,
where purely local additive models struggle.

Limitations

  • Non-extensive correction path
    MIL introduces a graph-level residual that partially relaxes strict extensivity.
    If the target property is fully local and additive, MIL may overfit to noise.

  • Hyperparameter sensitivity
    Performance depends on proper tuning of the MIL head width (mil_d_attn),
    residual scaling (mil_gamma_cap), and normalization strategy.

  • Marginal gains in perfectly periodic crystals
    When structural environments are highly homogeneous (ideal bulk), the MIL branch
    provides little benefit and may slightly degrade extrapolation.

  • Compute and memory overhead
    Although relatively small in our setting (+10% training time), overhead grows
    with large graphs or long-range neighbor lists.

  • Interpretability should be used cautiously
    Atom-level MIL attention highlights correlation, not causal influence;
    attention maps may shift under distribution drift.

  • Additional testing needed for scale and robustness
    Stability must be validated across different seeds, defects densities,
    and dataset complexities before enabling by default.


📦 Code Status

  • Fully optional ✅
  • LAMMPS export preserved ✅
  • No API breaking ✅
  • Full backward compatibility ✅

Affected modules:

  • mace/modules/models.py
  • mace/modules/mil_pooling.py (new)
  • mace/tools/model_script_utils.py

📌 Conclusion

This PR introduces a practical and interpretable global pooling capability to MACE.
It expands applicability toward structural heterogeneity while preserving the existing MACE design.

We recommend enabling MIL pooling when global geometry drives energy variability.


📚 Citation

@inproceedings{
early2024inherently,
title={Inherently Interpretable Time Series Classification via Multiple Instance Learning},
author={Joseph Early and Gavin Cheung and Kurt Cutajar and Hanting Xie and Jas Kandola and Niall Twomey},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=xriGRsoAza}
}

Contributors (co-development)

Thanks everyone for the joint effort — feel free to add comments or approve!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant