Skip to content

Commit 6df8827

Browse files
authored
Merge pull request ACEsuit#296 from ACEsuit/develop
Change stress input + update version
2 parents 44cc610 + 9395221 commit 6df8827

File tree

4 files changed

+10
-6
lines changed

4 files changed

+10
-6
lines changed

mace/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.3.3"
1+
__version__ = "0.3.4"

mace/data/atomic_data.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,9 @@ def from_config(
175175
else None
176176
)
177177
virials = (
178-
torch.tensor(config.virials, dtype=torch.get_default_dtype()).unsqueeze(0)
178+
voigt_to_matrix(
179+
torch.tensor(config.virials, dtype=torch.get_default_dtype())
180+
).unsqueeze(0)
179181
if config.virials is not None
180182
else None
181183
)

mace/data/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
Vector = np.ndarray # [3,]
1818
Positions = np.ndarray # [..., 3]
1919
Forces = np.ndarray # [..., 3]
20-
Stress = np.ndarray # [6, ]
21-
Virials = np.ndarray # [3,3]
20+
Stress = np.ndarray # [6, ], [3,3], [9, ]
21+
Virials = np.ndarray # [6, ], [3,3], [9, ]
2222
Charges = np.ndarray # [..., 1]
2323
Cell = np.ndarray # [3,3]
2424
Pbc = tuple # (3,)

mace/tools/torch_tools.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def cartesian_to_spherical(t: torch.Tensor):
107107
def voigt_to_matrix(t: torch.Tensor):
108108
"""
109109
Convert voigt notation to matrix notation
110-
:param t: (6,) tensor or (3, 3) tensor
110+
:param t: (6,) tensor or (3, 3) tensor or (9,) tensor
111111
:return: (3, 3) tensor
112112
"""
113113
if t.shape == (3, 3):
@@ -121,9 +121,11 @@ def voigt_to_matrix(t: torch.Tensor):
121121
],
122122
dtype=t.dtype,
123123
)
124+
if t.shape == (9,):
125+
return t.view(3, 3)
124126

125127
raise ValueError(
126-
f"Stress tensor must be of shape (6,) or (3, 3), but has shape {t.shape}"
128+
f"Stress tensor must be of shape (6,) or (3, 3), or (9,) but has shape {t.shape}"
127129
)
128130

129131

0 commit comments

Comments
 (0)