Skip to content

Add mps device #1018

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dcgan/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ usage: main.py [-h] --dataset DATASET --dataroot DATAROOT [--workers WORKERS]
[--batchSize BATCHSIZE] [--imageSize IMAGESIZE] [--nz NZ]
[--ngf NGF] [--ndf NDF] [--niter NITER] [--lr LR]
[--beta1 BETA1] [--cuda] [--ngpu NGPU] [--netG NETG]
[--netD NETD]
[--netD NETD] [--mps]

optional arguments:
-h, --help show this help message and exit
Expand All @@ -40,6 +40,7 @@ optional arguments:
--lr LR learning rate, default=0.0002
--beta1 BETA1 beta1 for adam. default=0.5
--cuda enables cuda
--mps enables macOS GPU
--ngpu NGPU number of GPUs to use
--netG NETG path to netG (to continue training)
--netD NETD path to netD (to continue training)
Expand Down
14 changes: 12 additions & 2 deletions dcgan/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@
parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument('--cuda', action='store_true', default=False, help='enables cuda')
parser.add_argument('--dry-run', action='store_true', help='check a single training cycle works')
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--netG', default='', help="path to netG (to continue training)")
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints')
parser.add_argument('--manualSeed', type=int, help='manual seed')
parser.add_argument('--classes', default='bedroom', help='comma separated list of classes for the lsun data set')
parser.add_argument('--mps', action='store_true', default=False, help='enables macOS GPU training')

opt = parser.parse_args()
print(opt)
Expand All @@ -52,6 +53,9 @@

if torch.cuda.is_available() and not opt.cuda:
print("WARNING: You have a CUDA device, so you should probably run with --cuda")

if torch.backends.mps.is_available() and not opt.mps:
print("WARNING: You have mps device, to enable macOS GPU run with --mps")

if opt.dataroot is None and str(opt.dataset).lower() != 'fake':
raise ValueError("`dataroot` parameter is required for dataset \"%s\"" % opt.dataset)
Expand Down Expand Up @@ -103,7 +107,13 @@
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
shuffle=True, num_workers=int(opt.workers))

device = torch.device("cuda:0" if opt.cuda else "cpu")
if opt.cuda:
device = torch.device("cuda:0")
elif opt.mps:
device = torch.device("mps")
else:
device = torch.device("cpu")

ngpu = int(opt.ngpu)
nz = int(opt.nz)
ngf = int(opt.ngf)
Expand Down
2 changes: 2 additions & 0 deletions fast_neural_style/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ python neural_style/neural_style.py eval --content-image </path/to/content/image
- `--output-image`: path for saving the output image.
- `--content-scale`: factor for scaling down the content image if memory is an issue (eg: value of 2 will halve the height and width of content-image)
- `--cuda`: set it to 1 for running on GPU, 0 for CPU.
- `--mps`: set it to 1 for running on macOS GPU

Train model

Expand All @@ -40,6 +41,7 @@ There are several command line arguments, the important ones are listed below
- `--style-image`: path to style-image.
- `--save-model-dir`: path to folder where trained model will be saved.
- `--cuda`: set it to 1 for running on GPU, 0 for CPU.
- `--mps`: set it to 1 for running on macOS GPU

Refer to `neural_style/neural_style.py` for other command line arguments. For training new models you might have to tune the values of `--content-weight` and `--style-weight`. The mosaic style model shown above was trained with `--content-weight 1e5` and `--style-weight 1e10`. The remaining 3 models were also trained with similar order of weight parameters with slight variation in the `--style-weight` (`5e10` or `1e11`).

Expand Down
14 changes: 11 additions & 3 deletions fast_neural_style/neural_style/neural_style.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ def check_paths(args):


def train(args):
device = torch.device("cuda" if args.cuda else "cpu")
if args.cuda:
device = torch.device("cuda")
elif args.mps:
device = torch.device("mps")
else:
device = torch.device("cpu")

np.random.seed(args.seed)
torch.manual_seed(args.seed)
Expand Down Expand Up @@ -224,10 +229,11 @@ def main():
help="path for saving the output image")
eval_arg_parser.add_argument("--model", type=str, required=True,
help="saved model to be used for stylizing the image. If file ends in .pth - PyTorch path is used, if in .onnx - Caffe2 path")
eval_arg_parser.add_argument("--cuda", type=int, required=True,
help="set it to 1 for running on GPU, 0 for CPU")
eval_arg_parser.add_argument("--cuda", type=int, default=False,
help="set it to 1 for running on cuda, 0 for CPU")
eval_arg_parser.add_argument("--export_onnx", type=str,
help="export ONNX model to a given file")
eval_arg_parser.add_argument('--mps', action='store_true', default=False, help='enable macOS GPU training')

args = main_arg_parser.parse_args()

Expand All @@ -237,6 +243,8 @@ def main():
if args.cuda and not torch.cuda.is_available():
print("ERROR: cuda is not available, try running on CPU")
sys.exit(1)
if not args.mps and torch.backends.mps.is_available():
print("WARNING: mps is available, run with --mps to enable macOS GPU")

if args.subcommand == "train":
check_paths(args)
Expand Down
73 changes: 50 additions & 23 deletions imagenet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ def main():

args.distributed = args.world_size > 1 or args.multiprocessing_distributed

ngpus_per_node = torch.cuda.device_count()
if torch.cuda.is_available():
ngpus_per_node = torch.cuda.device_count()
else:
ngpus_per_node = 1
if args.multiprocessing_distributed:
# Since we have ngpus_per_node processes per node, the total world_size
# needs to be adjusted accordingly
Expand Down Expand Up @@ -140,29 +143,33 @@ def main_worker(gpu, ngpus_per_node, args):
print("=> creating model '{}'".format(args.arch))
model = models.__dict__[args.arch]()

if not torch.cuda.is_available():
if not torch.cuda.is_available() and not torch.backends.mps.is_available():
print('using CPU, this will be slow')
elif args.distributed:
# For multiprocessing distributed, DistributedDataParallel constructor
# should always set the single device scope, otherwise,
# DistributedDataParallel will use all available devices.
if args.gpu is not None:
torch.cuda.set_device(args.gpu)
model.cuda(args.gpu)
# When using a single GPU per process and per
# DistributedDataParallel, we need to divide the batch size
# ourselves based on the total number of GPUs of the current node.
args.batch_size = int(args.batch_size / ngpus_per_node)
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
else:
model.cuda()
# DistributedDataParallel will divide and allocate batch_size to all
# available GPUs if device_ids are not set
model = torch.nn.parallel.DistributedDataParallel(model)
elif args.gpu is not None:
if torch.cuda.is_available():
if args.gpu is not None:
torch.cuda.set_device(args.gpu)
model.cuda(args.gpu)
# When using a single GPU per process and per
# DistributedDataParallel, we need to divide the batch size
# ourselves based on the total number of GPUs of the current node.
args.batch_size = int(args.batch_size / ngpus_per_node)
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
else:
model.cuda()
# DistributedDataParallel will divide and allocate batch_size to all
# available GPUs if device_ids are not set
model = torch.nn.parallel.DistributedDataParallel(model)
elif args.gpu is not None and torch.cuda.is_available():
torch.cuda.set_device(args.gpu)
model = model.cuda(args.gpu)
elif torch.backends.mps.is_available():
device = torch.device("mps")
model = model.to(device)
else:
# DataParallel will divide and allocate batch_size to all available GPUs
if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
Expand All @@ -171,8 +178,17 @@ def main_worker(gpu, ngpus_per_node, args):
else:
model = torch.nn.DataParallel(model).cuda()

if torch.cuda.is_available():
if args.gpu:
device = torch.device('cuda:{}'.format(args.gpu))
else:
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
# define loss function (criterion), optimizer, and learning rate scheduler
criterion = nn.CrossEntropyLoss().cuda(args.gpu)
criterion = nn.CrossEntropyLoss().to(device)

optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
Expand All @@ -187,7 +203,7 @@ def main_worker(gpu, ngpus_per_node, args):
print("=> loading checkpoint '{}'".format(args.resume))
if args.gpu is None:
checkpoint = torch.load(args.resume)
else:
elif torch.cuda.is_available():
# Map model to be loaded to specified single gpu.
loc = 'cuda:{}'.format(args.gpu)
checkpoint = torch.load(args.resume, map_location=loc)
Expand Down Expand Up @@ -302,10 +318,13 @@ def train(train_loader, model, criterion, optimizer, epoch, args):
# measure data loading time
data_time.update(time.time() - end)

if args.gpu is not None:
if args.gpu is not None and torch.cuda.is_available():
images = images.cuda(args.gpu, non_blocking=True)
if torch.cuda.is_available():
elif not args.gpu and torch.cuda.is_available():
target = target.cuda(args.gpu, non_blocking=True)
elif torch.backends.mps.is_available():
images = images.to('mps')
target = target.to('mps')

# compute output
output = model(images)
Expand Down Expand Up @@ -337,8 +356,11 @@ def run_validate(loader, base_progress=0):
end = time.time()
for i, (images, target) in enumerate(loader):
i = base_progress + i
if args.gpu is not None:
if args.gpu is not None and torch.cuda.is_available():
images = images.cuda(args.gpu, non_blocking=True)
if torch.backends.mps.is_available():
images = images.to('mps')
target = target.to('mps')
if torch.cuda.is_available():
target = target.cuda(args.gpu, non_blocking=True)

Expand Down Expand Up @@ -421,7 +443,12 @@ def update(self, val, n=1):
self.avg = self.sum / self.count

def all_reduce(self):
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
self.sum, self.count = total.tolist()
Expand Down
2 changes: 2 additions & 0 deletions legacy/snli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
if torch.cuda.is_available():
torch.cuda.set_device(args.gpu)
device = torch.device('cuda:{}'.format(args.gpu))
elif torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')

Expand Down
10 changes: 9 additions & 1 deletion mnist/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def main():
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--no-mps', action='store_true', default=False,
help='disables macOS GPU training')
parser.add_argument('--dry-run', action='store_true', default=False,
help='quickly check a single pass')
parser.add_argument('--seed', type=int, default=1, metavar='S',
Expand All @@ -95,10 +97,16 @@ def main():
help='For Saving the current Model')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
use_mps = not args.no_mps and torch.backends.mps.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")
if use_cuda:
device = torch.device("cuda")
elif use_mps:
device = torch.device("mps")
else:
device = torch.device("cpu")

train_kwargs = {'batch_size': args.batch_size}
test_kwargs = {'batch_size': args.test_batch_size}
Expand Down
11 changes: 10 additions & 1 deletion mnist_hogwild/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
help='how many training processes to use (default: 2)')
parser.add_argument('--cuda', action='store_true', default=False,
help='enables CUDA training')
parser.add_argument('--mps', action='store_true', default=False,
help='enables macOS GPU training')
parser.add_argument('--dry-run', action='store_true', default=False,
help='quickly check a single pass')

Expand All @@ -55,7 +57,14 @@ def forward(self, x):
args = parser.parse_args()

use_cuda = args.cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
use_mps = args.mps and torch.backends.mps.is_available()
if use_cuda:
device = torch.device("cuda")
elif use_mps:
device = torch.device("mps")
else:
device = torch.device("cpu")

transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
Expand Down
10 changes: 5 additions & 5 deletions run_python_examples.sh
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ function start() {

function dcgan() {
start
python main.py --dataset fake $CUDA_FLAG --dry-run || error "dcgan failed"
python main.py --dataset fake $CUDA_FLAG --mps --dry-run || error "dcgan failed"
}

function distributed() {
Expand All @@ -74,15 +74,15 @@ function fast_neural_style() {
test -d "saved_models" || { error "saved models not found"; return; }

echo "running fast neural style model"
python neural_style/neural_style.py eval --content-image images/content-images/amber.jpg --model saved_models/candy.pth --output-image images/output-images/amber-candy.jpg --cuda $CUDA || error "neural_style.py failed"
python neural_style/neural_style.py eval --content-image images/content-images/amber.jpg --model saved_models/candy.pth --output-image images/output-images/amber-candy.jpg --cuda $CUDA --mps || error "neural_style.py failed"
}

function imagenet() {
start
if [[ ! -d "sample/val" || ! -d "sample/train" ]]; then
mkdir -p sample/val/n
mkdir -p sample/train/n
wget "https://upload.wikimedia.org/wikipedia/commons/5/5a/Socks-clinton.jpg" || { error "couldn't download sample image for imagenet"; return; }
curl -O "https://upload.wikimedia.org/wikipedia/commons/5/5a/Socks-clinton.jpg" || { error "couldn't download sample image for imagenet"; return; }
mv Socks-clinton.jpg sample/train/n
cp sample/train/n/* sample/val/n/
fi
Expand Down Expand Up @@ -137,7 +137,7 @@ function fx() {

function super_resolution() {
start
python main.py --upscale_factor 3 --batchSize 4 --testBatchSize 100 --nEpochs 1 --lr 0.001 || error "super resolution failed"
python main.py --upscale_factor 3 --batchSize 4 --testBatchSize 100 --nEpochs 1 --lr 0.001 --mps || error "super resolution failed"
}

function time_sequence_prediction() {
Expand All @@ -153,7 +153,7 @@ function vae() {

function word_language_model() {
start
python main.py --epochs 1 --dry-run $CUDA_FLAG || error "word_language_model failed"
python main.py --epochs 1 --dry-run $CUDA_FLAG --mps || error "word_language_model failed"
}

function clean() {
Expand Down
10 changes: 9 additions & 1 deletion siamese_network/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ def main():
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--no-mps', action='store_true', default=False,
help='disables macOS GPU training')
parser.add_argument('--dry-run', action='store_true', default=False,
help='quickly check a single pass')
parser.add_argument('--seed', type=int, default=1, metavar='S',
Expand All @@ -260,10 +262,16 @@ def main():
args = parser.parse_args()

use_cuda = not args.no_cuda and torch.cuda.is_available()
use_mps = not args.no_mps and torch.backends.mps.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")
if use_cuda:
device = torch.device("cuda")
elif use_mps:
device = torch.device("mps")
else:
device = torch.device("cpu")

train_kwargs = {'batch_size': args.batch_size}
test_kwargs = {'batch_size': args.test_batch_size}
Expand Down
1 change: 1 addition & 0 deletions super_resolution/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ optional arguments:
--nEpochs number of epochs to train for
--lr Learning Rate. Default=0.01
--cuda use cuda
--mps enable GPU on macOS
--threads number of threads for data loader to use Default=4
--seed random seed to use. Default=123
```
Expand Down
Loading