Skip to content

Commit 0c8557f

Browse files
authored
Merge branch 'alpha0.0.1a4' into master
2 parents a4d886f + 8c647e6 commit 0c8557f

File tree

8 files changed

+97
-45
lines changed

8 files changed

+97
-45
lines changed

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@
186186
same "printed page" as the copyright notice for easier
187187
identification within third-party archives.
188188

189-
Copyright 2018 Junseong Kim, Scatter Labs, BERT contributors
189+
Copyright 2018 Junseong Kim, Scatter Lab, BERT contributors
190190

191191
Licensed under the Apache License, Version 2.0 (the "License");
192192
you may not use this file except in compliance with the License.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ bert-vocab -c data/corpus.small -o data/vocab.small
6060

6161
### 2. Train your own BERT model
6262
```shell
63-
bert -c data/dataset.small -v data/vocab.small -o output/bert.model
63+
bert -c data/corpus.small -v data/vocab.small -o output/bert.model
6464
```
6565

6666
## Language Model Pre-training

bert_pytorch/__main__.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,27 +10,30 @@
1010
def train():
1111
parser = argparse.ArgumentParser()
1212

13-
parser.add_argument("-c", "--train_dataset", required=True, type=str)
14-
parser.add_argument("-t", "--test_dataset", type=str, default=None)
15-
parser.add_argument("-v", "--vocab_path", required=True, type=str)
16-
parser.add_argument("-o", "--output_path", required=True, type=str)
17-
18-
parser.add_argument("-hs", "--hidden", type=int, default=256)
19-
parser.add_argument("-l", "--layers", type=int, default=8)
20-
parser.add_argument("-a", "--attn_heads", type=int, default=8)
21-
parser.add_argument("-s", "--seq_len", type=int, default=20)
22-
23-
parser.add_argument("-b", "--batch_size", type=int, default=64)
24-
parser.add_argument("-e", "--epochs", type=int, default=10)
25-
parser.add_argument("-w", "--num_workers", type=int, default=5)
26-
parser.add_argument("--with_cuda", type=bool, default=True)
27-
parser.add_argument("--log_freq", type=int, default=10)
28-
parser.add_argument("--corpus_lines", type=int, default=None)
29-
30-
parser.add_argument("--lr", type=float, default=1e-3)
31-
parser.add_argument("--adam_weight_decay", type=float, default=0.01)
32-
parser.add_argument("--adam_beta1", type=float, default=0.9)
33-
parser.add_argument("--adam_beta2", type=float, default=0.999)
13+
parser.add_argument("-c", "--train_dataset", required=True, type=str, help="train dataset for train bert")
14+
parser.add_argument("-t", "--test_dataset", type=str, default=None, help="test set for evaluate train set")
15+
parser.add_argument("-v", "--vocab_path", required=True, type=str, help="built vocab model path with bert-vocab")
16+
parser.add_argument("-o", "--output_path", required=True, type=str, help="ex)output/bert.model")
17+
18+
parser.add_argument("-hs", "--hidden", type=int, default=256, help="hidden size of transformer model")
19+
parser.add_argument("-l", "--layers", type=int, default=8, help="number of layers")
20+
parser.add_argument("-a", "--attn_heads", type=int, default=8, help="number of attention heads")
21+
parser.add_argument("-s", "--seq_len", type=int, default=20, help="maximum sequence len")
22+
23+
parser.add_argument("-b", "--batch_size", type=int, default=64, help="number of batch_size")
24+
parser.add_argument("-e", "--epochs", type=int, default=10, help="number of epochs")
25+
parser.add_argument("-w", "--num_workers", type=int, default=5, help="dataloader worker size")
26+
27+
parser.add_argument("--with_cuda", type=bool, default=True, help="training with CUDA: true, or false")
28+
parser.add_argument("--log_freq", type=int, default=10, help="printing loss every n iter: setting n")
29+
parser.add_argument("--corpus_lines", type=int, default=None, help="total number of lines in corpus")
30+
parser.add_argument("--cuda_devices", type=int, nargs='+', default=None, help="CUDA device ids")
31+
parser.add_argument("--on_memory", type=bool, default=True, help="Loading on memory: true or false")
32+
33+
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate of adam")
34+
parser.add_argument("--adam_weight_decay", type=float, default=0.01, help="weight_decay of adam")
35+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="adam first beta value")
36+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="adam first beta value")
3437

3538
args = parser.parse_args()
3639

@@ -39,11 +42,12 @@ def train():
3942
print("Vocab Size: ", len(vocab))
4043

4144
print("Loading Train Dataset", args.train_dataset)
42-
train_dataset = BERTDataset(args.train_dataset, vocab, seq_len=args.seq_len, corpus_lines=args.corpus_lines)
45+
train_dataset = BERTDataset(args.train_dataset, vocab, seq_len=args.seq_len,
46+
corpus_lines=args.corpus_lines, on_memory=args.on_memory)
4347

4448
print("Loading Test Dataset", args.test_dataset)
45-
test_dataset = BERTDataset(args.test_dataset, vocab,
46-
seq_len=args.seq_len) if args.test_dataset is not None else None
49+
test_dataset = BERTDataset(args.test_dataset, vocab, seq_len=args.seq_len, on_memory=args.on_memory) \
50+
if args.test_dataset is not None else None
4751

4852
print("Creating Dataloader")
4953
train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
@@ -56,7 +60,7 @@ def train():
5660
print("Creating BERT Trainer")
5761
trainer = BERTTrainer(bert, len(vocab), train_dataloader=train_data_loader, test_dataloader=test_data_loader,
5862
lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
59-
with_cuda=args.with_cuda, log_freq=args.log_freq)
63+
with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, log_freq=args.log_freq)
6064

6165
print("Training Start")
6266
for epoch in range(args.epochs):

bert_pytorch/dataset/dataset.py

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,37 @@
55

66

77
class BERTDataset(Dataset):
8-
def __init__(self, corpus_path, vocab, seq_len, encoding="utf-8", corpus_lines=None):
8+
def __init__(self, corpus_path, vocab, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True):
99
self.vocab = vocab
1010
self.seq_len = seq_len
1111

12+
self.on_memory = on_memory
13+
self.corpus_lines = corpus_lines
14+
self.corpus_path = corpus_path
15+
self.encoding = encoding
16+
1217
with open(corpus_path, "r", encoding=encoding) as f:
13-
self.datas = [line[:-1].split("\t")
14-
for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines)]
18+
if self.corpus_lines is None and not on_memory:
19+
for _ in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines):
20+
self.corpus_lines += 1
21+
22+
if on_memory:
23+
self.lines = [line[:-1].split("\t")
24+
for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines)]
25+
self.corpus_lines = len(self.lines)
26+
27+
if not on_memory:
28+
self.file = open(corpus_path, "r", encoding=encoding)
29+
self.random_file = open(corpus_path, "r", encoding=encoding)
30+
31+
for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)):
32+
self.random_file.__next__()
1533

1634
def __len__(self):
17-
return len(self.datas)
35+
return self.corpus_lines
1836

1937
def __getitem__(self, item):
20-
t1, (t2, is_next_label) = self.datas[item][0], self.random_sent(item)
38+
t1, t2, is_next_label = self.random_sent(item)
2139
t1_random, t1_label = self.random_word(t1)
2240
t2_random, t2_label = self.random_word(t2)
2341

@@ -49,16 +67,18 @@ def random_word(self, sentence):
4967
for i, token in enumerate(tokens):
5068
prob = random.random()
5169
if prob < 0.15:
52-
# 80% randomly change token to make token
53-
if prob < 0.15 * 0.8:
70+
prob /= 0.15
71+
72+
# 80% randomly change token to mask token
73+
if prob < 0.8:
5474
tokens[i] = self.vocab.mask_index
5575

5676
# 10% randomly change token to random token
57-
elif prob * 0.8 <= prob < prob * 0.9:
77+
elif prob < 0.9:
5878
tokens[i] = random.randrange(len(self.vocab))
5979

6080
# 10% randomly change token to current token
61-
elif prob >= prob * 0.9:
81+
else:
6282
tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)
6383

6484
output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index))
@@ -70,8 +90,36 @@ def random_word(self, sentence):
7090
return tokens, output_label
7191

7292
def random_sent(self, index):
93+
t1, t2 = self.get_corpus_line(index)
94+
7395
# output_text, label(isNotNext:0, isNext:1)
7496
if random.random() > 0.5:
75-
return self.datas[index][1], 1
97+
return t1, t2, 1
98+
else:
99+
return t1, self.get_random_line(), 0
100+
101+
def get_corpus_line(self, item):
102+
if self.on_memory:
103+
return self.lines[item][0], self.lines[item][1]
76104
else:
77-
return self.datas[random.randrange(len(self.datas))][1], 0
105+
line = self.file.__next__()
106+
if line is None:
107+
self.file.close()
108+
self.file = open(self.corpus_path, "r", encoding=self.encoding)
109+
line = self.file.__next__()
110+
111+
t1, t2 = line[:-1].split("\t")
112+
return t1, t2
113+
114+
def get_random_line(self):
115+
if self.on_memory:
116+
return self.lines[random.randrange(len(self.lines))][1]
117+
118+
line = self.file.__next__()
119+
if line is None:
120+
self.file.close()
121+
self.file = open(self.corpus_path, "r", encoding=self.encoding)
122+
for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)):
123+
self.random_file.__next__()
124+
line = self.random_file.__next__()
125+
return line[:-1].split("\t")[1]

bert_pytorch/model/embedding/position.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def __init__(self, d_model, max_len=512):
1313
pe.require_grad = False
1414

1515
position = torch.arange(0, max_len).float().unsqueeze(1)
16-
div_term = (torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)).float().exp()
16+
div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
1717

1818
pe[:, 0::2] = torch.sin(position * div_term)
1919
pe[:, 1::2] = torch.cos(position * div_term)

bert_pytorch/trainer/pretrain.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class BERTTrainer:
2222
def __init__(self, bert: BERT, vocab_size: int,
2323
train_dataloader: DataLoader, test_dataloader: DataLoader = None,
2424
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01,
25-
with_cuda: bool = True, log_freq: int = 10):
25+
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10):
2626
"""
2727
:param bert: BERT model which you want to train
2828
:param vocab_size: total word vocab size
@@ -45,9 +45,9 @@ def __init__(self, bert: BERT, vocab_size: int,
4545
self.model = BERTLM(bert, vocab_size).to(self.device)
4646

4747
# Distributed GPU training if CUDA can detect more than 1 GPU
48-
if torch.cuda.device_count() > 1:
48+
if with_cuda and torch.cuda.device_count() > 1:
4949
print("Using %d GPUS for BERT" % torch.cuda.device_count())
50-
self.model = nn.DataParallel(self.model)
50+
self.model = nn.DataParallel(self.model, device_ids=cuda_devices)
5151

5252
# Setting the train and test data loader
5353
self.train_data = train_dataloader

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
tqdm
22
numpy
3-
torch>=0.4.0
3+
torch>=0.4.0

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import os
44
import sys
55

6-
__version__ = "0.0.1a3"
6+
__version__ = "0.0.1a4"
77

88
with open("requirements.txt") as f:
9-
require_packages = [line[:-1] for line in f]
9+
require_packages = [line[:-1] if line[-1] == "\n" else line for line in f]
1010

1111
with open("README.md", "r", encoding="utf-8") as f:
1212
long_description = f.read()

0 commit comments

Comments
 (0)