Skip to content

Commit 5083495

Browse files
committed
Fixing torch tensor for electric field in MLMD
1 parent 12d827f commit 5083495

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

mace/modules/extensions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,10 +352,10 @@ def forward(
352352

353353
# Setting electric field
354354
if electric_field is not None:
355-
electric_field = electric_field.detach().clone().view(-1, 3).requires_grad_(True) # [num_graphs, 3]
355+
electric_field = electric_field.detach().clone().to(device=vectors.device.type, dtype=vectors.dtype).view(-1, 3).requires_grad_(True) # [num_graphs, 3]
356356
else:
357357
electric_field = (
358-
data["electric_field"].reshape(-1, 3).requires_grad_(True)
358+
data["electric_field"].to(device=vectors.device.type, dtype=vectors.dtype).reshape(-1, 3).requires_grad_(True)
359359
) # [num_graphs, 3]
360360

361361
# Cell volume

0 commit comments

Comments
 (0)