@@ -19,8 +19,8 @@ class BERTTrainer:
19
19
20
20
"""
21
21
22
- def __init__ (self , bert : BERT , vocab_size ,
23
- train_dataloader : DataLoader , test_dataloader : DataLoader = None ,
22
+ def __init__ (self , bert , vocab_size ,
23
+ train_dataloader , test_dataloader = None ,
24
24
lr : float = 1e-4 , betas = (0.9 , 0.999 ), weight_decay : float = 0.01 ,
25
25
with_cuda : bool = True , log_freq : int = 10 ):
26
26
"""
@@ -40,18 +40,18 @@ def __init__(self, bert: BERT, vocab_size,
40
40
self .device = torch .device ("cuda:0" if cuda_condition else "cpu" )
41
41
42
42
# This BERT model will be saved every epoch
43
- self .bert : BERT = bert
43
+ self .bert = bert
44
44
# Initialize the BERT Language Model, with BERT model
45
- self .model : BERTLM = BERTLM (bert , vocab_size ).to (self .device )
45
+ self .model = BERTLM (bert , vocab_size ).to (self .device )
46
46
47
47
# Distributed GPU training if CUDA can detect more than 1 GPU
48
48
if torch .cuda .device_count () > 1 :
49
49
print ("Using %d GPUS for BERT" % torch .cuda .device_count ())
50
50
self .model = nn .DataParallel (self .model )
51
51
52
52
# Setting the train and test data loader
53
- self .train_data : DataLoader = train_dataloader
54
- self .test_data : DataLoader = test_dataloader
53
+ self .train_data = train_dataloader
54
+ self .test_data = test_dataloader
55
55
56
56
# Setting the Adam optimizer with hyper-param
57
57
self .optim = Adam (self .model .parameters (), lr = lr , betas = betas , weight_decay = weight_decay )
0 commit comments