@@ -124,13 +124,13 @@ def main() -> None:
124124 logging .info (z_table )
125125 atomic_energies = np .array ([atomic_energies_dict [args .dataset ][z ] for z in z_table .zs ])
126126
127- train_loader = torch_geometric .data .DataLoader (
127+ train_loader = torch_geometric .dataloader .DataLoader (
128128 dataset = [data .AtomicData .from_config (c , z_table = z_table , cutoff = args .r_max ) for c in collections .train ],
129129 batch_size = args .batch_size ,
130130 shuffle = True ,
131131 drop_last = True ,
132132 )
133- valid_loader = torch_geometric .data .DataLoader (
133+ valid_loader = torch_geometric .dataloader .DataLoader (
134134 dataset = [data .AtomicData .from_config (c , z_table = z_table , cutoff = args .r_max ) for c in collections .valid ],
135135 batch_size = args .batch_size ,
136136 shuffle = False ,
@@ -293,7 +293,7 @@ def main() -> None:
293293 logging .info ('Computing metrics for training, validation, and test sets' )
294294 logger = tools .MetricsLogger (directory = args .results_dir , tag = tag + '_eval' )
295295 for name , subset in [('train' , collections .train ), ('valid' , collections .valid )] + collections .tests :
296- data_loader = torch_geometric .data .DataLoader (
296+ data_loader = torch_geometric .dataloader .DataLoader (
297297 dataset = [data .AtomicData .from_config (config , z_table = z_table , cutoff = args .r_max ) for config in subset ],
298298 batch_size = args .batch_size ,
299299 shuffle = False ,
0 commit comments