Skip to content

Commit 17132a2

Browse files
committed
Removing type-hint for support python3.5
1 parent 120ead8 commit 17132a2

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

bert_pytorch/trainer/pretrain.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ class BERTTrainer:
1919
2020
"""
2121

22-
def __init__(self, bert: BERT, vocab_size,
23-
train_dataloader: DataLoader, test_dataloader: DataLoader = None,
22+
def __init__(self, bert, vocab_size,
23+
train_dataloader, test_dataloader=None,
2424
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01,
2525
with_cuda: bool = True, log_freq: int = 10):
2626
"""
@@ -40,18 +40,18 @@ def __init__(self, bert: BERT, vocab_size,
4040
self.device = torch.device("cuda:0" if cuda_condition else "cpu")
4141

4242
# This BERT model will be saved every epoch
43-
self.bert: BERT = bert
43+
self.bert = bert
4444
# Initialize the BERT Language Model, with BERT model
45-
self.model: BERTLM = BERTLM(bert, vocab_size).to(self.device)
45+
self.model = BERTLM(bert, vocab_size).to(self.device)
4646

4747
# Distributed GPU training if CUDA can detect more than 1 GPU
4848
if torch.cuda.device_count() > 1:
4949
print("Using %d GPUS for BERT" % torch.cuda.device_count())
5050
self.model = nn.DataParallel(self.model)
5151

5252
# Setting the train and test data loader
53-
self.train_data: DataLoader = train_dataloader
54-
self.test_data: DataLoader = test_dataloader
53+
self.train_data = train_dataloader
54+
self.test_data = test_dataloader
5555

5656
# Setting the Adam optimizer with hyper-param
5757
self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)

0 commit comments

Comments
 (0)