Skip to content

Commit faa7a29

Browse files
committed
Adding multi-gpu backward ops using pytorch-encoding
1 parent 7b53875 commit faa7a29

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

bert_pytorch/trainer/pretrain.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from torch.optim import Adam
44
from torch.utils.data import DataLoader
55

6+
from encoding.parallel import DataParallelModel, DataParallelCriterion
7+
68
from ..model import BERTLM, BERT
79

810
import tqdm
@@ -47,7 +49,7 @@ def __init__(self, bert: BERT, vocab_size: int,
4749
# Distributed GPU training if CUDA can detect more than 1 GPU
4850
if torch.cuda.device_count() > 1:
4951
print("Using %d GPUS for BERT" % torch.cuda.device_count())
50-
self.model = nn.DataParallel(self.model)
52+
self.model = DataParallelModel(self.model)
5153

5254
# Setting the train and test data loader
5355
self.train_data = train_dataloader
@@ -57,7 +59,7 @@ def __init__(self, bert: BERT, vocab_size: int,
5759
self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
5860

5961
# Using Negative Log Likelihood Loss function for predicting the masked_token
60-
self.criterion = nn.NLLLoss(ignore_index=0)
62+
self.criterion = DataParallelCriterion(nn.NLLLoss(ignore_index=0))
6163

6264
self.log_freq = log_freq
6365

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
tqdm
22
numpy
33
torch>=0.4.0
4+
torch-encoding

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import sys
55

6-
__version__ = "0.0.1a3"
6+
__version__ = "0.0.1a4"
77

88
with open("requirements.txt") as f:
99
require_packages = [line[:-1] for line in f]

0 commit comments

Comments
 (0)