Skip to content

Commit 913c43a

Browse files
committed
Fixing Percentage Issue
1 parent a453ab8 commit 913c43a

File tree

3 files changed

+45
-35
lines changed

3 files changed

+45
-35
lines changed

bert_pytorch/__main__.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,30 @@
1010
def train():
1111
parser = argparse.ArgumentParser()
1212

13-
parser.add_argument("-c", "--train_dataset", required=True, type=str)
14-
parser.add_argument("-t", "--test_dataset", type=str, default=None)
15-
parser.add_argument("-v", "--vocab_path", required=True, type=str)
16-
parser.add_argument("-o", "--output_path", required=True, type=str)
17-
18-
parser.add_argument("-hs", "--hidden", type=int, default=256)
19-
parser.add_argument("-l", "--layers", type=int, default=8)
20-
parser.add_argument("-a", "--attn_heads", type=int, default=8)
21-
parser.add_argument("-s", "--seq_len", type=int, default=20)
22-
23-
parser.add_argument("-b", "--batch_size", type=int, default=64)
24-
parser.add_argument("-e", "--epochs", type=int, default=10)
25-
parser.add_argument("-w", "--num_workers", type=int, default=5)
26-
27-
parser.add_argument("--with_cuda", type=bool, default=True)
28-
parser.add_argument("--log_freq", type=int, default=10)
29-
parser.add_argument("--corpus_lines", type=int, default=None)
30-
parser.add_argument("--cuda_devices", type=int, nargs='+', default=None)
31-
32-
parser.add_argument("--lr", type=float, default=1e-3)
33-
parser.add_argument("--adam_weight_decay", type=float, default=0.01)
34-
parser.add_argument("--adam_beta1", type=float, default=0.9)
35-
parser.add_argument("--adam_beta2", type=float, default=0.999)
13+
parser.add_argument("-c", "--train_dataset", required=True, type=str, help="train dataset for train bert")
14+
parser.add_argument("-t", "--test_dataset", type=str, default=None, help="test set for evaluate train set")
15+
parser.add_argument("-v", "--vocab_path", required=True, type=str, help="built vocab model path with bert-vocab")
16+
parser.add_argument("-o", "--output_path", required=True, type=str, help="ex)output/bert.model")
17+
18+
parser.add_argument("-hs", "--hidden", type=int, default=256, help="hidden size of transformer model")
19+
parser.add_argument("-l", "--layers", type=int, default=8, help="number of layers")
20+
parser.add_argument("-a", "--attn_heads", type=int, default=8, help="number of attention heads")
21+
parser.add_argument("-s", "--seq_len", type=int, default=20, help="maximum sequence len")
22+
23+
parser.add_argument("-b", "--batch_size", type=int, default=64, help="number of batch_size")
24+
parser.add_argument("-e", "--epochs", type=int, default=10, help="number of epochs")
25+
parser.add_argument("-w", "--num_workers", type=int, default=5, help="dataloader worker size")
26+
27+
parser.add_argument("--with_cuda", type=bool, default=True, help="training with CUDA: true, or false")
28+
parser.add_argument("--log_freq", type=int, default=10, help="printing loss every n iter: setting n")
29+
parser.add_argument("--corpus_lines", type=int, default=None, help="total number of lines in corpus")
30+
parser.add_argument("--cuda_devices", type=int, nargs='+', default=None, help="CUDA device ids")
31+
parser.add_argument("--on_memory", type=bool, default=True, help="Loading on memory: true or false")
32+
33+
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate of adam")
34+
parser.add_argument("--adam_weight_decay", type=float, default=0.01, help="weight_decay of adam")
35+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="adam first beta value")
36+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="adam first beta value")
3637

3738
args = parser.parse_args()
3839

@@ -41,11 +42,12 @@ def train():
4142
print("Vocab Size: ", len(vocab))
4243

4344
print("Loading Train Dataset", args.train_dataset)
44-
train_dataset = BERTDataset(args.train_dataset, vocab, seq_len=args.seq_len, corpus_lines=args.corpus_lines)
45+
train_dataset = BERTDataset(args.train_dataset, vocab, seq_len=args.seq_len,
46+
corpus_lines=args.corpus_lines, on_memory=args.on_memory)
4547

4648
print("Loading Test Dataset", args.test_dataset)
47-
test_dataset = BERTDataset(args.test_dataset, vocab,
48-
seq_len=args.seq_len) if args.test_dataset is not None else None
49+
test_dataset = BERTDataset(args.test_dataset, vocab, seq_len=args.seq_len, on_memory=args.on_memory) \
50+
if args.test_dataset is not None else None
4951

5052
print("Creating Dataloader")
5153
train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers)

bert_pytorch/dataset/dataset.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,27 @@
55

66

77
class BERTDataset(Dataset):
8-
def __init__(self, corpus_path, vocab, seq_len, encoding="utf-8", corpus_lines=None):
8+
def __init__(self, corpus_path, vocab, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True):
99
self.vocab = vocab
1010
self.seq_len = seq_len
11+
self.on_memory = on_memory
12+
self.corpus_lines = corpus_lines
1113

1214
with open(corpus_path, "r", encoding=encoding) as f:
13-
self.datas = [line[:-1].split("\t")
14-
for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines)]
15+
if self.corpus_lines is None and not on_memory:
16+
for _ in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines):
17+
self.corpus_lines += 1
18+
19+
if on_memory:
20+
self.lines = [line[:-1].split("\t")
21+
for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines)]
22+
self.corpus_lines = len(self.lines)
1523

1624
def __len__(self):
17-
return len(self.datas)
25+
return self.corpus_lines
1826

1927
def __getitem__(self, item):
20-
t1, (t2, is_next_label) = self.datas[item][0], self.random_sent(item)
28+
t1, t2, is_next_label = self.random_sent(item)
2129
t1_random, t1_label = self.random_word(t1)
2230
t2_random, t2_label = self.random_word(t2)
2331

@@ -54,7 +62,7 @@ def random_word(self, sentence):
5462
tokens[i] = self.vocab.mask_index
5563

5664
# 10% randomly change token to random token
57-
elif prob * 0.8 <= prob < prob * 0.9:
65+
elif 0.15 * 0.8 <= prob < 0.15 * 0.9:
5866
tokens[i] = random.randrange(len(self.vocab))
5967

6068
# 10% randomly change token to current token
@@ -72,6 +80,6 @@ def random_word(self, sentence):
7280
def random_sent(self, index):
7381
# output_text, label(isNotNext:0, isNext:1)
7482
if random.random() > 0.5:
75-
return self.datas[index][1], 1
83+
return self.datas[index][0], self.datas[index][1], 1
7684
else:
77-
return self.datas[random.randrange(len(self.datas))][1], 0
85+
return self.datas[index][0], self.datas[random.randrange(len(self.datas))][1], 0

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
tqdm
22
numpy
33
torch>=0.4.0
4-
torch-encoding
4+
torch-encodin

0 commit comments

Comments
 (0)