Skip to content

Commit 316910c

Browse files
committed
Clone electric field tensor correctly
1 parent d5bf99c commit 316910c

File tree

2 files changed

+3
-8
lines changed

2 files changed

+3
-8
lines changed

mace/calculators/mace.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -471,9 +471,8 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes):
471471
compute_polarization = True
472472
compute_becs = True
473473
compute_polarizability = True
474-
batch_base["electric_field"] = torch.tensor(
475-
self.electric_field, dtype=next(self.models[0].parameters()).dtype
476-
)
474+
batch_base["electric_field"] = self.electric_field.detach().clone()
475+
477476

478477
ret_tensors = self._create_result_tensors(
479478
self.model_type, self.num_models, len(atoms)

mace/modules/extensions.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -352,11 +352,7 @@ def forward(
352352

353353
# Setting electric field
354354
if electric_field is not None:
355-
electric_field = (
356-
torch.tensor(electric_field, dtype=vectors.dtype)
357-
.view(-1, 3)
358-
.requires_grad_(True)
359-
) # [num_graphs, 3]
355+
electric_field = electric_field.detach().clone().view(-1, 3).requires_grad_(True) # [num_graphs, 3]
360356
else:
361357
electric_field = (
362358
data["electric_field"].reshape(-1, 3).requires_grad_(True)

0 commit comments

Comments
 (0)