Skip to content

Commit ee63b0f

Browse files
committed
edit seed option, add quiet option
1 parent 71056fb commit ee63b0f

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

main.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
import sys
88
import torch
9+
import random
910
import argparse
1011
import numpy as np
1112
from GPT2.model import (GPT2LMHeadModel)
@@ -17,7 +18,7 @@
1718
def text_generator(state_dict):
1819
parser = argparse.ArgumentParser()
1920
parser.add_argument("--text", type=str, required=True)
20-
parser.add_argument("--seed", type=int, default=0)
21+
parser.add_argument("--quiet", type=bool, default=False)
2122
parser.add_argument("--nsamples", type=int, default=1)
2223
parser.add_argument('--unconditional', action='store_true', help='If true, unconditional generation.')
2324
parser.add_argument("--batch_size", type=int, default=-1)
@@ -31,9 +32,10 @@ def text_generator(state_dict):
3132
args.batch_size = 1
3233
assert args.nsamples % args.batch_size == 0
3334

34-
np.random.seed(args.seed)
35-
torch.random.manual_seed(args.seed)
36-
torch.cuda.manual_seed(args.seed)
35+
seed = random.randint(0, 10000)
36+
np.random.seed(seed)
37+
torch.random.manual_seed(seed)
38+
torch.cuda.manual_seed(seed)
3739
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
3840

3941
# Load Model
@@ -51,7 +53,7 @@ def text_generator(state_dict):
5153

5254
print(args.text)
5355
context_tokens = enc.encode(args.text)
54-
print(context_tokens)
56+
5557
generated = 0
5658
for _ in range(args.nsamples // args.batch_size):
5759
out = sample_sequence(
@@ -65,10 +67,9 @@ def text_generator(state_dict):
6567
for i in range(args.batch_size):
6668
generated += 1
6769
text = enc.decode(out[i])
68-
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
70+
if args.quiet is False:
71+
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
6972
print(text)
70-
print("=" * 80)
71-
7273

7374
if __name__ == '__main__':
7475
if os.path.exists('gpt2-pytorch_model.bin'):

0 commit comments

Comments
 (0)