Skip to content

Commit a0255f8

Browse files
committed
fix dataloader import
1 parent a201991 commit a0255f8

File tree

5 files changed

+8
-8
lines changed

5 files changed

+8
-8
lines changed

scripts/eval_configs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def main():
4646

4747
z_table = tools.AtomicNumberTable([int(z) for z in args.atomic_numbers.split(',')])
4848

49-
data_loader = torch_geometric.data.DataLoader(
49+
data_loader = torch_geometric.dataloader.DataLoader(
5050
dataset=[data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max) for config in configs],
5151
batch_size=args.batch_size,
5252
shuffle=False,

scripts/run_train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,13 @@ def main() -> None:
124124
logging.info(z_table)
125125
atomic_energies = np.array([atomic_energies_dict[args.dataset][z] for z in z_table.zs])
126126

127-
train_loader = torch_geometric.data.DataLoader(
127+
train_loader = torch_geometric.dataloader.DataLoader(
128128
dataset=[data.AtomicData.from_config(c, z_table=z_table, cutoff=args.r_max) for c in collections.train],
129129
batch_size=args.batch_size,
130130
shuffle=True,
131131
drop_last=True,
132132
)
133-
valid_loader = torch_geometric.data.DataLoader(
133+
valid_loader = torch_geometric.dataloader.DataLoader(
134134
dataset=[data.AtomicData.from_config(c, z_table=z_table, cutoff=args.r_max) for c in collections.valid],
135135
batch_size=args.batch_size,
136136
shuffle=False,
@@ -293,7 +293,7 @@ def main() -> None:
293293
logging.info('Computing metrics for training, validation, and test sets')
294294
logger = tools.MetricsLogger(directory=args.results_dir, tag=tag + '_eval')
295295
for name, subset in [('train', collections.train), ('valid', collections.valid)] + collections.tests:
296-
data_loader = torch_geometric.data.DataLoader(
296+
data_loader = torch_geometric.dataloader.DataLoader(
297297
dataset=[data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max) for config in subset],
298298
batch_size=args.batch_size,
299299
shuffle=False,

tests/test_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_data_loader(self):
3535
data1 = AtomicData.from_config(config, z_table=table, cutoff=3.0)
3636
data2 = AtomicData.from_config(config, z_table=table, cutoff=3.0)
3737

38-
data_loader = torch_geometric.data.DataLoader(
38+
data_loader = torch_geometric.dataloader.DataLoader(
3939
dataset=[data1, data2],
4040
batch_size=2,
4141
shuffle=True,

tests/test_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_bo_model():
4141

4242
atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=3.0)
4343

44-
data_loader = torch_geometric.data.DataLoader(
44+
data_loader = torch_geometric.dataloader.DataLoader(
4545
dataset=[atomic_data, atomic_data],
4646
batch_size=2,
4747
shuffle=True,
@@ -66,7 +66,7 @@ def test_isolated_atom():
6666
r_cutoff = 3.0
6767
isolated_data = data.AtomicData.from_config(isolated_config, z_table=table, cutoff=r_cutoff)
6868
atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=r_cutoff)
69-
data_loader = torch_geometric.data.DataLoader(
69+
data_loader = torch_geometric.dataloader.DataLoader(
7070
dataset=[isolated_data, atomic_data],
7171
batch_size=2,
7272
shuffle=False,

tests/test_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_atomic_energies():
4343
energies_block = AtomicEnergiesBlock(atomic_energies=np.array([1.0, 3.0]))
4444

4545
data = AtomicData.from_config(config, z_table=table, cutoff=3.0)
46-
data_loader = torch_geometric.data.DataLoader(
46+
data_loader = torch_geometric.dataloader.DataLoader(
4747
dataset=[data, data],
4848
batch_size=2,
4949
shuffle=True,

0 commit comments

Comments
 (0)