Skip to content

Commit f224d9b

Browse files
committed
Fixing path issue
1 parent 156aa68 commit f224d9b

File tree

5 files changed

+10
-10
lines changed

5 files changed

+10
-10
lines changed

bert_pytorch/build_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .dataset import WordVocab, BERTDatasetCreator
1+
from bert_pytorch.dataset import WordVocab, BERTDatasetCreator
22

33
import argparse
44
import tqdm

bert_pytorch/build_vocab.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import argparse
22

3-
from .dataset import WordVocab
3+
from bert_pytorch.dataset import WordVocab
44

55

66
def build():

bert_pytorch/train.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,21 @@ def train():
1616
parser.add_argument("-o", "--output_path", required=True, type=str)
1717

1818
parser.add_argument("-hs", "--hidden", type=int, default=256)
19-
parser.add_argument("-n", "--layers", type=int, default=8)
19+
parser.add_argument("-l", "--layers", type=int, default=8)
2020
parser.add_argument("-a", "--attn_heads", type=int, default=8)
2121
parser.add_argument("-s", "--seq_len", type=int, default=20)
2222

2323
parser.add_argument("-b", "--batch_size", type=int, default=64)
2424
parser.add_argument("-e", "--epochs", type=int, default=10)
2525
parser.add_argument("-w", "--num_workers", type=int, default=5)
26+
parser.add_argument("-c", "--with_cuda", type=bool, default=True)
27+
parser.add_argument("--log_freq", type=int, default=10)
2628
parser.add_argument("--corpus_lines", type=int, default=None)
2729

2830
parser.add_argument("--lr", type=float, default=1e-3)
2931
parser.add_argument("--adam_weight_decay", type=float, default=0.01)
3032
parser.add_argument("--adam_beta1", type=float, default=0.9)
3133
parser.add_argument("--adam_beta2", type=float, default=0.999)
32-
parser.add_argument("--log_freq", type=int, default=10)
33-
34-
parser.add_argument("-c", "--cuda", type=bool, default=True)
3534

3635
args = parser.parse_args()
3736

@@ -56,7 +55,8 @@ def train():
5655

5756
print("Creating BERT Trainer")
5857
trainer = BERTTrainer(bert, len(vocab), train_dataloader=train_data_loader, test_dataloader=test_data_loader,
59-
lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay)
58+
lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
59+
with_cuda=args.with_cuda, log_freq=args.log_freq)
6060

6161
print("Training Start")
6262
for epoch in range(args.epochs):

bert_pytorch/trainer/pretrain.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ class BERTTrainer:
1919
2020
"""
2121

22-
def __init__(self, bert, vocab_size,
23-
train_dataloader, test_dataloader=None,
22+
def __init__(self, bert: BERT, vocab_size: int,
23+
train_dataloader: DataLoader, test_dataloader: DataLoader = None,
2424
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01,
2525
with_cuda: bool = True, log_freq: int = 10):
2626
"""

test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
from bert_pytorch import BERT
33

44

5-
class BERTTestCase(unittest.TestCase):
5+
class BERTVocabTestCase(unittest.TestCase):
66
pass

0 commit comments

Comments
 (0)