Skip to content

Commit ae094c8

Browse files
committed
Fixing tiny bugs
1 parent da60a87 commit ae094c8

File tree

4 files changed

+4
-5
lines changed

4 files changed

+4
-5
lines changed

bert_pytorch/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
def train():
1111
parser = argparse.ArgumentParser()
1212

13-
parser.add_argument("-d", "--train_dataset", required=True, type=str)
13+
parser.add_argument("-c", "--train_dataset", required=True, type=str)
1414
parser.add_argument("-t", "--test_dataset", type=str, default=None)
1515
parser.add_argument("-v", "--vocab_path", required=True, type=str)
1616
parser.add_argument("-o", "--output_path", required=True, type=str)
@@ -23,7 +23,7 @@ def train():
2323
parser.add_argument("-b", "--batch_size", type=int, default=64)
2424
parser.add_argument("-e", "--epochs", type=int, default=10)
2525
parser.add_argument("-w", "--num_workers", type=int, default=5)
26-
parser.add_argument("-c", "--with_cuda", type=bool, default=True)
26+
parser.add_argument("--with_cuda", type=bool, default=True)
2727
parser.add_argument("--log_freq", type=int, default=10)
2828
parser.add_argument("--corpus_lines", type=int, default=None)
2929

bert_pytorch/dataset/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
from .dataset import BERTDataset
2-
from .creator import BERTDatasetCreator
32
from .vocab import WordVocab

bert_pytorch/dataset/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __getitem__(self, item):
3838
output = {"bert_input": bert_input,
3939
"bert_label": bert_label,
4040
"segment_label": segment_label,
41-
"is_next": self.datas[item]["is_next"]}
41+
"is_next": is_next_label}
4242

4343
return {key: torch.tensor(value) for key, value in output.items()}
4444

bert_pytorch/dataset/vocab.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def load_vocab(vocab_path: str) -> 'WordVocab':
167167
return pickle.load(f)
168168

169169

170-
if __name__ == "__main__":
170+
def build():
171171
import argparse
172172

173173
parser = argparse.ArgumentParser()

0 commit comments

Comments
 (0)