Skip to content

Commit e32e2ad

Browse files
committed
note
1 parent d10dc4f commit e32e2ad

File tree

8 files changed

+181
-32
lines changed

8 files changed

+181
-32
lines changed

BERT-Pytorch 源码阅读.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# BERT-Pytorch 源码阅读
2+
3+
---
4+
5+
## 0. 数据准备
6+
7+
由于只是用来测试,因此,数据随便选了一个文本相似度数据集
8+
9+
## 1. 整体描述
10+
11+
BERT-Pytorch 在分发包时,主要设置了两大功能:
12+
13+
- bert-vocab :统计词频,token2idx, idx2token 等信息。对应 `bert_pytorch.dataset.vocab` 中的 `build` 函数。
14+
- bert:对应 `bert_pytorch.__main__` 下的 train 函数。
15+
16+
### 1. bert-vocab
17+
18+
```
19+
python3 -m ipdb test_bert_vocab.py # 调试 bert-vocab
20+
```
21+
22+
其实 bert-vocab 内部并没有什么重要信息,无非就是一些自然语言处理中常见的预处理手段, 自己花个十分钟调试一下就明白了, 我加了少部分注释, 很容易就能明白。
23+
24+
内部继承关系为:
25+
26+
```
27+
TorchVocab --> Vocab --> WordVocab
28+
```
29+
30+
### 2. bert
31+
32+
#### 1. Bert Model
33+
34+
![整体结构图](.\img\all.png)
35+
36+
37+

bert_pytorch/dataset/vocab.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,12 @@ class TorchVocab(object):
1515

1616
def __init__(self, counter, max_size=None, min_freq=1, specials=['<pad>', '<oov>'],
1717
vectors=None, unk_init=None, vectors_cache=None):
18-
"""Create a Vocab object from a collections.Counter.
19-
Arguments:
20-
counter: collections.Counter object holding the frequencies of
21-
each value found in the data.
22-
max_size: The maximum size of the vocabulary, or None for no
23-
maximum. Default: None.
24-
min_freq: The minimum frequency needed to include a token in the
25-
vocabulary. Values less than 1 will be set to 1. Default: 1.
26-
specials: The list of special tokens (e.g., padding or eos) that
27-
will be prepended to the vocabulary in addition to an <unk>
28-
token. Default: ['<pad>']
18+
"""用一个 collections.Counter 对象简历 Vocab
19+
Args:
20+
counter: collections.Counter 对象。预训练文件中的 token 统计 {'token': 10}
21+
max_size: 词表最大长度。 None for no maximum. Default: None.
22+
min_freq: 最小词频。 Default: 1.
23+
specials: 列表, 包含一系列特殊字符,如['<pad', 'unk']等。 Default: ['<pad>']
2924
vectors: One of either the available pretrained vectors
3025
or custom pretrained vectors (see Vocab.load_vectors);
3126
or a list of aforementioned vectors
@@ -39,23 +34,24 @@ def __init__(self, counter, max_size=None, min_freq=1, specials=['<pad>', '<oov>
3934
min_freq = max(min_freq, 1)
4035

4136
self.itos = list(specials)
42-
# frequencies of special tokens are not counted when building vocabulary
43-
# in frequency order
37+
38+
# 特殊字符不计入统计词频
4439
for tok in specials:
4540
del counter[tok]
4641

4742
max_size = None if max_size is None else max_size + len(self.itos)
4843

49-
# sort by frequency, then alphabetically
44+
# 先按照字典序排列,然后按照词频排列
5045
words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
5146
words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
5247

48+
# 依据词频和字典长度过滤数据
5349
for word, freq in words_and_frequencies:
5450
if freq < min_freq or len(self.itos) == max_size:
5551
break
5652
self.itos.append(word)
5753

58-
# stoi is simply a reverse dict for itos
54+
# token2idx
5955
self.stoi = {tok: i for i, tok in enumerate(self.itos)}
6056

6157
self.vectors = None
@@ -163,6 +159,7 @@ def from_seq(self, seq, join=False, with_pad=False):
163159

164160
@staticmethod
165161
def load_vocab(vocab_path: str) -> 'WordVocab':
162+
"""将 WordVocab 对象序列化到 vocab_path 文件中 """
166163
with open(vocab_path, "rb") as f:
167164
return pickle.load(f)
168165

bert_pytorch/model/bert.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@ class BERT(nn.Module):
1010
"""
1111

1212
def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):
13-
"""
14-
:param vocab_size: vocab_size of total words
15-
:param hidden: BERT model hidden size
16-
:param n_layers: numbers of Transformer blocks(layers)
17-
:param attn_heads: number of attention heads
18-
:param dropout: dropout rate
13+
""" Bert 模型
14+
Args:
15+
vocab_size: 词表大小
16+
hidden: BERT 的 hidden size
17+
n_layers: Transformer 的层数
18+
attn_heads: Multi-head Attention 中的 head 数
19+
dropout: dropout rate
1920
"""
2021

2122
super().__init__()
@@ -26,10 +27,10 @@ def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0
2627
# paper noted they used 4*hidden_size for ff_network_hidden_size
2728
self.feed_forward_hidden = hidden * 4
2829

29-
# embedding for BERT, sum of positional, segment, token embeddings
30+
# BERT的输入embedding, 由 positional, segment, token embeddings 三部分组成
3031
self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden)
3132

32-
# multi-layers transformer blocks, deep network
33+
# 多层的 Transformer (Encoder)
3334
self.transformer_blocks = nn.ModuleList(
3435
[TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)])
3536

bert_pytorch/model/embedding/bert.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,18 @@
66

77
class BERTEmbedding(nn.Module):
88
"""
9-
BERT Embedding which is consisted with under features
10-
1. TokenEmbedding : normal embedding matrix
11-
2. PositionalEmbedding : adding positional information using sin, cos
12-
2. SegmentEmbedding : adding sentence segment info, (sent_A:1, sent_B:2)
13-
14-
sum of all these features are output of BERTEmbedding
9+
BERT Embedding 由以下三部分组成:
10+
1. TokenEmbedding : token embedding matrix
11+
2. PositionalEmbedding : 位置信息编码
12+
2. SegmentEmbedding : 句子信息编码, (sent_A:1, sent_B:2)
1513
"""
1614

1715
def __init__(self, vocab_size, embed_size, dropout=0.1):
1816
"""
19-
:param vocab_size: total vocab size
20-
:param embed_size: embedding size of token embedding
21-
:param dropout: dropout rate
17+
Args:
18+
vocab_size: 词表大小
19+
embed_size: token embedding 的维度
20+
dropout: dropout rate
2221
"""
2322
super().__init__()
2423
self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size)

bert_pytorch/model/embedding/segment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33

44
class SegmentEmbedding(nn.Embedding):
55
def __init__(self, embed_size=512):
6+
""" 3 为 padding_idx, sent_A, sent_B """
67
super().__init__(3, embed_size, padding_idx=0)

img/all.png

87.4 KB
Loading

test_bert.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import argparse
2+
3+
from torch.utils.data import DataLoader
4+
5+
from bert_pytorch.model import BERT
6+
from bert_pytorch.trainer import BERTTrainer
7+
from bert_pytorch.dataset import BERTDataset, WordVocab
8+
9+
10+
if __name__ == "__main__":
11+
parser = argparse.ArgumentParser()
12+
13+
parser.add_argument("-c", "--train_dataset", required=True,
14+
type=str, help="train dataset for train bert")
15+
parser.add_argument("-t", "--test_dataset", type=str,
16+
default=None, help="test set for evaluate train set")
17+
parser.add_argument("-v", "--vocab_path", required=True,
18+
type=str, help="built vocab model path with bert-vocab")
19+
parser.add_argument("-o", "--output_path", required=True,
20+
type=str, help="ex)output/bert.model")
21+
22+
parser.add_argument("-hs", "--hidden", type=int,
23+
default=256, help="hidden size of transformer model")
24+
parser.add_argument("-l", "--layers", type=int,
25+
default=8, help="number of layers")
26+
parser.add_argument("-a", "--attn_heads", type=int,
27+
default=8, help="number of attention heads")
28+
parser.add_argument("-s", "--seq_len", type=int,
29+
default=20, help="maximum sequence len")
30+
31+
parser.add_argument("-b", "--batch_size", type=int,
32+
default=64, help="number of batch_size")
33+
parser.add_argument("-e", "--epochs", type=int,
34+
default=10, help="number of epochs")
35+
parser.add_argument("-w", "--num_workers", type=int,
36+
default=5, help="dataloader worker size")
37+
38+
parser.add_argument("--with_cuda", type=bool, default=True,
39+
help="training with CUDA: true, or false")
40+
parser.add_argument("--log_freq", type=int, default=10,
41+
help="printing loss every n iter: setting n")
42+
parser.add_argument("--corpus_lines", type=int,
43+
default=None, help="total number of lines in corpus")
44+
parser.add_argument("--cuda_devices", type=int, nargs='+',
45+
default=None, help="CUDA device ids")
46+
parser.add_argument("--on_memory", type=bool, default=True,
47+
help="Loading on memory: true or false")
48+
49+
parser.add_argument("--lr", type=float, default=1e-3,
50+
help="learning rate of adam")
51+
parser.add_argument("--adam_weight_decay", type=float,
52+
default=0.01, help="weight_decay of adam")
53+
parser.add_argument("--adam_beta1", type=float,
54+
default=0.9, help="adam first beta value")
55+
parser.add_argument("--adam_beta2", type=float,
56+
default=0.999, help="adam first beta value")
57+
58+
args = parser.parse_args()
59+
60+
print("Loading Vocab", args.vocab_path)
61+
vocab = WordVocab.load_vocab(args.vocab_path)
62+
print("Vocab Size: ", len(vocab))
63+
64+
print("Loading Train Dataset", args.train_dataset)
65+
train_dataset = BERTDataset(args.train_dataset, vocab, seq_len=args.seq_len,
66+
corpus_lines=args.corpus_lines, on_memory=args.on_memory)
67+
68+
print("Loading Test Dataset", args.test_dataset)
69+
test_dataset = BERTDataset(args.test_dataset, vocab, seq_len=args.seq_len, on_memory=args.on_memory) \
70+
if args.test_dataset is not None else None
71+
72+
print("Creating Dataloader")
73+
train_data_loader = DataLoader(
74+
train_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
75+
test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers) \
76+
if test_dataset is not None else None
77+
78+
print("Building BERT model")
79+
bert = BERT(len(vocab), hidden=args.hidden,
80+
n_layers=args.layers, attn_heads=args.attn_heads)
81+
82+
print("Creating BERT Trainer")
83+
trainer = BERTTrainer(bert, len(vocab), train_dataloader=train_data_loader, test_dataloader=test_data_loader,
84+
lr=args.lr, betas=(
85+
args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
86+
with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, log_freq=args.log_freq)
87+
88+
print("Training Start")
89+
for epoch in range(args.epochs):
90+
trainer.train(epoch)
91+
trainer.save(epoch, args.output_path)
92+
93+
if test_data_loader is not None:
94+
trainer.test(epoch)

test_bert_vocab.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
2+
from bert_pytorch.dataset.vocab import *
3+
4+
5+
if __name__ == "__main__":
6+
import argparse
7+
8+
parser = argparse.ArgumentParser()
9+
parser.add_argument("-c", "--corpus_path", required=True, type=str)
10+
parser.add_argument("-o", "--output_path", required=True, type=str)
11+
parser.add_argument("-s", "--vocab_size", type=int, default=None)
12+
parser.add_argument("-e", "--encoding", type=str, default="utf-8")
13+
parser.add_argument("-m", "--min_freq", type=int, default=1)
14+
args = parser.parse_args()
15+
16+
with open(args.corpus_path, "r", encoding=args.encoding) as f:
17+
vocab = WordVocab(f, max_size=args.vocab_size, min_freq=args.min_freq)
18+
19+
print("VOCAB SIZE:", len(vocab))
20+
vocab.save_vocab(args.output_path)

0 commit comments

Comments
 (0)