We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 82fad96 commit a201991Copy full SHA for a201991
botnet/modules/loss.py
@@ -1,5 +1,5 @@
1
import torch
2
-from torch_geometric.data import Batch
+from botnet.tools.torch_geometric.batch import Batch
3
4
from botnet.tools import TensorDict
5
botnet/tools/train.py
@@ -102,7 +102,7 @@ def train(
102
def take_step(
103
model: torch.nn.Module,
104
loss_fn: torch.nn.Module,
105
- batch: torch_geometric.data.Batch,
+ batch: torch_geometric.batch.Batch,
106
optimizer: torch.optim.Optimizer,
107
ema: Optional[ExponentialMovingAverage],
108
device: torch.device,
0 commit comments