6
6
import os
7
7
import sys
8
8
import torch
9
+ import random
9
10
import argparse
10
11
import numpy as np
11
12
from GPT2 .model import (GPT2LMHeadModel )
17
18
def text_generator (state_dict ):
18
19
parser = argparse .ArgumentParser ()
19
20
parser .add_argument ("--text" , type = str , required = True )
20
- parser .add_argument ("--seed " , type = int , default = 0 )
21
+ parser .add_argument ("--quiet " , type = bool , default = False )
21
22
parser .add_argument ("--nsamples" , type = int , default = 1 )
22
23
parser .add_argument ('--unconditional' , action = 'store_true' , help = 'If true, unconditional generation.' )
23
24
parser .add_argument ("--batch_size" , type = int , default = - 1 )
@@ -31,9 +32,10 @@ def text_generator(state_dict):
31
32
args .batch_size = 1
32
33
assert args .nsamples % args .batch_size == 0
33
34
34
- np .random .seed (args .seed )
35
- torch .random .manual_seed (args .seed )
36
- torch .cuda .manual_seed (args .seed )
35
+ seed = random .randint (0 , 10000 )
36
+ np .random .seed (seed )
37
+ torch .random .manual_seed (seed )
38
+ torch .cuda .manual_seed (seed )
37
39
device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
38
40
39
41
# Load Model
@@ -51,7 +53,7 @@ def text_generator(state_dict):
51
53
52
54
print (args .text )
53
55
context_tokens = enc .encode (args .text )
54
- print ( context_tokens )
56
+
55
57
generated = 0
56
58
for _ in range (args .nsamples // args .batch_size ):
57
59
out = sample_sequence (
@@ -65,10 +67,9 @@ def text_generator(state_dict):
65
67
for i in range (args .batch_size ):
66
68
generated += 1
67
69
text = enc .decode (out [i ])
68
- print ("=" * 40 + " SAMPLE " + str (generated ) + " " + "=" * 40 )
70
+ if args .quiet is False :
71
+ print ("=" * 40 + " SAMPLE " + str (generated ) + " " + "=" * 40 )
69
72
print (text )
70
- print ("=" * 80 )
71
-
72
73
73
74
if __name__ == '__main__' :
74
75
if os .path .exists ('gpt2-pytorch_model.bin' ):
0 commit comments