import torch import torch.nn as nn from torch.utils.data import DataLoader from torchvision import datasets def get_dataloaders(batch_size=64): transform = transforms.Compose([transforms.ToTensor()]) train = datasets.MNIST(root="data", train=True, download=True, transform=transform) test = datasets.MNIST(root="data", train=False, download=True, transform=transform) return DataLoader(train, batch_s

